package org.esa.snap.classification.gpf;

import be.abeel.util.MTRandom;
import com.bc.ceres.core.ProgressMonitor;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.classification.evaluation.PerformanceMeasure;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.featureselection.scoring.GainRatio;
import org.esa.snap.classification.gpf.BaseClassifier;
import org.esa.snap.classification.gpf.ClassifierAttributeEvaluation;
import org.esa.snap.core.gpf.OperatorException;

/* loaded from: input_file:org/esa/snap/classification/gpf/Evaluator.class */
public class Evaluator {
    private final Classifier mlClassifier;
    private final ClassifierReport classifierReport;
    private final Score score = new Score();
    private static final int NUM_PERTURBATIONS = 3;
    private static final boolean doCrossValidation = true;

    /* loaded from: input_file:org/esa/snap/classification/gpf/Evaluator$Score.class */
    public static class Score {
        public double classifierPercent;
        public double crossValidationPercent;
        public Map<String, String> featureScoreMap = new HashMap();
    }

    public Evaluator(Classifier classifier, ClassifierReport classifierReport) {
        this.mlClassifier = classifier;
        this.classifierReport = classifierReport;
    }

    private static String f(double d) {
        return String.format("%-6.4f", Double.valueOf(d));
    }

    public Score getScore() {
        return this.score;
    }

    public Score evaluateClassifier(Map<Double, String> map, List<Instance> list, Dataset dataset, String str) {
        Map<Object, Integer> classDistributionInDataset = getClassDistributionInDataset(list, dataset);
        try {
            Map<Object, PerformanceMeasure> crossValidation = new CrossValidation(this.mlClassifier).crossValidation(dataset, 5, new Random());
            this.score.crossValidationPercent = printEvaluation("Cross Validation", map, dataset, str, classDistributionInDataset, crossValidation);
        } catch (Exception e) {
            e.printStackTrace();
        }
        printDistribution(map, classDistributionInDataset);
        return this.score;
    }

    private void printDistribution(Map<Double, String> map, Map<Object, Integer> map2) {
        StringBuilder sb = new StringBuilder(512);
        Object[] sortedObjects = ClassifierAttributeEvaluation.getSortedObjects(map2.keySet());
        sb.append("Distribution:\n");
        int i = 0;
        int length = sortedObjects.length;
        for (int i2 = 0; i2 < length; i2 += doCrossValidation) {
            i += map2.get(sortedObjects[i2]).intValue();
        }
        int length2 = sortedObjects.length;
        for (int i3 = 0; i3 < length2; i3 += doCrossValidation) {
            Object obj = sortedObjects[i3];
            int intValue = map2.get(obj).intValue();
            sb.append("   class " + obj + ": " + String.format("%-25s", map.get(obj)) + "  " + intValue + "\t (" + f((100.0d * intValue) / i) + "%)\n");
        }
        this.classifierReport.addClassifierEvaluation(sb.toString());
    }

    private double printEvaluation(String str, Map<Double, String> map, Dataset dataset, String str2, Map<Object, Integer> map2, Map<Object, PerformanceMeasure> map3) {
        StringBuilder sb = new StringBuilder(512);
        sb.append(str + '\n');
        sb.append("Number of classes = " + map3.size() + '\n');
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        Object[] sortedObjects = ClassifierAttributeEvaluation.getSortedObjects(map3.keySet());
        int length = sortedObjects.length;
        for (int i4 = 0; i4 < length; i4 += doCrossValidation) {
            Object obj = sortedObjects[i4];
            PerformanceMeasure performanceMeasure = map3.get(obj);
            i2 = (int) (i2 + performanceMeasure.tp);
            i3 = (int) (i3 + performanceMeasure.tp + performanceMeasure.fn);
            i += map2.get(obj).intValue();
            sb.append("   class " + obj + ": " + String.format("%-25s", map.get(obj)) + '\n');
            sb.append("    accuracy = " + f(performanceMeasure.getAccuracy()) + " precision = " + f(performanceMeasure.getPrecision()) + " correlation = " + f(performanceMeasure.getCorrelation()) + " errorRate = " + f(performanceMeasure.getErrorRate()) + '\n');
            sb.append("    TruePositives = " + f(performanceMeasure.tp) + " FalsePositives = " + f(performanceMeasure.fp) + " TrueNegatives = " + f(performanceMeasure.tn) + " FalseNegatives = " + f(performanceMeasure.fn) + '\n');
        }
        if (i3 != dataset.size()) {
            throw new OperatorException("totalSamples = " + i3 + " dataset size = " + dataset.size());
        }
        double size = i2 / dataset.size();
        sb.append("\nUsing " + str2 + " dataset, % correct predictions = " + f(size * 100.0d) + '\n');
        sb.append("Total samples = " + i + '\n');
        this.classifierReport.addClassifierEvaluation(sb.toString());
        return size;
    }

    public void evaluateFeatures(BaseClassifier.FeatureInfo[] featureInfoArr, Dataset dataset, String str, ProgressMonitor progressMonitor) {
        StringBuilder sb = new StringBuilder(512);
        sb.append(str + " feature importance score:\n");
        ClassifierAttributeEvaluation.FeatureScore[] performEvaluation = new ClassifierAttributeEvaluation(this.mlClassifier, NUM_PERTURBATIONS, new MTRandom()).performEvaluation(dataset, progressMonitor);
        TreeMap treeMap = new TreeMap();
        for (int i = 0; i < featureInfoArr.length; i += doCrossValidation) {
            treeMap.put(Double.valueOf(performEvaluation[i].tp), Integer.valueOf(i));
        }
        GainRatio gainRatio = new GainRatio();
        gainRatio.build(dataset);
        sb.append("Each feature is perturbed 3 times and the % correct predictions are averaged\n");
        sb.append("The importance score is the original % correct prediction - average\n");
        Double d = (Double) treeMap.lastKey();
        int i2 = doCrossValidation;
        while (d != null) {
            int intValue = ((Integer) treeMap.get(d)).intValue();
            String str2 = "score: tp=" + f(performEvaluation[intValue].tp) + " accuracy=" + f(performEvaluation[intValue].accuracy) + " precision=" + f(performEvaluation[intValue].precision) + " correlation=" + f(performEvaluation[intValue].correlation) + " errorRate=" + f(performEvaluation[intValue].errorRate) + " cost=" + f(performEvaluation[intValue].cost) + " GainRatio = " + f(gainRatio.score(intValue));
            this.score.featureScoreMap.put(featureInfoArr[intValue].featureBand.getName(), str2);
            sb.append("   rank " + String.format("%-3d", Integer.valueOf(i2)) + " feature " + String.format("%-3d", Integer.valueOf(intValue + doCrossValidation)) + ": " + String.format("%-25s", featureInfoArr[intValue].featureBand.getName()) + str2 + '\n');
            d = (Double) treeMap.lowerKey(d);
            i2 += doCrossValidation;
        }
        if (i2 <= featureInfoArr.length) {
            sb.append("Warning: rank <= featureBandList.length\n");
        }
        sb.append('\n');
        this.classifierReport.addFeatureEvaluation(sb.toString());
    }

    private static Map<Object, Integer> getClassDistributionInDataset(List<Instance> list, Dataset dataset) {
        HashMap hashMap = new HashMap();
        Iterator it = dataset.classes().iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), 0);
        }
        Iterator<Instance> it2 = list.iterator();
        while (it2.hasNext()) {
            Object classValue = it2.next().classValue();
            hashMap.put(classValue, Integer.valueOf(((Integer) hashMap.get(classValue)).intValue() + doCrossValidation));
        }
        return hashMap;
    }
}
