package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractSupervisedBatchAndIncrementalLearner;
import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.distribution.PoissonDistribution;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.Random;

@PublicationReference(author = {"Nikunj C. Oza", "Stuart Russell"}, title = "Online Bagging and Boosting", year = 2001, type = PublicationType.Conference, publication = "In Artificial Intelligence and Statistics", pages = {105, 112}, url = "http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.32.8889")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/ensemble/OnlineBaggingCategorizerLearner.class */
public class OnlineBaggingCategorizerLearner<InputType, CategoryType, MemberType extends Evaluator<? super InputType, ? extends CategoryType>> extends AbstractSupervisedBatchAndIncrementalLearner<InputType, CategoryType, VotingCategorizerEnsemble<InputType, CategoryType, MemberType>> implements Randomized {
    public static final int DEFAULT_ENSEMBLE_SIZE = 100;
    public static final double DEFAULT_PERCENT_TO_SAMPLE = 1.0d;
    protected IncrementalLearner<? super InputOutputPair<? extends InputType, CategoryType>, MemberType> learner;
    protected int ensembleSize;
    protected double percentToSample;
    protected Random random;

    public OnlineBaggingCategorizerLearner() {
        this(null);
    }

    public OnlineBaggingCategorizerLearner(IncrementalLearner<? super InputOutputPair<? extends InputType, CategoryType>, MemberType> incrementalLearner) {
        this(incrementalLearner, 100, 1.0d, new Random());
    }

    public OnlineBaggingCategorizerLearner(IncrementalLearner<? super InputOutputPair<? extends InputType, CategoryType>, MemberType> incrementalLearner, int i, double d, Random random) {
        setLearner(incrementalLearner);
        setEnsembleSize(i);
        setPercentToSample(d);
        setRandom(random);
    }

    @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
    public VotingCategorizerEnsemble<InputType, CategoryType, MemberType> createInitialLearnedObject() {
        ArrayList arrayList = new ArrayList(getEnsembleSize());
        for (int i = 0; i < this.ensembleSize; i++) {
            arrayList.add(getLearner().createInitialLearnedObject());
        }
        return new VotingCategorizerEnsemble<>(new LinkedHashSet(), arrayList);
    }

    public void update(VotingCategorizerEnsemble<InputType, CategoryType, MemberType> votingCategorizerEnsemble, InputType inputtype, CategoryType categorytype) {
        update((VotingCategorizerEnsemble) votingCategorizerEnsemble, (InputOutputPair) DefaultInputOutputPair.create(inputtype, categorytype));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.learning.algorithm.AbstractSupervisedBatchAndIncrementalLearner, gov.sandia.cognition.learning.algorithm.IncrementalLearner
    public void update(VotingCategorizerEnsemble<InputType, CategoryType, MemberType> votingCategorizerEnsemble, InputOutputPair<? extends InputType, CategoryType> inputOutputPair) {
        CategoryType output = inputOutputPair.getOutput();
        if (!votingCategorizerEnsemble.getCategories().contains(output)) {
            votingCategorizerEnsemble.getCategories().add(output);
        }
        PoissonDistribution.PMF pmf = new PoissonDistribution.PMF(getPercentToSample());
        for (MemberType membertype : votingCategorizerEnsemble.getMembers()) {
            int intValue = ((Number) pmf.sample(this.random)).intValue();
            for (int i = 0; i < intValue; i++) {
                this.learner.update((IncrementalLearner<? super InputOutputPair<? extends InputType, CategoryType>, MemberType>) membertype, (MemberType) inputOutputPair);
            }
        }
    }

    public IncrementalLearner<? super InputOutputPair<? extends InputType, CategoryType>, MemberType> getLearner() {
        return this.learner;
    }

    public void setLearner(IncrementalLearner<? super InputOutputPair<? extends InputType, CategoryType>, MemberType> incrementalLearner) {
        this.learner = incrementalLearner;
    }

    public int getEnsembleSize() {
        return this.ensembleSize;
    }

    public void setEnsembleSize(int i) {
        ArgumentChecker.assertIsPositive("ensembleSize", i);
        this.ensembleSize = i;
    }

    public double getPercentToSample() {
        return this.percentToSample;
    }

    public void setPercentToSample(double d) {
        ArgumentChecker.assertIsPositive("percentToSample", d);
        this.percentToSample = d;
    }

    @Override // gov.sandia.cognition.util.Randomized
    public Random getRandom() {
        return this.random;
    }

    @Override // gov.sandia.cognition.util.Randomized
    public void setRandom(Random random) {
        this.random = random;
    }

    public static <InputType, CategoryType, MemberType extends Evaluator<? super InputType, ? extends CategoryType>> OnlineBaggingCategorizerLearner<InputType, CategoryType, MemberType> create(IncrementalLearner<? super InputOutputPair<? extends InputType, CategoryType>, MemberType> incrementalLearner, int i, double d, Random random) {
        return new OnlineBaggingCategorizerLearner<>(incrementalLearner, i, d, random);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.learning.algorithm.SupervisedIncrementalLearner
    public /* bridge */ /* synthetic */ void update(Evaluator evaluator, Object obj, Object obj2) {
        update((VotingCategorizerEnsemble<VotingCategorizerEnsemble<InputType, CategoryType, MemberType>, Object, MemberType>) evaluator, (VotingCategorizerEnsemble<InputType, CategoryType, MemberType>) obj, obj2);
    }
}
