package org.esa.snap.binning.aggregators;

import org.esa.snap.binning.BinContext;
import org.esa.snap.binning.MyVariableContext;
import org.esa.snap.binning.operator.TestUtils;
import org.esa.snap.binning.support.VectorImpl;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/esa/snap/binning/aggregators/AggregatorAverageMLTest.class */
public class AggregatorAverageMLTest {
    BinContext ctx;

    @Before
    public void setUp() throws Exception {
        this.ctx = AggregatorTestUtils.createCtx();
    }

    @Test
    public void testMetadata_noSums() {
        AggregatorAverageML aggregatorAverageML = new AggregatorAverageML(new MyVariableContext("b"), "b", "b", 0.5d, false);
        Assert.assertEquals("AVG_ML", aggregatorAverageML.getName());
        Assert.assertEquals(2L, aggregatorAverageML.getSpatialFeatureNames().length);
        Assert.assertEquals("b_sum", aggregatorAverageML.getSpatialFeatureNames()[0]);
        Assert.assertEquals("b_sum_sq", aggregatorAverageML.getSpatialFeatureNames()[1]);
        Assert.assertEquals(3L, aggregatorAverageML.getTemporalFeatureNames().length);
        Assert.assertEquals("b_sum", aggregatorAverageML.getTemporalFeatureNames()[0]);
        Assert.assertEquals("b_sum_sq", aggregatorAverageML.getTemporalFeatureNames()[1]);
        Assert.assertEquals("b_weights", aggregatorAverageML.getTemporalFeatureNames()[2]);
        Assert.assertEquals(4L, aggregatorAverageML.getOutputFeatureNames().length);
        Assert.assertEquals("b_mean", aggregatorAverageML.getOutputFeatureNames()[0]);
        Assert.assertEquals("b_sigma", aggregatorAverageML.getOutputFeatureNames()[1]);
        Assert.assertEquals("b_median", aggregatorAverageML.getOutputFeatureNames()[2]);
        Assert.assertEquals("b_mode", aggregatorAverageML.getOutputFeatureNames()[3]);
    }

    @Test
    public void testMetadata_withSums() {
        AggregatorAverageML aggregatorAverageML = new AggregatorAverageML(new MyVariableContext("b"), "b", "b", 0.5d, true);
        Assert.assertEquals("AVG_ML", aggregatorAverageML.getName());
        Assert.assertEquals(2L, aggregatorAverageML.getSpatialFeatureNames().length);
        Assert.assertEquals("b_sum", aggregatorAverageML.getSpatialFeatureNames()[0]);
        Assert.assertEquals("b_sum_sq", aggregatorAverageML.getSpatialFeatureNames()[1]);
        Assert.assertEquals(3L, aggregatorAverageML.getTemporalFeatureNames().length);
        Assert.assertEquals("b_sum", aggregatorAverageML.getTemporalFeatureNames()[0]);
        Assert.assertEquals("b_sum_sq", aggregatorAverageML.getTemporalFeatureNames()[1]);
        Assert.assertEquals("b_weights", aggregatorAverageML.getTemporalFeatureNames()[2]);
        Assert.assertEquals(3L, aggregatorAverageML.getOutputFeatureNames().length);
        Assert.assertEquals("b_sum", aggregatorAverageML.getOutputFeatureNames()[0]);
        Assert.assertEquals("b_sum_sq", aggregatorAverageML.getOutputFeatureNames()[1]);
        Assert.assertEquals("b_weights", aggregatorAverageML.getOutputFeatureNames()[2]);
    }

