/*
* File: SimplifiedSequentialMinimalOptimizationTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry Learning Core
*
* Copyright July 19, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
*/
package gov.sandia.cognition.learning.algorithm.svm;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.KernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.LinearKernel;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.mtj.Vector2;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Random;
import junit.framework.TestCase;
/**
* Unit tests for class SimplifiedSequentialMinimalOptimization.
*
* @author Justin Basilico
* @since 3.0
*/
public class SimplifiedSequentialMinimalOptimizationTest
extends TestCase
{
/**
* Creates a new test.
*
* @param testName The test name.
*/
public SimplifiedSequentialMinimalOptimizationTest(
String testName)
{
super(testName);
}
/**
* Test of learn method, of class SimplifiedSequentialMinimalOptimization.
*/
public void testLearn()
{
final Random random = new Random();
int d = 300;
int pointsToGenerate = 100;
final ArrayList<InputOutputPair<Vector, Boolean>> data =
new ArrayList<InputOutputPair<Vector, Boolean>>(pointsToGenerate);
final VectorFactory<?> factory = VectorFactory.getDenseDefault();
Vector target = VectorFactory.getDenseDefault().createUniformRandom(d, -1.0, 1.0, random);
while (data.size() < pointsToGenerate)
{
Vector input = VectorFactory.getSparseDefault().createVector(d, 0.0);
for (int i = 0; i < d / 10; i++)
{
int index = random.nextInt(d);
input.setElement(index, 1.0);
}
double dotProduct = input.dotProduct(target);
if (dotProduct < -1.0)
{
data.add(DefaultInputOutputPair.create(input, false));
}
else if (dotProduct > +1.0)
{
data.add(DefaultInputOutputPair.create(input, true));
}
// else - The dot product wsa between -1.0 and +1.0, try again.
}
Vector2[] positives = new Vector2[]{
// new Vector2( 1.00, 1.00 ),
// new Vector2( 1.50, 0.00 ),
// new Vector2( 1.00, 2.00 ),
new Vector2( 1.00, 1.00 ),
new Vector2( 1.00, 3.00 ),
new Vector2( 0.25, 4.00 ),
new Vector2( 2.00, 1.00 ),
new Vector2( 5.00, -3.00 )
};
Vector2[] negatives = new Vector2[]{
// new Vector2( -1.00, -1.00 ),
// new Vector2( -2.00, -2.00 ),
new Vector2( 2.00, 3.00 ),
new Vector2( 2.00, 4.00 ),
new Vector2( 3.00, 2.00 ),
new Vector2( 4.25, 3.75 ),
new Vector2( 4.00, 7.00 ),
new Vector2( 7.00, 4.00 )
};
ArrayList<InputOutputPair<Vector2, Boolean>> examples =
new ArrayList<InputOutputPair<Vector2, Boolean>>();
for (Vector2 example : positives)
{
examples.add( new DefaultInputOutputPair<Vector2, Boolean>( example, true ) );
}
for (Vector2 example : negatives)
{
examples.add( new DefaultInputOutputPair<Vector2, Boolean>( example, false ) );
}
SimplifiedSequentialMinimalOptimization<Vector> instance =
new SimplifiedSequentialMinimalOptimization<Vector>();
instance.setKernel(new LinearKernel());
instance.setRandom(random);
instance.setMaxIterations(1000);
instance.setMaxPenalty(100.0);
final KernelBinaryCategorizer<Vector, ?> result = instance.learn(data);
assertSame(result, instance.getResult());
System.out.println("Result " + result);
for (WeightedValue<?> support : result.getExamples())
{
System.out.println(" " + support.getWeight() + " " + support.getValue());
}
System.out.println("Bias: " + result.getBias());
// for (Vector2 example : positives)
// {
// assertTrue( result.evaluate( example ) );
// }
//
// for (Vector2 example : negatives)
// {
// assertFalse( result.evaluate( example ) );
// }
for (InputOutputPair<Vector, Boolean> example : data)
{
// System.out.println("" + example.getInput() + " -> " + example.getOutput());
assertEquals( example.getOutput(), result.evaluate( example.getInput() ) );
}
}
/**
* Test of getResult method, of class SimplifiedSequentialMinimalOptimization.
*/
public void testGetResult()
{
// Tested by testLearn.
}
/**
* Test of getRandom method, of class SimplifiedSequentialMinimalOptimization.
*/
public void testGetRandom()
{
this.testSetRandom();
}
/**
* Test of setRandom method, of class SimplifiedSequentialMinimalOptimization.
*/
public void testSetRandom()
{
SimplifiedSequentialMinimalOptimization<String> instance =
new SimplifiedSequentialMinimalOptimization<String>();
assertNotNull(instance.getRandom());
Random random = new Random();
instance.setRandom(random);
assertSame(random, instance.getRandom());
random = new Random();
instance.setRandom(random);
assertSame(random, instance.getRandom());
random = null;
instance.setRandom(random);
assertSame(random, instance.getRandom());
random = new Random();
instance.setRandom(random);
assertSame(random, instance.getRandom());
}
}