/*
* File: OnlineVotedPerceptronTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry Learning Core
*
* Copyright October 20, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
*/
package gov.sandia.cognition.learning.algorithm.perceptron;
import gov.sandia.cognition.learning.algorithm.ensemble.WeightedBinaryEnsemble;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.math.matrix.mtj.SparseVectorFactoryMTJ;
import gov.sandia.cognition.math.matrix.mtj.Vector2;
import junit.framework.TestCase;
/**
* Unit tests for class OnlineVotedPerceptron.
*
* @author Justin Basilico
* @since 3.1
*/
public class OnlineVotedPerceptronTest
extends TestCase
{
/**
* Creates a new test.
*
* @param testName The test name.
*/
public OnlineVotedPerceptronTest(
String testName)
{
super(testName);
}
/**
* Test of constructors of class OnlineVotedPerceptron.
*/
public void testConstructors()
{
VectorFactory<?> factory = VectorFactory.getDefault();
OnlineVotedPerceptron instance = new OnlineVotedPerceptron();
assertSame(VectorFactory.getDefault(), instance.getVectorFactory());
factory = new SparseVectorFactoryMTJ();
instance = new OnlineVotedPerceptron(factory);
assertSame(factory, instance.getVectorFactory());
}
/**
* Test of createInitialLearnedObject method, of class OnlineVotedPerceptron.
*/
public void testCreateInitialLearnedObject()
{
OnlineVotedPerceptron instance = new OnlineVotedPerceptron();
WeightedBinaryEnsemble<Vectorizable, ?> result =
instance.createInitialLearnedObject();
assertEquals(0, result.getMembers().size());
assertNotSame(result, instance.createInitialLearnedObject());
}
/**
* Test of update method, of class OnlineVotedPerceptron.
*/
public void testUpdate()
{
System.out.println("update");
WeightedBinaryEnsemble<Vectorizable, LinearBinaryCategorizer> target =
null;
InputOutputPair<? extends Vectorizable, Boolean> example = null;
OnlineVotedPerceptron instance = new OnlineVotedPerceptron();
WeightedBinaryEnsemble<Vectorizable, LinearBinaryCategorizer> result =
instance.createInitialLearnedObject();
assertEquals(0, result.getMembers().size());
instance.update(result, DefaultInputOutputPair.create(new Vector2(2.0, 3.0), true));
assertEquals(1, result.getMembers().size());
assertEquals(new Vector2(2.0, 3.0), result.getMembers().get(0).getValue().getWeights());
assertEquals(1.0, result.getMembers().get(0).getValue().getBias());
assertEquals(1.0, result.getMembers().get(0).getWeight());
instance.update(result, DefaultInputOutputPair.create(new Vector2(4.0, 4.0), true));
assertEquals(1, result.getMembers().size());
assertEquals(new Vector2(2.0, 3.0), result.getMembers().get(0).getValue().getWeights());
assertEquals(1.0, result.getMembers().get(0).getValue().getBias());
assertEquals(2.0, result.getMembers().get(0).getWeight());
instance.update(result, DefaultInputOutputPair.create(new Vector2(1.0, 1.0), false));
assertEquals(2, result.getMembers().size());
assertEquals(new Vector2(2.0, 3.0), result.getMembers().get(0).getValue().getWeights());
assertEquals(1.0, result.getMembers().get(0).getValue().getBias());
assertEquals(2.0, result.getMembers().get(0).getWeight());
assertEquals(new Vector2(1.0, 2.0), result.getMembers().get(1).getValue().getWeights());
assertEquals(0.0, result.getMembers().get(1).getValue().getBias());
assertEquals(1.0, result.getMembers().get(1).getWeight());
instance.update(result, DefaultInputOutputPair.create(new Vector2(1.0, 1.0), false));
assertEquals(3, result.getMembers().size());
assertEquals(new Vector2(2.0, 3.0), result.getMembers().get(0).getValue().getWeights());
assertEquals(1.0, result.getMembers().get(0).getValue().getBias());
assertEquals(2.0, result.getMembers().get(0).getWeight());
assertEquals(new Vector2(1.0, 2.0), result.getMembers().get(1).getValue().getWeights());
assertEquals(0.0, result.getMembers().get(1).getValue().getBias());
assertEquals(1.0, result.getMembers().get(1).getWeight());
assertEquals(new Vector2(0.0, 1.0), result.getMembers().get(2).getValue().getWeights());
assertEquals(-1.0, result.getMembers().get(2).getValue().getBias());
assertEquals(1.0, result.getMembers().get(2).getWeight());
instance.update(result, DefaultInputOutputPair.create(new Vector2(2.0, 3.0), true));
assertEquals(3, result.getMembers().size());
assertEquals(new Vector2(2.0, 3.0), result.getMembers().get(0).getValue().getWeights());
assertEquals(1.0, result.getMembers().get(0).getValue().getBias());
assertEquals(2.0, result.getMembers().get(0).getWeight());
assertEquals(new Vector2(1.0, 2.0), result.getMembers().get(1).getValue().getWeights());
assertEquals(0.0, result.getMembers().get(1).getValue().getBias());
assertEquals(1.0, result.getMembers().get(1).getWeight());
assertEquals(new Vector2(0.0, 1.0), result.getMembers().get(2).getValue().getWeights());
assertEquals(-1.0, result.getMembers().get(2).getValue().getBias());
assertEquals(2.0, result.getMembers().get(2).getWeight());
instance.update(result, DefaultInputOutputPair.create(new Vector2(2.0, 3.0), true));
assertEquals(3, result.getMembers().size());
assertEquals(new Vector2(2.0, 3.0), result.getMembers().get(0).getValue().getWeights());
assertEquals(1.0, result.getMembers().get(0).getValue().getBias());
assertEquals(2.0, result.getMembers().get(0).getWeight());
assertEquals(new Vector2(1.0, 2.0), result.getMembers().get(1).getValue().getWeights());
assertEquals(0.0, result.getMembers().get(1).getValue().getBias());
assertEquals(1.0, result.getMembers().get(1).getWeight());
assertEquals(new Vector2(0.0, 1.0), result.getMembers().get(2).getValue().getWeights());
assertEquals(-1.0, result.getMembers().get(2).getValue().getBias());
assertEquals(3.0, result.getMembers().get(2).getWeight());
instance.update(result, DefaultInputOutputPair.create(new Vector2(2.0, 3.0), true));
assertEquals(3, result.getMembers().size());
assertEquals(new Vector2(2.0, 3.0), result.getMembers().get(0).getValue().getWeights());
assertEquals(1.0, result.getMembers().get(0).getValue().getBias());
assertEquals(2.0, result.getMembers().get(0).getWeight());
assertEquals(new Vector2(1.0, 2.0), result.getMembers().get(1).getValue().getWeights());
assertEquals(0.0, result.getMembers().get(1).getValue().getBias());
assertEquals(1.0, result.getMembers().get(1).getWeight());
assertEquals(new Vector2(0.0, 1.0), result.getMembers().get(2).getValue().getWeights());
assertEquals(-1.0, result.getMembers().get(2).getValue().getBias());
assertEquals(4.0, result.getMembers().get(2).getWeight());
}
/**
* Test of getLastMember method, of class OnlineVotedPerceptron.
*/
public void testGetLastMember()
{
System.out.println("getLastMember");
WeightedBinaryEnsemble<Vectorizable, LinearBinaryCategorizer> ensemble =
new WeightedBinaryEnsemble<Vectorizable, LinearBinaryCategorizer>();
assertNull(OnlineVotedPerceptron.getLastMember(ensemble));
LinearBinaryCategorizer last = new LinearBinaryCategorizer();
ensemble.add(last);
assertSame(last, OnlineVotedPerceptron.getLastMember(ensemble).getValue());
last = new LinearBinaryCategorizer();
ensemble.add(last);
assertSame(last, OnlineVotedPerceptron.getLastMember(ensemble).getValue());
}
/**
* Test of getVectorFactory method, of class OnlineVotedPerceptron.
*/
public void testGetVectorFactory()
{
this.testSetVectorFactory();
}
/**
* Test of setVectorFactory method, of class OnlineVotedPerceptron.
*/
public void testSetVectorFactory()
{
VectorFactory<?> factory = VectorFactory.getDefault();
OnlineVotedPerceptron instance = new OnlineVotedPerceptron();
assertSame(VectorFactory.getDefault(), instance.getVectorFactory());
factory = new SparseVectorFactoryMTJ();
instance.setVectorFactory(factory);
assertSame(factory, instance.getVectorFactory());
factory = null;
instance.setVectorFactory(factory);
assertSame(factory, instance.getVectorFactory());
factory = VectorFactory.getDenseDefault();
instance.setVectorFactory(factory);
assertSame(factory, instance.getVectorFactory());
}
}