package gov.sandia.cognition.learning.data;

import gov.sandia.cognition.collection.DefaultMultiCollection;
import gov.sandia.cognition.collection.MultiCollection;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorUtil;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import gov.sandia.cognition.statistics.DataHistogram;
import gov.sandia.cognition.statistics.distribution.MapBasedDataHistogram;
import gov.sandia.cognition.util.DefaultPair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:gov/sandia/cognition/learning/data/DatasetUtil.class */
public class DatasetUtil {
    public static ArrayList<Vector> appendBias(Collection<? extends Vector> collection) {
        return appendBias(collection, 1.0d);
    }

    public static ArrayList<Vector> appendBias(Collection<? extends Vector> collection, double d) {
        ArrayList<Vector> arrayList = new ArrayList<>(collection.size());
        Vector copyValues = VectorFactory.getDefault().copyValues(d);
        Iterator<? extends Vector> it = collection.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().stack(copyValues));
        }
        return arrayList;
    }

    public static ArrayList<ArrayList<InputOutputPair<Double, Double>>> decoupleVectorPairDataset(Collection<? extends InputOutputPair<? extends Vector, ? extends Vector>> collection) {
        int size = collection.size();
        int dimensionality = collection.iterator().next().getInput().getDimensionality();
        ArrayList<ArrayList<InputOutputPair<Double, Double>>> arrayList = new ArrayList<>(dimensionality);
        for (int i = 0; i < dimensionality; i++) {
            arrayList.add(new ArrayList<>(size));
        }
        for (InputOutputPair<? extends Vector, ? extends Vector> inputOutputPair : collection) {
            if (inputOutputPair.getInput().getDimensionality() != dimensionality || inputOutputPair.getOutput().getDimensionality() != dimensionality) {
                throw new IllegalArgumentException("All input-output Vectors must have same dimension!");
            }
            for (int i2 = 0; i2 < dimensionality; i2++) {
                double element = inputOutputPair.getInput().getElement(i2);
                double element2 = inputOutputPair.getOutput().getElement(i2);
                arrayList.get(i2).add(inputOutputPair instanceof WeightedInputOutputPair ? new DefaultWeightedInputOutputPair<>(Double.valueOf(element), Double.valueOf(element2), ((WeightedInputOutputPair) inputOutputPair).getWeight()) : new DefaultInputOutputPair<>(Double.valueOf(element), Double.valueOf(element2)));
            }
        }
        return arrayList;
    }

    public static ArrayList<ArrayList<Double>> decoupleVectorDataset(Collection<? extends Vector> collection) {
        int dimensionality = collection.iterator().next().getDimensionality();
        int size = collection.size();
        ArrayList<ArrayList<Double>> arrayList = new ArrayList<>(dimensionality);
        for (int i = 0; i < dimensionality; i++) {
            arrayList.add(new ArrayList<>(size));
        }
        for (Vector vector : collection) {
            if (dimensionality != vector.getDimensionality()) {
                throw new IllegalArgumentException("All vectors in the dataset must be the same size");
            }
            for (int i2 = 0; i2 < dimensionality; i2++) {
                arrayList.get(i2).add(Double.valueOf(vector.getElement(i2)));
            }
        }
        return arrayList;
    }

    public static <DataType> DefaultPair<LinkedList<DataType>, LinkedList<DataType>> splitDatasets(Collection<? extends InputOutputPair<? extends DataType, Boolean>> collection) {
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        for (InputOutputPair<? extends DataType, Boolean> inputOutputPair : collection) {
            if (inputOutputPair.getOutput().booleanValue()) {
                linkedList.add(inputOutputPair.getInput());
            } else {
                linkedList2.add(inputOutputPair.getInput());
            }
        }
        return DefaultPair.create(linkedList, linkedList2);
    }

    public static <InputType, CategoryType> Map<CategoryType, List<InputType>> splitOnOutput(Iterable<? extends InputOutputPair<? extends InputType, ? extends CategoryType>> iterable) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (InputOutputPair<? extends InputType, ? extends CategoryType> inputOutputPair : iterable) {
            CategoryType output = inputOutputPair.getOutput();
            List list = (List) linkedHashMap.get(output);
            if (list == null) {
                list = new ArrayList();
                linkedHashMap.put(output, list);
            }
            list.add(inputOutputPair.getInput());
        }
        return linkedHashMap;
    }

    public static Matrix computeOuterProductDataMatrix(ArrayList<? extends Vector> arrayList) {
        int dimensionality = arrayList.iterator().next().getDimensionality();
        SparseMatrix createMatrix = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(dimensionality, dimensionality);
        for (int i = 0; i < dimensionality; i++) {
            for (int i2 = 0; i2 < dimensionality; i2++) {
                double d = 0.0d;
                for (int i3 = 0; i3 < arrayList.size(); i3++) {
                    Vector vector = arrayList.get(i3);
                    d += vector.getElement(i2) * vector.getElement(i);
                }
                if (d != 0.0d) {
                    createMatrix.setElement(i2, i, d);
                }
            }
        }
        return createMatrix;
    }

    public static double computeOutputMean(Collection<? extends InputOutputPair<?, ? extends Number>> collection) {
        if (collection == null) {
            return 0.0d;
        }
        double d = 0.0d;
        int i = 0;
        Iterator<? extends InputOutputPair<?, ? extends Number>> it = collection.iterator();
        while (it.hasNext()) {
            d += it.next().getOutput().doubleValue();
            i++;
        }
        if (i <= 0) {
            return 0.0d;
        }
        return d / i;
    }

    public static double computeWeightedOutputMean(Collection<? extends InputOutputPair<?, ? extends Number>> collection) {
        if (collection == null || collection.size() <= 0.0d) {
            return 0.0d;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (InputOutputPair<?, ? extends Number> inputOutputPair : collection) {
            double weight = getWeight(inputOutputPair);
            d += weight * inputOutputPair.getOutput().doubleValue();
            d2 += weight;
        }
        if (d2 == 0.0d) {
            return 0.0d;
        }
        return d / d2;
    }

    public static double computeOutputVariance(Collection<? extends InputOutputPair<?, ? extends Number>> collection) {
        int size;
        if (collection == null || (size = collection.size()) <= 0) {
            return 0.0d;
        }
        double computeOutputMean = computeOutputMean(collection);
        double d = 0.0d;
        Iterator<? extends InputOutputPair<?, ? extends Number>> it = collection.iterator();
        while (it.hasNext()) {
            double doubleValue = it.next().getOutput().doubleValue() - computeOutputMean;
            d += doubleValue * doubleValue;
        }
        return d / size;
    }

    public static <OutputType> Set<OutputType> findUniqueOutputs(Iterable<? extends InputOutputPair<?, ? extends OutputType>> iterable) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        if (iterable != null) {
            Iterator<? extends InputOutputPair<?, ? extends OutputType>> it = iterable.iterator();
            while (it.hasNext()) {
                linkedHashSet.add(it.next().getOutput());
            }
        }
        return linkedHashSet;
    }

    public static <OutputType> DataHistogram<OutputType> countOutputValues(Iterable<? extends InputOutputPair<?, ? extends OutputType>> iterable) {
        MapBasedDataHistogram mapBasedDataHistogram = new MapBasedDataHistogram();
        if (iterable != null) {
            Iterator<? extends InputOutputPair<?, ? extends OutputType>> it = iterable.iterator();
            while (it.hasNext()) {
                mapBasedDataHistogram.add(it.next().getOutput());
            }
        }
        return mapBasedDataHistogram;
    }

    public static <InputType> List<InputType> inputsList(Iterable<? extends InputOutputPair<? extends InputType, ?>> iterable) {
        ArrayList arrayList = new ArrayList();
        if (iterable != null) {
            Iterator<? extends InputOutputPair<? extends InputType, ?>> it = iterable.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().getInput());
            }
        }
        return arrayList;
    }

    public static <EntryType> MultiCollection<EntryType> asMultiCollection(Collection<EntryType> collection) {
        return collection instanceof MultiCollection ? (MultiCollection) collection : new DefaultMultiCollection(Collections.singletonList(collection));
    }

    public static Collection<Vector> asVectorCollection(Collection<? extends Vectorizable> collection) {
        ArrayList arrayList = new ArrayList(collection.size());
        Iterator<? extends Vectorizable> it = collection.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().convertToVector());
        }
        return arrayList;
    }

    public static int getInputDimensionality(Iterable<? extends InputOutputPair<? extends Vectorizable, ?>> iterable) {
        Vectorizable input;
        Vector convertToVector;
        if (iterable == null) {
            return -1;
        }
        for (InputOutputPair<? extends Vectorizable, ?> inputOutputPair : iterable) {
            if (inputOutputPair != null && (input = inputOutputPair.getInput()) != null && (convertToVector = input.convertToVector()) != null) {
                return convertToVector.getDimensionality();
            }
        }
        return -1;
    }

    public static void assertInputDimensionalitiesAllEqual(Iterable<? extends InputOutputPair<? extends Vectorizable, ?>> iterable) {
        assertInputDimensionalitiesAllEqual(iterable, getInputDimensionality(iterable));
    }

    public static void assertInputDimensionalitiesAllEqual(Iterable<? extends InputOutputPair<? extends Vectorizable, ?>> iterable, int i) {
        Vectorizable input;
        Vector convertToVector;
        if (iterable != null) {
            for (InputOutputPair<? extends Vectorizable, ?> inputOutputPair : iterable) {
                if (inputOutputPair != null && (input = inputOutputPair.getInput()) != null && (convertToVector = input.convertToVector()) != null) {
                    convertToVector.assertDimensionalityEquals(i);
                }
            }
        }
    }

    public static int getDimensionality(Iterable<? extends Vectorizable> iterable) {
        Vector convertToVector;
        if (iterable == null) {
            return -1;
        }
        for (Vectorizable vectorizable : iterable) {
            if (vectorizable != null && (convertToVector = vectorizable.convertToVector()) != null) {
                return convertToVector.getDimensionality();
            }
        }
        return -1;
    }

    public static void assertDimensionalitiesAllEqual(Iterable<? extends Vectorizable> iterable) {
        VectorUtil.assertDimensionalitiesAllEqual(iterable, getDimensionality(iterable));
    }

    public static double getWeight(InputOutputPair<?, ?> inputOutputPair) {
        if (inputOutputPair instanceof WeightedInputOutputPair) {
            return ((WeightedInputOutputPair) inputOutputPair).getWeight();
        }
        return 1.0d;
    }

    public static double getWeight(TargetEstimatePair<?, ?> targetEstimatePair) {
        if (targetEstimatePair instanceof WeightedTargetEstimatePair) {
            return ((WeightedTargetEstimatePair) targetEstimatePair).getWeight();
        }
        return 1.0d;
    }
}
