/*
* File: RandomByTwoFoldCreatorTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright January 20, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*/
package gov.sandia.cognition.learning.experiment;
import gov.sandia.cognition.learning.data.PartitionedDataset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import junit.framework.TestCase;
/**
* Unit tests for class RandomByTwoFoldCreator.
*
* @author Justin Basilico
* @since 3.0
*/
public class RandomByTwoFoldCreatorTest
extends TestCase
{
protected Random random;
public RandomByTwoFoldCreatorTest(
final String testName)
{
super(testName);
this.random = new Random(1);
}
/**
* Test of constants of class RandomByTwoFoldCreator.
*/
public void testConstants()
{
assertEquals(5, RandomByTwoFoldCreator.DEFAULT_NUM_SPLITS);
}
/**
* Test of constructors of class RandomByTwoFoldCreator.
*/
public void testConstructors()
{
int numSplits = RandomByTwoFoldCreator.DEFAULT_NUM_SPLITS;
RandomByTwoFoldCreator<Double> instance = new RandomByTwoFoldCreator<Double>();
assertEquals(numSplits, instance.getNumSplits());
assertNotNull(instance.getRandom());
numSplits = numSplits * 10;
instance = new RandomByTwoFoldCreator<Double>(numSplits, this.random);
assertEquals(numSplits, instance.getNumSplits());
assertSame(this.random, instance.getRandom());
}
/**
* Test of createFolds method, of class RandomByTwoFoldCreator.
*/
public void testCreateFolds()
{
int numSplits = 7;
RandomByTwoFoldCreator<Double> instance = new RandomByTwoFoldCreator<Double>(
numSplits, this.random);
Collection<Double> data = new ArrayList<Double>();
List<PartitionedDataset<Double>> folds = null;
int count = 25;
for (int i = 0; i < count; i++)
{
data.add(random.nextDouble());
if (i > 1)
{
folds = instance.createFolds(data);
checkFolds(data, numSplits, folds);
}
}
data.clear();
boolean exceptionThrown = false;
try
{
folds = instance.createFolds(data);
}
catch ( IllegalArgumentException e )
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
data.add(this.random.nextDouble());
exceptionThrown = false;
try
{
folds = instance.createFolds(data);
}
catch ( IllegalArgumentException e )
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
}
/**
* Checks that the folds are correct.
*
* @param data
* The data.
* @param numRequestedSplits
* The requested number of splits.
* @param folds
* The folds that were created.
*/
public static void checkFolds(
final Collection<Double> data,
final int numRequestedSplits,
final List<PartitionedDataset<Double>> folds)
{
int dataSize = data.size();
int halfDataSize = dataSize / 2;
int numFolds = Math.min(2 * dataSize, 2 * numRequestedSplits);
assertEquals(numFolds, folds.size());
for (PartitionedDataset<Double> fold : folds)
{
int trainSize = fold.getTrainingSet().size();
int testSize = fold.getTestingSet().size();
assertTrue(trainSize > 0);
assertTrue(testSize > 0);
assertEquals(dataSize, trainSize + testSize);
}
for (PartitionedDataset<Double> fold : folds)
{
int trainCount = 0;
int testCount = 0;
for (Double value : data)
{
boolean inTrain = fold.getTrainingSet().contains(value);
boolean inTest = fold.getTestingSet().contains(value);
assertTrue(inTrain ^ inTest);
if (inTrain)
{
trainCount++;
}
if (inTest)
{
testCount++;
}
}
assertEquals(dataSize, trainCount + testCount);
assertTrue(
trainCount >= halfDataSize
&& trainCount <= halfDataSize + 1);
assertTrue(
testCount >= halfDataSize
&& testCount <= halfDataSize + 1);
}
}
/**
* Test of getNumSplits method, of class RandomByTwoFoldCreator.
*/
public void testGetNumSplits()
{
this.testSetNumSplits();
}
/**
* Test of setNumSplits method, of class RandomByTwoFoldCreator.
*/
public void testSetNumSplits()
{
int numSplits = RandomByTwoFoldCreator.DEFAULT_NUM_SPLITS;
RandomByTwoFoldCreator<Double> instance = new RandomByTwoFoldCreator<Double>();
assertEquals(numSplits, instance.getNumSplits());
numSplits = numSplits * 10;
instance.setNumSplits(numSplits);
assertEquals(numSplits, instance.getNumSplits());
numSplits = 1;
instance.setNumSplits(numSplits);
assertEquals(numSplits, instance.getNumSplits());
numSplits = this.random.nextInt(147);
instance.setNumSplits(numSplits);
assertEquals(numSplits, instance.getNumSplits());
boolean exceptionThrown = false;
try
{
instance.setNumSplits(0);
}
catch ( IllegalArgumentException e )
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(numSplits, instance.getNumSplits());
exceptionThrown = false;
try
{
instance.setNumSplits(-1);
}
catch ( IllegalArgumentException e )
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(numSplits, instance.getNumSplits());
}
}