/*
* File: PrimalEstimatedSubGradientTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright October 06, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.svm;
import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import java.util.ArrayList;
import java.util.Random;
import junit.framework.TestCase;
/**
* Unit tests for class PrimalEstimatedSubGradient.
*
* @author Justin Basilico
* @since 3.1
*/
public class PrimalEstimatedSubGradientTest
extends TestCase
{
protected Random random = new Random(211);
/**
* Creates a new test.
*
* @param testName The test name.
*/
public PrimalEstimatedSubGradientTest(
String testName)
{
super(testName);
}
/**
* Test of learn method, of class PrimalEstimatedSubGradient.
*/
public void testLearn()
{
// Generate some data using the example synthetic data from Platt's
// original SMO paper.
int d = 300;
int pointsToGenerate = 100;
final ArrayList<InputOutputPair<Vector, Boolean>> data =
new ArrayList<InputOutputPair<Vector, Boolean>>(pointsToGenerate);
Vector target = VectorFactory.getDenseDefault().createUniformRandom(d, -1.0, 1.0, random);
while (data.size() < pointsToGenerate)
{
Vector input = VectorFactory.getSparseDefault().createVector(d, 0.0);
for (int i = 0; i < d / 10; i++)
{
int index = random.nextInt(d);
input.setElement(index, 1.0);
}
double dotProduct = input.dotProduct(target);
if (dotProduct < -1.0)
{
data.add(DefaultInputOutputPair.create(input, false));
}
else if (dotProduct > +1.0)
{
data.add(DefaultInputOutputPair.create(input, true));
}
// else - The dot product wsa between -1.0 and +1.0, try again.
}
final PrimalEstimatedSubGradient instance = new PrimalEstimatedSubGradient(
10, 0.01, 200, random);
// Plug in a little algorithm listener to print out the performance
// information.
instance.addIterativeAlgorithmListener(new AbstractIterativeAlgorithmListener()
{
@Override
public void stepEnded(
final IterativeAlgorithm algorithm)
{
final LinearBinaryCategorizer result = instance.getResult();
// Compute the loss of the data as well as the number of
// errors.
double loss = 0.0;
int errorCount = 0;
for (InputOutputPair<? extends Vectorizable, Boolean> example
: instance.getData())
{
final double predicted = result.evaluateAsDouble(
example.getInput());
final double actual = example.getOutput() ? 1.0 : -1.0;
loss += Math.max(0, 1.0 - actual * predicted);
if (predicted * actual <= 0.0)
{
errorCount += 1;
}
}
loss /= instance.getData().size();
// Compute the regularization term.
final double regularization =
instance.getRegularizationWeight() / 2.0
* result.getWeights().norm2Squared();
final double objective = loss + regularization;
System.out.println(
"Iteration: " + instance.getIteration()
+ " Objective: " + objective
+ " Loss: " + loss
+ " Regularization: " + regularization
+ " Errors: " + errorCount);
}
});
final LinearBinaryCategorizer result = instance.learn(data);
assertSame(result, instance.getResult());
// Make sure there is perfect learning on this example.
for (InputOutputPair<Vector, Boolean> example : data)
{
// System.out.println("" + example.getInput() + " -> " + example.getOutput());
assertEquals(example.getOutput(), result.evaluate(example.getInput()));
}
}
/**
* Test of getResult method, of class PrimalEstimatedSubGradient.
*/
public void testGetResult()
{
PrimalEstimatedSubGradient instance = new PrimalEstimatedSubGradient();
assertNull(instance.getResult());
}
/**
* Test of getSampleSize method, of class PrimalEstimatedSubGradient.
*/
public void testGetSampleSize()
{
this.testSetSampleSize();
}
/**
* Test of setRequestedSampleCount method, of class PrimalEstimatedSubGradient.
*/
public void testSetSampleSize()
{
int sampleSize = PrimalEstimatedSubGradient.DEFAULT_SAMPLE_SIZE;
PrimalEstimatedSubGradient instance = new PrimalEstimatedSubGradient();
assertEquals(sampleSize, instance.getSampleSize());
sampleSize /= 3;
instance.setSampleSize(sampleSize);
assertEquals(sampleSize, instance.getSampleSize());
sampleSize = Integer.MAX_VALUE;
instance.setSampleSize(sampleSize);
assertEquals(sampleSize, instance.getSampleSize());
sampleSize = 1;
instance.setSampleSize(sampleSize);
assertEquals(sampleSize, instance.getSampleSize());
boolean exceptionThrown = false;
try
{
instance.setSampleSize(0);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(sampleSize, instance.getSampleSize());
exceptionThrown = false;
try
{
instance.setSampleSize(-1);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(sampleSize, instance.getSampleSize());
}
/**
* Test of getRegularizationWeight method, of class PrimalEstimatedSubGradient.
*/
public void testGetRegularizationWeight()
{
this.testSetRegularizationWeight();
}
/**
* Test of setRegularizationWeight method, of class PrimalEstimatedSubGradient.
*/
public void testSetRegularizationWeight()
{
double regularizationWeight = PrimalEstimatedSubGradient.DEFAULT_REGULARIZATION_WEIGHT;
PrimalEstimatedSubGradient instance = new PrimalEstimatedSubGradient();
assertEquals(regularizationWeight, instance.getRegularizationWeight(), 0.0);
regularizationWeight *= this.random.nextDouble();
instance.setRegularizationWeight(regularizationWeight);
assertEquals(regularizationWeight, instance.getRegularizationWeight(), 0.0);
regularizationWeight = 1.0;
instance.setRegularizationWeight(regularizationWeight);
assertEquals(regularizationWeight, instance.getRegularizationWeight(), 0.0);
regularizationWeight = this.random.nextDouble();
instance.setRegularizationWeight(regularizationWeight);
assertEquals(regularizationWeight, instance.getRegularizationWeight(), 0.0);
boolean exceptionThrown = false;
try
{
instance.setRegularizationWeight(0.0);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(regularizationWeight, instance.getRegularizationWeight(), 0.0);
exceptionThrown = false;
try
{
instance.setRegularizationWeight(-0.1);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(regularizationWeight, instance.getRegularizationWeight(), 0.0);
}
/**
* Test of getRandom method, of class PrimalEstimatedSubGradient.
*/
public void testGetRandom()
{
this.testSetRandom();
}
/**
* Test of setRandom method, of class PrimalEstimatedSubGradient.
*/
public void testSetRandom()
{
Random random = null;
PrimalEstimatedSubGradient instance = new PrimalEstimatedSubGradient();
assertNotNull(instance.getRandom());
random = new Random();
instance.setRandom(random);
assertSame(random, instance.getRandom());
random = new Random();
instance.setRandom(random);
assertSame(random, instance.getRandom());
random = null;
instance.setRandom(random);
assertSame(random, instance.getRandom());
random = new Random();
instance.setRandom(random);
assertSame(random, instance.getRandom());
}
}