package org.esa.snap.classification.gpf;

import com.bc.ceres.core.ProgressMonitor;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.classification.evaluation.PerformanceMeasure;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.Instance;

/* loaded from: input_file:org/esa/snap/classification/gpf/ClassifierAttributeEvaluation.class */
public class ClassifierAttributeEvaluation {
    private final Classifier classifier;
    private final int numPerturbations;
    private final Random rg;

    /* loaded from: input_file:org/esa/snap/classification/gpf/ClassifierAttributeEvaluation$FeatureScore.class */
    public static class FeatureScore {
        public double tp;
        public double accuracy;
        public double precision;
        public double correlation;
        public double errorRate;
        public double cost;
        public double total;
    }

    public ClassifierAttributeEvaluation(Classifier classifier, int i, Random random) {
        this.classifier = classifier;
        this.numPerturbations = i;
        this.rg = random;
    }

    private static double getTPPercentage(Map<Object, PerformanceMeasure> map) {
        int i = 0;
        int i2 = 0;
        Iterator<Object> it = map.keySet().iterator();
        while (it.hasNext()) {
            PerformanceMeasure performanceMeasure = map.get(it.next());
            i = (int) (i + performanceMeasure.tp);
            i2 = (int) (i2 + performanceMeasure.tp + performanceMeasure.fn);
        }
        return i / i2;
    }

    private static FeatureScore getFeatureScores(Map<Object, PerformanceMeasure> map) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        double d6 = 0.0d;
        double d7 = 0.0d;
        double d8 = 0.0d;
        int size = map.keySet().size();
        Iterator<Object> it = map.keySet().iterator();
        while (it.hasNext()) {
            PerformanceMeasure performanceMeasure = map.get(it.next());
            d += performanceMeasure.tp;
            d2 += performanceMeasure.tp + performanceMeasure.fn;
            d3 += performanceMeasure.getErrorRate();
            d4 += performanceMeasure.getAccuracy();
            if (!Double.isNaN(performanceMeasure.getPrecision())) {
                d5 += performanceMeasure.getPrecision();
            }
            if (!Double.isNaN(performanceMeasure.getCorrelation())) {
                d6 += performanceMeasure.getCorrelation();
            }
            if (!Double.isNaN(performanceMeasure.getCost())) {
                d7 += performanceMeasure.getCost();
            }
            d8 += performanceMeasure.getTotal();
        }
        FeatureScore featureScore = new FeatureScore();
        featureScore.tp = d / d2;
        featureScore.accuracy = d4 / size;
        featureScore.precision = d5 / size;
        featureScore.correlation = d6 / size;
        featureScore.errorRate = d3 / size;
        featureScore.cost = d7 / size;
        featureScore.total = d8 / size;
        return featureScore;
    }

    public FeatureScore[] performEvaluation(Dataset dataset, ProgressMonitor progressMonitor) {
        FeatureScore featureScores = getFeatureScores(testDataset(this.classifier, dataset));
        FeatureScore[] featureScoreArr = new FeatureScore[dataset.noAttributes()];
        progressMonitor.beginTask("Evaluating classifier... ", dataset.noAttributes());
        for (int i = 0; i < dataset.noAttributes(); i++) {
            FeatureScore featureScore = new FeatureScore();
            for (int i2 = 0; i2 < this.numPerturbations; i2++) {
                DefaultDataset defaultDataset = new DefaultDataset();
                Iterator it = dataset.iterator();
                while (it.hasNext()) {
                    Instance copy = ((Instance) it.next()).copy();
                    copy.put(Integer.valueOf(i), Double.valueOf(this.rg.nextDouble()));
                    defaultDataset.add(copy);
                }
                FeatureScore featureScores2 = getFeatureScores(testDataset(this.classifier, defaultDataset));
                featureScore.tp += featureScores2.tp;
                featureScore.accuracy += featureScores2.accuracy;
                featureScore.precision += featureScores2.precision;
                featureScore.correlation += featureScores2.correlation;
                featureScore.errorRate += featureScores2.errorRate;
                featureScore.cost += featureScores2.cost;
            }
            FeatureScore featureScore2 = new FeatureScore();
            featureScore2.tp = featureScore.tp / this.numPerturbations;
            featureScore2.accuracy = featureScore.accuracy / this.numPerturbations;
            featureScore2.precision = featureScore.precision / this.numPerturbations;
            featureScore2.correlation = featureScore.correlation / this.numPerturbations;
            featureScore2.errorRate = featureScore.errorRate / this.numPerturbations;
            featureScore2.cost = featureScore.cost / this.numPerturbations;
            featureScoreArr[i] = new FeatureScore();
            featureScoreArr[i].tp = featureScores.tp - featureScore2.tp;
            featureScoreArr[i].accuracy = featureScores.accuracy - featureScore2.accuracy;
            featureScoreArr[i].precision = featureScores.precision - featureScore2.precision;
            featureScoreArr[i].correlation = featureScores.correlation - featureScore2.correlation;
            featureScoreArr[i].errorRate = featureScores.errorRate - featureScore2.errorRate;
            featureScoreArr[i].cost = featureScores.cost - featureScore2.cost;
            progressMonitor.worked(1);
        }
        progressMonitor.done();
        return featureScoreArr;
    }

    public static Map<Object, PerformanceMeasure> testDataset(Classifier classifier, Dataset dataset) {
        HashMap hashMap = new HashMap();
        Iterator it = dataset.classes().iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), new PerformanceMeasure());
        }
        Iterator it2 = dataset.iterator();
        while (it2.hasNext()) {
            Instance instance = (Instance) it2.next();
            Object classify = classifier.classify(instance);
            if (instance.classValue().equals(classify)) {
                for (Object obj : hashMap.keySet()) {
                    if (obj.equals(instance.classValue())) {
                        ((PerformanceMeasure) hashMap.get(obj)).tp += 1.0d;
                    } else {
                        ((PerformanceMeasure) hashMap.get(obj)).tn += 1.0d;
                    }
                }
            } else {
                for (Object obj2 : hashMap.keySet()) {
                    if (obj2.equals(classify)) {
                        ((PerformanceMeasure) hashMap.get(obj2)).fp += 1.0d;
                    } else if (obj2.equals(instance.classValue())) {
                        ((PerformanceMeasure) hashMap.get(obj2)).fn += 1.0d;
                    } else {
                        ((PerformanceMeasure) hashMap.get(obj2)).tn += 1.0d;
                    }
                }
            }
        }
        return hashMap;
    }

    public static Object[] getSortedObjects(Set<Object> set) {
        Object[] objArr = new Object[set.size()];
        int i = 0;
        Iterator<Object> it = set.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            objArr[i2] = it.next();
        }
        Arrays.sort(objArr);
        return objArr;
    }
}
