/*
 * Decompiled with CFR 0.152.
 */
package org.esa.snap.classification.gpf;

import be.abeel.util.MTRandom;
import com.bc.ceres.core.ProgressMonitor;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.classification.evaluation.PerformanceMeasure;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.featureselection.scoring.GainRatio;
import org.esa.snap.classification.gpf.BaseClassifier;
import org.esa.snap.classification.gpf.ClassifierAttributeEvaluation;
import org.esa.snap.classification.gpf.ClassifierReport;
import org.esa.snap.classification.gpf.CrossValidation;
import org.esa.snap.core.gpf.OperatorException;

public class Evaluator {
    private final Classifier mlClassifier;
    private final ClassifierReport classifierReport;
    private final Score score = new Score();
    private static final int NUM_PERTURBATIONS = 3;
    private static final boolean doCrossValidation = true;

    public Evaluator(Classifier classifier, ClassifierReport classifierReport) {
        this.mlClassifier = classifier;
        this.classifierReport = classifierReport;
    }

    private static String f(double val) {
        return String.format("%-6.4f", val);
    }

    public Score getScore() {
        return this.score;
    }

    public Score evaluateClassifier(Map<Double, String> labelMap, List<Instance> instanceList, Dataset dataset, String datasetType) {
        Map<Object, Integer> classDistribution = Evaluator.getClassDistributionInDataset(instanceList, dataset);
        CrossValidation cv = new CrossValidation(this.mlClassifier);
        try {
            Map<Object, PerformanceMeasure> cvPerfMeasure = cv.crossValidation(dataset, 5, new Random());
            this.score.crossValidationPercent = this.printEvaluation("Cross Validation", labelMap, dataset, datasetType, classDistribution, cvPerfMeasure);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.printDistribution(labelMap, classDistribution);
        return this.score;
    }

    private void printDistribution(Map<Double, String> labelMap, Map<Object, Integer> classDistribution) {
        StringBuilder log = new StringBuilder(512);
        Object[] sortedClassValues = ClassifierAttributeEvaluation.getSortedObjects(classDistribution.keySet());
        log.append("Distribution:\n");
        int sum = 0;
        for (Object o : sortedClassValues) {
            sum += classDistribution.get(o).intValue();
        }
        for (Object o : sortedClassValues) {
            int cntVal = classDistribution.get(o);
            String label = labelMap.get(o);
            log.append("   class " + o + ": " + String.format("%-25s", label) + "  " + cntVal + "\t (" + Evaluator.f(100.0 * (double)cntVal / (double)sum) + "%)\n");
        }
        this.classifierReport.addClassifierEvaluation(log.toString());
    }

    private double printEvaluation(String title, Map<Double, String> labelMap, Dataset dataset, String datasetType, Map<Object, Integer> classDistribution, Map<Object, PerformanceMeasure> performanceMeasureMap) {
        Object[] sortedClassValues;
        StringBuilder log = new StringBuilder(512);
        log.append(title + '\n');
        log.append("Number of classes = " + performanceMeasureMap.size() + '\n');
        int sum = 0;
        int totalTP = 0;
        int totalSamples = 0;
        for (Object o : sortedClassValues = ClassifierAttributeEvaluation.getSortedObjects(performanceMeasureMap.keySet())) {
            PerformanceMeasure perMea = performanceMeasureMap.get(o);
            totalTP = (int)((double)totalTP + perMea.tp);
            totalSamples = (int)((double)totalSamples + (perMea.tp + perMea.fn));
            int cntVal = classDistribution.get(o);
            sum += cntVal;
            String label = labelMap.get(o);
            log.append("   class " + o + ": " + String.format("%-25s", label) + '\n');
            log.append("    accuracy = " + Evaluator.f(perMea.getAccuracy()) + " precision = " + Evaluator.f(perMea.getPrecision()) + " correlation = " + Evaluator.f(perMea.getCorrelation()) + " errorRate = " + Evaluator.f(perMea.getErrorRate()) + '\n');
            log.append("    TruePositives = " + Evaluator.f(perMea.tp) + " FalsePositives = " + Evaluator.f(perMea.fp) + " TrueNegatives = " + Evaluator.f(perMea.tn) + " FalseNegatives = " + Evaluator.f(perMea.fn) + '\n');
        }
        if (totalSamples != dataset.size()) {
            throw new OperatorException("totalSamples = " + totalSamples + " dataset size = " + dataset.size());
        }
        double tpPct = (double)totalTP / (double)dataset.size();
        log.append("\nUsing " + datasetType + " dataset, % correct predictions = " + Evaluator.f(tpPct * 100.0) + '\n');
        log.append("Total samples = " + sum + '\n');
        this.classifierReport.addClassifierEvaluation(log.toString());
        return tpPct;
    }

    public void evaluateFeatures(BaseClassifier.FeatureInfo[] featureInfoList, Dataset dataset, String datasetType, ProgressMonitor pm) {
        StringBuilder log = new StringBuilder(512);
        log.append(datasetType + " feature importance score:\n");
        ClassifierAttributeEvaluation eval = new ClassifierAttributeEvaluation(this.mlClassifier, 3, (Random)new MTRandom());
        ClassifierAttributeEvaluation.FeatureScore[] fs = eval.performEvaluation(dataset, pm);
        TreeMap<Double, Integer> sortedMap = new TreeMap<Double, Integer>();
        for (int i = 0; i < featureInfoList.length; ++i) {
            double importanceScore = fs[i].tp;
            sortedMap.put(importanceScore, i);
        }
        GainRatio gr = new GainRatio();
        gr.build(dataset);
        log.append("Each feature is perturbed 3 times and the % correct predictions are averaged\n");
        log.append("The importance score is the original % correct prediction - average\n");
        Double key = (Double)sortedMap.lastKey();
        int rank = 1;
        while (key != null) {
            int i = (Integer)sortedMap.get(key);
            String scoreStr = "score: tp=" + Evaluator.f(fs[i].tp) + " accuracy=" + Evaluator.f(fs[i].accuracy) + " precision=" + Evaluator.f(fs[i].precision) + " correlation=" + Evaluator.f(fs[i].correlation) + " errorRate=" + Evaluator.f(fs[i].errorRate) + " cost=" + Evaluator.f(fs[i].cost) + " GainRatio = " + Evaluator.f(gr.score(i));
            this.score.featureScoreMap.put(featureInfoList[i].featureBand.getName(), scoreStr);
            log.append("   rank " + String.format("%-3d", rank) + " feature " + String.format("%-3d", i + 1) + ": " + String.format("%-25s", featureInfoList[i].featureBand.getName()) + scoreStr + '\n');
            key = sortedMap.lowerKey(key);
            ++rank;
        }
        if (rank <= featureInfoList.length) {
            log.append("Warning: rank <= featureBandList.length\n");
        }
        log.append('\n');
        this.classifierReport.addFeatureEvaluation(log.toString());
    }

    private static Map<Object, Integer> getClassDistributionInDataset(List<Instance> instanceList, Dataset dataset) {
        HashMap<Object, Integer> cnt = new HashMap<Object, Integer>();
        for (Object o : dataset.classes()) {
            cnt.put(o, 0);
        }
        for (Instance i : instanceList) {
            Object o = i.classValue();
            int oldCnt = (Integer)cnt.get(o);
            cnt.put(o, oldCnt + 1);
        }
        return cnt;
    }

    public static class Score {
        public double classifierPercent;
        public double crossValidationPercent;
        public Map<String, String> featureScoreMap = new HashMap<String, String>();
    }
}

