/* * File: KernelWeightedRobustRegressionTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Dec 2, 2007, Sandia Corporation. Under the terms of Contract * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by * or on behalf of the U.S. Government. Export of this program may require a * license from the United States Government. See CopyrightHistory.txt for * complete details. * */ package gov.sandia.cognition.learning.algorithm.regression; import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner; import gov.sandia.cognition.learning.data.DefaultInputOutputPair; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.function.cost.MeanSquaredErrorCostFunction; import gov.sandia.cognition.learning.function.kernel.Kernel; import gov.sandia.cognition.learning.function.kernel.RadialBasisKernel; import gov.sandia.cognition.learning.function.scalar.LinearDiscriminant; import gov.sandia.cognition.learning.function.scalar.PolynomialFunction; import gov.sandia.cognition.learning.function.scalar.VectorFunctionLinearDiscriminant; import gov.sandia.cognition.learning.function.vector.ScalarBasisSet; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.VectorFactory; import gov.sandia.cognition.math.matrix.VectorFunction; import gov.sandia.cognition.util.ObjectUtil; import java.util.LinkedList; import junit.framework.TestCase; /** * JUnit tests for class KernelWeightedRobustRegressionTest * @author Kevin R. Dixon */ public class KernelWeightedRobustRegressionTest extends TestCase { /** * Entry point for JUnit tests for class KernelWeightedRobustRegressionTest * @param testName name of this test */ public KernelWeightedRobustRegressionTest( String testName ) { super( testName ); } public KernelWeightedRobustRegression<Vector, Vector> createInstance() { return new KernelWeightedRobustRegression<Vector, Vector>( new MultivariateLinearRegression(), new RadialBasisKernel() ); } /** * Creates an uncorrupted dataset * @return */ public LinkedList<InputOutputPair<Vector, Vector>> createDataset1() { LinkedList<InputOutputPair<Vector, Vector>> d = new LinkedList<InputOutputPair<Vector, Vector>>(); for (int i = 1; i < 4; i++) { d.add( new DefaultInputOutputPair<Vector, Vector>( VectorFactory.getDefault().copyValues( i ), VectorFactory.getDefault().copyValues( 2 * i ) ) ); } return d; } /** * Creates a dataset with an outlier * @return */ public LinkedList<InputOutputPair<Vector, Vector>> createDataset2() { LinkedList<InputOutputPair<Vector, Vector>> d = new LinkedList<InputOutputPair<Vector, Vector>>(); ScalarBasisSet<Double> polynomials = new ScalarBasisSet<Double>( PolynomialFunction.createPolynomials( 0.0, 1.0, 2.0 ) ); VectorFunctionLinearDiscriminant<Double> f = new VectorFunctionLinearDiscriminant<Double>( polynomials, new LinearDiscriminant( VectorFactory.getDefault().copyValues(-5.0, 2.0, 0.0 ) ) ); for (double i = 1; i <= 10; i++) { d.add(new DefaultInputOutputPair<Vector, Vector>( polynomials.evaluate(i), VectorFactory.getDefault().copyValues(f.evaluate(i) ) ) ); } double j = 2.5; d.add(new DefaultInputOutputPair<Vector, Vector>( polynomials.evaluate(j), VectorFactory.getDefault().copyValues(30+f.evaluate(j) )) ); return d; } /** * Test of getKernelWeightingFunction method, of class KernelWeightedRobustRegression. */ public void testGetKernelWeightingFunction() { System.out.println( "getKernelWeightingFunction" ); KernelWeightedRobustRegression<Vector, Vector> instance = this.createInstance(); assertNotNull( instance.getKernelWeightingFunction() ); } /** * Test of setKernelWeightingFunction method, of class KernelWeightedRobustRegression. */ public void testSetKernelWeightingFunction() { System.out.println( "setKernelWeightingFunction" ); KernelWeightedRobustRegression<Vector, Vector> instance = this.createInstance(); Kernel<? super Vector> kernel = instance.getKernelWeightingFunction(); assertNotNull( kernel ); instance.setKernelWeightingFunction( null ); assertNull( instance.getKernelWeightingFunction() ); instance.setKernelWeightingFunction( kernel ); assertSame( kernel, instance.getKernelWeightingFunction() ); } /** * Test of getTolerance method, of class KernelWeightedRobustRegression. */ public void testGetTolerance() { System.out.println( "getTolerance" ); KernelWeightedRobustRegression<Vector, Vector> instance = this.createInstance(); assertTrue( instance.getTolerance() > 0.0 ); } /** * Test of setTolerance method, of class KernelWeightedRobustRegression. */ public void testSetTolerance() { System.out.println( "setTolerance" ); double tolerance = Math.random(); KernelWeightedRobustRegression<Vector, Vector> instance = this.createInstance(); assertTrue( instance.getTolerance() > 0.0 ); instance.setTolerance( tolerance ); assertEquals( tolerance, instance.getTolerance() ); try { instance.setTolerance( 0.0 ); fail( "Tolerance must be > 0.0" ); } catch (Exception e) { System.out.println( "Good: " + e ); } } /** * Test of getIterationLearner method, of class KernelWeightedRobustRegression. */ public void testGetIterationLearner() { System.out.println( "getIterationLearner" ); KernelWeightedRobustRegression<Vector, Vector> instance = this.createInstance(); SupervisedBatchLearner<Vector,Vector, ?> learner = instance.getIterationLearner(); assertNotNull( learner ); } /** * Test of setIterationLearner method, of class KernelWeightedRobustRegression. */ public void testSetIterationLearner() { System.out.println( "setIterationLearner" ); KernelWeightedRobustRegression<Vector, Vector> instance = this.createInstance(); SupervisedBatchLearner<Vector,Vector, ?> learner = instance.getIterationLearner(); assertNotNull( learner ); instance.setIterationLearner( null ); assertNull( instance.getIterationLearner() ); instance.setIterationLearner( learner ); assertSame( learner, instance.getIterationLearner() ); } public void testLearning1() { System.out.println( "learn1" ); LinkedList<InputOutputPair<Vector, Vector>> d1 = this.createDataset1(); KernelWeightedRobustRegression<Vector, Vector> r1 = this.createInstance(); VectorFunction f2 = (VectorFunction) r1.learn( d1 ); System.out.println( "Learner:\n" + ObjectUtil.inspectFieldValues( r1 ) ); // Since there are no outliers, I should have only iterated once assertEquals( 1, r1.getIteration() ); // I should have no mean-squared error, either. MeanSquaredErrorCostFunction cost = new MeanSquaredErrorCostFunction( d1 ); assertEquals( 0.0, cost.evaluate( f2 ), 1e-5 ); } public void testLearning2() { System.out.println( "learn2" ); LinkedList<InputOutputPair<Vector, Vector>> d1 = this.createDataset2(); KernelWeightedRobustRegression<Vector, Vector> r1 = this.createInstance(); r1.setKernelWeightingFunction( new RadialBasisKernel( 10.0 ) ); VectorFunction f2 = (VectorFunction) r1.learn( d1 ); System.out.println( "Learner:\n" + ObjectUtil.inspectFieldValues( r1 ) ); // Since we have an outlier, it will take a few iterations to arrive // at a stable solution assertTrue( 1 < r1.getIteration() ); // I should have no mean-squared error, either. MeanSquaredErrorCostFunction cost = new MeanSquaredErrorCostFunction( d1 ); double outlier = d1.getLast().getOutput().getElement( 0 ); double expected = outlier * outlier / d1.size(); assertEquals( expected, cost.evaluate( f2 ), outlier / d1.size() ); } }