/*
* File: BatchMultiPerceptronTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry Learning Core
*
* Copyright April 21, 2011, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
*/
package gov.sandia.cognition.learning.algorithm.perceptron;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import java.util.ArrayList;
import java.util.Random;
import gov.sandia.cognition.math.matrix.mtj.SparseVectorFactoryMTJ;
import gov.sandia.cognition.learning.function.categorization.LinearMultiCategorizer;
import gov.sandia.cognition.math.matrix.VectorFactory;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Unit tests for class BatchMultiPerceptron.
*
* @author Justin Basilico
* @since 3.2.0
*/
public class BatchMultiPerceptronTest
{
/** Random number generator. */
protected Random random = new Random(211);
/**
* Creates a new test.
*/
public BatchMultiPerceptronTest()
{
}
/**
* Test of constructors of class BatchMultiPerceptron.
*/
@Test
public void testConstructors()
{
int maxIterations = BatchMultiPerceptron.DEFAULT_MAX_ITERATIONS;
double minMargin = BatchMultiPerceptron.DEFAULT_MIN_MARGIN;
VectorFactory<?> vectorFactory = null;
BatchMultiPerceptron<String> instance = new BatchMultiPerceptron<String>();
assertEquals(maxIterations, instance.getMaxIterations());
assertEquals(minMargin, instance.getMinMargin(), 0.0);
assertNotNull(instance.getVectorFactory());
maxIterations = 1 + random.nextInt(10000);
instance = new BatchMultiPerceptron<String>(maxIterations);
assertEquals(maxIterations, instance.getMaxIterations());
assertEquals(minMargin, instance.getMinMargin(), 0.0);
assertNotNull(instance.getVectorFactory());
minMargin = 10.0 * random.nextDouble();
instance = new BatchMultiPerceptron<String>(maxIterations, minMargin);
assertEquals(maxIterations, instance.getMaxIterations());
assertEquals(minMargin, instance.getMinMargin(), 0.0);
assertNotNull(instance.getVectorFactory());
vectorFactory = new SparseVectorFactoryMTJ();
instance = new BatchMultiPerceptron<String>(maxIterations, minMargin,
vectorFactory);
assertEquals(maxIterations, instance.getMaxIterations());
assertEquals(minMargin, instance.getMinMargin(), 0.0);
assertSame(vectorFactory, instance.getVectorFactory());
}
/**
* Test of learn method, of class BatchMultiPerceptron.
*/
@Test
public void testLearn()
{
System.out.println("testLearn");
BatchMultiPerceptron<String> instance = new BatchMultiPerceptron<String>();
LinearMultiCategorizer<String> learned = instance.learn(null);
assertNull(learned);
instance.learn(new ArrayList<InputOutputPair<Vector, String>>());
assertNull(learned);
int d = 10;
int trainCount = 1000;
int testCount = 100;
String[] categories = { "a", "b", "c" };
LinearMultiCategorizer<String> real = new LinearMultiCategorizer<String>();
for (String category : categories)
{
real.getPrototypes().put(category,
new LinearBinaryCategorizer(VectorFactory.getDenseDefault().createUniformRandom(
d, -1, +1,random),
0.0));
}
ArrayList<InputOutputPair<Vector, String>> trainData =
new ArrayList<InputOutputPair<Vector, String>>(trainCount);
for (int i = 0; i < trainCount; i++)
{
Vector input = VectorFactory.getDenseDefault().createUniformRandom(
d, -1, +1, random);
String output = real.evaluate(input);
trainData.add(DefaultInputOutputPair.create(input, output));
}
ArrayList<InputOutputPair<Vector, String>> testData =
new ArrayList<InputOutputPair<Vector, String>>(testCount);
for (int i = 0; i < testCount; i++)
{
Vector input = VectorFactory.getDenseDefault().createUniformRandom(
d, -1, +1, random);
String actual = real.evaluate(input);
testData.add(DefaultInputOutputPair.create(input, actual));
}
learned = instance.learn(trainData);
int correctCount = 0;
for (InputOutputPair<Vector, String> example : testData)
{
String actual = example.getOutput();
String predicted = learned.evaluate(example.getInput());
if (actual.equals(predicted))
{
correctCount++;
}
}
double accuracy = (double) correctCount / testData.size();
System.out.println("Accuracy: " + accuracy);
assertTrue(accuracy >= 0.95);
}
/**
* Test of learn method, of class BatchMultiPerceptron.
*/
@Test
public void testLearnBinarySeparable()
{
System.out.println("testLearnBinarySeparable");
int d = 10;
int trainCount = 1000;
int testCount = 100;
LinearBinaryCategorizer real = new LinearBinaryCategorizer(
VectorFactory.getDenseDefault().createUniformRandom(d, -1, +1, random), 0.0);
ArrayList<InputOutputPair<Vector, Boolean>> trainData =
new ArrayList<InputOutputPair<Vector, Boolean>>(trainCount);
for (int i = 0; i < trainCount; i++)
{
Vector input = VectorFactory.getDenseDefault().createUniformRandom(
d, -1, +1, random);
boolean output = real.evaluate(input);
trainData.add(DefaultInputOutputPair.create(input, output));
}
ArrayList<InputOutputPair<Vector, Boolean>> testData =
new ArrayList<InputOutputPair<Vector, Boolean>>(testCount);
for (int i = 0; i < testCount; i++)
{
Vector input = VectorFactory.getDenseDefault().createUniformRandom(
d, -1, +1, random);
boolean actual = real.evaluate(input);
testData.add(DefaultInputOutputPair.create(input, actual));
}
BatchMultiPerceptron<Boolean> instance = new BatchMultiPerceptron<Boolean>();
LinearMultiCategorizer<Boolean> learned = instance.learn(trainData);
int correctCount = 0;
for (InputOutputPair<Vector, Boolean> example : testData)
{
boolean actual = example.getOutput();
boolean predicted = learned.evaluate(example.getInput());
if (actual == predicted)
{
correctCount++;
}
}
double accuracy = (double) correctCount / testData.size();
System.out.println("Accuracy: " + accuracy);
assertTrue(accuracy >= 0.95);
double cosine = learned.getPrototypes().get(true).getWeights().unitVector().cosine(
real.getWeights().unitVector());
System.out.println("Cosine: " + cosine);
assertTrue(cosine >= 0.95);
}
/**
* Test of getResult method, of class BatchMultiPerceptron.
*/
@Test
public void testGetResult()
{
LinearMultiCategorizer<String> result = null;
BatchMultiPerceptron<String> instance = new BatchMultiPerceptron<String>();
assertSame(result, instance.getResult());
result = new LinearMultiCategorizer<String>();
instance.setResult(result);
assertSame(result, instance.getResult());
}
/**
* Test of getMinMargin method, of class BatchMultiPerceptron.
*/
@Test
public void testGetMinMargin()
{
this.testSetMinMargin();
}
/**
* Test of setMinMargin method, of class BatchMultiPerceptron.
*/
@Test
public void testSetMinMargin()
{
double minMargin = BatchMultiPerceptron.DEFAULT_MIN_MARGIN;
BatchMultiPerceptron<String> instance = new BatchMultiPerceptron<String>();
assertEquals(minMargin, instance.getMinMargin(), 0.0);
double[] goodValues = { 0.0, 0.1, 1.0, 104.1 };
for (double goodValue : goodValues)
{
minMargin = goodValue;
instance.setMinMargin(minMargin);
assertEquals(minMargin, instance.getMinMargin(), 0.0);
}
double[] badValues = { -0.1, -0.5, -1.0, -2.0 };
for (double badValue : badValues)
{
boolean exceptionThrown = false;
try
{
instance.setMinMargin(badValue);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(minMargin, instance.getMinMargin(), 0.0);
}
}
/**
* Test of getVectorFactory method, of class BatchMultiPerceptron.
*/
@Test
public void testGetVectorFactory()
{
this.testSetVectorFactory();
}
/**
* Test of setVectorFactory method, of class BatchMultiPerceptron.
*/
@Test
public void testSetVectorFactory()
{
BatchMultiPerceptron<String> instance = new BatchMultiPerceptron<String>();
assertSame(VectorFactory.getDefault(), instance.getVectorFactory());
VectorFactory<?> vectorFactory = new SparseVectorFactoryMTJ();
instance.setVectorFactory(vectorFactory);
assertSame(vectorFactory, instance.getVectorFactory());
vectorFactory = null;
instance.setVectorFactory(vectorFactory);
assertNull(instance.getVectorFactory());
vectorFactory = VectorFactory.getDenseDefault();
assertNull(instance.getVectorFactory());
}
/**
* Test of getErrorCount method, of class BatchMultiPerceptron.
*/
@Test
public void testGetErrorCount()
{
int errorCount = 0;
BatchMultiPerceptron<String> instance = new BatchMultiPerceptron<String>();
assertEquals(errorCount, instance.getErrorCount());
errorCount = 40;
instance.setErrorCount(errorCount);
assertEquals(errorCount, instance.getErrorCount());
}
/**
* Test of getPerformance method, of class BatchMultiPerceptron.
*/
@Test
public void testGetPerformance()
{
int errorCount = 0;
BatchMultiPerceptron<String> instance = new BatchMultiPerceptron<String>();
assertEquals(errorCount, instance.getPerformance().getValue(), 0.0);
errorCount = 40;
instance.setErrorCount(errorCount);
assertEquals(errorCount, instance.getPerformance().getValue(), 0.0);
}
}