package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/tree/VectorThresholdVarianceLearner.class */
public class VectorThresholdVarianceLearner extends AbstractCloneableSerializable implements DeciderLearner<Vectorizable, Double, Boolean, VectorElementThresholdCategorizer> {
    @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
    public VectorElementThresholdCategorizer learn(Collection<? extends InputOutputPair<? extends Vectorizable, Double>> collection) {
        if (collection == null || collection.size() <= 1) {
            return null;
        }
        double computeOutputVariance = DatasetUtil.computeOutputVariance(collection);
        int dimensionality = getDimensionality(collection);
        double d = -1.0d;
        int i = -1;
        double d2 = 0.0d;
        for (int i2 = 0; i2 < dimensionality; i2++) {
            DefaultPair<Double, Double> computeBestGainThreshold = computeBestGainThreshold(collection, i2, computeOutputVariance);
            if (computeBestGainThreshold != null) {
                double doubleValue = computeBestGainThreshold.getFirst().doubleValue();
                if (i == -1 || doubleValue > d) {
                    d = doubleValue;
                    i = i2;
                    d2 = computeBestGainThreshold.getSecond().doubleValue();
                }
            }
        }
        if (i < 0) {
            return null;
        }
        return new VectorElementThresholdCategorizer(i, d2);
    }

    protected int getDimensionality(Collection<? extends InputOutputPair<? extends Vectorizable, ?>> collection) {
        if (collection == null || collection.size() <= 0) {
            return 0;
        }
        return collection.iterator().next().getInput().convertToVector().getDimensionality();
    }

    public DefaultPair<Double, Double> computeBestGainThreshold(Collection<? extends InputOutputPair<? extends Vectorizable, Double>> collection, int i, double d) {
        int size = collection.size();
        ArrayList arrayList = new ArrayList(size);
        double d2 = 0.0d;
        for (InputOutputPair<? extends Vectorizable, Double> inputOutputPair : collection) {
            Vector convertToVector = inputOutputPair.getInput().convertToVector();
            Double output = inputOutputPair.getOutput();
            arrayList.add(new DefaultPair(Double.valueOf(convertToVector.getElement(i)), output));
            d2 += output.doubleValue();
        }
        Collections.sort(arrayList, new Comparator<DefaultPair<Double, Double>>() { // from class: gov.sandia.cognition.learning.algorithm.tree.VectorThresholdVarianceLearner.1
            @Override // java.util.Comparator
            public int compare(DefaultPair<Double, Double> defaultPair, DefaultPair<Double, Double> defaultPair2) {
                return defaultPair.getFirst().compareTo(defaultPair2.getFirst());
            }
        });
        if (size <= 1 || ((Double) ((DefaultPair) arrayList.get(0)).getFirst()).equals(((DefaultPair) arrayList.get(size - 1)).getFirst())) {
            return null;
        }
        double d3 = 0.0d;
        double d4 = d2;
        double d5 = 0.0d;
        double d6 = 0.0d;
        double d7 = 0.0d;
        double d8 = 0.0d;
        for (int i2 = 0; i2 < size; i2++) {
            DefaultPair defaultPair = (DefaultPair) arrayList.get(i2);
            double doubleValue = ((Double) defaultPair.getFirst()).doubleValue();
            double doubleValue2 = ((Double) defaultPair.getSecond()).doubleValue();
            if (i2 == 0) {
                d5 = 0.0d;
                d6 = 0.0d;
                d7 = doubleValue;
            } else if (doubleValue != d8) {
                int i3 = i2;
                int i4 = size - i2;
                double d9 = d3 / i3;
                double d10 = 0.0d;
                for (int i5 = 0; i5 < i2; i5++) {
                    double doubleValue3 = ((Double) ((DefaultPair) arrayList.get(i5)).getSecond()).doubleValue() - d9;
                    d10 += doubleValue3 * doubleValue3;
                }
                double d11 = d10 / i3;
                double d12 = d4 / i4;
                double d13 = 0.0d;
                for (int i6 = i2; i6 < size; i6++) {
                    double doubleValue4 = ((Double) ((DefaultPair) arrayList.get(i6)).getSecond()).doubleValue() - d12;
                    d13 += doubleValue4 * doubleValue4;
                }
                double d14 = d13 / i4;
                double d15 = i4 / size;
                double d16 = i3 / size;
                double d17 = (d - (d15 * d14)) - (d16 * d11);
                if (d17 >= d5) {
                    double abs = 1.0d - Math.abs(d15 - d16);
                    if (d17 > d5 || abs > d6) {
                        d5 = d17;
                        d6 = abs;
                        d7 = (doubleValue + d8) / 2.0d;
                    }
                }
            }
            d4 -= doubleValue2;
            d3 += doubleValue2;
            d8 = doubleValue;
        }
        return new DefaultPair<>(Double.valueOf(d5), Double.valueOf(d7));
    }
}
