package org.esa.snap.cluster;

import com.bc.ceres.core.ProgressMonitor;
import com.bc.ceres.core.SubProgressMonitor;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.Rectangle;
import java.util.Arrays;
import java.util.stream.Stream;
import org.esa.snap.core.datamodel.Band;
import org.esa.snap.core.datamodel.FlagCoding;
import org.esa.snap.core.datamodel.Mask;
import org.esa.snap.core.datamodel.MetadataAttribute;
import org.esa.snap.core.datamodel.MetadataElement;
import org.esa.snap.core.datamodel.Product;
import org.esa.snap.core.datamodel.ProductData;
import org.esa.snap.core.datamodel.RasterDataNode;
import org.esa.snap.core.gpf.Operator;
import org.esa.snap.core.gpf.OperatorException;
import org.esa.snap.core.gpf.OperatorSpi;
import org.esa.snap.core.gpf.Tile;
import org.esa.snap.core.gpf.annotations.OperatorMetadata;
import org.esa.snap.core.gpf.annotations.Parameter;
import org.esa.snap.core.gpf.annotations.SourceProduct;
import org.esa.snap.core.gpf.annotations.TargetProduct;
import org.esa.snap.core.image.ImageManager;
import org.esa.snap.core.util.ProductUtils;
import org.esa.snap.core.util.StringUtils;
import org.esa.snap.core.util.math.MathUtils;

@OperatorMetadata(alias = "PCA", category = "Raster/Image Analysis", version = "1.0", authors = "Norman Fomferra", copyright = "(c) 2013 by Brockmann Consult", description = "Performs a Principal Component Analysis.")
/* loaded from: input_file:org/esa/snap/cluster/PrincipalComponentAnalysisOp.class */
public class PrincipalComponentAnalysisOp extends Operator {

    @SourceProduct(alias = "source", label = "Source product", description = "The source product.")
    private Product sourceProduct;

    @TargetProduct
    private Product targetProduct;

    @Parameter(label = "Source band names", description = "The names of the bands being used for the cluster analysis.", rasterDataNodeType = Band.class)
    private String[] sourceBandNames;

    @Parameter(label = "Maximum component count", description = "The maximum number of principal components to compute.", defaultValue = "-1")
    private int componentCount;

    @Parameter(label = "ROI mask name", description = "The name of the ROI mask that should be used.", defaultValue = "", rasterDataNodeType = Mask.class)
    private String roiMaskName;

    @Parameter(label = "Remove non-ROI pixels", description = "Removes all non-ROI pixels in the target product.", defaultValue = "false")
    private boolean removeNonRoiPixels;
    private transient Roi roi;
    private transient Band[] sourceBands;
    private transient PrincipalComponentAnalysis pca;
    private transient Band[] componentBands;
    private transient Band responseBand;
    private transient Band errorBand;
    private transient Band flagsBand;

    /* loaded from: input_file:org/esa/snap/cluster/PrincipalComponentAnalysisOp$Spi.class */
    public static class Spi extends OperatorSpi {
        public Spi() {
            super(PrincipalComponentAnalysisOp.class);
        }
    }