    @Test
    public void testAggregatorAverageML() {
        AggregatorAverageML aggregatorAverageML = new AggregatorAverageML(new MyVariableContext("b"), "b", 0.5d);
        VectorImpl vec = AggregatorTestUtils.vec(Float.NaN, Float.NaN);
        VectorImpl vec2 = AggregatorTestUtils.vec(Float.NaN, Float.NaN, Float.NaN);
        VectorImpl vec3 = AggregatorTestUtils.vec(Float.NaN, Float.NaN, Float.NaN, Float.NaN);
        aggregatorAverageML.initSpatial(this.ctx, vec);
        Assert.assertEquals(TestUtils.WESTERN_LON, vec.get(0), TestUtils.WESTERN_LON);
        Assert.assertEquals(TestUtils.WESTERN_LON, vec.get(1), TestUtils.WESTERN_LON);
        aggregatorAverageML.aggregateSpatial(this.ctx, AggregatorTestUtils.obsNT(1.5f), vec);
        aggregatorAverageML.aggregateSpatial(this.ctx, AggregatorTestUtils.obsNT(2.5f), vec);
        aggregatorAverageML.aggregateSpatial(this.ctx, AggregatorTestUtils.obsNT(0.5f), vec);
        Assert.assertEquals(Math.log(1.5d) + Math.log(2.5d) + Math.log(0.5d), vec.get(0), 1.0E-5d);
        Assert.assertEquals((Math.log(1.5d) * Math.log(1.5d)) + (Math.log(2.5d) * Math.log(2.5d)) + (Math.log(0.5d) * Math.log(0.5d)), vec.get(1), 9.999999747378752E-6d);
        aggregatorAverageML.completeSpatial(this.ctx, 3, vec);
        Assert.assertEquals(((Math.log(1.5d) + Math.log(2.5d)) + Math.log(0.5d)) / Math.sqrt(3.0d), vec.get(0), 9.999999747378752E-6d);
        Assert.assertEquals((((Math.log(1.5d) * Math.log(1.5d)) + (Math.log(2.5d) * Math.log(2.5d))) + (Math.log(0.5d) * Math.log(0.5d))) / Math.sqrt(3.0d), vec.get(1), 9.999999747378752E-6d);
        aggregatorAverageML.initTemporal(this.ctx, vec2);
        Assert.assertEquals(TestUtils.WESTERN_LON, vec2.get(0), TestUtils.WESTERN_LON);
        Assert.assertEquals(TestUtils.WESTERN_LON, vec2.get(1), TestUtils.WESTERN_LON);
        Assert.assertEquals(TestUtils.WESTERN_LON, vec2.get(2), TestUtils.WESTERN_LON);
        aggregatorAverageML.aggregateTemporal(this.ctx, AggregatorTestUtils.vec(0.3f, 0.09f), 3, vec2);
        aggregatorAverageML.aggregateTemporal(this.ctx, AggregatorTestUtils.vec(0.1f, 0.01f), 2, vec2);
        aggregatorAverageML.aggregateTemporal(this.ctx, AggregatorTestUtils.vec(0.2f, 0.04f), 1, vec2);
        aggregatorAverageML.aggregateTemporal(this.ctx, AggregatorTestUtils.vec(0.1f, 0.01f), 7, vec2);
        Assert.assertEquals(0.7000000476837158d, vec2.get(0), 1.0E-5d);
        Assert.assertEquals(0.15000000596046448d, vec2.get(1), 1.0E-5d);
        Assert.assertEquals(Math.sqrt(3.0d) + Math.sqrt(2.0d) + Math.sqrt(1.0d) + Math.sqrt(7.0d), vec2.get(2), 1.0E-5d);
        aggregatorAverageML.completeTemporal(this.ctx, 4, vec2);
        aggregatorAverageML.computeOutput(vec2, vec3);
        Assert.assertEquals(1.114932f, vec3.get(0), 1.0E-5f);
        Assert.assertEquals(0.119713f, vec3.get(1), 1.0E-5f);
        Assert.assertEquals(1.10856f, vec3.get(2), 1.0E-5f);
        Assert.assertEquals(1.095926f, vec3.get(3), 1.0E-5f);
    }

