/*
* File: SupervisedCostFunctionTestHarness.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Jul 8, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.function.cost;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair;
import gov.sandia.cognition.learning.data.DefaultWeightedTargetEstimatePair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.learning.data.WeightedInputOutputPair;
import gov.sandia.cognition.learning.function.vector.MultivariateDiscriminant;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFunction;
import gov.sandia.cognition.math.matrix.mtj.Vector3;
import java.util.ArrayList;
import java.util.Collection;
import junit.framework.TestCase;
import java.util.Random;
/**
* Unit tests for SupervisedCostFunctionTestHarness.
*
* @param <InputType> Input type
* @param <TargetType> Target type
* @author krdixon
*/
public abstract class SupervisedCostFunctionTestHarness<InputType,TargetType>
extends TestCase
{
/**
* Random number generator to use for a fixed random seed.
*/
public final Random RANDOM = new Random( 1 );
/**
* Default tolerance of the regression tests, {@value}.
*/
public final double TOLERANCE = 1e-5;
/**
* Number of samples to generate
*/
public static final int NUM_SAMPLES = 10;
/**
* Tests for class SupervisedCostFunctionTestHarness.
* @param testName Name of the test.
*/
public SupervisedCostFunctionTestHarness(
String testName)
{
super(testName);
}
/**
* Creates cost function
* @return cost function
*/
public abstract AbstractSupervisedCostFunction<InputType,TargetType> createInstance();
/**
* creates cost parameters
* @return cost parameters
*/
public abstract Collection<? extends InputOutputPair<InputType,TargetType>> createRandomCostParameters();
/**
* create evaluator
* @return evaluator
*/
public abstract Evaluator<? super InputType, ? extends TargetType> createEvaluator();
/**
* Tests against known values.
*/
public abstract void testKnownValues();
/**
* Tests the constructors of class SupervisedCostFunctionTestHarness.
*/
public abstract void testConstructors();
/**
* Returns VectorFunction
* @return VectorFunction
*/
protected MultivariateDiscriminant createVectorFunction()
{
return new MultivariateDiscriminant(
MatrixFactory.getDefault().createUniformRandom(3, 3, -1.0, 1.0, RANDOM) );
}
/**
* Creates vector data
* @return vector data
*/
protected ArrayList<WeightedInputOutputPair<Vector,Vector>> createVectorCostParameters()
{
ArrayList<WeightedInputOutputPair<Vector,Vector>> dataset =
new ArrayList<WeightedInputOutputPair<Vector,Vector>>( NUM_SAMPLES );
VectorFunction f = this.createVectorFunction();
for( int i = 0; i < NUM_SAMPLES; i++ )
{
Vector input = Vector3.createRandom(RANDOM);
Vector output = f.evaluate(input);
double weight = RANDOM.nextDouble();
WeightedInputOutputPair<Vector,Vector> pair =
new DefaultWeightedInputOutputPair<Vector,Vector>( input, output, weight );
dataset.add( pair );
}
return dataset;
}
/**
* Creates target-estimate pairs from the stuff
* @param costFunction cost function
* @param f evaluator
* @return target estimate pairs
*/
public ArrayList<TargetEstimatePair<TargetType, TargetType>> createTargetEstimatePairs(
AbstractSupervisedCostFunction<InputType,TargetType> costFunction,
Evaluator<? super InputType, ? extends TargetType> f )
{
Collection<? extends InputOutputPair<? extends InputType, TargetType>> data =
costFunction.getCostParameters();
ArrayList<TargetEstimatePair<TargetType, TargetType>> tedata =
new ArrayList<TargetEstimatePair<TargetType, TargetType>>( data.size() );
for( InputOutputPair<? extends InputType,TargetType> pair : data )
{
TargetType estimate = f.evaluate( pair.getInput() );
tedata.add(DefaultWeightedTargetEstimatePair.create(
pair.getOutput(), estimate, DatasetUtil.getWeight(pair) ) );
}
return tedata;
}
/**
* Tests the abstract functions
*/
public void testAbstractMethods()
{
System.out.println( "Abstract methods" );
AbstractSupervisedCostFunction<InputType,TargetType> instance =
this.createInstance();
assertNotNull( instance );
assertNotNull( instance.getCostParameters() );
Collection<? extends InputOutputPair<InputType,TargetType>> data =
this.createRandomCostParameters();
assertNotNull( data );
assertTrue( data.size() > 0 );
assertNotNull( this.createEvaluator() );
}
/**
* Test of clone method, of class AbstractSupervisedCostFunction.
*/
public void testClone()
{
System.out.println("clone");
AbstractSupervisedCostFunction<InputType,TargetType> instance =
this.createInstance();
instance.setCostParameters( this.createRandomCostParameters() );
AbstractSupervisedCostFunction<InputType,TargetType> clone =
instance.clone();
assertNotNull( clone );
assertNotSame( instance, clone );
assertNotNull( clone.getCostParameters() );
assertNotSame( instance.getCostParameters(), clone.getCostParameters() );
}
/**
* Test of evaluatePerformance method, of class AbstractSupervisedCostFunction.
*/
public void testEvaluatePerformance()
{
System.out.println("evaluatePerformance");
Collection<? extends InputOutputPair<InputType,TargetType>> data =
this.createRandomCostParameters();
AbstractSupervisedCostFunction<InputType,TargetType> instance =
this.createInstance();
instance.setCostParameters(data);
Evaluator<? super InputType, ? extends TargetType> f = this.createEvaluator();
ArrayList<TargetEstimatePair<TargetType, TargetType>> tedata =
createTargetEstimatePairs(instance, f);
assertEquals( instance.evaluate(f), instance.evaluatePerformance(tedata), TOLERANCE );
assertEquals( instance.summarize(tedata), instance.evaluatePerformance(tedata) );
// No data should always be 0.0
tedata.clear();
assertEquals( 0.0, instance.evaluatePerformance(tedata) );
}
/**
* Test of evaluate method, of class AbstractSupervisedCostFunction.
*/
public void testEvaluate()
{
System.out.println("evaluate");
Evaluator<? super InputType, ? extends TargetType> evaluator =
this.createEvaluator();
AbstractSupervisedCostFunction<InputType,TargetType> instance =
this.createInstance();
instance.setCostParameters( this.createRandomCostParameters() );
ArrayList<TargetEstimatePair<TargetType, TargetType>> tedata =
createTargetEstimatePairs(instance, evaluator);
assertEquals( instance.evaluatePerformance(tedata), instance.evaluate(evaluator), TOLERANCE );
assertEquals( instance.summarize(tedata), instance.evaluate(evaluator), TOLERANCE );
// No data should always be 0.0
instance.getCostParameters().clear();
assertEquals( 0.0, instance.evaluate(evaluator) );
}
/**
* Test of getCostParameters method, of class AbstractSupervisedCostFunction.
*/
public void testGetCostParameters()
{
System.out.println("getCostParameters");
AbstractSupervisedCostFunction<InputType,TargetType> instance = this.createInstance();
instance.setCostParameters(this.createRandomCostParameters());
assertNotNull( instance.getCostParameters() );
}
/**
* Test of setCostParameters method, of class AbstractSupervisedCostFunction.
*/
public void testSetCostParameters()
{
System.out.println("setCostParameters");
Collection<? extends InputOutputPair<? extends InputType, TargetType>> costParameters =
this.createRandomCostParameters();
AbstractSupervisedCostFunction<InputType,TargetType> instance = this.createInstance();
instance.setCostParameters(costParameters);
assertSame( costParameters, instance.getCostParameters() );
}
/**
* Test of summarize method, of class AbstractSupervisedCostFunction.
*/
public void testSummarize()
{
System.out.println("summarize");
Evaluator<? super InputType, ? extends TargetType> evaluator =
this.createEvaluator();
AbstractSupervisedCostFunction<InputType,TargetType> instance =
this.createInstance();
instance.setCostParameters( this.createRandomCostParameters() );
ArrayList<TargetEstimatePair<TargetType, TargetType>> tedata =
createTargetEstimatePairs(instance, evaluator);
assertEquals( instance.summarize(tedata), instance.evaluatePerformance(tedata) );
assertEquals( instance.summarize(tedata), instance.evaluate(evaluator), TOLERANCE );
tedata.clear();
assertEquals( 0.0, instance.summarize(tedata) );
}
}