/* * File: SequencePredictionLearnerTest.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright December 23, 2009, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. * See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.algorithm; import gov.sandia.cognition.collection.CollectionUtil; import gov.sandia.cognition.collection.MultiCollection; import gov.sandia.cognition.evaluator.Evaluator; import gov.sandia.cognition.learning.algorithm.nearest.KNearestNeighborExhaustive; import gov.sandia.cognition.learning.algorithm.nearest.NearestNeighbor; import gov.sandia.cognition.learning.data.InputOutputPair; import java.util.ArrayList; import java.util.Random; import junit.framework.TestCase; /** * Unit tests for class SequencePredictionLearner. * * @author Justin Basilico * @since 3.0 */ public class SequencePredictionLearnerTest extends TestCase { protected Random random; /** * Creates a new test. * * @param testName The test name. */ public SequencePredictionLearnerTest( String testName) { super(testName); this.random = new Random(); } public void testConstants() { assertEquals(1, SequencePredictionLearner.DEFAULT_PREDICTION_HORIZION); } public void testConstructors() { int predictionHorizon = SequencePredictionLearner.DEFAULT_PREDICTION_HORIZION; KNearestNeighborExhaustive.Learner<Double, Double> learner = null; SequencePredictionLearner<Double, Evaluator<Double, Double>> instance = new SequencePredictionLearner<Double, Evaluator<Double, Double>>(); assertEquals(predictionHorizon, instance.getPredictionHorizion()); assertSame(learner, instance.getLearner()); predictionHorizon = 4; learner = new KNearestNeighborExhaustive.Learner<Double, Double>(); instance = new SequencePredictionLearner<Double, Evaluator<Double, Double>>( learner, predictionHorizon); assertEquals(predictionHorizon, instance.getPredictionHorizion()); assertSame(learner, instance.getLearner()); } /** * Test of learn method, of class SequencePredictionLearner. */ public void testLearn() { SequencePredictionLearner<Double, NearestNeighbor<Double, Double>> instance = new SequencePredictionLearner<Double, NearestNeighbor<Double, Double>>( new KNearestNeighborExhaustive.Learner<Double, Double>(), 1); ArrayList<Double> data = new ArrayList<Double>(); for (int i = 0; i < 10; i++) { data.add(random.nextDouble()); } NearestNeighbor<Double, Double> result = instance.learn(data); assertEquals(9, result.getData().size()); instance.setPredictionHorizon(2); result = instance.learn(data); assertEquals(8, result.getData().size()); } /** * Test of createPredictionDataset method, of class SequencePredictionLearner. */ public void testCreatePredictionDataset() { final ArrayList<Double> data = new ArrayList<Double>(); MultiCollection<InputOutputPair<Double, Double>> result = null; result = SequencePredictionLearner.createPredictionDataset(data, 1); assertTrue(result.isEmpty()); data.add(random.nextDouble()); result = SequencePredictionLearner.createPredictionDataset(data, 1); assertTrue(result.isEmpty()); data.add(random.nextDouble()); result = SequencePredictionLearner.createPredictionDataset(data, 1); assertEquals(1, result.size()); assertEquals(data.get(0), CollectionUtil.getElement(result, 0).getInput()); assertEquals(data.get(1), CollectionUtil.getElement(result, 0).getOutput()); result = SequencePredictionLearner.createPredictionDataset(data, 2); assertTrue(result.isEmpty()); assertTrue(result.isEmpty()); data.add(random.nextDouble()); data.add(random.nextDouble()); data.add(random.nextDouble()); data.add(random.nextDouble()); result = SequencePredictionLearner.createPredictionDataset(data, 1); for (int i = 0; i < (data.size() - 1); i++) { assertEquals(data.get(i), CollectionUtil.getElement(result, i).getInput()); assertEquals(data.get(i + 1), CollectionUtil.getElement(result, i).getOutput()); } result = SequencePredictionLearner.createPredictionDataset(data, 3); for (int i = 0; i < (data.size() - 3); i++) { assertEquals(data.get(i), CollectionUtil.getElement(result, i).getInput()); assertEquals(data.get(i + 3), CollectionUtil.getElement(result, i).getOutput()); } } /** * Test of getPredictionHorizion method, of class SequencePredictionLearner. */ public void testGetPredictionHorizion() { this.testSetPredictionHorizon(); } /** * Test of setPredictionHorizon method, of class SequencePredictionLearner. */ public void testSetPredictionHorizon() { int predictionHorizon = 1; SequencePredictionLearner<Double, Evaluator<Double, Double>> instance = new SequencePredictionLearner<Double, Evaluator<Double, Double>>(); assertEquals(predictionHorizon, instance.getPredictionHorizion()); predictionHorizon = 2; instance.setPredictionHorizon(predictionHorizon); assertEquals(predictionHorizon, instance.getPredictionHorizion()); predictionHorizon = 77; instance.setPredictionHorizon(predictionHorizon); assertEquals(predictionHorizon, instance.getPredictionHorizion()); // Ensure that bad prediction horizons don't work: boolean exceptionThrown = false; try { instance.setPredictionHorizon(0); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(predictionHorizon, instance.getPredictionHorizion()); exceptionThrown = false; try { instance.setPredictionHorizon(-1); } catch (IllegalArgumentException e) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } assertEquals(predictionHorizon, instance.getPredictionHorizion()); } }