package hex;
import static org.junit.Assert.assertArrayEquals;
import static water.util.Utils.nfold;
import java.util.Arrays;
import junit.framework.Assert;
import org.junit.Test;
import water.*;
import water.fvec.Frame;
import water.util.Utils;
public class NFoldFrameExtractorTest extends TestUtil {
@Test
public void testNFoldSplitUtility() {
// 10-fold
for (int i=0; i<10; i++) {
assertArrayEquals("10 fold of 10 elements : "+i+"-th failed!", ar(i,1L), nfold(10, 10, i));
}
// 10-fold
for (int i=0; i<9; i++) {
assertArrayEquals("10 fold of 11 elements : "+i+"-th failed!", ar(i,1L), nfold(11, 10, i));
}
assertArrayEquals("10 fold of 11 elements : 9-th failed!", ar(9,2L), nfold(11, 10, 9));
}
@Test
public void testEspcSplit() {
NFoldFrameExtractor fe = null;
long [][] espc = null;
// N-fold split - test on the chunk boundary split - start/end are at chunk boundaries
for (int i=0; i<3; i++) {
fe = new NFoldFrameExtractor(null, 3, i, null, null);
espc = fe.computeEspcPerSplit(ar(0,2,4,6,8,10,12), 12L);
assertArrayEquals(ar(0L, 2L, 4L, 6L, 8L), espc[0]);
assertArrayEquals(ar(0L, 2L, 4L), espc[1]);
}
// Split inside chunk
fe = new NFoldFrameExtractor(null, 3, 1, null, null);
espc = fe.computeEspcPerSplit(ar(0,2,4,6,8,10), 10L);
assertArrayEquals(ar(0L, 2L, 3L, 5L, 7L), espc[0]);
assertArrayEquals(ar(0L, 1L, 3L), espc[1]);
// Split inside chunk
fe = new NFoldFrameExtractor(null, 3, 0, null, null);
espc = fe.computeEspcPerSplit(ar(0,3,6), 6L);
assertArrayEquals(ar(0L, 1L, 4L), espc[0]);
assertArrayEquals(ar(0L, 2L), espc[1]);
fe = new NFoldFrameExtractor(null, 3, 1, null, null);
espc = fe.computeEspcPerSplit(ar(0,3,6), 6L);
assertArrayEquals(ar(0L, 2L, 4L), espc[0]);
assertArrayEquals(ar(0L, 1L, 2L), espc[1]);
fe = new NFoldFrameExtractor(null, 3, 2, null, null);
espc = fe.computeEspcPerSplit(ar(0,3,6), 6L);
assertArrayEquals(ar(0L, 3L, 4L), espc[0]);
assertArrayEquals(ar(0L, 2L), espc[1]);
// Test scenario that fold split one chunk into 3 parts
fe = new NFoldFrameExtractor(null, 3, 0 , null, null);
espc = fe.computeEspcPerSplit(ar(0,6), 6L);
assertArrayEquals(ar(0L, 4L), espc[0]);
assertArrayEquals(ar(0L, 2L), espc[1]);
fe = new NFoldFrameExtractor(null, 3, 1 , null, null);
espc = fe.computeEspcPerSplit(ar(0,6), 6L);
assertArrayEquals(ar(0L, 2L, 4L), espc[0]);
assertArrayEquals(ar(0L, 2L), espc[1]);
fe = new NFoldFrameExtractor(null, 3, 2 , null, null);
espc = fe.computeEspcPerSplit(ar(0,6), 6L);
assertArrayEquals(ar(0L, 4L), espc[0]);
assertArrayEquals(ar(0L, 2L), espc[1]);
}
@Test
public void testIris() {
Key key = Key.make("iris.hex");
Frame fr = parseFrame(key, "./smalldata/iris/iris.csv");
int[] nfolds = new int[] {2,3,10,11};
long nrows = fr.numRows();
try {
for (int i=0; i<nfolds.length; i++) {
int n = nfolds[i];
for (int f=0; f<n; f++) {
Frame[] splits = null;
try {
NFoldFrameExtractor nffe = new NFoldFrameExtractor(fr, n, f, null, null);
H2O.submitTask(nffe);
splits = nffe.getResult();
Arrays.deepToString(splits);
Assert.assertEquals("N-Fold extract should always produce 2 frames!", 2, splits.length);
Assert.assertEquals("N-Fold extract should not modify input frame!", nrows, fr.numRows());
Assert.assertEquals(nrows, splits[0].numRows()+splits[1].numRows());
} finally {
if (splits!=null) for (Frame fs: splits) fs.delete();
}
}
}
} finally {
if (fr!=null) fr.delete();
}
}
}