/*
* File: RandomForestFactoryTest.java
* Authors: Justin Basilico
* Project: Cognitive Foundry
*
* Copyright 2015 Cognitive Foundry. All rights reserved.
*/
package gov.sandia.cognition.learning.algorithm.tree;
import gov.sandia.cognition.learning.algorithm.ensemble.BaggingCategorizerLearner;
import gov.sandia.cognition.learning.algorithm.ensemble.BaggingRegressionLearner;
import gov.sandia.cognition.math.matrix.Vector;
import java.util.Random;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Unit tests for class {@link RandomForestFactory}.
*
* @author Justin Basilico
* @since 3.4.2
*/
public class RandomForestFactoryTest
extends Object
{
protected Random random = new Random(3333);
/**
* Test of createCategorizationLearner method, of class RandomForestFactory.
*/
@Test
public void testCreateCategorizationLearner()
{
int ensembleSize = 3 + random.nextInt(1000);
double baggingFraction = random.nextDouble();
double dimensionsFraction = random.nextDouble();
int maxTreeDepth = 3 + random.nextInt(10);
int minLeafSize = 4 + random.nextInt(10);
Random random = new Random();
BaggingCategorizerLearner<Vector, String> result
= RandomForestFactory.createCategorizationLearner(ensembleSize,
baggingFraction, dimensionsFraction, maxTreeDepth, minLeafSize,
random);
assertEquals(ensembleSize, result.getMaxIterations());
assertEquals(baggingFraction, result.getPercentToSample(), 0.0);
assertSame(random, result.getRandom());
@SuppressWarnings("rawtypes")
CategorizationTreeLearner treeLearner =
(CategorizationTreeLearner) result.getLearner();
assertEquals(maxTreeDepth, treeLearner.getMaxDepth());
assertTrue(treeLearner.getLeafCountThreshold() >= 2 * minLeafSize);
RandomSubVectorThresholdLearner<?> randomSubspace = (RandomSubVectorThresholdLearner<?>)
treeLearner.getDeciderLearner();
assertEquals(dimensionsFraction, randomSubspace.getPercentToSample(), 0.0);
assertSame(random, randomSubspace.getRandom());
VectorThresholdInformationGainLearner<?> splitLearner = (VectorThresholdInformationGainLearner<?>)
randomSubspace.getSubLearner();
assertEquals(minLeafSize, splitLearner.getMinSplitSize());
}
/**
* Test of createRegressionLearner method, of class RandomForestFactory.
*/
@Test
public void testCreateRegressionLearner()
{
int ensembleSize = 3 + random.nextInt(1000);
double baggingFraction = random.nextDouble();
double dimensionsFraction = random.nextDouble();
int maxTreeDepth = 3 + random.nextInt(10);
int minLeafSize = 4 + random.nextInt(10);
Random random = new Random();
BaggingRegressionLearner<Vector> result
= RandomForestFactory.createRegressionLearner(ensembleSize,
baggingFraction, dimensionsFraction, maxTreeDepth, minLeafSize,
random);
assertEquals(ensembleSize, result.getMaxIterations());
assertEquals(baggingFraction, result.getPercentToSample(), 0.0);
assertSame(random, result.getRandom());
@SuppressWarnings("rawtypes")
RegressionTreeLearner treeLearner =
(RegressionTreeLearner) result.getLearner();
assertEquals(maxTreeDepth, treeLearner.getMaxDepth());
assertTrue(treeLearner.getLeafCountThreshold() >= 2 * minLeafSize);
assertNull(treeLearner.getRegressionLearner());
RandomSubVectorThresholdLearner<?> randomSubspace = (RandomSubVectorThresholdLearner<?>)
treeLearner.getDeciderLearner();
assertEquals(dimensionsFraction, randomSubspace.getPercentToSample(), 0.0);
assertSame(random, randomSubspace.getRandom());
VectorThresholdVarianceLearner splitLearner = (VectorThresholdVarianceLearner)
randomSubspace.getSubLearner();
assertEquals(minLeafSize, splitLearner.getMinSplitSize());
}
}