/*
 * Decompiled with CFR 0.152.
 */
package org.esa.snap.cluster;

import com.bc.ceres.core.ProgressMonitor;
import com.bc.ceres.core.SubProgressMonitor;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.Rectangle;
import java.util.Arrays;
import java.util.stream.Stream;
import org.esa.snap.cluster.PixelIter;
import org.esa.snap.cluster.PrincipalComponentAnalysis;
import org.esa.snap.cluster.Roi;
import org.esa.snap.core.datamodel.Band;
import org.esa.snap.core.datamodel.FlagCoding;
import org.esa.snap.core.datamodel.Mask;
import org.esa.snap.core.datamodel.MetadataAttribute;
import org.esa.snap.core.datamodel.MetadataElement;
import org.esa.snap.core.datamodel.Product;
import org.esa.snap.core.datamodel.ProductData;
import org.esa.snap.core.datamodel.ProductNode;
import org.esa.snap.core.datamodel.RasterDataNode;
import org.esa.snap.core.datamodel.SampleCoding;
import org.esa.snap.core.gpf.Operator;
import org.esa.snap.core.gpf.OperatorException;
import org.esa.snap.core.gpf.OperatorSpi;
import org.esa.snap.core.gpf.Tile;
import org.esa.snap.core.gpf.annotations.OperatorMetadata;
import org.esa.snap.core.gpf.annotations.Parameter;
import org.esa.snap.core.gpf.annotations.SourceProduct;
import org.esa.snap.core.gpf.annotations.TargetProduct;
import org.esa.snap.core.image.ImageManager;
import org.esa.snap.core.util.ProductUtils;
import org.esa.snap.core.util.StringUtils;
import org.esa.snap.core.util.math.MathUtils;

