/* * File: RegressionTreeLearnerTest.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright November 30, 2007, Sandia Corporation. Under the terms of Contract * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by * or on behalf of the U.S. Government. Export of this program may require a * license from the United States Government. See CopyrightHistory.txt for * complete details. * */ package gov.sandia.cognition.learning.algorithm.tree; import gov.sandia.cognition.learning.algorithm.regression.KernelBasedIterativeRegression; import gov.sandia.cognition.learning.algorithm.regression.LinearRegression; import gov.sandia.cognition.learning.data.DefaultInputOutputPair; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.math.matrix.Vectorizable; import gov.sandia.cognition.math.matrix.mtj.Vector3; import gov.sandia.cognition.util.CloneableSerializable; import java.util.ArrayList; import junit.framework.TestCase; /** * This class implements JUnit tests for the following classes: RegressionTreeLearner * * @author Justin Basilico * @since 2.0 */ public class RegressionTreeLearnerTest extends TestCase { public RegressionTreeLearnerTest( String testName) { super(testName); } public void testConstants() { assertEquals(4, RegressionTreeLearner.DEFAULT_LEAF_COUNT_THRESHOLD); } public void testConstructors() { RegressionTreeLearner<Vectorizable> instance = new RegressionTreeLearner<Vectorizable>(); assertNull(instance.getDeciderLearner()); assertNull(instance.getRegressionLearner()); assertEquals(RegressionTreeLearner.DEFAULT_LEAF_COUNT_THRESHOLD, instance.getLeafCountThreshold()); VectorThresholdVarianceLearner deciderLearner = new VectorThresholdVarianceLearner(); instance = new RegressionTreeLearner<Vectorizable>( deciderLearner); assertSame(deciderLearner, instance.getDeciderLearner()); assertNull(instance.getRegressionLearner()); assertEquals(RegressionTreeLearner.DEFAULT_LEAF_COUNT_THRESHOLD, instance.getLeafCountThreshold()); KernelBasedIterativeRegression<Vectorizable> regressionLearner = new KernelBasedIterativeRegression<Vectorizable>(); instance = new RegressionTreeLearner<Vectorizable>( deciderLearner, regressionLearner); assertSame(deciderLearner, instance.getDeciderLearner()); assertSame(regressionLearner, instance.getRegressionLearner()); assertEquals(RegressionTreeLearner.DEFAULT_LEAF_COUNT_THRESHOLD, instance.getLeafCountThreshold()); int leafCountThreshold = RegressionTreeLearner.DEFAULT_LEAF_COUNT_THRESHOLD + 1; int maxDepth = 10; instance = new RegressionTreeLearner<Vectorizable>( deciderLearner, regressionLearner, leafCountThreshold, maxDepth); assertSame(deciderLearner, instance.getDeciderLearner()); assertSame(regressionLearner, instance.getRegressionLearner()); assertEquals(leafCountThreshold, instance.getLeafCountThreshold()); assertEquals(maxDepth, instance.getMaxDepth()); } /** * Tests of clone */ public void testClone() { System.out.println( "Clone" ); RegressionTreeLearner<?> instance = new RegressionTreeLearner<Vectorizable>(); CloneableSerializable clone = instance.clone(); assertNotNull( clone ); assertNotSame( instance, clone ); } /** * Test of learn method, of class gov.sandia.cognition.learning.algorithm.tree.RegressionTreeLearner. */ public void testLearn() { VectorThresholdVarianceLearner deciderLearner = new VectorThresholdVarianceLearner(); LinearRegression regressionLearner = new LinearRegression(); RegressionTreeLearner<Vectorizable> instance = new RegressionTreeLearner<Vectorizable>(deciderLearner, regressionLearner); double epsilon = 0.001; RegressionTree<Vectorizable> result = instance.learn(null); assertNull(result); ArrayList<InputOutputPair<Vector3, Double>> data = new ArrayList<InputOutputPair<Vector3, Double>>(); result = instance.learn(data); assertNull(result.getRootNode()); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 3.0, 2.0), 14.0)); result = instance.learn(data); assertNotNull(result.getRootNode()); assertTrue(result.getRootNode().isLeaf()); assertEquals(14.0, result.evaluate(data.get(0).getInput()), epsilon); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 1.0, 2.0), 14.1)); result = instance.learn(data); assertNotNull(result.getRootNode()); assertTrue(result.getRootNode().isLeaf()); assertEquals(14.0, result.evaluate(data.get(0).getInput()), epsilon); assertEquals(14.1, result.evaluate(data.get(1).getInput()), epsilon); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 10.0, 3.0), 1.0)); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 9.0, 4.0), 0.5)); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 8.0, 2.0), 0.0)); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 0.0, 2.0), 14.2)); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 7.0, 2.0), -0.5)); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 6.0, 2.0), -1.0)); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 5.0, 2.0), -1.5)); result = instance.learn(data); assertNotNull(result.getRootNode()); assertFalse(result.getRootNode().isLeaf()); for ( InputOutputPair<Vector3, Double> example : data ) { assertEquals(example.getOutput(), result.evaluate(example.getInput()), 0.1); } data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 1.0, 100.0), 4.6)); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 1.0, 100.0), 4.8)); result = instance.learn(data); assertEquals(4.7, result.evaluate(new Vector3(1.0, 1.0, 100.0)), 0.1); data.clear(); instance.setLeafCountThreshold(1); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 1.0, 100.0), 4.6)); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 1.0, 100.0), 4.8)); result = instance.learn(data); assertEquals(4.7, result.evaluate(new Vector3(1.0, 1.0, 100.0)), 0.1); data.clear(); // This is XOR. instance.setLeafCountThreshold(1); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(0.0, 0.0, 0.0), 1.0)); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 0.0, 0.0), -1.0)); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(0.0, 1.0, 0.0), -1.0)); data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 1.0, 0.0), 1.0)); result = instance.learn(data); for ( InputOutputPair<Vector3, Double> example : data ) { assertEquals(example.getOutput(), result.evaluate(example.getInput()), 0.1); } } /** * Test of getRegressionLearner method, of class gov.sandia.cognition.learning.algorithm.tree.RegressionTreeLearner. */ public void testGetRegressionLearner() { this.testSetRegressionLearner(); } /** * Test of setRegressionLearner method, of class gov.sandia.cognition.learning.algorithm.tree.RegressionTreeLearner. */ public void testSetRegressionLearner() { RegressionTreeLearner<Vectorizable> instance = new RegressionTreeLearner<Vectorizable>(); assertNull(instance.getRegressionLearner()); KernelBasedIterativeRegression<Vectorizable> regressionLearner = new KernelBasedIterativeRegression<Vectorizable>(); instance.setRegressionLearner(regressionLearner); assertSame(regressionLearner, instance.getRegressionLearner()); instance.setRegressionLearner(null); assertNull(instance.getRegressionLearner()); } /** * Test of getLeafCountThreshold method, of class gov.sandia.cognition.learning.algorithm.tree.RegressionTreeLearner. */ public void testGetLeafCountThreshold() { this.testSetLeafCountThreshold(); } /** * Test of setLeafCountThreshold method, of class gov.sandia.cognition.learning.algorithm.tree.RegressionTreeLearner. */ public void testSetLeafCountThreshold() { RegressionTreeLearner<Vector3> instance = new RegressionTreeLearner<Vector3>(); assertEquals(RegressionTreeLearner.DEFAULT_LEAF_COUNT_THRESHOLD, instance.getLeafCountThreshold()); int leafCountThreshold = RegressionTreeLearner.DEFAULT_LEAF_COUNT_THRESHOLD + 1; instance.setLeafCountThreshold(leafCountThreshold); assertEquals(leafCountThreshold, instance.getLeafCountThreshold()); leafCountThreshold = 1; instance.setLeafCountThreshold(leafCountThreshold); assertEquals(leafCountThreshold, instance.getLeafCountThreshold()); leafCountThreshold = 0; instance.setLeafCountThreshold(leafCountThreshold); assertEquals(leafCountThreshold, instance.getLeafCountThreshold()); boolean exceptionThrown = false; try { instance.setLeafCountThreshold(-1); } catch ( IllegalArgumentException e ) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } } }