/*
 * Decompiled with CFR 0.152.
 */
package org.esa.snap.cluster;

import com.bc.ceres.core.ProgressMonitor;
import com.bc.ceres.core.SubProgressMonitor;
import java.awt.Rectangle;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Map;
import java.util.stream.Stream;
import org.esa.snap.cluster.ClusterMetaDataUtils;
import org.esa.snap.cluster.EMCluster;
import org.esa.snap.cluster.EMClusterer;
import org.esa.snap.cluster.ProbabilityCalculator;
import org.esa.snap.cluster.Roi;
import org.esa.snap.core.datamodel.Band;
import org.esa.snap.core.datamodel.IndexCoding;
import org.esa.snap.core.datamodel.Mask;
import org.esa.snap.core.datamodel.MetadataElement;
import org.esa.snap.core.datamodel.Product;
import org.esa.snap.core.datamodel.ProductNode;
import org.esa.snap.core.datamodel.RasterDataNode;
import org.esa.snap.core.datamodel.SampleCoding;
import org.esa.snap.core.gpf.Operator;
import org.esa.snap.core.gpf.OperatorException;
import org.esa.snap.core.gpf.OperatorSpi;
import org.esa.snap.core.gpf.Tile;
import org.esa.snap.core.gpf.annotations.OperatorMetadata;
import org.esa.snap.core.gpf.annotations.Parameter;
import org.esa.snap.core.gpf.annotations.SourceProduct;
import org.esa.snap.core.gpf.annotations.TargetProduct;
import org.esa.snap.core.util.ProductUtils;

