package org.nd4j.linalg.dataset; import org.junit.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4jBackend; import java.io.File; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import static org.junit.Assert.assertTrue; /** * Created by agibsonccc on 6/24/16. */ public class BalanceMinibatchesTest extends BaseNd4jTest { public BalanceMinibatchesTest(Nd4jBackend backend) { super(backend); } @Test public void testBalance() { DataSetIterator iterator = new IrisDataSetIterator(10, 150); BalanceMinibatches balanceMinibatches = BalanceMinibatches.builder().dataSetIterator(iterator).miniBatchSize(10) .numLabels(3).rootDir(new File("minibatches")).rootSaveDir(new File("minibatchessave")).build(); balanceMinibatches.balance(); DataSetIterator balanced = new ExistingMiniBatchDataSetIterator(balanceMinibatches.getRootSaveDir()); while (balanced.hasNext()) { assertTrue(balanced.next().labelCounts().size() > 0); } } @Test public void testMiniBatchBalanced() { int miniBatchSize = 10; DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150); BalanceMinibatches balanceMinibatches = BalanceMinibatches.builder().dataSetIterator(iterator).miniBatchSize(miniBatchSize) .numLabels(iterator.totalOutcomes()).rootDir(new File("minibatches")).rootSaveDir(new File("minibatchessave")).build(); balanceMinibatches.balance(); DataSetIterator balanced = new ExistingMiniBatchDataSetIterator(balanceMinibatches.getRootSaveDir()); assertTrue(iterator.resetSupported()); // this is testing the Iris dataset more than anything iterator.reset(); List<Double> totalCounts = new ArrayList<Double>(iterator.totalOutcomes()); while (iterator.hasNext()) { Map<Integer, Double> outcomes = iterator.next().labelCounts(); for (int i = 0; i < iterator.totalOutcomes(); i++) { if (outcomes.containsKey(i)) { totalCounts.set(i, totalCounts.get(i) + outcomes.get(i)); } } } List<Integer> fullBatches = new ArrayList<Integer>(totalCounts.size()); for (int i = 0; i < totalCounts.size(); i++) { fullBatches.set(i, totalCounts.get(i).intValue() * iterator.totalOutcomes() / miniBatchSize); } // this is the number of batches for which we can balance every class int fullyBalanceableBatches = Collections.min(fullBatches); // check the first few batches are actually balanced for (int b = 0; b < fullyBalanceableBatches; b++){ Map<Integer, Double> balancedCounts = balanced.next().labelCounts(); for (int i = 0; i < iterator.totalOutcomes(); i++) { assertTrue(balancedCounts.containsKey(i) && balancedCounts.get(i) >= miniBatchSize / iterator.totalOutcomes()); } } } /** * The ordering for this test * This test will only be invoked for * the given test and ignored for others * * @return the ordering for this test */ @Override public char ordering() { return 'c'; } }