package org.nd4j.linalg.dataset;
import org.apache.commons.io.FileUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
/**
* Created by agibsonccc on 9/10/15.
*/
@RunWith(Parameterized.class)
public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTest {
public MiniBatchFileDataSetIteratorTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void testMiniBatches() throws Exception {
DataSet load = new IrisDataSetIterator(150, 150).next();
final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false);
while (iter.hasNext())
assertEquals(10, iter.next().numExamples());
if (iter.getRootDir() == null)
return;
DataSetIterator existing = new ExistingMiniBatchDataSetIterator(iter.getRootDir());
while (iter.hasNext())
assertEquals(10, existing.next().numExamples());
Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
@Override
public void run() {
try {
FileUtils.deleteDirectory(iter.getRootDir());
} catch (IOException e) {
e.printStackTrace();
}
}
}));
}
@Override
public char ordering() {
return 'f';
}
}