package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.MultiCollection;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;

@PublicationReference(author = {"Lawrence R. Rabiner"}, title = "A tutorial on hidden Markov models and selected applications in speech recognition", type = PublicationType.Journal, year = 1989, publication = "Proceedings of the IEEE", pages = {257, 286}, url = "http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf", notes = {"Rabiner's transition matrix is transposed from mine."})
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/hmm/BaumWelchAlgorithm.class */
public class BaumWelchAlgorithm<ObservationType> extends AbstractBaumWelchAlgorithm<ObservationType, Collection<? extends ObservationType>> {
    private transient ArrayList<DefaultWeightedValue<ObservationType>> weightedData;
    private transient ArrayList<DefaultWeightedValue<Double>> sequenceLogLikelihoods;
    private transient int totalNum;
    protected transient MultiCollection<? extends ObservationType> multicollection;
    protected transient ArrayList<Vector> sequenceGammas;

    public BaumWelchAlgorithm() {
        this(null, null, true);
    }

    public BaumWelchAlgorithm(HiddenMarkovModel<ObservationType> hiddenMarkovModel, BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>, ? extends ComputableDistribution<ObservationType>> batchLearner, boolean z) {
        super(hiddenMarkovModel, batchLearner, z);
    }

    @Override // gov.sandia.cognition.learning.algorithm.hmm.AbstractBaumWelchAlgorithm, gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner, gov.sandia.cognition.algorithm.AbstractIterativeAlgorithm, gov.sandia.cognition.util.AbstractCloneableSerializable
    /* renamed from: clone */
    public BaumWelchAlgorithm<ObservationType> mo557clone() {
        return (BaumWelchAlgorithm) super.mo557clone();
    }

    public HiddenMarkovModel<ObservationType> learn(MultiCollection<ObservationType> multiCollection) {
        return (HiddenMarkovModel) super.learn((BaumWelchAlgorithm<ObservationType>) multiCollection);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    public boolean initializeAlgorithm() {
        this.multicollection = DatasetUtil.asMultiCollection((Collection) this.data);
        this.data = null;
        this.sequenceLogLikelihoods = new ArrayList<>(this.multicollection.getSubCollectionsCount());
        this.totalNum = 0;
        for (Collection<? extends ObservationType> collection : this.multicollection.subCollections()) {
            this.sequenceLogLikelihoods.add(new DefaultWeightedValue<>());
            this.totalNum += collection.size();
        }
        this.weightedData = new ArrayList<>(this.totalNum);
        this.sequenceGammas = new ArrayList<>(this.totalNum);
        Iterator<? extends Collection<? extends ObservationType>> it = this.multicollection.subCollections().iterator();
        while (it.hasNext()) {
            Iterator<? extends ObservationType> it2 = it.next().iterator();
            while (it2.hasNext()) {
                this.weightedData.add(new DefaultWeightedValue<>(it2.next()));
                this.sequenceGammas.add(null);
            }
        }
        this.result = getInitialGuess().mo557clone();
        this.lastLogLikelihood = updateSequenceLogLikelihoods(this.result);
        return this.result != null;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        boolean z;
        int subCollectionsCount = this.multicollection.getSubCollectionsCount();
        boolean reestimateInitialProbabilities = getReestimateInitialProbabilities();
        Pair<ArrayList<ArrayList<Vector>>, ArrayList<Matrix>> computeSequenceParameters = computeSequenceParameters();
        ArrayList<ArrayList<Vector>> first = computeSequenceParameters.getFirst();
        ArrayList<Matrix> second = computeSequenceParameters.getSecond();
        ArrayList<Vector> arrayList = reestimateInitialProbabilities ? new ArrayList<>(subCollectionsCount) : null;
        int i = 0;
        for (int i2 = 0; i2 < subCollectionsCount; i2++) {
            ArrayList<Vector> arrayList2 = first.get(i2);
            if (reestimateInitialProbabilities) {
                arrayList.add(arrayList2.get(0));
            }
            int size = arrayList2.size();
            for (int i3 = 0; i3 < size; i3++) {
                this.sequenceGammas.set(i, arrayList2.get(i3));
                i++;
            }
        }
        Vector initialProbability = this.result.getInitialProbability();
        if (getReestimateInitialProbabilities()) {
            initialProbability = updateInitialProbabilities(arrayList);
        }
        Matrix updateTransitionMatrix = updateTransitionMatrix(second);
        ArrayList<ProbabilityFunction<ObservationType>> updateProbabilityFunctions = updateProbabilityFunctions(this.sequenceGammas);
        if (getMaxIterations() <= 1) {
            this.result.emissionFunctions = updateProbabilityFunctions;
            this.result.initialProbability = initialProbability;
            this.result.transitionProbability = updateTransitionMatrix;
            z = true;
        } else {
            HiddenMarkovModel<ObservationType> mo557clone = this.result.mo557clone();
            mo557clone.emissionFunctions = updateProbabilityFunctions;
            mo557clone.initialProbability = initialProbability;
            mo557clone.transitionProbability = updateTransitionMatrix;
            double updateSequenceLogLikelihoods = updateSequenceLogLikelihoods(mo557clone);
            z = updateSequenceLogLikelihoods > this.lastLogLikelihood || getIteration() <= 1;
            if (z) {
                this.result = mo557clone;
                this.lastLogLikelihood = updateSequenceLogLikelihoods;
            }
        }
        return z;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        this.multicollection = null;
        this.weightedData = null;
        this.sequenceLogLikelihoods = null;
        this.totalNum = 0;
    }

    protected Pair<ArrayList<ArrayList<Vector>>, ArrayList<Matrix>> computeSequenceParameters() {
        int subCollectionsCount = this.multicollection.getSubCollectionsCount();
        ArrayList arrayList = new ArrayList(subCollectionsCount);
        ArrayList arrayList2 = new ArrayList(subCollectionsCount);
        int i = 0;
        for (Collection<? extends ObservationType> collection : this.multicollection.subCollections()) {
            double weight = this.sequenceLogLikelihoods.get(i).getWeight();
            ArrayList<Vector> computeObservationLikelihoods = this.result.computeObservationLikelihoods((Collection) collection);
            ArrayList<WeightedValue<Vector>> computeForwardProbabilities = this.result.computeForwardProbabilities(computeObservationLikelihoods, true);
            ArrayList<WeightedValue<Vector>> computeBackwardProbabilities = this.result.computeBackwardProbabilities(computeObservationLikelihoods, computeForwardProbabilities);
            arrayList.add(this.result.computeStateObservationLikelihood(computeForwardProbabilities, computeBackwardProbabilities, weight));
            Matrix computeTransitions = this.result.computeTransitions(computeForwardProbabilities, computeBackwardProbabilities, computeObservationLikelihoods);
            if (weight != 1.0d) {
                computeTransitions.scaleEquals(weight);
            }
            arrayList2.add(computeTransitions);
            i++;
        }
        return DefaultPair.create(arrayList, arrayList2);
    }

    protected ArrayList<ProbabilityFunction<ObservationType>> updateProbabilityFunctions(ArrayList<Vector> arrayList) {
        int numStates = this.result.getNumStates();
        ArrayList<ProbabilityFunction<ObservationType>> arrayList2 = new ArrayList<>(numStates);
        for (int i = 0; i < numStates; i++) {
            int i2 = 0;
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                this.weightedData.get(i2).setWeight(arrayList.get(i3).getElement(i));
                i2++;
            }
            arrayList2.add(this.distributionLearner.learn(this.weightedData).getProbabilityFunction());
        }
        return arrayList2;
    }

