package org.nd4j.linalg.dataset; import org.apache.commons.io.FileUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.dataset.api.iterator.CachingDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.SamplingDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.cache.DataSetCache; import org.nd4j.linalg.dataset.api.iterator.cache.InFileDataSetCache; import org.nd4j.linalg.dataset.api.iterator.cache.InMemoryDataSetCache; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import static org.junit.Assert.*; /** * Created by anton on 7/18/16. */ @RunWith(Parameterized.class) public class CachingDataSetIteratorTest extends BaseNd4jTest { public CachingDataSetIteratorTest(Nd4jBackend backend) { super(backend); } @Override public char ordering() { return 'f'; } @Test public void testInMemory() { DataSetCache cache = new InMemoryDataSetCache(); runDataSetTest(cache); } @Test public void testInFile() throws IOException { Path cacheDir = Files.createTempDirectory("nd4j-data-set-cache-test"); DataSetCache cache = new InFileDataSetCache(cacheDir); runDataSetTest(cache); FileUtils.deleteDirectory(cacheDir.toFile()); } private void runDataSetTest(DataSetCache cache) { int rows = 500; int inputColumns = 100; int outputColumns = 2; DataSet dataSet = new DataSet(Nd4j.ones(rows, inputColumns), Nd4j.zeros(rows, outputColumns)); int batchSize = 10; int totalNumberOfSamples = 50; int expectedNumberOfDataSets = totalNumberOfSamples / batchSize; DataSetIterator it = new SamplingDataSetIterator(dataSet, batchSize, totalNumberOfSamples); String namespace = "test-namespace"; CachingDataSetIterator cachedIt = new CachingDataSetIterator(it, cache, namespace); PreProcessor preProcessor = new PreProcessor(); cachedIt.setPreProcessor(preProcessor); assertDataSetCacheGetsCompleted(cache, namespace, cachedIt); assertPreProcessingGetsCached(expectedNumberOfDataSets, it, cachedIt, preProcessor); assertCachingDataSetIteratorHasAllTheData(rows, inputColumns, outputColumns, dataSet, it, cachedIt); } private void assertDataSetCacheGetsCompleted(DataSetCache cache, String namespace, DataSetIterator cachedIt) { while (cachedIt.hasNext()) { assertFalse(cache.isComplete(namespace)); cachedIt.next(); } assertTrue(cache.isComplete(namespace)); } private void assertPreProcessingGetsCached(int expectedNumberOfDataSets, DataSetIterator it, CachingDataSetIterator cachedIt, PreProcessor preProcessor) { assertSame(preProcessor, cachedIt.getPreProcessor()); assertSame(preProcessor, it.getPreProcessor()); cachedIt.reset(); it.reset(); while (cachedIt.hasNext()) { cachedIt.next(); } assertEquals(expectedNumberOfDataSets, preProcessor.getCallCount()); cachedIt.reset(); it.reset(); while (cachedIt.hasNext()) { cachedIt.next(); } assertEquals(expectedNumberOfDataSets, preProcessor.getCallCount()); } private void assertCachingDataSetIteratorHasAllTheData(int rows, int inputColumns, int outputColumns, DataSet dataSet, DataSetIterator it, CachingDataSetIterator cachedIt) { cachedIt.reset(); it.reset(); dataSet.setFeatures(Nd4j.zeros(rows, inputColumns)); dataSet.setLabels(Nd4j.ones(rows, outputColumns)); while (it.hasNext()) { assertTrue(cachedIt.hasNext()); DataSet cachedDs = cachedIt.next(); assertEquals(1000.0, cachedDs.getFeatureMatrix().sumNumber()); assertEquals(0.0, cachedDs.getLabels().sumNumber()); DataSet ds = it.next(); assertEquals(0.0, ds.getFeatureMatrix().sumNumber()); assertEquals(20.0, ds.getLabels().sumNumber()); } assertFalse(cachedIt.hasNext()); assertFalse(it.hasNext()); } private class PreProcessor implements DataSetPreProcessor { private int callCount; @Override public void preProcess(org.nd4j.linalg.dataset.api.DataSet toPreProcess) { callCount++; } public int getCallCount() { return callCount; } } }