/* * File: MatrixVectorMultiplyFunctionTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright October 5, 2006, Sandia Corporation. Under the terms of Contract * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by * or on behalf of the U.S. Government. Export of this program may require a * license from the United States Government. See CopyrightHistory.txt for * complete details. * */ package gov.sandia.cognition.learning.function.cost; import gov.sandia.cognition.annotation.CodeReview; import gov.sandia.cognition.learning.data.DatasetUtil; import gov.sandia.cognition.learning.function.vector.MultivariateDiscriminant; import gov.sandia.cognition.math.RingAccumulator; import gov.sandia.cognition.math.matrix.Matrix; import java.util.ArrayList; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.data.TargetEstimatePair; import gov.sandia.cognition.math.matrix.Vector; import java.util.Collection; /** * This class implements JUnit tests for the following classes: * * MeanSquaredErrorCostFunction * * @author Kevin R. Dixon * @since 1.0 */ @CodeReview( reviewer="Justin Basilico", date="2006-10-06", changesNeeded=false, comments="Test class looks fine." ) public class MeanSquaredErrorCostFunctionTest extends SupervisedCostFunctionTestHarness<Vector,Vector> { /** * Creates a new instance of MeanSquaredErrorCostFunctionTest. * * @param testName The test name. */ public MeanSquaredErrorCostFunctionTest( String testName) { super(testName); } @Override public MeanSquaredErrorCostFunction createInstance() { return new MeanSquaredErrorCostFunction( this.createRandomCostParameters() ); } @Override public Collection<? extends InputOutputPair<Vector, Vector>> createRandomCostParameters() { return this.createVectorCostParameters(); } @Override public MultivariateDiscriminant createEvaluator() { return this.createVectorFunction(); } /** * Constructors */ @Override public void testConstructors() { MeanSquaredErrorCostFunction instance = new MeanSquaredErrorCostFunction(); assertNull(instance.getCostParameters()); instance = new MeanSquaredErrorCostFunction(this.createRandomCostParameters()); assertNotNull(instance.getCostParameters()); } /** * Test of evaluate method, of class MeanSquaredErrorCostFunction. */ @Override public void testKnownValues() { System.out.println("evaluate"); MeanSquaredErrorCostFunction costFunction = this.createInstance(); MultivariateDiscriminant estimateFunction = this.createEvaluator(); double totalSquaredError = 0.0; double weightSum = 0.0; Collection<TargetEstimatePair<Vector,Vector>> tepairs = this.createTargetEstimatePairs(costFunction, estimateFunction); for( TargetEstimatePair<Vector,Vector> pair : tepairs ) { Vector error = pair.getTarget().minus( pair.getEstimate() ); double weight = DatasetUtil.getWeight(pair); totalSquaredError += weight * error.norm2Squared(); weightSum += weight; } double expected = totalSquaredError / weightSum; double result = costFunction.evaluate( estimateFunction ); assertEquals( expected, result, TOLERANCE ); } /** * Test of differentiate method, of class MeanSquaredErrorCostFunction. */ public void testDifferentiate() { System.out.println("differentiate"); MultivariateDiscriminant targetFunction = this.createEvaluator(); MultivariateDiscriminant estimateFunction = this.createEvaluator(); ArrayList<InputOutputPair<Vector,Vector>> dataset = new ArrayList<InputOutputPair<Vector,Vector>>(); MeanSquaredErrorCostFunction dummy = new MeanSquaredErrorCostFunction( dataset ); assertNull( dummy.computeParameterGradient(targetFunction) ); MeanSquaredErrorCostFunction costFunction = this.createInstance(); RingAccumulator<Vector> totaler = new RingAccumulator<Vector>(); double weightSum = 0.0; for( InputOutputPair<? extends Vector,Vector> pair : costFunction.getCostParameters() ) { Vector input = pair.getInput(); Vector estimate = estimateFunction.evaluate(input); double weight = DatasetUtil.getWeight(pair); weightSum += weight; Vector error = pair.getOutput().minus( estimate ).scale( weight ); Matrix gradient = estimateFunction.computeParameterGradient( input ); totaler.accumulate( gradient.transpose().times( error ) ); } Vector expected = totaler.getSum().scale( -1.0 ).scale( 1.0/weightSum ); Vector result = costFunction.computeParameterGradient( estimateFunction ); assertEquals( expected, result ); } }