package org.openimaj.experiment.dataset.split;

import java.util.Iterator;
import java.util.Map;
import org.openimaj.data.RandomData;
import org.openimaj.data.dataset.GroupedDataset;
import org.openimaj.data.dataset.ListBackedDataset;
import org.openimaj.data.dataset.ListDataset;
import org.openimaj.data.dataset.MapBackedDataset;
import org.openimaj.experiment.validation.ValidationData;
import org.openimaj.experiment.validation.cross.CrossValidationIterable;

/* loaded from: input_file:org/openimaj/experiment/dataset/split/GroupedRandomSplitter.class */
public class GroupedRandomSplitter<KEY, INSTANCE> implements TrainSplitProvider<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>, TestSplitProvider<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>, ValidateSplitProvider<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> {
    private GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset;
    private GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> trainingSplit;
    private GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> validationSplit;
    private GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> testingSplit;
    private int numTraining;
    private int numValidation;
    private int numTesting;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.openimaj.experiment.dataset.split.GroupedRandomSplitter$1, reason: invalid class name */
    /* loaded from: input_file:org/openimaj/experiment/dataset/split/GroupedRandomSplitter$1.class */
    public static class AnonymousClass1 implements CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> {
        private GroupedRandomSplitter<KEY, INSTANCE> splits;
        final /* synthetic */ GroupedDataset val$dataset;
        final /* synthetic */ int val$numTraining;
        final /* synthetic */ int val$numValidation;
        final /* synthetic */ int val$numIterations;

        AnonymousClass1(GroupedDataset groupedDataset, int i, int i2, int i3) {
            this.val$dataset = groupedDataset;
            this.val$numTraining = i;
            this.val$numValidation = i2;
            this.val$numIterations = i3;
            this.splits = new GroupedRandomSplitter<>(this.val$dataset, this.val$numTraining, this.val$numValidation, 0);
        }

        @Override // java.lang.Iterable
        public Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>> iterator() {
            return new Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>>() { // from class: org.openimaj.experiment.dataset.split.GroupedRandomSplitter.1.1
                int current = 0;

                @Override // java.util.Iterator
                public boolean hasNext() {
                    return this.current < AnonymousClass1.this.val$numIterations;
                }

                @Override // java.util.Iterator
                public ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> next() {
                    AnonymousClass1.this.splits.recomputeSubsets();
                    this.current++;
                    return new ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>() { // from class: org.openimaj.experiment.dataset.split.GroupedRandomSplitter.1.1.1
                        @Override // org.openimaj.experiment.dataset.split.TrainSplitProvider
                        public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getTrainingDataset() {
                            return AnonymousClass1.this.splits.getTrainingDataset();
                        }

                        @Override // org.openimaj.experiment.dataset.split.ValidateSplitProvider
                        public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getValidationDataset() {
                            return AnonymousClass1.this.splits.getValidationDataset();
                        }
                    };
                }

                @Override // java.util.Iterator
                public void remove() {
                    throw new UnsupportedOperationException("Removal not supported");
                }
            };
        }

        @Override // org.openimaj.experiment.validation.cross.CrossValidationIterable
        public int numberIterations() {
            return this.val$numIterations;
        }
    }

    public GroupedRandomSplitter(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> groupedDataset, int i, int i2, int i3) {
        this.dataset = groupedDataset;
        this.numTraining = i;
        this.numValidation = i2;
        this.numTesting = i3;
        recomputeSubsets();
    }

    public void recomputeSubsets() {
        this.trainingSplit = new MapBackedDataset();
        this.validationSplit = new MapBackedDataset();
        this.testingSplit = new MapBackedDataset();
        Iterator it = this.dataset.entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry entry = (Map.Entry) it.next();
            Object key = entry.getKey();
            ListDataset listDataset = (ListDataset) entry.getValue();
            if (listDataset.size() < this.numTraining + 1) {
                throw new RuntimeException("Too many training examples; none would be available for validation or testing.");
            }
            if (listDataset.size() < this.numTraining + this.numValidation + 1) {
                throw new RuntimeException("Too many training and validation instances; none would be available for testing.");
            }
            int[] uniqueRandomInts = RandomData.getUniqueRandomInts(Math.min(this.numTraining + this.numValidation + this.numTesting, listDataset.size()), 0, listDataset.size());
            ListBackedDataset listBackedDataset = new ListBackedDataset();
            for (int i = 0; i < this.numTraining; i++) {
                listBackedDataset.add(listDataset.get(uniqueRandomInts[i]));
            }
            this.trainingSplit.put(key, listBackedDataset);
            ListBackedDataset listBackedDataset2 = new ListBackedDataset();
            for (int i2 = this.numTraining; i2 < this.numTraining + this.numValidation; i2++) {
                listBackedDataset2.add(listDataset.get(uniqueRandomInts[i2]));
            }
            this.validationSplit.put(key, listBackedDataset2);
            ListBackedDataset listBackedDataset3 = new ListBackedDataset();
            for (int i3 = this.numTraining + this.numValidation; i3 < uniqueRandomInts.length; i3++) {
                listBackedDataset3.add(listDataset.get(uniqueRandomInts[i3]));
            }
            this.testingSplit.put(key, listBackedDataset3);
        }
    }

    @Override // org.openimaj.experiment.dataset.split.TestSplitProvider
    public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getTestDataset() {
        return this.testingSplit;
    }

    @Override // org.openimaj.experiment.dataset.split.TrainSplitProvider
    public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getTrainingDataset() {
        return this.trainingSplit;
    }

    @Override // org.openimaj.experiment.dataset.split.ValidateSplitProvider
    public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getValidationDataset() {
        return this.validationSplit;
    }

    public static <KEY, INSTANCE> CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> createCrossValidationData(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> groupedDataset, int i, int i2, int i3) {
        return new AnonymousClass1(groupedDataset, i, i2, i3);
    }
}
