/* * File: TimeSeriesPredictionLearnerTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Mar 6, 2009, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.algorithm; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.algorithm.nearest.KNearestNeighborExhaustive; import gov.sandia.cognition.learning.data.DefaultInputOutputPair; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import junit.framework.TestCase; import java.util.Random; /** * Unit tests for TimeSeriesPredictionLearnerTest. * * @author krdixon */ public class TimeSeriesPredictionLearnerTest extends TestCase { /** * Random number generator to use for a fixed random seed. */ public Random random = new Random( 1 ); /** * Tests for class TimeSeriesPredictionLearnerTest. * @param testName Name of the test. */ public TimeSeriesPredictionLearnerTest( String testName) { super(testName); } /** * Test of getPredictionHorizon method, of class TimeSeriesPredictionLearner. */ public void testGetPredictionHorizon() { System.out.println( "getPredictionHorizon" ); TimeSeriesPredictionLearner<?,?,?> instance = new TimeSeriesPredictionLearner<Double,Double,KNearestNeighborExhaustive<Double,Double>>(); assertEquals( TimeSeriesPredictionLearner.DEFAULT_PREDICTION_HORIZON, instance.getPredictionHorizon() ); } /** * Test of setPredictionHorizon method, of class TimeSeriesPredictionLearner. */ public void testSetPredictionHorizon() { System.out.println( "setPredictionHorizon" ); TimeSeriesPredictionLearner<?,?,?> instance = new TimeSeriesPredictionLearner<Double,Double,KNearestNeighborExhaustive<Double,Double>>(); assertEquals( TimeSeriesPredictionLearner.DEFAULT_PREDICTION_HORIZON, instance.getPredictionHorizon() ); final int p2 = 10; instance.setPredictionHorizon( p2 ); assertEquals( p2, instance.getPredictionHorizon() ); } /** * Test of getSupervisedLearner method, of class TimeSeriesPredictionLearner. */ public void testGetSupervisedLearner() { System.out.println( "getSupervisedLearner" ); TimeSeriesPredictionLearner<?,?,?> instance = new TimeSeriesPredictionLearner<Double,Double,KNearestNeighborExhaustive<Double,Double>>(); assertNull( instance.getSupervisedLearner() ); } /** * Test of setSupervisedLearner method, of class TimeSeriesPredictionLearner. */ @SuppressWarnings("unchecked") public void testSetSupervisedLearner() { System.out.println( "setSupervisedLearner" ); TimeSeriesPredictionLearner<Double,Double,KNearestNeighborExhaustive<Double,Double>> instance = new TimeSeriesPredictionLearner<Double,Double,KNearestNeighborExhaustive<Double,Double>>(); assertNull( instance.getSupervisedLearner() ); KNearestNeighborExhaustive.Learner<Double, Double> learner = new KNearestNeighborExhaustive.Learner<Double, Double>(); instance.setSupervisedLearner( learner ); assertSame( learner, instance.getSupervisedLearner() ); } /** * Test of learn method, of class TimeSeriesPredictionLearner. */ public void testLearn() { System.out.println( "learn" ); final int predictionHorizon = 2; KNearestNeighborExhaustive.Learner<Double,Double> learner = new KNearestNeighborExhaustive.Learner<Double,Double>(); TimeSeriesPredictionLearner<Double,Double,KNearestNeighborExhaustive<Double,Double>> instance = new TimeSeriesPredictionLearner<Double,Double,KNearestNeighborExhaustive<Double,Double>>( predictionHorizon, learner ); @SuppressWarnings("unchecked") final List<DefaultInputOutputPair<Double,Double>> data = Arrays.asList( new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ) ); KNearestNeighborExhaustive<Double,Double> result = instance.learn( data ); assertEquals( data.size()-predictionHorizon, result.getData().size() ); } /** * Test of createPredictionDataset method, of class TimeSeriesPredictionLearner. */ public void testCreatePredictionDataset() { System.out.println( "createPredictionDataset" ); final int predictionHorizon = 2; @SuppressWarnings("unchecked") final List<DefaultInputOutputPair<Double,Double>> data = Arrays.asList( new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ), new DefaultInputOutputPair<Double,Double>( random.nextGaussian(), random.nextGaussian() ) ); ArrayList<InputOutputPair<Double,Double>> result = TimeSeriesPredictionLearner.createPredictionDataset( predictionHorizon, data ); assertEquals( data.size()-predictionHorizon, result.size() ); for( int i = 0; i < result.size(); i++ ) { assertSame( data.get(i).getInput(), result.get(i).getInput() ); assertSame( data.get(i+predictionHorizon).getOutput(), result.get(i).getOutput() ); } try { result = TimeSeriesPredictionLearner.createPredictionDataset( -1, data ); fail( "Prediction horizon must be >= 0" ); } catch (Exception e) { System.out.println( "Good: " + e ); } result = TimeSeriesPredictionLearner.createPredictionDataset( data.size(), data ); assertEquals( 0, result.size() ); result = TimeSeriesPredictionLearner.createPredictionDataset( data.size()+1, data ); assertEquals( 0, result.size() ); } }