/*
 * Decompiled with CFR 0.152.
 */
package org.esa.snap.classification.gpf;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.classification.evaluation.PerformanceMeasure;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.Instance;

public class CrossValidation {
    private Classifier classifier;
    private double rmseSum;
    private int rmseTotal;
    private double biasSum1;
    private double biasSum2;

    public CrossValidation(Classifier classifier) {
        this.classifier = classifier;
    }

    public double getRMSE() {
        return this.rmseTotal == 0 ? 0.0 : Math.sqrt(this.rmseSum / (double)this.rmseTotal);
    }

    public double getBias() {
        return this.rmseTotal == 0 ? 0.0 : this.biasSum1 / (double)this.rmseTotal - this.biasSum2 / (double)this.rmseTotal;
    }

    public Map<Object, PerformanceMeasure> crossValidation(Dataset data, int numFolds, Random rg) {
        Dataset[] folds = data.folds(numFolds, rg);
        HashMap<Object, PerformanceMeasure> out = new HashMap<Object, PerformanceMeasure>();
        for (Object o : data.classes()) {
            out.put(o, new PerformanceMeasure());
        }
        this.rmseSum = 0.0;
        this.rmseTotal = 0;
        this.biasSum1 = 0.0;
        this.biasSum2 = 0.0;
        for (int i = 0; i < numFolds; ++i) {
            Dataset validation = folds[i];
            DefaultDataset training = new DefaultDataset();
            for (int j = 0; j < numFolds; ++j) {
                if (j == i) continue;
                training.addAll((Collection)folds[j]);
            }
            this.classifier.buildClassifier((Dataset)training);
            for (Instance instance : validation) {
                Object prediction = this.classifier.classify(instance);
                if (instance.classValue() instanceof Double) {
                    Double observed = (Double)instance.classValue();
                    Double predicted = (Double)prediction;
                    double diff = observed - predicted;
                    this.rmseSum += diff * diff;
                    ++this.rmseTotal;
                    this.biasSum1 += predicted.doubleValue();
                    this.biasSum2 += observed.doubleValue();
                }
                if (instance.classValue().equals(prediction)) {
                    for (Object o : out.keySet()) {
                        if (o.equals(instance.classValue())) {
                            ((PerformanceMeasure)out.get(o)).tp += 1.0;
                            continue;
                        }
                        ((PerformanceMeasure)out.get(o)).tn += 1.0;
                    }
                    continue;
                }
                for (Object o : out.keySet()) {
                    if (o.equals(prediction)) {
                        ((PerformanceMeasure)out.get(o)).fp += 1.0;
                        continue;
                    }
                    if (o.equals(instance.classValue())) {
                        ((PerformanceMeasure)out.get(o)).fn += 1.0;
                        continue;
                    }
                    ((PerformanceMeasure)out.get(o)).tn += 1.0;
                }
            }
        }
        return out;
    }

    public Map<Object, PerformanceMeasure> crossValidation(Dataset data, int folds) {
        return this.crossValidation(data, folds, new Random(System.currentTimeMillis()));
    }

    public Map<Object, PerformanceMeasure> crossValidation(Dataset data) {
        return this.crossValidation(data, 10);
    }
}

