/*
* File: LogisticRegressionTest.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Nov 27, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.regression;
import gov.sandia.cognition.learning.algorithm.regression.LogisticRegression.Function;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vector;
import java.util.LinkedList;
import java.util.Random;
import junit.framework.TestCase;
/**
* JUnit tests for class LogisticRegressionTest
* @author Kevin R. Dixon
*/
public class LogisticRegressionTest
extends TestCase
{
/** The random number generator for the tests. */
public final Random random = new Random(1);
/**
* Tolerance for quality
*/
public final double TOLERANCE = 1e-4;
/**
* Entry point for JUnit tests for class LogisticRegressionTest
* @param testName name of this test
*/
public LogisticRegressionTest(
String testName )
{
super( testName );
}
/**
* Learn
*/
public void testLearn()
{
System.out.println( "learn" );
LogisticRegression instance = new LogisticRegression();
// From http://faculty.vassar.edu/lowry/logreg1.html
LinkedList<InputOutputPair<Vector,Double>> data =
new LinkedList<InputOutputPair<Vector, Double>>();
data.add( new DefaultWeightedInputOutputPair<Vector, Double>(
VectorFactory.getDefault().copyValues( 28.0 ), 1.0/3.0, 6 ) );
data.add( new DefaultWeightedInputOutputPair<Vector, Double>(
VectorFactory.getDefault().copyValues( 29.0 ), 2.0/5.0, 5 ) );
data.add( new DefaultWeightedInputOutputPair<Vector, Double>(
VectorFactory.getDefault().copyValues( 30.0 ), 7.0/9.0, 9 ) );
data.add( new DefaultWeightedInputOutputPair<Vector, Double>(
VectorFactory.getDefault().copyValues( 31.0 ), 7.0/9.0, 9 ) );
data.add( new DefaultWeightedInputOutputPair<Vector, Double>(
VectorFactory.getDefault().copyValues( 32.0 ), 16.0/20.0, 20 ) );
data.add( new DefaultWeightedInputOutputPair<Vector, Double>(
VectorFactory.getDefault().copyValues( 33.0 ), 14.0/15.0, 15 ) );
LogisticRegression.Function f = instance.learn( data );
Vector w = f.convertToVector();
assertEquals( 2, w.getDimensionality() );
assertEquals( 0.5769, f.getFirst().getWeightVector().getElement(0), TOLERANCE );
assertEquals( -16.7198, f.getFirst().getBias(), TOLERANCE );
assertSame( f, instance.getResult() );
assertNotSame( instance.getObjectToOptimize(), instance.getResult() );
Vector wclone = w.clone();
LogisticRegression.Function f2 = instance.learn( data );
Vector w2 = f2.convertToVector();
assertEquals( wclone, w2 );
LogisticRegression.Function fclone = f2.clone();
assertNotNull( fclone );
assertNotSame( f2, fclone );
assertEquals( f2.convertToVector(), fclone.convertToVector() );
LogisticRegression clone = instance.clone();
assertNotNull( clone );
assertNotSame( instance, clone );
assertNotNull( clone.getResult() );
assertNotSame( instance.getResult(), clone.getResult() );
assertNotNull( clone.getObjectToOptimize() );
assertNotSame( instance.getObjectToOptimize(), clone.getObjectToOptimize() );
}
/**
* Learn
*/
public void testLearn2()
{
System.out.println( "learn2" );
// http://luna.cas.usf.edu/~mbrannic/files/regression/Logistic.html
LinkedList<InputOutputPair<Vector,Double>> data =
new LinkedList<InputOutputPair<Vector, Double>>();
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(1.0, 70), 1.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(1.0, 80), 1.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(1.0, 50), 1.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(0.0, 60), 1.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(0.0, 40), 1.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(0.0, 65), 1.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(0.0, 75), 1.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(0.0, 80), 1.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(0.0, 70), 1.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(0.0, 60), 1.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(1.0, 65), 0.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(1.0, 50), 0.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(1.0, 45), 0.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(1.0, 35), 0.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(1.0, 40), 0.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(1.0, 50), 0.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(0.0, 55), 0.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(0.0, 45), 0.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(0.0, 50), 0.0 ) );
data.add( new DefaultInputOutputPair<Vector, Double>( VectorFactory.getDefault().copyValues(0.0, 60), 0.0 ) );
final double r1 = 0.1;
LogisticRegression instance = new LogisticRegression(r1);
assertEquals( r1, instance.getRegularization() );
LogisticRegression.Function f = instance.learn( data );
Vector w1 = f.convertToVector();
System.out.println( "R1: " + w1 );
// As we increase the regularization term, that will decrease the
// L2 norm of the resulting weight vector.
final double r2 = 1.0;
instance.setRegularization(r2);
assertEquals( r2, instance.getRegularization() );
f = instance.learn(data);
Vector w2 = f.convertToVector();
System.out.println( "R2: " + w2 );
assertTrue( w2.norm2() < w1.norm1() );
}
/**
* Test of getObjectToOptimize method, of class LogisticRegression.
*/
public void testGetObjectToOptimize()
{
System.out.println( "getObjectToOptimize" );
LogisticRegression instance = new LogisticRegression();
assertNull( instance.getObjectToOptimize() );
}
/**
* Test of setObjectToOptimize method, of class LogisticRegression.
*/
public void testSetObjectToOptimize()
{
System.out.println( "setObjectToOptimize" );
Function objectToOptimize = new Function( 2 );
LogisticRegression instance = new LogisticRegression();
assertNull( instance.getObjectToOptimize() );
instance.setObjectToOptimize( objectToOptimize );
assertSame( objectToOptimize, instance.getObjectToOptimize() );
}
/**
* Test of getTolerance method, of class LogisticRegression.
*/
public void testGetTolerance()
{
System.out.println( "getTolerance" );
LogisticRegression instance = new LogisticRegression();
assertEquals( LogisticRegression.DEFAULT_TOLERANCE, instance.getTolerance() );
}
/**
* Test of setTolerance method, of class LogisticRegression.
*/
public void testSetTolerance()
{
System.out.println( "setTolerance" );
LogisticRegression instance = new LogisticRegression();
assertEquals( LogisticRegression.DEFAULT_TOLERANCE, instance.getTolerance() );
double tolerance = instance.getTolerance() + 1.0;
instance.setTolerance( tolerance );
assertEquals( tolerance, instance.getTolerance() );
}
}