package gov.sandia.cognition.learning.function.vector;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.CodeReviewResponse;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.math.matrix.VectorizableDifferentiableVectorFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import java.util.Collection;

@CodeReview(reviewer = {"Justin Basilico"}, date = "2006-10-06", changesNeeded = true, comments = {"Can you just add a comment for why the differentiation code is correct?", "Otherwise, class looks fine."}, response = {@CodeReviewResponse(respondent = "Kevin R. Dixon", date = "2006-10-06", moreChangesNeeded = false, comments = {"Added in-code comment describing the derivation of the differentiation formulae."})})
/* loaded from: input_file:gov/sandia/cognition/learning/function/vector/MatrixMultiplyVectorFunction.class */
public class MatrixMultiplyVectorFunction extends AbstractCloneableSerializable implements VectorizableDifferentiableVectorFunction, VectorInputEvaluator<Vector, Vector>, VectorOutputEvaluator<Vector, Vector>, GradientDescendable {
    private Matrix internalMatrix;

    /* loaded from: input_file:gov/sandia/cognition/learning/function/vector/MatrixMultiplyVectorFunction$ClosedFormSolver.class */
    public static class ClosedFormSolver extends AbstractCloneableSerializable implements SupervisedBatchLearner<Vector, Vector, MatrixMultiplyVectorFunction> {
        @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
        public MatrixMultiplyVectorFunction learn(Collection<? extends InputOutputPair<? extends Vector, Vector>> collection) {
            InputOutputPair<? extends Vector, Vector> next = collection.iterator().next();
            int dimensionality = next.getOutput().getDimensionality();
            int dimensionality2 = next.getInput().getDimensionality();
            int size = collection.size();
            Matrix createMatrix = MatrixFactory.getDefault().createMatrix(dimensionality, size);
            Matrix createMatrix2 = MatrixFactory.getDefault().createMatrix(dimensionality2, size);
            int i = 0;
            for (InputOutputPair<? extends Vector, Vector> inputOutputPair : collection) {
                Vector input = inputOutputPair.getInput();
                Vector output = inputOutputPair.getOutput();
                double weight = DatasetUtil.getWeight(inputOutputPair);
                if (weight != 1.0d) {
                    input = input.scale(weight);
                    output = output.scale(weight);
                }
                createMatrix2.setColumn(i, input);
                createMatrix.setColumn(i, output);
                i++;
            }
            return learn(createMatrix2, createMatrix);
        }

        public static MatrixMultiplyVectorFunction learn(Matrix matrix, Matrix matrix2) {
            return new MatrixMultiplyVectorFunction(matrix.transpose().solve(matrix2.transpose()).transpose());
        }
    }

    public MatrixMultiplyVectorFunction() {
        this(1, 1);
    }

    public MatrixMultiplyVectorFunction(int i, int i2) {
        this(MatrixFactory.getDefault().createIdentity(i2, i));
    }

    public MatrixMultiplyVectorFunction(Matrix matrix) {
        setInternalMatrix(matrix);
    }

    public MatrixMultiplyVectorFunction(MatrixMultiplyVectorFunction matrixMultiplyVectorFunction) {
        this(matrixMultiplyVectorFunction.getInternalMatrix().mo557clone());
    }

    @Override // gov.sandia.cognition.util.AbstractCloneableSerializable, gov.sandia.cognition.util.CloneableSerializable
    /* renamed from: clone */
    public MatrixMultiplyVectorFunction mo557clone() {
        MatrixMultiplyVectorFunction matrixMultiplyVectorFunction = (MatrixMultiplyVectorFunction) super.mo557clone();
        matrixMultiplyVectorFunction.setInternalMatrix(getInternalMatrix().mo557clone());
        return matrixMultiplyVectorFunction;
    }

    public Matrix getInternalMatrix() {
        return this.internalMatrix;
    }

    protected void setInternalMatrix(Matrix matrix) {
        this.internalMatrix = matrix;
    }

    @Override // gov.sandia.cognition.math.matrix.Vectorizable
    public Vector convertToVector() {
        return this.internalMatrix.convertToVector();
    }

    @Override // gov.sandia.cognition.math.matrix.Vectorizable
    public void convertFromVector(Vector vector) {
        this.internalMatrix.convertFromVector(vector);
    }

    @Override // gov.sandia.cognition.evaluator.Evaluator
    public Vector evaluate(Vector vector) {
        return this.internalMatrix.times(vector);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // gov.sandia.cognition.math.DifferentiableEvaluator
    public Matrix differentiate(Vector vector) {
        return getInternalMatrix();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // gov.sandia.cognition.learning.algorithm.gradient.ParameterGradientEvaluator
    public Matrix computeParameterGradient(Vector vector) {
        return computeParameterGradient(this.internalMatrix, vector);
    }

    public static Matrix computeParameterGradient(Matrix matrix, Vector vector) {
        int numRows = matrix.getNumRows();
        int numColumns = matrix.getNumColumns();
        Matrix createMatrix = MatrixFactory.getDefault().createMatrix(numRows, numRows * numColumns);
        int i = 0;
        for (int i2 = 0; i2 < numColumns; i2++) {
            double element = vector.getElement(i2);
            for (int i3 = 0; i3 < numRows; i3++) {
                createMatrix.setElement(i3, i, element);
                i++;
            }
        }
        return createMatrix;
    }

    public String toString() {
        return getInternalMatrix().toString();
    }

    @Override // gov.sandia.cognition.math.matrix.VectorInputEvaluator
    public int getInputDimensionality() {
        return getInternalMatrix().getNumColumns();
    }

    @Override // gov.sandia.cognition.math.matrix.VectorOutputEvaluator
    public int getOutputDimensionality() {
        return getInternalMatrix().getNumRows();
    }
}