@OperatorMetadata(alias="PCA", category="Raster/Image Analysis", version="1.0", authors="Norman Fomferra", copyright="(c) 2013 by Brockmann Consult", description="Performs a Principal Component Analysis.")
public class PrincipalComponentAnalysisOp
extends Operator {
    @SourceProduct(alias="source", label="Source product", description="The source product.")
    private Product sourceProduct;
    @TargetProduct
    private Product targetProduct;
    @Parameter(label="Source band names", description="The names of the bands being used for the cluster analysis.", rasterDataNodeType=Band.class)
    private String[] sourceBandNames;
    @Parameter(label="Maximum component count", description="The maximum number of principal components to compute.", defaultValue="-1")
    private int componentCount;
    @Parameter(label="ROI mask name", description="The name of the ROI mask that should be used.", defaultValue="", rasterDataNodeType=Mask.class)
    private String roiMaskName;
    @Parameter(label="Remove non-ROI pixels", description="Removes all non-ROI pixels in the target product.", defaultValue="false")
    private boolean removeNonRoiPixels;
    private transient Roi roi;
    private transient Band[] sourceBands;
    private transient PrincipalComponentAnalysis pca;
    private transient Band[] componentBands;
    private transient Band responseBand;
    private transient Band errorBand;
    private transient Band flagsBand;

    public void initialize() throws OperatorException {
        this.collectSourceBands();
        if (this.componentCount <= 0 || this.componentCount > this.sourceBands.length) {
            this.componentCount = this.sourceBands.length;
        }
        if (this.roiMaskName != null) {
            this.ensureSingleRasterSize((RasterDataNode[])Stream.concat(Arrays.stream(this.sourceBands), Stream.of(this.sourceProduct.getMaskGroup().get(this.roiMaskName))).toArray(Band[]::new));
        } else {
            this.ensureSingleRasterSize((RasterDataNode[])this.sourceBands);
        }
        int width = this.sourceBands[0].getRasterWidth();
        int height = this.sourceBands[0].getRasterHeight();
        String name = this.sourceProduct.getName() + "_PCA";
        String type = this.sourceProduct.getProductType() + "_PCA";
        this.targetProduct = new Product(name, type, width, height);
        if (this.sourceProduct.getSceneRasterSize().equals(this.targetProduct.getSceneRasterSize())) {
            ProductUtils.copyTiePointGrids((Product)this.sourceProduct, (Product)this.targetProduct);
            ProductUtils.copyGeoCoding((Product)this.sourceProduct, (Product)this.targetProduct);
        }
        this.targetProduct.setStartTime(this.sourceProduct.getStartTime());
        this.targetProduct.setEndTime(this.sourceProduct.getEndTime());
        this.componentBands = new Band[this.componentCount];
        for (int i = 0; i < this.componentCount; ++i) {
            Band componentBand = this.targetProduct.addBand("component_" + (i + 1), 30);
            ProductUtils.copySpectralBandProperties((Band)this.sourceBands[i], (Band)componentBand);
            this.componentBands[i] = componentBand;
        }
        this.responseBand = this.targetProduct.addBand("response", 30);
        this.errorBand = this.targetProduct.addBand("error", 30);
        this.flagsBand = this.targetProduct.addBand("flags", 20);
        FlagCoding flags = new FlagCoding("flags");
        flags.addFlag("PCA_ROI_PIXEL", 1, "Pixel has been used to perform the PCA.");
        this.flagsBand.setSampleCoding((SampleCoding)flags);
        this.targetProduct.getFlagCodingGroup().add((ProductNode)flags);
        this.targetProduct.addMask("pca_roi_pixel", "flags.PCA_ROI_PIXEL", "Pixel has been used to perform the PCA.", Color.RED, 0.5);
        this.targetProduct.addMask("pca_non_roi_pixel", "!flags.PCA_ROI_PIXEL", "Pixel has not been used to perform the PCA.", Color.BLACK, 0.5);
        if (this.removeNonRoiPixels) {
            for (Band componentBand : this.componentBands) {
                componentBand.setValidPixelExpression("flags.PCA_ROI_PIXEL");
            }
            this.responseBand.setValidPixelExpression("flags.PCA_ROI_PIXEL");
            this.errorBand.setValidPixelExpression("flags.PCA_ROI_PIXEL");
        }
        if (!StringUtils.isNullOrEmpty((String)this.roiMaskName) && this.sourceProduct.getMaskGroup().get(this.roiMaskName) == null) {
            throw new OperatorException("Missing required mask '" + this.roiMaskName + "' in source product.");
        }
        this.roi = new Roi(this.sourceProduct, this.sourceBands, this.roiMaskName);
        this.setTargetProduct(this.targetProduct);
    }

    public void doExecute(ProgressMonitor pm) {
        this.initPca(pm);
        MetadataElement pcaMetadata = this.createPcaMetadata();
        MetadataElement metadataRoot = this.targetProduct.getMetadataRoot();
        MetadataElement element = metadataRoot.getElement(pcaMetadata.getName());
        if (element != null) {
            int elementIndex = metadataRoot.getElementIndex(element);
            metadataRoot.addElementAt(metadataRoot, elementIndex);
        } else {
            metadataRoot.addElement(pcaMetadata);
        }
    }

    private MetadataElement createPcaMetadata() {
        MetadataElement meanVectorElement = new MetadataElement("MEAN_VECTOR");
        double[] meanVector = this.pca.getMeanVector();
        for (int i = 0; i < this.componentCount; ++i) {
            meanVectorElement.addAttribute(new MetadataAttribute(this.sourceBands[i].getName(), ProductData.createInstance((double[])new double[]{meanVector[i]}), true));
        }
        MetadataElement basisVectorsElement = new MetadataElement("BASIS_VECTORS");
        for (int i = 0; i < this.componentCount; ++i) {
            double[] basisVector = this.pca.getBasisVector(i);
            basisVectorsElement.addAttribute(new MetadataAttribute("component_" + (i + 1), ProductData.createInstance((double[])basisVector), true));
        }
        MetadataElement pcaAnalysisMD = new MetadataElement("PCA_RESULT");
        pcaAnalysisMD.addElement(meanVectorElement);
        pcaAnalysisMD.addElement(basisVectorsElement);
        return pcaAnalysisMD;
    }

    private Band[] collectSourceBands() {
        if (this.sourceBandNames != null && this.sourceBandNames.length > 0) {
            this.sourceBands = new Band[this.sourceBandNames.length];
            for (int i = 0; i < this.sourceBandNames.length; ++i) {
                Band sourceBand = this.sourceProduct.getBand(this.sourceBandNames[i]);
                if (sourceBand == null) {
                    throw new OperatorException("source band not found: " + this.sourceBandNames[i]);
                }
                this.sourceBands[i] = sourceBand;
            }
        } else {
            this.sourceBands = this.sourceProduct.getBands();
        }
        return this.sourceBands;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void computeTile(Band targetBand, Tile targetTile, ProgressMonitor pm) throws OperatorException {
        pm.beginTask("Computing component...", targetTile.getHeight());
        try {
            int componentIndex = this.getComponentIndex(targetBand);
            Rectangle targetRectangle = targetTile.getRectangle();
            Tile[] sourceTiles = new Tile[this.sourceBands.length];
            for (int i = 0; i < sourceTiles.length; ++i) {
                sourceTiles[i] = this.getSourceTile((RasterDataNode)this.sourceBands[i], targetRectangle);
            }
            double[] point = new double[sourceTiles.length];
            for (int y = targetRectangle.y; y < targetRectangle.y + targetRectangle.height; ++y) {
                this.checkForCancellation();
                for (int x = targetRectangle.x; x < targetRectangle.x + targetRectangle.width; ++x) {
                    for (int i = 0; i < sourceTiles.length; ++i) {
                        point[i] = sourceTiles[i].getSampleDouble(x, y);
                    }
                    if (componentIndex >= 0) {
                        double[] eigenPoint = this.pca.sampleToEigenSpace(point);
                        targetTile.setSample(x, y, eigenPoint[componentIndex]);
                        continue;
                    }
                    if (targetBand == this.responseBand) {
                        double response = this.pca.response(point);
                        targetTile.setSample(x, y, response);
                        continue;
                    }
                    if (targetBand == this.errorBand) {
                        double error = this.pca.errorMembership(point);
                        targetTile.setSample(x, y, error);
                        continue;
                    }
                    if (targetBand != this.flagsBand) continue;
                    boolean roiPixel = this.roi.contains(x, y);
                    targetTile.setSample(x, y, roiPixel ? 1 : 0);
                }
                pm.worked(1);
            }
        }
        finally {
            pm.done();
        }
    }

    public void dispose() {
        this.pca = null;
        this.targetProduct = null;
        this.componentBands = null;
        this.responseBand = null;
        this.flagsBand = null;
    }

    private int getComponentIndex(Band targetBand) {
        for (int i = 0; i < this.componentBands.length; ++i) {
            if (this.componentBands[i] != targetBand) continue;
            return i;
        }
        return -1;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private synchronized void initPca(ProgressMonitor pm) {
        Rectangle[] tileRectangles = this.getAllTileRectangles();
        pm.beginTask("Extracting data points...", tileRectangles.length * 2);
        try {
            int pointSize = this.sourceBands.length;
            double[] point = new double[pointSize];
            double[] pointData = new double[10000 * pointSize];
            int pointCount = 0;
            for (Rectangle rectangle : tileRectangles) {
                PixelIter iter = this.createPixelIter(rectangle, SubProgressMonitor.create((ProgressMonitor)pm, (int)1));
                while (iter.next(point) != null) {
                    if (pointData.length < (pointCount + 1) * pointSize) {
                        double[] tmp = pointData;
                        pointData = new double[2 * pointData.length];
                        System.arraycopy(tmp, 0, pointData, 0, tmp.length);
                    }
                    System.arraycopy(point, 0, pointData, pointCount * pointSize, pointSize);
                    ++pointCount;
                }
                pm.worked(1);
            }
            if (pointData.length > pointCount * pointSize) {
                double[] tmp = pointData;
                pointData = new double[pointCount * pointSize];
                System.arraycopy(tmp, 0, pointData, 0, pointData.length);
            }
            this.pca = new PrincipalComponentAnalysis(pointSize);
            this.pca.computeBasis(pointData, this.componentCount);
            pm.worked(tileRectangles.length);
        }
        finally {
            pm.done();
        }
    }

    private Rectangle[] getAllTileRectangles() {
        Dimension tileSize = ImageManager.getPreferredTileSize((Product)this.sourceProduct);
        int rasterHeight = this.sourceProduct.getSceneRasterHeight();
        int rasterWidth = this.sourceProduct.getSceneRasterWidth();
        Rectangle boundary = new Rectangle(rasterWidth, rasterHeight);
        int tileCountX = MathUtils.ceilInt((double)((double)boundary.width / (double)tileSize.width));
        int tileCountY = MathUtils.ceilInt((double)((double)boundary.height / (double)tileSize.height));
        Rectangle[] rectangles = new Rectangle[tileCountX * tileCountY];
        int index = 0;
        for (int tileY = 0; tileY < tileCountY; ++tileY) {
            for (int tileX = 0; tileX < tileCountX; ++tileX) {
                Rectangle intersection;
                Rectangle tileRectangle = new Rectangle(tileX * tileSize.width, tileY * tileSize.height, tileSize.width, tileSize.height);
                rectangles[index] = intersection = boundary.intersection(tileRectangle);
                ++index;
            }
        }
        return rectangles;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private PixelIter createPixelIter(Rectangle rectangle, ProgressMonitor pm) {
        Tile[] sourceTiles = new Tile[this.sourceBands.length];
        try {
            pm.beginTask("Extracting data points...", this.sourceBands.length);
            for (int i = 0; i < this.sourceBands.length; ++i) {
                sourceTiles[i] = this.getSourceTile((RasterDataNode)this.sourceBands[i], rectangle);
                pm.worked(1);
            }
        }
        finally {
            pm.done();
        }
        return new PixelIter(sourceTiles, this.roi);
    }

    public static class Spi
    extends OperatorSpi {
        public Spi() {
            super(PrincipalComponentAnalysisOp.class);
        }
    }
}