    public void initialize() throws OperatorException {
        collectSourceBands();
        if (this.componentCount <= 0 || this.componentCount > this.sourceBands.length) {
            this.componentCount = this.sourceBands.length;
        }
        if (this.roiMaskName != null) {
            ensureSingleRasterSize((RasterDataNode[]) Stream.concat(Arrays.stream(this.sourceBands), Stream.of(this.sourceProduct.getMaskGroup().get(this.roiMaskName))).toArray(i -> {
                return new Band[i];
            }));
        } else {
            ensureSingleRasterSize(this.sourceBands);
        }
        this.targetProduct = new Product(this.sourceProduct.getName() + "_PCA", this.sourceProduct.getProductType() + "_PCA", this.sourceBands[0].getRasterWidth(), this.sourceBands[0].getRasterHeight());
        if (this.sourceProduct.getSceneRasterSize().equals(this.targetProduct.getSceneRasterSize())) {
            ProductUtils.copyTiePointGrids(this.sourceProduct, this.targetProduct);
            ProductUtils.copyGeoCoding(this.sourceProduct, this.targetProduct);
        }
        this.targetProduct.setStartTime(this.sourceProduct.getStartTime());
        this.targetProduct.setEndTime(this.sourceProduct.getEndTime());
        this.componentBands = new Band[this.componentCount];
        for (int i2 = 0; i2 < this.componentCount; i2++) {
            Band addBand = this.targetProduct.addBand("component_" + (i2 + 1), 30);
            ProductUtils.copySpectralBandProperties(this.sourceBands[i2], addBand);
            this.componentBands[i2] = addBand;
        }
        this.responseBand = this.targetProduct.addBand("response", 30);
        this.errorBand = this.targetProduct.addBand("error", 30);
        this.flagsBand = this.targetProduct.addBand("flags", 20);
        FlagCoding flagCoding = new FlagCoding("flags");
        flagCoding.addFlag("PCA_ROI_PIXEL", 1, "Pixel has been used to perform the PCA.");
        this.flagsBand.setSampleCoding(flagCoding);
        this.targetProduct.getFlagCodingGroup().add(flagCoding);
        this.targetProduct.addMask("pca_roi_pixel", "flags.PCA_ROI_PIXEL", "Pixel has been used to perform the PCA.", Color.RED, 0.5d);
        this.targetProduct.addMask("pca_non_roi_pixel", "!flags.PCA_ROI_PIXEL", "Pixel has not been used to perform the PCA.", Color.BLACK, 0.5d);
        if (this.removeNonRoiPixels) {
            for (Band band : this.componentBands) {
                band.setValidPixelExpression("flags.PCA_ROI_PIXEL");
            }
            this.responseBand.setValidPixelExpression("flags.PCA_ROI_PIXEL");
            this.errorBand.setValidPixelExpression("flags.PCA_ROI_PIXEL");
        }
        if (!StringUtils.isNullOrEmpty(this.roiMaskName) && this.sourceProduct.getMaskGroup().get(this.roiMaskName) == null) {
            throw new OperatorException("Missing required mask '" + this.roiMaskName + "' in source product.");
        }
        this.roi = new Roi(this.sourceProduct, this.sourceBands, this.roiMaskName);
        setTargetProduct(this.targetProduct);
    }

    public void doExecute(ProgressMonitor progressMonitor) {
        initPca(progressMonitor);
        MetadataElement createPcaMetadata = createPcaMetadata();
        MetadataElement metadataRoot = this.targetProduct.getMetadataRoot();
        MetadataElement element = metadataRoot.getElement(createPcaMetadata.getName());
        if (element != null) {
            metadataRoot.addElementAt(metadataRoot, metadataRoot.getElementIndex(element));
        } else {
            metadataRoot.addElement(createPcaMetadata);
        }
    }

    private MetadataElement createPcaMetadata() {
        MetadataElement metadataElement = new MetadataElement("MEAN_VECTOR");
        double[] meanVector = this.pca.getMeanVector();
        for (int i = 0; i < this.componentCount; i++) {
            metadataElement.addAttribute(new MetadataAttribute(this.sourceBands[i].getName(), ProductData.createInstance(new double[]{meanVector[i]}), true));
        }
        MetadataElement metadataElement2 = new MetadataElement("BASIS_VECTORS");
        for (int i2 = 0; i2 < this.componentCount; i2++) {
            metadataElement2.addAttribute(new MetadataAttribute("component_" + (i2 + 1), ProductData.createInstance(this.pca.getBasisVector(i2)), true));
        }
        MetadataElement metadataElement3 = new MetadataElement("PCA_RESULT");
        metadataElement3.addElement(metadataElement);
        metadataElement3.addElement(metadataElement2);
        return metadataElement3;
    }

    private Band[] collectSourceBands() {
        if (this.sourceBandNames == null || this.sourceBandNames.length <= 0) {
            this.sourceBands = this.sourceProduct.getBands();
        } else {
            this.sourceBands = new Band[this.sourceBandNames.length];
            for (int i = 0; i < this.sourceBandNames.length; i++) {
                Band band = this.sourceProduct.getBand(this.sourceBandNames[i]);
                if (band == null) {
                    throw new OperatorException("source band not found: " + this.sourceBandNames[i]);
                }
                this.sourceBands[i] = band;
            }
        }
        return this.sourceBands;
    }