    @Test
    public void testAggregatorAverageML_WithSums() {
        AggregatorAverageML aggregatorAverageML = new AggregatorAverageML(new MyVariableContext("b"), "b", "b", 0.5d, true);
        VectorImpl vec = AggregatorTestUtils.vec(Float.NaN, Float.NaN);
        VectorImpl vec2 = AggregatorTestUtils.vec(Float.NaN, Float.NaN, Float.NaN);
        VectorImpl vec3 = AggregatorTestUtils.vec(Float.NaN, Float.NaN, Float.NaN);
        aggregatorAverageML.initSpatial(this.ctx, vec);
        Assert.assertEquals(TestUtils.WESTERN_LON, vec.get(0), TestUtils.WESTERN_LON);
        Assert.assertEquals(TestUtils.WESTERN_LON, vec.get(1), TestUtils.WESTERN_LON);
        aggregatorAverageML.aggregateSpatial(this.ctx, AggregatorTestUtils.obsNT(1.5f), vec);
        aggregatorAverageML.aggregateSpatial(this.ctx, AggregatorTestUtils.obsNT(2.5f), vec);
        aggregatorAverageML.aggregateSpatial(this.ctx, AggregatorTestUtils.obsNT(0.5f), vec);
        double log = Math.log(1.5d) + Math.log(2.5d) + Math.log(0.5d);
        double log2 = (Math.log(1.5d) * Math.log(1.5d)) + (Math.log(2.5d) * Math.log(2.5d)) + (Math.log(0.5d) * Math.log(0.5d));
        Assert.assertEquals(log, vec.get(0), 1.0E-5d);
        Assert.assertEquals(log2, vec.get(1), 9.999999747378752E-6d);
        aggregatorAverageML.completeSpatial(this.ctx, 3, vec);
        Assert.assertEquals(((Math.log(1.5d) + Math.log(2.5d)) + Math.log(0.5d)) / Math.sqrt(3.0d), vec.get(0), 9.999999747378752E-6d);
        Assert.assertEquals((((Math.log(1.5d) * Math.log(1.5d)) + (Math.log(2.5d) * Math.log(2.5d))) + (Math.log(0.5d) * Math.log(0.5d))) / Math.sqrt(3.0d), vec.get(1), 9.999999747378752E-6d);
        aggregatorAverageML.initTemporal(this.ctx, vec2);
        Assert.assertEquals(TestUtils.WESTERN_LON, vec2.get(0), TestUtils.WESTERN_LON);
        Assert.assertEquals(TestUtils.WESTERN_LON, vec2.get(1), TestUtils.WESTERN_LON);
        Assert.assertEquals(TestUtils.WESTERN_LON, vec2.get(2), TestUtils.WESTERN_LON);
        aggregatorAverageML.aggregateTemporal(this.ctx, AggregatorTestUtils.vec(0.3f, 0.09f), 3, vec2);
        aggregatorAverageML.aggregateTemporal(this.ctx, AggregatorTestUtils.vec(0.1f, 0.01f), 2, vec2);
        aggregatorAverageML.aggregateTemporal(this.ctx, AggregatorTestUtils.vec(0.2f, 0.04f), 1, vec2);
        aggregatorAverageML.aggregateTemporal(this.ctx, AggregatorTestUtils.vec(0.1f, 0.01f), 7, vec2);
        double sqrt = Math.sqrt(3.0d) + Math.sqrt(2.0d) + Math.sqrt(1.0d) + Math.sqrt(7.0d);
        Assert.assertEquals(0.7000000476837158d, vec2.get(0), 1.0E-5d);
        Assert.assertEquals(0.15000000596046448d, vec2.get(1), 1.0E-5d);
        Assert.assertEquals(sqrt, vec2.get(2), 1.0E-5d);
        aggregatorAverageML.completeTemporal(this.ctx, 4, vec2);
        aggregatorAverageML.computeOutput(vec2, vec3);
        Assert.assertEquals(0.7000000476837158d, vec3.get(0), 1.0E-5d);
        Assert.assertEquals(0.15000000596046448d, vec3.get(1), 1.0E-5d);
        Assert.assertEquals(sqrt, vec3.get(2), 1.0E-5d);
    }
}