@OperatorMetadata(alias="EMClusterAnalysis", category="Raster/Image Analysis/Clustering", version="1.0", authors="Ralf Quast", copyright="(c) 2007 by Brockmann Consult", description="Performs an expectation-maximization (EM) cluster analysis.")
public class EMClusterOp
extends Operator {
    private static final int NO_DATA_VALUE = 255;
    @SourceProduct(alias="source", label="Source product", description="The source product")
    private Product sourceProduct;
    @TargetProduct
    private Product targetProduct;
    @Parameter(label="Number of clusters", description="Number of clusters", defaultValue="14", interval="(0,100]")
    private int clusterCount;
    @Parameter(label="Number of iterations", description="Number of iterations", defaultValue="30", interval="(0,10000]")
    private int iterationCount;
    @Parameter(label="Random seed", defaultValue="31415", description="Seed for the random generator, used for initialising the algorithm.")
    private int randomSeed;
    @Parameter(label="Source band names", description="The names of the bands being used for the cluster analysis.", rasterDataNodeType=Band.class)
    private String[] sourceBandNames;
    @Parameter(label="ROI-mask", description="The name of the ROI-Mask that should be used.", rasterDataNodeType=Mask.class)
    private String roiMaskName;
    @Parameter(label="Include probabilities", defaultValue="false", description="Determines whether the posterior probabilities are included as band data.")
    private boolean includeProbabilityBands;
    private transient Comparator<EMCluster> clusterComparator;
    private transient Band[] sourceBands;
    private transient Band clusterMapBand;
    private transient Band[] probabilityBands;
    private transient Roi roi;
    private transient MetadataElement clusterAnalysis;
    private volatile transient ProbabilityCalculator probabilityCalculator;

    public EMClusterOp() {
    }

    public EMClusterOp(Product sourceProduct, int clusterCount, int iterationCount, String[] sourceBandNames, boolean includeProbabilityBands, Comparator<EMCluster> clusterComparator) {
        this.sourceProduct = sourceProduct;
        this.clusterCount = clusterCount;
        this.iterationCount = iterationCount;
        this.sourceBandNames = sourceBandNames;
        this.includeProbabilityBands = includeProbabilityBands;
        this.clusterComparator = clusterComparator;
    }

    public void initialize() throws OperatorException {
        this.sourceBands = this.collectSourceBands();
        if (this.roiMaskName != null) {
            this.ensureSingleRasterSize((RasterDataNode[])Stream.concat(Arrays.stream(this.sourceBands), Stream.of(this.sourceProduct.getMaskGroup().get(this.roiMaskName))).toArray(Band[]::new));
        } else {
            this.ensureSingleRasterSize((RasterDataNode[])this.sourceBands);
        }
        int width = this.sourceBands[0].getRasterWidth();
        int height = this.sourceBands[0].getRasterHeight();
        String name = this.sourceProduct.getName() + "_CLUSTERS";
        String type = this.sourceProduct.getProductType() + "_CLUSTERS";
        this.targetProduct = new Product(name, type, width, height);
        if (this.sourceProduct.getSceneRasterSize().equals(this.sourceBands[0].getRasterSize())) {
            ProductUtils.copyTiePointGrids((Product)this.sourceProduct, (Product)this.targetProduct);
            ProductUtils.copyGeoCoding((Product)this.sourceProduct, (Product)this.targetProduct);
        }
        this.targetProduct.setStartTime(this.sourceProduct.getStartTime());
        this.targetProduct.setEndTime(this.sourceProduct.getEndTime());
        this.targetProduct.setPreferredTileSize(width, height);
        if (this.includeProbabilityBands) {
            this.createProbabilityBands();
        }
        this.clusterMapBand = new Band("class_indices", 20, width, height);
        this.clusterMapBand.setDescription("Class_indices");
        this.clusterMapBand.setNoDataValue(255.0);
        this.clusterMapBand.setNoDataValueUsed(true);
        this.targetProduct.addBand(this.clusterMapBand);
        IndexCoding indexCoding = new IndexCoding("Cluster_classes");
        for (int i = 0; i < this.clusterCount; ++i) {
            indexCoding.addIndex("class_" + (i + 1), i, "Cluster " + (i + 1));
        }
        this.targetProduct.getIndexCodingGroup().add((ProductNode)indexCoding);
        this.clusterMapBand.setSampleCoding((SampleCoding)indexCoding);
        this.clusterAnalysis = new MetadataElement("Cluster_Analysis");
        this.targetProduct.getMetadataRoot().addElement(this.clusterAnalysis);
        this.setTargetProduct(this.targetProduct);
    }

    public void dispose() {
        this.probabilityCalculator = null;
        super.dispose();
    }

    private Band[] collectSourceBands() {
        Band[] srcBands;
        if (this.sourceBandNames != null && this.sourceBandNames.length > 0) {
            srcBands = new Band[this.sourceBandNames.length];
            for (int i = 0; i < this.sourceBandNames.length; ++i) {
                Band sourceBand = this.sourceProduct.getBand(this.sourceBandNames[i]);
                if (sourceBand == null) {
                    throw new OperatorException("Source band not found: " + this.sourceBandNames[i]);
                }
                srcBands[i] = sourceBand;
            }
        } else {
            srcBands = this.sourceProduct.getBands();
        }
        return srcBands;
    }

    private void createProbabilityBands() {
        this.probabilityBands = new Band[this.clusterCount];
        for (int i = 0; i < this.clusterCount; ++i) {
            Band targetBand = this.targetProduct.addBand("probability_" + i, 30);
            targetBand.setUnit("dl");
            targetBand.setDescription("Cluster posterior probabilities");
            targetBand.setNoDataValue(255.0);
            targetBand.setNoDataValueUsed(true);
            this.probabilityBands[i] = targetBand;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void computeTileStack(Map<Band, Tile> targetTileMap, Rectangle targetRectangle, ProgressMonitor pm) throws OperatorException {
        int totalWork = targetRectangle.height;
        if (this.probabilityCalculator == null) {
            totalWork += 100;
        }
        pm.beginTask("Computing clusters...", totalWork);
        try {
            EMClusterOp eMClusterOp = this;
            synchronized (eMClusterOp) {
                if (this.probabilityCalculator == null) {
                    this.probabilityCalculator = this.performClustering(SubProgressMonitor.create((ProgressMonitor)pm, (int)100));
                }
            }
            Tile[] sourceTiles = new Tile[this.sourceBands.length];
            for (int i = 0; i < sourceTiles.length; ++i) {
                sourceTiles[i] = this.getSourceTile((RasterDataNode)this.sourceBands[i], targetRectangle);
            }
            Tile clusterMapTile = targetTileMap.get(this.clusterMapBand);
            Tile[] targetTiles = new Tile[this.clusterCount];
            if (this.includeProbabilityBands) {
                for (int i = 0; i < targetTiles.length; ++i) {
                    targetTiles[i] = targetTileMap.get(this.probabilityBands[i]);
                }
            }
            double[] point = new double[sourceTiles.length];
            double[] posteriors = new double[this.clusterCount];
            for (int y = targetRectangle.y; y < targetRectangle.y + targetRectangle.height; ++y) {
                this.checkForCancellation();
                for (int x = targetRectangle.x; x < targetRectangle.x + targetRectangle.width; ++x) {
                    int i;
                    if (this.roi == null || this.roi.contains(x, y)) {
                        for (i = 0; i < sourceTiles.length; ++i) {
                            point[i] = sourceTiles[i].getSampleDouble(x, y);
                        }
                        this.probabilityCalculator.calculate(point, posteriors);
                        if (this.includeProbabilityBands) {
                            for (i = 0; i < this.clusterCount; ++i) {
                                targetTiles[i].setSample(x, y, posteriors[i]);
                            }
                        }
                        clusterMapTile.setSample(x, y, EMClusterOp.findMaxIndex(posteriors));
                        continue;
                    }
                    if (this.includeProbabilityBands) {
                        for (i = 0; i < this.clusterCount; ++i) {
                            targetTiles[i].setSample(x, y, 255);
                        }
                    }
                    clusterMapTile.setSample(x, y, 255);
                }
                pm.worked(1);
            }
        }
        finally {
            pm.done();
        }
    }

    private static int findMaxIndex(double[] samples) {
        int index = 0;
        for (int i = 1; i < samples.length; ++i) {
            if (!(samples[i] > samples[index])) continue;
            index = i;
        }
        return index;
    }

    private synchronized ProbabilityCalculator performClustering(ProgressMonitor pm) {
        try {
            pm.beginTask("Extracting data points...", this.iterationCount + 100);
            this.roi = new Roi(this.sourceProduct, this.sourceBands, this.roiMaskName);
            EMClusterer clusterer = this.createClusterer(SubProgressMonitor.create((ProgressMonitor)pm, (int)100));
            for (int i = 0; i < this.iterationCount; ++i) {
                this.checkForCancellation();
                clusterer.iterate();
                pm.worked(1);
            }
            EMCluster[] clusters = this.clusterComparator == null ? clusterer.getClusters() : clusterer.getClusters(this.clusterComparator);
            double[][] means = new double[this.clusterCount][0];
            double[][][] covariances = new double[this.clusterCount][0][0];
            double[] priorProbabilities = new double[this.clusterCount];
            for (int i = 0; i < this.clusterCount; ++i) {
                means[i] = clusters[i].getMean();
                covariances[i] = clusters[i].getCovariances();
                priorProbabilities[i] = clusters[i].getPriorProbability();
            }
            ClusterMetaDataUtils.addCenterToIndexCoding(this.clusterMapBand.getIndexCoding(), this.sourceBands, means);
            ClusterMetaDataUtils.addCenterToMetadata(this.clusterAnalysis, this.sourceBands, means);
            ClusterMetaDataUtils.addEMInfoToMetadata(this.clusterAnalysis, covariances, priorProbabilities);
            ProbabilityCalculator probabilityCalculator = EMClusterer.createProbabilityCalculator(clusters);
            return probabilityCalculator;
        }
        catch (Exception e) {
            throw new OperatorException((Throwable)e);
        }
        finally {
            pm.done();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private EMClusterer createClusterer(ProgressMonitor pm) {
        int sceneWidth = this.sourceProduct.getSceneRasterWidth();
        int sceneHeight = this.sourceProduct.getSceneRasterHeight();
        int roiSize = 0;
        if (this.roi == null) {
            roiSize = sceneWidth * sceneHeight;
        } else {
            for (int y = 0; y < sceneHeight; ++y) {
                for (int x = 0; x < sceneWidth; ++x) {
                    if (!this.roi.contains(x, y)) continue;
                    ++roiSize;
                }
            }
        }
        if (roiSize < this.clusterCount) {
            throw new OperatorException("The combination of ROI and valid pixel masks contain " + roiSize + " pixel. These are too few to initialize the clustering.");
        }
        double[][] points = new double[roiSize][this.sourceBands.length];
        try {
            pm.beginTask("Extracting data points...", this.sourceBands.length * sceneHeight);
            for (int i = 0; i < this.sourceBands.length; ++i) {
                int index = 0;
                for (int y = 0; y < sceneHeight; ++y) {
                    this.checkForCancellation();
                    Tile sourceTile = this.getSourceTile((RasterDataNode)this.sourceBands[i], new Rectangle(0, y, sceneWidth, 1));
                    for (int x = 0; x < sceneWidth; ++x) {
                        double sample;
                        if (this.roi != null && !this.roi.contains(x, y)) continue;
                        points[index][i] = sample = sourceTile.getSampleDouble(x, y);
                        ++index;
                    }
                    pm.worked(1);
                }
            }
        }
        finally {
            pm.done();
        }
        return new EMClusterer(points, this.clusterCount, this.randomSeed);
    }

    public static class Spi
    extends OperatorSpi {
        public Spi() {
            super(EMClusterOp.class);
        }
    }
}