    public void computeTile(Band band, Tile tile, ProgressMonitor progressMonitor) throws OperatorException {
        progressMonitor.beginTask("Computing component...", tile.getHeight());
        try {
            int componentIndex = getComponentIndex(band);
            Rectangle rectangle = tile.getRectangle();
            Tile[] tileArr = new Tile[this.sourceBands.length];
            for (int i = 0; i < tileArr.length; i++) {
                tileArr[i] = getSourceTile(this.sourceBands[i], rectangle);
            }
            double[] dArr = new double[tileArr.length];
            for (int i2 = rectangle.y; i2 < rectangle.y + rectangle.height; i2++) {
                checkForCancellation();
                for (int i3 = rectangle.x; i3 < rectangle.x + rectangle.width; i3++) {
                    for (int i4 = 0; i4 < tileArr.length; i4++) {
                        dArr[i4] = tileArr[i4].getSampleDouble(i3, i2);
                    }
                    if (componentIndex >= 0) {
                        tile.setSample(i3, i2, this.pca.sampleToEigenSpace(dArr)[componentIndex]);
                    } else if (band == this.responseBand) {
                        tile.setSample(i3, i2, this.pca.response(dArr));
                    } else if (band == this.errorBand) {
                        tile.setSample(i3, i2, this.pca.errorMembership(dArr));
                    } else if (band == this.flagsBand) {
                        tile.setSample(i3, i2, this.roi.contains(i3, i2) ? 1 : 0);
                    }
                }
                progressMonitor.worked(1);
            }
        } finally {
            progressMonitor.done();
        }
    }

    public void dispose() {
        this.pca = null;
        this.targetProduct = null;
        this.componentBands = null;
        this.responseBand = null;
        this.flagsBand = null;
    }

    private int getComponentIndex(Band band) {
        for (int i = 0; i < this.componentBands.length; i++) {
            if (this.componentBands[i] == band) {
                return i;
            }
        }
        return -1;
    }

    private synchronized void initPca(ProgressMonitor progressMonitor) {
        Rectangle[] allTileRectangles = getAllTileRectangles();
        progressMonitor.beginTask("Extracting data points...", allTileRectangles.length * 2);
        try {
            int length = this.sourceBands.length;
            double[] dArr = new double[length];
            double[] dArr2 = new double[10000 * length];
            int i = 0;
            for (Rectangle rectangle : allTileRectangles) {
                PixelIter createPixelIter = createPixelIter(rectangle, SubProgressMonitor.create(progressMonitor, 1));
                while (createPixelIter.next(dArr) != null) {
                    if (dArr2.length < (i + 1) * length) {
                        double[] dArr3 = dArr2;
                        dArr2 = new double[2 * dArr2.length];
                        System.arraycopy(dArr3, 0, dArr2, 0, dArr3.length);
                    }
                    System.arraycopy(dArr, 0, dArr2, i * length, length);
                    i++;
                }
                progressMonitor.worked(1);
            }
            if (dArr2.length > i * length) {
                double[] dArr4 = dArr2;
                dArr2 = new double[i * length];
                System.arraycopy(dArr4, 0, dArr2, 0, dArr2.length);
            }
            this.pca = new PrincipalComponentAnalysis(length);
            this.pca.computeBasis(dArr2, this.componentCount);
            progressMonitor.worked(allTileRectangles.length);
            progressMonitor.done();
        } catch (Throwable th) {
            progressMonitor.done();
            throw th;
        }
    }

    private Rectangle[] getAllTileRectangles() {
        Dimension preferredTileSize = ImageManager.getPreferredTileSize(this.sourceProduct);
        Rectangle rectangle = new Rectangle(this.sourceProduct.getSceneRasterWidth(), this.sourceProduct.getSceneRasterHeight());
        int ceilInt = MathUtils.ceilInt(rectangle.width / preferredTileSize.width);
        int ceilInt2 = MathUtils.ceilInt(rectangle.height / preferredTileSize.height);
        Rectangle[] rectangleArr = new Rectangle[ceilInt * ceilInt2];
        int i = 0;
        for (int i2 = 0; i2 < ceilInt2; i2++) {
            for (int i3 = 0; i3 < ceilInt; i3++) {
                rectangleArr[i] = rectangle.intersection(new Rectangle(i3 * preferredTileSize.width, i2 * preferredTileSize.height, preferredTileSize.width, preferredTileSize.height));
                i++;
            }
        }
        return rectangleArr;
    }

    private PixelIter createPixelIter(Rectangle rectangle, ProgressMonitor progressMonitor) {
        Tile[] tileArr = new Tile[this.sourceBands.length];
        try {
            progressMonitor.beginTask("Extracting data points...", this.sourceBands.length);
            for (int i = 0; i < this.sourceBands.length; i++) {
                tileArr[i] = getSourceTile(this.sourceBands[i], rectangle);
                progressMonitor.worked(1);
            }
            return new PixelIter(tileArr, this.roi);
        } finally {
            progressMonitor.done();
        }
    }
}
