/* * File: SumSquaredErrorCostFunctionTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Jul 4, 2008, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.function.cost; import gov.sandia.cognition.evaluator.Evaluator; import gov.sandia.cognition.learning.data.DatasetUtil; import gov.sandia.cognition.learning.data.DefaultWeightedTargetEstimatePair; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.data.TargetEstimatePair; import gov.sandia.cognition.learning.function.vector.MultivariateDiscriminant; import gov.sandia.cognition.math.RingAccumulator; import gov.sandia.cognition.math.matrix.Matrix; import gov.sandia.cognition.math.matrix.Vector; import java.util.Collection; import java.util.LinkedList; /** * JUnit tests for class SumSquaredErrorCostFunctionTest * @author Kevin R. Dixon */ public class SumSquaredErrorCostFunctionTest extends SupervisedCostFunctionTestHarness<Vector,Vector> { /** * Entry point for JUnit tests for class SumSquaredErrorCostFunctionTest * @param testName name of this test */ public SumSquaredErrorCostFunctionTest( String testName) { super(testName); } @Override public SumSquaredErrorCostFunction createInstance() { return new SumSquaredErrorCostFunction( this.createRandomCostParameters() ); } @Override public Collection<? extends InputOutputPair<Vector, Vector>> createRandomCostParameters() { return this.createVectorCostParameters(); } @Override public MultivariateDiscriminant createEvaluator() { return this.createVectorFunction(); } /** * Test of clone method, of class SumSquaredErrorCostFunction. */ public void testConstructors() { System.out.println( "constructors" ); SumSquaredErrorCostFunction sse = new SumSquaredErrorCostFunction(); assertNull( sse.getCostParameters() ); Collection<? extends InputOutputPair<Vector,Vector>> data = this.createRandomCostParameters(); sse = new SumSquaredErrorCostFunction(data); assertNotNull(sse.getCostParameters()); assertSame(data, sse.getCostParameters()); } /** * Test of evaluatePerformance method, of class SumSquaredErrorCostFunction. */ @Override public void testKnownValues() { System.out.println( "evaluatePerformance" ); SumSquaredErrorCostFunction instance = this.createInstance(); Evaluator<? super Vector, ? extends Vector> f = this.createEvaluator(); Collection<TargetEstimatePair<Vector,Vector>> ted = this.createTargetEstimatePairs(instance, f); double actual = 0.0; double weightSum = 0.0; for( InputOutputPair<? extends Vector,Vector> sample : instance.getCostParameters() ) { Vector input = sample.getInput(); Vector output = sample.getOutput(); Vector estimate = f.evaluate( input ); double weight = DatasetUtil.getWeight(sample); actual += weight * output.minus( estimate ).norm2Squared(); weightSum += weight; ted.add(DefaultWeightedTargetEstimatePair.create( output, estimate, weight) ); } actual /= (weightSum*2.0); assertEquals( actual, instance.evaluatePerformance( ted ), TOLERANCE ); assertEquals( instance.evaluate( f ), instance.evaluatePerformance( ted ), TOLERANCE ); SumSquaredErrorCostFunction dummy = new SumSquaredErrorCostFunction( new LinkedList<InputOutputPair<Vector,Vector>>()); assertEquals( 0.0, dummy.evaluate(f) ); } /** * Test of computeParameterGradient method, of class SumSquaredErrorCostFunction. */ public void testComputeParameterGradient() { System.out.println( "computeParameterGradient" ); SumSquaredErrorCostFunction instance = this.createInstance(); RingAccumulator<Vector> actual = new RingAccumulator<Vector>(); double weightSum = 0.0; MultivariateDiscriminant fhat = this.createEvaluator(); for( InputOutputPair<? extends Vector,Vector> sample : instance.getCostParameters() ) { Vector input = sample.getInput(); Vector output = sample.getOutput(); Vector estimate = fhat.evaluate( input ); double weight = DatasetUtil.getWeight(sample); weightSum += weight; Vector deriv = output.minus( estimate ).scale( weight ); Matrix grad = fhat.computeParameterGradient( input ); actual.accumulate( grad.transpose().times( deriv ) ); } weightSum *= 2.0; assertEquals( actual.getSum().scale( -1.0/weightSum ), instance.computeParameterGradient( fhat ) ); SumSquaredErrorCostFunction dummy = new SumSquaredErrorCostFunction( new LinkedList<InputOutputPair<Vector,Vector>>()); try { dummy.computeParameterGradient(fhat); fail( "Should fail without data!" ); } catch (Exception e) { System.out.println( "Good: " + e ); } } /** * Test of computeParameterGradient method, of class SumSquaredErrorCostFunction. */ public void testCache() { System.out.println( "cache" ); SumSquaredErrorCostFunction instance = this.createInstance(); MultivariateDiscriminant fhat = this.createEvaluator(); SumSquaredErrorCostFunction.Cache cache = SumSquaredErrorCostFunction.Cache.compute( fhat, instance.getCostParameters() ); assertEquals( instance.evaluate( fhat ), cache.parameterCost, TOLERANCE ); assertEquals( instance.computeParameterGradient( fhat ), cache.Jte ); } }