/*
* File: SequentialMinimalOptimizationTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright August 04, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.svm;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.KernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.LinearKernel;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import java.util.ArrayList;
import java.util.Random;
import junit.framework.TestCase;
/**
* Unit tests for class SequentialMinimalOptimization.
*
* @author Justin Basilico
* @since 3.1
*/
public class SequentialMinimalOptimizationTest
extends TestCase
{
protected Random random = new Random(211);
/**
* Creates a new test.
*
* @param testName The test name.
*/
public SequentialMinimalOptimizationTest(
String testName)
{
super(testName);
}
/**
* Test of learn method, of class SimplifiedSequentialMinimalOptimization.
*/
public void testLearn()
{
// Generate some data using the example synthetic data from Platt's
// original SMO paper.
int d = 300;
int pointsToGenerate = 100;
final ArrayList<InputOutputPair<Vector, Boolean>> data =
new ArrayList<InputOutputPair<Vector, Boolean>>(pointsToGenerate);
Vector target = VectorFactory.getDenseDefault().createUniformRandom(d, -1.0, 1.0, random);
while (data.size() < pointsToGenerate)
{
Vector input = VectorFactory.getSparseDefault().createVector(d, 0.0);
for (int i = 0; i < d / 10; i++)
{
int index = random.nextInt(d);
input.setElement(index, 1.0);
}
double dotProduct = input.dotProduct(target);
if (dotProduct < -1.0)
{
data.add(DefaultInputOutputPair.create(input, false));
}
else if (dotProduct > +1.0)
{
data.add(DefaultInputOutputPair.create(input, true));
}
// else - The dot product wsa between -1.0 and +1.0, try again.
}
SequentialMinimalOptimization<Vector> instance =
new SequentialMinimalOptimization<Vector>();
instance.setKernel(new LinearKernel());
instance.setRandom(random);
instance.setMaxIterations(1000);
instance.setMaxPenalty(100.0);
instance.setKernelCacheSize(0);
final KernelBinaryCategorizer<Vector, ?> result = instance.learn(data);
assertSame(result, instance.getResult());
for (InputOutputPair<Vector, Boolean> example : data)
{
// System.out.println("" + example.getInput() + " -> " + example.getOutput());
assertEquals(example.getOutput(), result.evaluate(example.getInput()));
}
}
/**
* Test of getResult method, of class SimplifiedSequentialMinimalOptimization.
*/
public void testGetResult()
{
// Tested by testLearn.
}
/**
* Test of getRandom method, of class SimplifiedSequentialMinimalOptimization.
*/
public void testGetRandom()
{
this.testSetRandom();
}
/**
* Test of setRandom method, of class SimplifiedSequentialMinimalOptimization.
*/
public void testSetRandom()
{
SimplifiedSequentialMinimalOptimization<String> instance =
new SimplifiedSequentialMinimalOptimization<String>();
assertNotNull(instance.getRandom());
Random random = new Random();
instance.setRandom(random);
assertSame(random, instance.getRandom());
random = new Random();
instance.setRandom(random);
assertSame(random, instance.getRandom());
random = null;
instance.setRandom(random);
assertSame(random, instance.getRandom());
random = new Random();
instance.setRandom(random);
assertSame(random, instance.getRandom());
}
}