package quickml.supervised.crossValidation.data;
import com.google.common.collect.Lists;
import org.junit.Before;
import org.junit.Test;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.testng.AssertJUnit.assertFalse;
public class FoldedDataTest {
private List<Integer> instances;
@Before
public void setUp() throws Exception {
instances = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
}
@Test
public void testCycleThrough5Of10Folds() throws Exception {
FoldedData<Integer> foldedData = new FoldedData<Integer>(instances, 10, 5);
assertEquals(new Integer(1), foldedData.getValidationSet().get(0));
foldedData.nextCycle();
assertEquals(new Integer(2), foldedData.getValidationSet().get(0));
foldedData.nextCycle();
assertEquals(new Integer(3), foldedData.getValidationSet().get(0));
foldedData.nextCycle();
assertEquals(new Integer(4), foldedData.getValidationSet().get(0));
foldedData.nextCycle();
assertEquals(new Integer(5), foldedData.getValidationSet().get(0));
foldedData.nextCycle();
assertEquals(9, foldedData.getTrainingSet().size());
assertFalse(foldedData.hasMore());
}
@Test(expected = IllegalArgumentException.class)
public void testNumFoldsIsZero() throws Exception {
new FoldedData<>(instances, 0, 4);
}
@Test(expected = IllegalArgumentException.class)
public void testFoldsUsedIs0() throws Exception {
new FoldedData<>(instances, 4, 0);
}
@Test(expected = IllegalArgumentException.class)
public void testFoldsUsedMustBeLessThanFolds() throws Exception {
new FoldedData<>(instances, 2, 4);
}
@Test(expected = IllegalArgumentException.class)
public void testNoEmptySet() throws Exception {
new FoldedData<>(Lists.newArrayList(), 0, 0);
}
}