/*
 * Decompiled with CFR 0.152.
 */
package org.esa.snap.classification.gpf;

import com.bc.ceres.core.ProgressMonitor;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.classification.evaluation.PerformanceMeasure;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.Instance;

public class ClassifierAttributeEvaluation {
    private final Classifier classifier;
    private final int numPerturbations;
    private final Random rg;

    public ClassifierAttributeEvaluation(Classifier classifier, int numPerturbations, Random rg) {
        this.classifier = classifier;
        this.numPerturbations = numPerturbations;
        this.rg = rg;
    }

    private static double getTPPercentage(Map<Object, PerformanceMeasure> performanceMeasureMap) {
        int totalTP = 0;
        int totalSamples = 0;
        for (Object o : performanceMeasureMap.keySet()) {
            PerformanceMeasure perMea = performanceMeasureMap.get(o);
            totalTP = (int)((double)totalTP + perMea.tp);
            totalSamples = (int)((double)totalSamples + (perMea.tp + perMea.fn));
        }
        return (double)totalTP / (double)totalSamples;
    }

    private static FeatureScore getFeatureScores(Map<Object, PerformanceMeasure> performanceMeasureMap) {
        double totalTP = 0.0;
        double totalSamples = 0.0;
        double tp = 0.0;
        double errorRate = 0.0;
        double accuracy = 0.0;
        double precision = 0.0;
        double correlation = 0.0;
        double cost = 0.0;
        double total = 0.0;
        int numSets = performanceMeasureMap.keySet().size();
        for (Object o : performanceMeasureMap.keySet()) {
            PerformanceMeasure pm = performanceMeasureMap.get(o);
            totalTP += pm.tp;
            totalSamples += pm.tp + pm.fn;
            errorRate += pm.getErrorRate();
            accuracy += pm.getAccuracy();
            if (!Double.isNaN(pm.getPrecision())) {
                precision += pm.getPrecision();
            }
            if (!Double.isNaN(pm.getCorrelation())) {
                correlation += pm.getCorrelation();
            }
            if (!Double.isNaN(pm.getCost())) {
                cost += pm.getCost();
            }
            total += pm.getTotal();
        }
        FeatureScore fs = new FeatureScore();
        fs.tp = totalTP / totalSamples;
        fs.accuracy = accuracy / (double)numSets;
        fs.precision = precision / (double)numSets;
        fs.correlation = correlation / (double)numSets;
        fs.errorRate = errorRate / (double)numSets;
        fs.cost = cost / (double)numSets;
        fs.total = total / (double)numSets;
        return fs;
    }

    public FeatureScore[] performEvaluation(Dataset dataset, ProgressMonitor pm) {
        Map<Object, PerformanceMeasure> performanceMeasureMap = ClassifierAttributeEvaluation.testDataset(this.classifier, dataset);
        FeatureScore orignalFS = ClassifierAttributeEvaluation.getFeatureScores(performanceMeasureMap);
        FeatureScore[] importanceScores = new FeatureScore[dataset.noAttributes()];
        pm.beginTask("Evaluating classifier... ", dataset.noAttributes());
        for (int i = 0; i < dataset.noAttributes(); ++i) {
            FeatureScore sumFS = new FeatureScore();
            for (int j = 0; j < this.numPerturbations; ++j) {
                DefaultDataset perturbed = new DefaultDataset();
                for (Instance inst : dataset) {
                    Instance per = inst.copy();
                    per.put((Object)i, (Object)this.rg.nextDouble());
                    perturbed.add(per);
                }
                Map<Object, PerformanceMeasure> perturbedPM = ClassifierAttributeEvaluation.testDataset(this.classifier, (Dataset)perturbed);
                FeatureScore perturbedFS = ClassifierAttributeEvaluation.getFeatureScores(perturbedPM);
                sumFS.tp += perturbedFS.tp;
                sumFS.accuracy += perturbedFS.accuracy;
                sumFS.precision += perturbedFS.precision;
                sumFS.correlation += perturbedFS.correlation;
                sumFS.errorRate += perturbedFS.errorRate;
                sumFS.cost += perturbedFS.cost;
            }
            FeatureScore avgFS = new FeatureScore();
            avgFS.tp = sumFS.tp / (double)this.numPerturbations;
            avgFS.accuracy = sumFS.accuracy / (double)this.numPerturbations;
            avgFS.precision = sumFS.precision / (double)this.numPerturbations;
            avgFS.correlation = sumFS.correlation / (double)this.numPerturbations;
            avgFS.errorRate = sumFS.errorRate / (double)this.numPerturbations;
            avgFS.cost = sumFS.cost / (double)this.numPerturbations;
            importanceScores[i] = new FeatureScore();
            importanceScores[i].tp = orignalFS.tp - avgFS.tp;
            importanceScores[i].accuracy = orignalFS.accuracy - avgFS.accuracy;
            importanceScores[i].precision = orignalFS.precision - avgFS.precision;
            importanceScores[i].correlation = orignalFS.correlation - avgFS.correlation;
            importanceScores[i].errorRate = orignalFS.errorRate - avgFS.errorRate;
            importanceScores[i].cost = orignalFS.cost - avgFS.cost;
            pm.worked(1);
        }
        pm.done();
        return importanceScores;
    }

    public static Map<Object, PerformanceMeasure> testDataset(Classifier classifier, Dataset dataset) {
        HashMap<Object, PerformanceMeasure> out = new HashMap<Object, PerformanceMeasure>();
        for (Object o : dataset.classes()) {
            out.put(o, new PerformanceMeasure());
        }
        for (Instance instance : dataset) {
            Object prediction = classifier.classify(instance);
            if (instance.classValue().equals(prediction)) {
                for (Object o : out.keySet()) {
                    if (o.equals(instance.classValue())) {
                        ((PerformanceMeasure)out.get(o)).tp += 1.0;
                        continue;
                    }
                    ((PerformanceMeasure)out.get(o)).tn += 1.0;
                }
                continue;
            }
            for (Object o : out.keySet()) {
                if (o.equals(prediction)) {
                    ((PerformanceMeasure)out.get(o)).fp += 1.0;
                    continue;
                }
                if (o.equals(instance.classValue())) {
                    ((PerformanceMeasure)out.get(o)).fn += 1.0;
                    continue;
                }
                ((PerformanceMeasure)out.get(o)).tn += 1.0;
            }
        }
        return out;
    }

    public static Object[] getSortedObjects(Set<Object> set) {
        Object[] a = new Object[set.size()];
        int idx = 0;
        for (Object o : set) {
            a[idx++] = o;
        }
        Arrays.sort(a);
        return a;
    }

    public static class FeatureScore {
        public double tp;
        public double accuracy;
        public double precision;
        public double correlation;
        public double errorRate;
        public double cost;
        public double total;
    }
}

