/*
* File: LinearRegressionTest.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright September 6, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.regression;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.WeightedInputOutputPair;
import gov.sandia.cognition.learning.function.scalar.LinearDiscriminant;
import gov.sandia.cognition.learning.function.scalar.LinearDiscriminantWithBias;
import gov.sandia.cognition.learning.function.scalar.PolynomialFunction;
import gov.sandia.cognition.learning.function.vector.ScalarBasisSet;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Random;
import junit.framework.TestCase;
/**
*
* @author Kevin R. Dixon
*/
@CodeReview(
reviewer="Kevin R. Dixon",
date="2008-09-02",
changesNeeded=false,
comments={
"Added more rigorous check against Justin's pathological example.",
"Looks fine now."
}
)
public class LinearRegressionTest
extends TestCase
{
/** The random number generator for the tests. */
public static Random random = new Random(1);
/**
* Tolerance for equality
*/
private final double EPS = 1e-5;
/**
*
* @param testName
*/
public LinearRegressionTest(
String testName )
{
super( testName );
}
/**
*
* @return
*/
public static LinearRegression createInstance()
{
/*
LinearCombinationScalarFunction<Double> f = new LinearCombinationScalarFunction<Double>(
PolynomialFunction.createPolynomials( 0.0, 1.0, 2.0 ),
VectorFactory.getDefault().createUniformRandom( 3, -1, 1 ) );
*/
return new LinearRegression();
// PolynomialFunction.createPolynomials( 0.0, 1.0, 2.0 ) );
}
/**
*
* @param basis
* @param f
* @return
*/
public static Collection<InputOutputPair<Vector, Double>> createDataset(
Evaluator<Double, Vector> basis,
LinearDiscriminant f )
{
Collection<InputOutputPair<Vector, Double>> retval =
new LinkedList<InputOutputPair<Vector, Double>>();
int num = random.nextInt(100) + 10;
for (int i = 0; i < num; i++)
{
double x = random.nextGaussian();
Vector phi = basis.evaluate(x);
double y = f.evaluate( phi );
retval.add( DefaultInputOutputPair.create(phi, y) );
}
return retval;
}
/**
* Tests unused constructors
*/
public void testConstructors()
{
System.out.println( "Constructors" );
@SuppressWarnings("unchecked")
LinearRegression f = new LinearRegression();
assertNotNull( f );
}
/**
* Test of clone method, of class gov.sandia.cognition.learning.regression.LinearRegression.
*/
public void testClone()
{
System.out.println( "clone" );
LinearRegression instance = LinearRegressionTest.createInstance();
assertTrue( instance.getUsePseudoInverse() );
instance.setUsePseudoInverse( false );
assertFalse( instance.getUsePseudoInverse() );
ScalarBasisSet<Double> polynomials = new ScalarBasisSet<Double>(
PolynomialFunction.createPolynomials( 1.0, 2.0 ) );
LinearDiscriminant f = new LinearDiscriminant(
VectorFactory.getDefault().createUniformRandom( 2, -1, 1, random ) );
Collection<InputOutputPair<Vector, Double>> data = LinearRegressionTest.createDataset( polynomials, f );
instance.learn( data );
LinearRegression clone = instance.clone();
assertNotNull( clone );
assertNotSame( instance, clone );
assertFalse( instance.getUsePseudoInverse() );
}
/**
* Tests a known non-pathological function
*/
public void testKnownClosedFormWeighted()
{
System.out.println( "learn known non-pathological function" );
LinkedList<WeightedInputOutputPair<Double,Double>> data =
new LinkedList<WeightedInputOutputPair<Double,Double>>();
data.add( new DefaultWeightedInputOutputPair<Double, Double>( 1.0, 0.0, 1.0 ) );
data.add( new DefaultWeightedInputOutputPair<Double, Double>( 2.0, 1.0, 2.0 ) );
data.add( new DefaultWeightedInputOutputPair<Double, Double>( 3.0, 4.0, 3.0 ) );
ArrayList<WeightedInputOutputPair<Vector,Double>> vectorData =
new ArrayList<WeightedInputOutputPair<Vector, Double>>( data.size() );
for( WeightedInputOutputPair<Double,Double> pair : data )
{
vectorData.add(new DefaultWeightedInputOutputPair<Vector, Double>(
VectorFactory.getDefault().copyValues( pair.getInput() ), pair.getOutput(), pair.getWeight() ) );
}
for( WeightedInputOutputPair<Vector,Double> d : vectorData )
{
System.out.println( "Input: " + d.getInput() + ", Output: " + d.getOutput() + ", Weight: " + d.getWeight() );
}
LinearRegression regression = new LinearRegression();
regression.setUsePseudoInverse(false);
LinearDiscriminantWithBias result = regression.learn(vectorData);
System.out.println( "Weights: " + result.convertToVector() );
// I computed this result by hand in octave
assertEquals( 2.4210526316, result.getWeightVector().getElement(0), EPS );
assertEquals( -3.3684210526, result.getBias(), EPS );
}
/**
* Test of learn method, of class gov.sandia.cognition.learning.regression.LinearRegression.
*/
public void testLearn()
{
System.out.println( "learn" );
LinearRegression instance = LinearRegressionTest.createInstance();
ScalarBasisSet<Double> polynomials = new ScalarBasisSet<Double>(
PolynomialFunction.createPolynomials( 1.0, 2.0 ) );
LinearDiscriminantWithBias f = new LinearDiscriminantWithBias(
VectorFactory.getDefault().createUniformRandom( 2, -1, 1, random ), random.nextGaussian() );
Collection<InputOutputPair<Vector, Double>> data = LinearRegressionTest.createDataset( polynomials, f );
LinearDiscriminantWithBias result = instance.learn( data );
if (!result.convertToVector().equals( f.convertToVector(), EPS ))
{
assertEquals( f.convertToVector(), result.convertToVector() );
}
}
/**
* This tests Justin's pathological example that kills the LU solver.
*/
public void testLearn2()
{
System.out.println( "learn2" );
LinearRegression regressionLearner = new LinearRegression();
regressionLearner.setUsePseudoInverse(true);
ArrayList<InputOutputPair<Vector, Double>> data =
new ArrayList<InputOutputPair<Vector, Double>>();
// Make sure we're using pseudoinverse and not LU, which can't
// handle this test case
assertTrue( regressionLearner.getUsePseudoInverse() );
// This is a rank-one matrix, which the third dimension has the regression equation
// y = -0.642857*x + 1.85714 (according to Octave)
data.add(new DefaultInputOutputPair<Vector, Double>(VectorFactory.getDefault().copyValues(1.0, 2.0, 3.0 ), 0.0));
data.add(new DefaultInputOutputPair<Vector, Double>(VectorFactory.getDefault().copyValues(1.0, 2.0, 1.0 ), 1.0));
data.add(new DefaultInputOutputPair<Vector, Double>(VectorFactory.getDefault().copyValues(1.0, 2.0, 0.0 ), 2.0));
LinearDiscriminantWithBias result = regressionLearner.learn( data );
System.out.println( "SVD weights: " + result.convertToVector() );
// These are the results, as validated by octave's backslash "\" command
assertEquals( -0.071429, result.evaluate( data.get(0).getInput() ), EPS );
assertEquals( 1.214286, result.evaluate( data.get(1).getInput() ), EPS );
assertEquals( 1.857143, result.evaluate( data.get(2).getInput() ), EPS );
// Now use LU solver, which will return numerically unstable results for this example
regressionLearner.setUsePseudoInverse( false );
LinearDiscriminantWithBias resultLU = regressionLearner.learn( data );
Vector v1 = resultLU.convertToVector();
System.out.println( "LU weights: " + v1 );
// These results aren't quite right, especially when you see that the
// LU weights from this unit test are ~1e14
// This is because the LU decomposition used in the solve() method
// doesn't handle the singular matrix as well as the pseudoinverse does
// But this is a pathological example, and I will probably continue to
// happily use the LU solver. -- krdixon, 2008-09-03
for( int i = 0; i < 3; i++ )
{
System.out.println( i + ": " + resultLU.evaluate( data.get(i).getInput() ) );
}
assertEquals( -33.67857142857143, resultLU.evaluate( data.get(0).getInput() ), EPS );
assertEquals( -32.392857142857146, resultLU.evaluate( data.get(1).getInput() ), EPS );
assertEquals( -31.75, resultLU.evaluate( data.get(2).getInput() ), EPS );
regressionLearner.setRegularization(1e-3);
LinearDiscriminantWithBias resultLUr = regressionLearner.learn(data);
Vector v2 = resultLUr.convertToVector();
System.out.println( "LUr weights: " + v2 );
for( int i = 0; i < 3; i++ )
{
System.out.println( i + ": " + resultLUr.evaluate( data.get(i).getInput() ) );
}
assertEquals( -0.0711990287795472, resultLUr.evaluate( data.get(0).getInput() ), EPS );
assertEquals( 1.2142398057559096, resultLUr.evaluate( data.get(1).getInput() ), EPS );
assertEquals( 1.856959223023638, resultLUr.evaluate( data.get(2).getInput() ), EPS );
}
/**
*
* @param lr1
* @param lr2
* @return
*/
public LinearRegression.Statistic createStatisticInstance(
LinearDiscriminant f1,
LinearDiscriminant f2 )
{
LinkedList<Double> targets = new LinkedList<Double>();
LinkedList<Double> estimates = new LinkedList<Double>();
MultivariateGaussian g = new MultivariateGaussian( f1.getInputDimensionality() );
ArrayList<Vector> samples = g.sample(random, 100);
for( Vector x : samples )
{
targets.add( f1.evaluate(x) );
estimates.add( f2.evaluate(x) );
}
return new LinearRegression.Statistic( targets, estimates, 3 );
}
/**
* Test of getRootMeanSquaredError method, of class gov.sandia.cognition.learning.regression.LinearRegression.Statistic.
*/
public void testStatisticGetRootMeanSquaredError()
{
System.out.println( "Statistic.getRootMeanSquaredError" );
LinearDiscriminant f1 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearDiscriminant f2 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearRegression.Statistic instance = this.createStatisticInstance( f1, f2 );
double v = instance.getRootMeanSquaredError();
assertTrue( v > 0.0 );
assertEquals( v * v * instance.getNumSamples(), instance.getChiSquare(), EPS );
}
/**
* Statistic.clone
*/
public void testStatisticClone()
{
System.out.println( "Statistic.clone" );
LinearDiscriminant f1 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearDiscriminant f2 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearRegression.Statistic instance = this.createStatisticInstance( f1, f2 );
LinearRegression.Statistic clone = instance.clone();
assertNotNull( clone );
assertNotSame( instance, clone );
String s1 = instance.toString();
String s2 = clone.toString();
System.out.println( "Instance: " + s1 );
System.out.println( "Clone: " + s2 );
assertEquals( s1, s2 );
}
/**
* Test of getTargetEstimateCorrelation method, of class gov.sandia.cognition.learning.regression.LinearRegression.Statistic.
*/
public void testStatisticGetTargetEstimateCorrelation()
{
System.out.println( "Statistic.getTargetEstimateCorrelation" );
LinearDiscriminant f1 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearDiscriminant f2 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearRegression.Statistic instance = this.createStatisticInstance( f1, f2 );
double v = instance.getTargetEstimateCorrelation();
assertTrue( v != 0.0 );
}
/**
* Test of getUnpredictedErrorFraction method, of class gov.sandia.cognition.learning.regression.LinearRegression.Statistic.
*/
public void testStatisticGetUnpredictedErrorFraction()
{
System.out.println( "Statistic.getUnpredictedErrorFraction" );
LinearDiscriminant f1 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearDiscriminant f2 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearRegression.Statistic instance = this.createStatisticInstance( f1, f2 );
double v = instance.getUnpredictedErrorFraction();
assertTrue( 0.0 < v );
assertTrue( v < 1.0 );
double c = instance.getTargetEstimateCorrelation();
assertEquals( 1.0 - c * c, v, EPS );
}
/**
* Test of getNumSamples method, of class gov.sandia.cognition.learning.regression.LinearRegression.Statistic.
*/
public void testStatisticGetNumSamples()
{
System.out.println( "Statistic.getNumSamples" );
LinearDiscriminant f1 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearDiscriminant f2 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearRegression.Statistic instance = this.createStatisticInstance( f1, f2 );
assertTrue( instance.getNumSamples() > 0.0 );
}
/**
* Test of getDegreesOfFreedom method, of class gov.sandia.cognition.learning.regression.LinearRegression.Statistic.
*/
public void testStatisticGetDegreesOfFreedom()
{
System.out.println( "Statistic.getDegreesOfFreedom" );
LinearDiscriminant f1 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearDiscriminant f2 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearRegression.Statistic instance = this.createStatisticInstance( f1, f2 );
double v = instance.getDegreesOfFreedom();
assertTrue( v > 0.0 );
assertEquals( (double) (instance.getNumSamples() - instance.getNumParameters()), v );
}
/**
* Statistic.getMeanL1Error
*/
public void testStatisticGetMeanL1Error()
{
System.out.println( "Statistic.getMeanL1Error" );
LinearDiscriminant f1 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearDiscriminant f2 = new LinearDiscriminant( VectorFactory.getDefault().createUniformRandom(3, -1, 1, random) );
LinearRegression.Statistic instance = this.createStatisticInstance( f1, f2 );
assertTrue( instance.getMeanL1Error() > 0.0 );
}
/**
* Degenerate stuff
*/
public void testStatisticsDegenerateConstructor()
{
System.out.println( "Statistic degenerate" );
LinearRegression.Statistic stat;
try
{
stat = new LinearRegression.Statistic(
Arrays.asList( random.nextGaussian() ),
Arrays.asList( random.nextGaussian(), random.nextGaussian() ),
Arrays.asList( random.nextDouble() ),
random.nextInt(10) + 1 );
fail( "Collections do not match!" );
}
catch (Exception e)
{
System.out.println( "Good: " + e );
}
stat = new LinearRegression.Statistic(
Arrays.asList( random.nextGaussian() ),
Arrays.asList( random.nextGaussian() ),
Arrays.asList( 0.0 ),
0 );
assertEquals( 0.0, stat.getMeanL1Error() );
assertEquals( 1.0, stat.getDegreesOfFreedom() );
}
}