    protected Matrix updateTransitionMatrix(ArrayList<Matrix> arrayList) {
        Matrix matrix = (Matrix) new RingAccumulator(arrayList).getSum();
        this.result.normalizeTransitionMatrix(matrix);
        return matrix;
    }

    protected Vector updateInitialProbabilities(ArrayList<Vector> arrayList) {
        RingAccumulator ringAccumulator = new RingAccumulator();
        for (int i = 0; i < arrayList.size(); i++) {
            ringAccumulator.accumulate((RingAccumulator) arrayList.get(i));
        }
        Vector vector = (Vector) ringAccumulator.getSum();
        vector.scaleEquals(1.0d / vector.norm1());
        return vector;
    }

    protected double updateSequenceLogLikelihoods(HiddenMarkovModel<ObservationType> hiddenMarkovModel) {
        int i = 0;
        double d = Double.NEGATIVE_INFINITY;
        double d2 = 0.0d;
        Iterator<? extends Collection<? extends ObservationType>> it = this.multicollection.subCollections().iterator();
        while (it.hasNext()) {
            double computeObservationLogLikelihood = hiddenMarkovModel.computeObservationLogLikelihood(it.next());
            if (d < computeObservationLogLikelihood) {
                d = computeObservationLogLikelihood;
            }
            this.sequenceLogLikelihoods.get(i).setValue(Double.valueOf(computeObservationLogLikelihood));
            d2 += computeObservationLogLikelihood;
            i++;
        }
        int subCollectionsCount = this.multicollection.getSubCollectionsCount();
        for (int i2 = 0; i2 < subCollectionsCount; i2++) {
            DefaultWeightedValue<Double> defaultWeightedValue = this.sequenceLogLikelihoods.get(i2);
            defaultWeightedValue.setWeight(1.0d / Math.exp(defaultWeightedValue.getValue().doubleValue() - d));
        }
        return d2;
    }
}
