/*
* File: RandomSubVectorThresholdLearnerTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright December 23, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.tree;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import java.util.ArrayList;
import java.util.Random;
import junit.framework.TestCase;
/**
* Unit tests for class RandomSubVectorThresholdLearner.
*
* @author Justin Basilico
* @since 3.0
*/
public class RandomSubVectorThresholdLearnerTest
extends TestCase
{
protected Random random;
/**
* Creates a new test.
*
* @param testName The test name.
*/
public RandomSubVectorThresholdLearnerTest(
String testName)
{
super(testName);
this.random = new Random();
}
/**
* Test of constructors of class RandomSubVectorThresholdLearner.
*/
public void testConstructors()
{
VectorThresholdInformationGainLearner<String> subLearner = null;
double percentToSample = RandomSubVectorThresholdLearner.DEFAULT_PERCENT_TO_SAMPLE;
VectorFactory<?> vectorFactory = VectorFactory.getDefault();
int[] dimensionsToConsider = null;
RandomSubVectorThresholdLearner<String> instance = new RandomSubVectorThresholdLearner<String>();
assertSame(subLearner, instance.getSubLearner());
assertEquals(percentToSample, instance.getPercentToSample());
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
assertNotNull(instance.getRandom());
assertSame(vectorFactory, instance.getVectorFactory());
subLearner = new VectorThresholdInformationGainLearner<String>();
percentToSample = percentToSample / 2.0;
instance = new RandomSubVectorThresholdLearner<String>(subLearner,
percentToSample, random);
assertSame(subLearner, instance.getSubLearner());
assertEquals(percentToSample, instance.getPercentToSample());
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
assertSame(random, instance.getRandom());
assertSame(vectorFactory, instance.getVectorFactory());
vectorFactory = VectorFactory.getSparseDefault();
instance = new RandomSubVectorThresholdLearner<String>(subLearner,
percentToSample, random, vectorFactory);
assertSame(subLearner, instance.getSubLearner());
assertEquals(percentToSample, instance.getPercentToSample());
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
assertSame(random, instance.getRandom());
assertSame(vectorFactory, instance.getVectorFactory());
dimensionsToConsider = new int[] {5, 12};
vectorFactory = VectorFactory.getSparseDefault();
instance = new RandomSubVectorThresholdLearner<String>(subLearner,
percentToSample, dimensionsToConsider, random, vectorFactory);
assertSame(subLearner, instance.getSubLearner());
assertEquals(percentToSample, instance.getPercentToSample());
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
assertSame(random, instance.getRandom());
assertSame(vectorFactory, instance.getVectorFactory());
}
/**
* Test of learn method, of class RandomSubVectorThresholdLearner.
*/
public void testLearn()
{
RandomSubVectorThresholdLearner<String> instance = new RandomSubVectorThresholdLearner<String>(
new VectorThresholdInformationGainLearner<String>(),
0.1, random);
VectorFactory<?> vectorFactory = VectorFactory.getDefault();
ArrayList<InputOutputPair<Vector, String>> data =
new ArrayList<InputOutputPair<Vector, String>>();
for (int i = 0; i < 10; i++)
{
data.add(new DefaultInputOutputPair<Vector, String>(vectorFactory.createUniformRandom(
100, 1.0, 10.0, random), "a"));
}
for (int i = 0; i < 10; i++)
{
data.add(new DefaultInputOutputPair<Vector, String>(vectorFactory.createUniformRandom(
100, 1.0, 10.0, random), "b"));
}
VectorElementThresholdCategorizer result = instance.learn(data);
assertNotNull(result);
assertTrue(result.getIndex() >= 0);
assertTrue(result.getIndex() < 100);
// Change the dimensions to consider.
instance.setDimensionsToConsider(new int[] {10, 20, 30, 40, 50});
instance.setPercentToSample(0.5);
result = instance.learn(data);
assertNotNull(result);
assertTrue(result.getIndex() >= 10);
assertTrue(result.getIndex() <= 50);
assertTrue(result.getIndex() % 10 == 0);
}
/**
* Test of learn method, of class RandomSubVectorThresholdLearner.
*/
public void testLearnFullDimensions()
{
RandomSubVectorThresholdLearner<String> instance = new RandomSubVectorThresholdLearner<>(
new VectorThresholdInformationGainLearner<String>(),
0.9999, random);
VectorFactory<?> vectorFactory = VectorFactory.getDefault();
ArrayList<InputOutputPair<Vector, String>> data = new ArrayList<>();
for (int i = 0; i < 10; i++)
{
data.add(new DefaultInputOutputPair<>(vectorFactory.createUniformRandom(
100, 1.0, 10.0, random), "a"));
}
for (int i = 0; i < 10; i++)
{
data.add(new DefaultInputOutputPair<>(vectorFactory.createUniformRandom(
100, 1.0, 10.0, random), "b"));
}
VectorElementThresholdCategorizer result = instance.learn(data);
assertTrue(result.getIndex() >= 0);
assertTrue(result.getIndex() < 100);
// Change the dimensions to consider.
instance.setDimensionsToConsider(new int[] {10});
result = instance.learn(data);
assertTrue(result.getIndex() == 10);
}
/**
* Test of getSubDimensionality method, of class RandomSubVectorThresholdLearner.
*/
public void testGetSubDimensionality()
{
RandomSubVectorThresholdLearner<String> instance = new RandomSubVectorThresholdLearner<String>();
instance.setPercentToSample(0.5);
assertEquals(5, instance.getSubDimensionality(10));
instance.setPercentToSample(0.25);
assertEquals(2, instance.getSubDimensionality(9));
instance.setPercentToSample(1.0);
assertEquals(9, instance.getSubDimensionality(9));
instance.setPercentToSample(0.0);
assertEquals(1, instance.getSubDimensionality(9));
}
/**
* Test of getSubLearner method, of class RandomSubVectorThresholdLearner.
*/
public void testGetSubLearner()
{
this.testSetSubLearner();
}
/**
* Test of setSubLearner method, of class RandomSubVectorThresholdLearner.
*/
public void testSetSubLearner()
{
VectorThresholdInformationGainLearner<String> subLearner = null;
RandomSubVectorThresholdLearner<String> instance = new RandomSubVectorThresholdLearner<String>();
assertSame(subLearner, instance.getSubLearner());
subLearner = new VectorThresholdInformationGainLearner<String>();
instance.setSubLearner(subLearner);
assertSame(subLearner, instance.getSubLearner());
subLearner = new VectorThresholdInformationGainLearner<String>();
instance.setSubLearner(subLearner);
assertSame(subLearner, instance.getSubLearner());
subLearner = null;
instance.setSubLearner(subLearner);
assertSame(subLearner, instance.getSubLearner());
subLearner = new VectorThresholdInformationGainLearner<String>();
instance.setSubLearner(subLearner);
assertSame(subLearner, instance.getSubLearner());
}
/**
* Test of getPercentToSample method, of class RandomSubVectorThresholdLearner.
*/
public void testGetPercentToSample()
{
this.testSetPercentToSample();
}
/**
* Test of setPercentToSample method, of class RandomSubVectorThresholdLearner.
*/
public void testSetPercentToSample()
{
double percentToSample = RandomSubVectorThresholdLearner.DEFAULT_PERCENT_TO_SAMPLE;
RandomSubVectorThresholdLearner<String> instance = new RandomSubVectorThresholdLearner<String>();
assertEquals(percentToSample, instance.getPercentToSample(), 0.0);
percentToSample = percentToSample / 2.0;
instance.setPercentToSample(percentToSample);
assertEquals(percentToSample, instance.getPercentToSample());
percentToSample = 1.0;
instance.setPercentToSample(percentToSample);
assertEquals(percentToSample, instance.getPercentToSample());
percentToSample = 0.0;
instance.setPercentToSample(percentToSample);
assertEquals(percentToSample, instance.getPercentToSample());
percentToSample = 0.47;
instance.setPercentToSample(percentToSample);
assertEquals(percentToSample, instance.getPercentToSample());
boolean exceptionThrown = false;
try
{
instance.setPercentToSample(-0.1);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(percentToSample, instance.getPercentToSample());
exceptionThrown = false;
try
{
instance.setPercentToSample(1.1);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(percentToSample, instance.getPercentToSample());
}
public void testGetDimensionsToConsider()
{
this.testSetDimensionsToConsider();
}
public void testSetDimensionsToConsider()
{
int[] dimensionsToConsider = null;
RandomSubVectorThresholdLearner<String> instance
= new RandomSubVectorThresholdLearner<>();
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
dimensionsToConsider = new int[] {1,2,5};
instance.setDimensionsToConsider(dimensionsToConsider);
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
dimensionsToConsider = new int[] {0, 9, 12};
instance.setDimensionsToConsider(dimensionsToConsider);
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
dimensionsToConsider = null;
instance.setDimensionsToConsider(dimensionsToConsider);
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
}
/**
* Test of getVectorFactory method, of class RandomSubVectorThresholdLearner.
*/
public void testGetVectorFactory()
{
this.testSetVectorFactory();
}
/**
* Test of setVectorFactory method, of class RandomSubVectorThresholdLearner.
*/
public void testSetVectorFactory()
{
VectorFactory<?> vectorFactory = VectorFactory.getDefault();
RandomSubVectorThresholdLearner<String> instance = new RandomSubVectorThresholdLearner<String>();
assertSame(vectorFactory, instance.getVectorFactory());
vectorFactory = VectorFactory.getSparseDefault();
instance.setVectorFactory(vectorFactory);
assertSame(vectorFactory, instance.getVectorFactory());
vectorFactory = VectorFactory.getDefault();
instance.setVectorFactory(vectorFactory);
assertSame(vectorFactory, instance.getVectorFactory());
vectorFactory = null;
instance.setVectorFactory(vectorFactory);
assertSame(vectorFactory, instance.getVectorFactory());
vectorFactory = VectorFactory.getDefault();
instance.setVectorFactory(vectorFactory);
assertSame(vectorFactory, instance.getVectorFactory());
}
}