/* * File: VectorThresholdInformationGainLearnerTest.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright November 12, 2007, Sandia Corporation. Under the terms of Contract * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by * or on behalf of the U.S. Government. Export of this program may require a * license from the United States Government. See CopyrightHistory.txt for * complete details. * */ package gov.sandia.cognition.learning.algorithm.tree; import gov.sandia.cognition.learning.data.DefaultInputOutputPair; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer; import gov.sandia.cognition.math.matrix.mtj.Vector2; import gov.sandia.cognition.math.matrix.mtj.Vector3; import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution; import gov.sandia.cognition.util.DefaultPair; import java.lang.reflect.*; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedList; import junit.framework.TestCase; /** * This class implements JUnit tests for the following classes: * * VectorThresholdInformationGainLearner * * @author Justin Basilico * @since 2.0 */ public class VectorThresholdInformationGainLearnerTest extends TestCase { public VectorThresholdInformationGainLearnerTest( String testName) { super(testName); } /** * Test of constructors of class VectorThresholdInformationGainLearner. */ public void testConstructors() { int minSplitSize = VectorThresholdInformationGainLearner.DEFAULT_MIN_SPLIT_SIZE; VectorThresholdInformationGainLearner<Boolean> instance = new VectorThresholdInformationGainLearner<>(); assertEquals(minSplitSize, instance.getMinSplitSize()); minSplitSize = 6; instance = new VectorThresholdInformationGainLearner<>(minSplitSize); assertEquals(minSplitSize, instance.getMinSplitSize()); } /** * Test of learn method, of class VectorThresholdInformationGainLearner. */ public void testLearn() { VectorThresholdInformationGainLearner<Boolean> instance = new VectorThresholdInformationGainLearner<Boolean>(); VectorElementThresholdCategorizer result = instance.learn(null); assertNull(result); LinkedList<InputOutputPair<Vector3, Boolean>> data = new LinkedList<InputOutputPair<Vector3, Boolean>>(); result = instance.learn(data); assertNull(result); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 4.0, 2.0), true)); result = instance.learn(data); assertNull(result); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 4.0, 2.0), true)); result = instance.learn(data); assertNull(result); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 1.0, 2.0), true)); result = instance.learn(data); assertNotNull(result); assertEquals(1, result.getIndex()); assertEquals(2.5, result.getThreshold()); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 2.0, 3.0), false)); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 4.0, 4.0), false)); result = instance.learn(data); assertNotNull(result); assertEquals(2, result.getIndex()); assertEquals(2.5, result.getThreshold()); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 3.0, 2.0), false)); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 0.0, 2.0), true)); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 5.0, 2.0), false)); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 7.0, 2.0), false)); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 8.0, 2.0), false)); result = instance.learn(data); assertNotNull(result); assertEquals(1, result.getIndex()); assertEquals(1.5, result.getThreshold()); } /** * Test of computeBestGainAndThreshold method, of class VectorThresholdInformationGainLearner. */ public void testComputeBestThreshold() { VectorThresholdInformationGainLearner<Boolean> instance = new VectorThresholdInformationGainLearner<Boolean>(); DefaultDataDistribution<Boolean> baseCounts = null; DefaultPair<Double, Double> result = null; LinkedList<InputOutputPair<Vector3, Boolean>> data = new LinkedList<InputOutputPair<Vector3, Boolean>>(); baseCounts = CategorizationTreeLearner.getOutputCounts(data); result = instance.computeBestGainAndThreshold(data, 0, baseCounts); assertNull(result); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 4.0, 2.0), true)); baseCounts = CategorizationTreeLearner.getOutputCounts(data); result = instance.computeBestGainAndThreshold(data, 0, baseCounts); assertNull(result); result = instance.computeBestGainAndThreshold(data, 1, baseCounts); assertNull(result); result = instance.computeBestGainAndThreshold(data, 2, baseCounts); assertNull(result); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 1.0, 2.0), true)); baseCounts = CategorizationTreeLearner.getOutputCounts(data); result = instance.computeBestGainAndThreshold(data, 0, baseCounts); assertNull(result); baseCounts = CategorizationTreeLearner.getOutputCounts(data); result = instance.computeBestGainAndThreshold(data, 1, baseCounts); assertEquals(0.0, result.getFirst()); assertEquals(2.5, result.getSecond()); baseCounts = CategorizationTreeLearner.getOutputCounts(data); result = instance.computeBestGainAndThreshold(data, 2, baseCounts); assertNull(result); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 2.0, 3.0), false)); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 4.0, 4.0), false)); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 3.0, 5.0), false)); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 0.0, 2.0), true)); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 5.0, 2.0), false)); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 7.0, 2.0), false)); data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 8.0, 2.0), false)); baseCounts = CategorizationTreeLearner.getOutputCounts(data); result = instance.computeBestGainAndThreshold(data, 0, baseCounts); assertNull(result); baseCounts = CategorizationTreeLearner.getOutputCounts(data); result = instance.computeBestGainAndThreshold(data, 1, baseCounts); assertEquals(0.458, result.getFirst(), 0.001); assertEquals(1.5, result.getSecond()); baseCounts = CategorizationTreeLearner.getOutputCounts(data); result = instance.computeBestGainAndThreshold(data, 2, baseCounts); assertEquals(0.251, result.getFirst(), 0.001); assertEquals(2.5, result.getSecond()); } /** * Test computeSplitGain() when manual priors are used. */ public void testManualPriors() { VectorThresholdInformationGainLearner<Integer> instance = new VectorThresholdInformationGainLearner<Integer>(); // Make dummy data set w/ imbalanced class frequencies. HashMap<Integer, Integer> trainCounts = new HashMap<Integer, Integer>(); trainCounts.put(0, 50); trainCounts.put(1, 25); trainCounts.put(2, 20); trainCounts.put(3, 5); int numKlass = trainCounts.size(); // Manually assign priors so examples each class has equal // prior probability. HashMap<Integer, Double> equalPrior = new HashMap<Integer, Double>(); for (int i = 0; i < numKlass; ++i) { equalPrior.put(i, 0.25); } instance.configure(equalPrior, trainCounts); DefaultDataDistribution<Integer> baseCounts = new DefaultDataDistribution<Integer>(); baseCounts.set(0, 10); baseCounts.set(1, 10); baseCounts.set(2, 8); baseCounts.set(3, 2); DefaultDataDistribution<Integer> leftCounts = new DefaultDataDistribution<Integer>(); leftCounts.set(0, 10); leftCounts.set(1, 0); leftCounts.set(2, 0); leftCounts.set(3, 2); DefaultDataDistribution<Integer> rightCounts = new DefaultDataDistribution<Integer>(); rightCounts.set(0, 0); rightCounts.set(1, 10); rightCounts.set(2, 8); rightCounts.set(3, 0); double gain = instance.computeSplitGain(baseCounts, rightCounts, leftCounts); assertEquals(0.98522, gain, 1e-3); } /** * Test configure() method. */ public void testConfigure() { VectorThresholdInformationGainLearner<Integer> instance = new VectorThresholdInformationGainLearner<Integer>(); // Make dummy data set w/ imbalanced class frequencies. HashMap<Integer, Integer> trainCounts = new HashMap<Integer, Integer>(); trainCounts.put(0, 50); trainCounts.put(1, 25); trainCounts.put(2, 20); trainCounts.put(3, 5); int numKlass = trainCounts.size(); // Make sure configure() assigns proper defaults. instance.configure(null, trainCounts); double[] expected = {0.5, 0.25, 0.2, 0.05}; ArrayList<Integer> index = instance.categories; double[] priors = instance.categoryPriors; for (int i = 0; i < numKlass; ++i) { int klass = index.get(i); assertEquals(expected[klass], priors[i], 1e-5); } // Make sure configure() assigns manual priors. HashMap<Integer, Double> inversePriors = new HashMap<Integer, Double>(); double mass = 0; for (int i = 0; i < numKlass; ++i) { expected[i] = 1.0 / expected[i]; mass += expected[i]; } for (int i = 0; i < numKlass; ++i) { expected[i] /= mass; inversePriors.put(i, expected[i]); } instance.configure(inversePriors, trainCounts); index = instance.categories; priors = instance.categoryPriors; for (int i = 0; i < numKlass; ++i) { int klass = index.get(i); assertEquals(expected[klass], priors[i], 1e-5); } } /** * Tests a corner-case of creating a threshold where the first split is * the result. */ public void testThresholdBug() { VectorThresholdInformationGainLearner<Boolean> instance = new VectorThresholdInformationGainLearner<Boolean>(); DefaultDataDistribution<Boolean> baseCounts = null; DefaultPair<Double, Double> result = null; LinkedList<InputOutputPair<Vector2, Boolean>> data = new LinkedList<InputOutputPair<Vector2, Boolean>>(); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 1.0763), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 1.0763), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 1.0763), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 1.0763), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 5.5639), false)); baseCounts = CategorizationTreeLearner.getOutputCounts(data); result = instance.computeBestGainAndThreshold(data, 1, baseCounts); assertEquals((5.5639 + 1.0763) / 2.0, result.getSecond()); } public void testThresholdRoundoff() { VectorThresholdInformationGainLearner<Boolean> instance = new VectorThresholdInformationGainLearner<Boolean>(); DefaultDataDistribution<Boolean> baseCounts = null; DefaultPair<Double, Double> result = null; LinkedList<InputOutputPair<Vector2, Boolean>> data = new LinkedList<InputOutputPair<Vector2, Boolean>>(); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 0.0), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, 0.0), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, Double.MIN_VALUE), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, Double.MIN_VALUE), false)); baseCounts = CategorizationTreeLearner.getOutputCounts(data); result = instance.computeBestGainAndThreshold(data, 1, baseCounts); assertEquals(Double.MIN_VALUE, result.getSecond()); data.clear(); double x1 = Double.MIN_VALUE; double x2 = x1 + x1; data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, x1), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, x1), true)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, x2), false)); data.add(new DefaultInputOutputPair<Vector2, Boolean>(new Vector2(0.0, x2), false)); baseCounts = CategorizationTreeLearner.getOutputCounts(data); result = instance.computeBestGainAndThreshold(data, 1, baseCounts); assertEquals(x2, result.getSecond()); } }