package org.nd4j.linalg.dataset; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.KFoldIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.assertEquals; /** * Created by susaneraly on 11/4/16. */ @RunWith(Parameterized.class) public class KFoldIteratorTest extends BaseNd4jTest { public KFoldIteratorTest(Nd4jBackend backend) { super(backend); } @Test public void checkFolds() { randomDataSet randomDS = new randomDataSet(new int[] {2, 3}, new int[] {3, 3, 3, 2}); DataSet allData = randomDS.getAllFolds(); KFoldIterator kiter = new KFoldIterator(4, allData); int i = 0; while (kiter.hasNext()) { DataSet now = kiter.next(); DataSet test = kiter.testFold(); assertEquals(now.getFeatures(), randomDS.getFoldbutk(i, true)); assertEquals(now.getLabels(), randomDS.getFoldbutk(i, false)); assertEquals(test.getFeatures(), randomDS.getfoldK(i, true)); assertEquals(test.getLabels(), randomDS.getfoldK(i, false)); i++; System.out.println("Fold " + i + " passed"); } assertEquals(i, 4); } /* //this will throw illegal argument exception @Test public void checkCornerCaseA() { randomDataSet randomDS = new randomDataSet(new int[] {2,3},new int []{3}); DataSet allData = randomDS.getAllFolds(); KFoldIterator kiter = new KFoldIterator(1,allData); int i = 0; while (kiter.hasNext()) { DataSet now = kiter.next(); DataSet test = kiter.testFold(); assertEquals(now.getFeatures(),randomDS.getFoldbutk(i,true)); assertEquals(now.getLabels(),randomDS.getFoldbutk(i,false)); assertEquals(test.getFeatures(),randomDS.getfoldK(i,true)); assertEquals(test.getLabels(),randomDS.getfoldK(i,false)); i++; System.out.println("Fold "+i+" passed"); } assertEquals(i,1); } */ @Test public void checkCornerCaseA() { randomDataSet randomDS = new randomDataSet(new int[] {2, 3}, new int[] {2, 1}); DataSet allData = randomDS.getAllFolds(); KFoldIterator kiter = new KFoldIterator(2, allData); int i = 0; while (kiter.hasNext()) { DataSet now = kiter.next(); DataSet test = kiter.testFold(); assertEquals(now.getFeatures(), randomDS.getFoldbutk(i, true)); assertEquals(now.getLabels(), randomDS.getFoldbutk(i, false)); assertEquals(test.getFeatures(), randomDS.getfoldK(i, true)); assertEquals(test.getLabels(), randomDS.getfoldK(i, false)); i++; System.out.println("Fold " + i + " passed"); } assertEquals(i, 2); } public class randomDataSet { //only one label private int[] dataShape; private int dataRank; private int dataElementCount; private int[] ks; private DataSet allFolds; private INDArray allFeatures; private INDArray allLabels; private INDArray[] kfoldFeats; private INDArray[] kfoldLabels; public randomDataSet(int[] dataShape, int[] ks) { this.dataShape = dataShape; this.dataRank = this.dataShape.length; this.ks = ks; this.dataElementCount = 1; int[] eachFoldSize = new int[dataRank + 1]; eachFoldSize[0] = 0; kfoldFeats = new INDArray[ks.length]; kfoldLabels = new INDArray[ks.length]; for (int i = 0; i < dataRank; i++) { this.dataElementCount *= dataShape[i]; eachFoldSize[i + 1] = dataShape[i]; } for (int i = 0; i < ks.length; i++) { eachFoldSize[0] = ks[i]; INDArray currentFoldF = Nd4j.rand(eachFoldSize); INDArray currentFoldL = Nd4j.rand(ks[i], 1); kfoldFeats[i] = currentFoldF; kfoldLabels[i] = currentFoldL; if (i == 0) { allFeatures = currentFoldF.dup(); allLabels = currentFoldL.dup(); } else { allFeatures = Nd4j.vstack(allFeatures, currentFoldF).dup(); allLabels = Nd4j.vstack(allLabels, currentFoldL).dup(); } } allFolds = new DataSet(allFeatures, allLabels.reshape(allFeatures.size(0), 1)); } public DataSet getAllFolds() { return allFolds; } public INDArray getfoldK(int k, boolean feat) { return feat == true ? kfoldFeats[k] : kfoldLabels[k]; } public INDArray getFoldbutk(int k, boolean feat) { INDArray iFold = null; boolean notInit = true; for (int i = 0; i < ks.length; i++) { if (i == k) continue; if (notInit) { iFold = getfoldK(i, feat); notInit = false; } else { iFold = Nd4j.vstack(iFold, getfoldK(i, feat)).dup(); } } return iFold; } } @Override public char ordering() { return 'c'; } }