package org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix;

import gov.sandia.cognition.learning.data.DefaultTargetEstimatePair;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.learning.performance.categorization.ConfusionMatrixPerformanceEvaluator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.openimaj.experiment.evaluation.classification.ClassificationAnalyser;
import org.openimaj.experiment.evaluation.classification.ClassificationResult;

/* loaded from: input_file:org/openimaj/experiment/evaluation/classification/analysers/confusionmatrix/CMAnalyser.class */
public class CMAnalyser<OBJECT, CLASS> implements ClassificationAnalyser<CMResult<CLASS>, CLASS, OBJECT> {
    protected Strategy strategy;
    ConfusionMatrixPerformanceEvaluator<?, CLASS> eval = new ConfusionMatrixPerformanceEvaluator<>();

    /* loaded from: input_file:org/openimaj/experiment/evaluation/classification/analysers/confusionmatrix/CMAnalyser$Strategy.class */
    public enum Strategy {
        SINGLE { // from class: org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMAnalyser.Strategy.1
            @Override // org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMAnalyser.Strategy
            protected <CLASS> void add(List<TargetEstimatePair<CLASS, CLASS>> list, Set<CLASS> set, Set<CLASS> set2) {
                list.add(DefaultTargetEstimatePair.create(set2.size() == 0 ? null : new ArrayList(set2).get(0), set.size() == 0 ? null : new ArrayList(set).get(0)));
            }
        },
        MULTIPLE { // from class: org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMAnalyser.Strategy.2
            @Override // org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMAnalyser.Strategy
            protected <CLASS> void add(List<TargetEstimatePair<CLASS, CLASS>> list, Set<CLASS> set, Set<CLASS> set2) {
                HashSet hashSet = new HashSet();
                hashSet.addAll(set);
                hashSet.addAll(set2);
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    Object next = it.next();
                    list.add(DefaultTargetEstimatePair.create(set2.contains(next) ? next : null, set.contains(next) ? next : null));
                }
            }
        },
        MULTIPLE_ORDERED { // from class: org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMAnalyser.Strategy.3
            @Override // org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMAnalyser.Strategy
            protected <CLASS> void add(List<TargetEstimatePair<CLASS, CLASS>> list, Set<CLASS> set, Set<CLASS> set2) {
                LinkedHashSet linkedHashSet = (LinkedHashSet) set;
                LinkedHashSet linkedHashSet2 = (LinkedHashSet) set2;
                if (linkedHashSet.size() != linkedHashSet2.size()) {
                    throw new RuntimeException("Sets are not the same size!");
                }
                Object[] array = linkedHashSet.toArray();
                Object[] array2 = linkedHashSet2.toArray();
                for (int i = 0; i < array.length; i++) {
                    list.add(new DefaultTargetEstimatePair(array[i], array2[i]));
                }
            }
        };

        protected abstract <CLASS> void add(List<TargetEstimatePair<CLASS, CLASS>> list, Set<CLASS> set, Set<CLASS> set2);
    }

    public CMAnalyser(Strategy strategy) {
        this.strategy = strategy;
    }

    @Override // org.openimaj.experiment.evaluation.classification.ClassificationAnalyser
    public CMResult<CLASS> analyse(Map<OBJECT, ClassificationResult<CLASS>> map, Map<OBJECT, Set<CLASS>> map2) {
        ArrayList arrayList = new ArrayList();
        for (OBJECT object : map.keySet()) {
            this.strategy.add(arrayList, map.get(object).getPredictedClasses(), map2.get(object));
        }
        return new CMResult<>(this.eval.evaluatePerformance((Collection<? extends TargetEstimatePair<CLASS, CLASS>>) arrayList));
    }
}
