/*
* File: BallseptronTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry Learning Core
*
* Copyright March 08, 2011, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
*/
package gov.sandia.cognition.learning.algorithm.perceptron;
import gov.sandia.cognition.learning.algorithm.perceptron.KernelizableBinaryCategorizerOnlineLearner;
import gov.sandia.cognition.learning.algorithm.perceptron.Ballseptron;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.mtj.Vector2;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.util.ObjectUtil;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Unit test for class Ballseptron.
*
* @author Justin Basilico
* @since 3.3.0
*/
public class BallseptronTest
extends KernelizableBinaryCategorizerOnlineLearnerTestHarness
{
/**
* Creates a new test.
*/
public BallseptronTest()
{
}
@Override
protected KernelizableBinaryCategorizerOnlineLearner createLinearInstance()
{
return new Ballseptron();
}
/**
* Test of constructors of class Ballseptron.
*/
@Test
public void testConstructors()
{
double radius = Ballseptron.DEFAULT_RADIUS;
Ballseptron instance = new Ballseptron();
assertEquals(radius, instance.getRadius(), 0.0);
radius = this.random.nextDouble();
instance = new Ballseptron(radius);
assertEquals(radius, instance.getRadius(), 0.0);
}
/**
* Test of update method, of class Ballseptron.
*/
@Test
public void testUpdate()
{
double epsilon = 1E-5;
Ballseptron instance = new Ballseptron();
LinearBinaryCategorizer result = instance.createInitialLearnedObject();
assertNull(result.getWeights());
assertEquals(0.0, result.getBias(), 0.0);
Vector input = new Vector2(2.0, 3.0);
Boolean output = true;
instance.update(result, DefaultInputOutputPair.create(input, output));
assertEquals(output, result.evaluate(input));
input = new Vector2(4.0, 4.0);
output = true;
instance.update(result, DefaultInputOutputPair.create(input, output));
assertEquals(output, result.evaluate(input));
input = new Vector2(1.0, -1.0);
output = false;
instance.update(result, DefaultInputOutputPair.create(input, output));
assertEquals(output, result.evaluate(input));
input = new Vector2(1.0, -1.0);
output = false;
instance.update(result, DefaultInputOutputPair.create(input, output));
assertEquals(output, result.evaluate(input));
input = new Vector2(2.0, 3.0);
output = true;
instance.update(result, DefaultInputOutputPair.create(input, output));
assertEquals(output, result.evaluate(input));
result = instance.createInitialLearnedObject();
MultivariateGaussian positive = new MultivariateGaussian(2);
positive.setMean(new Vector2(1.0, 1.0));
positive.getCovariance().setElement(0, 0, 0.2);
positive.getCovariance().setElement(1, 1, 2.0);
MultivariateGaussian negative = new MultivariateGaussian(2);
negative.setMean(new Vector2(-1.0, -1.0));
negative.getCovariance().setElement(0, 0, 0.2);
negative.getCovariance().setElement(1, 1, 2.0);
for (int i = 0; i < 4000; i++)
{
output = random.nextBoolean();
input = (output ? positive : negative).sample(random);
Vector oldWeights = ObjectUtil.cloneSafe(result.getWeights());
double prediction = result.evaluateAsDouble(input);
double actual = (output ? +1 : -1);
double margin = prediction * actual;
instance.update(result, DefaultInputOutputPair.create(input, output));
if (oldWeights == null)
{
assertEquals(input, result.getWeights());
assertNotSame(input, result.getWeights());
}
else if (margin <= 0.0)
{
Vector expectedWeights = oldWeights.plus(input.scale(actual));
if(!
expectedWeights.equals(
result.getWeights(), epsilon))
{
System.out.println("Actual " + result.getWeights());
System.out.println("Expected: " + expectedWeights);
}
}
else if (margin/oldWeights.norm2() <= instance.getRadius())
{
}
else
{
assertEquals(oldWeights, result.getWeights());
}
}
}
/**
* Test of getRadius method, of class Ballseptron.
*/
@Test
public void testGetRadius()
{
this.testSetRadius();
}
/**
* Test of setRadius method, of class Ballseptron.
*/
@Test
public void testSetRadius()
{
double radius = Ballseptron.DEFAULT_RADIUS;
Ballseptron instance = new Ballseptron();
assertEquals(radius, instance.getRadius(), 0.0);
radius = 2.0 * radius;
instance.setRadius(radius);
assertEquals(radius, instance.getRadius(), 0.0);
radius = 1.0;
instance.setRadius(radius);
assertEquals(radius, instance.getRadius(), 0.0);
radius = 0.00000001;
instance.setRadius(radius);
assertEquals(radius, instance.getRadius(), 0.0);
double[] badValues = {0.0, -0.1, -1.0};
for (double badValue : badValues)
{
boolean exceptionThrown = false;
try
{
instance.setRadius(badValue);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(radius, instance.getRadius(), 0.0);
}
}
}