/* * File: LeaveOneOutFoldCreatorTest.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright September 26, 2007, Sandia Corporation. Under the terms of Contract * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by * or on behalf of the U.S. Government. Export of this program may require a * license from the United States Government. See CopyrightHistory.txt for * complete details. * */ package gov.sandia.cognition.learning.experiment; import gov.sandia.cognition.learning.data.PartitionedDataset; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Random; import junit.framework.TestCase; /** * This class implements JUnit tests for the following classes: * * @author Justin Basilico * @since 2.0 */ public class LeaveOneOutFoldCreatorTest extends TestCase { Random random = new Random(1); public LeaveOneOutFoldCreatorTest( String testName) { super(testName); } /** * Test of createFolds method, of class gov.sandia.cognition.learning.experiment.LeaveOneOutFoldCreator. */ public void testCreateFolds() { Collection<Double> data = new ArrayList<Double>(); LeaveOneOutFoldCreator<Double> instance = new LeaveOneOutFoldCreator<Double>(); List<PartitionedDataset<Double>> folds = null; int count = 25; for (int i = 0; i < count; i++) { data.add(this.random.nextDouble()); if ( i > 1 ) { folds = instance.createFolds(data); checkFolds(data, folds); } } data.clear(); boolean exceptionThrown = false; try { folds = instance.createFolds(data); } catch ( IllegalArgumentException e ) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } data.add(this.random.nextDouble()); exceptionThrown = false; try { folds = instance.createFolds(data); } catch ( IllegalArgumentException e ) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } } public void checkFolds( Collection<Double> data, List<PartitionedDataset<Double>> folds) { int count = data.size(); assertEquals(count, folds.size()); for ( PartitionedDataset<Double> fold : folds ) { assertEquals(count - 1, fold.getTrainingSet().size()); assertEquals(1, fold.getTestingSet().size()); } for ( Double value : data ) { int trainCount = 0; int testCount = 0; for ( PartitionedDataset<Double> fold : folds ) { boolean inTrain = fold.getTrainingSet().contains(value); boolean inTest = fold.getTestingSet().contains(value); assertTrue(inTrain ^ inTest); if ( inTrain ) { trainCount++; } if ( inTest ) { testCount++; } } assertEquals(count - 1, trainCount); assertEquals(1, testCount); } } }