/* * File: RandomForestFactory.java * Authors: Justin Basilico * Project: Cognitive Foundry * * Copyright 2015 Cognitive Foundry. All rights reserved. */ package gov.sandia.cognition.learning.algorithm.tree; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.learning.algorithm.ensemble.BaggingCategorizerLearner; import gov.sandia.cognition.learning.algorithm.ensemble.BaggingRegressionLearner; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.util.AbstractCloneableSerializable; import java.util.Random; /** * A factory class for creating Random Forest learners. A random forest is a * combination of using bagging to create an ensemble where the ensemble members * are decision trees where each split in the tree is created by only * considering a random subset of the features (random subspace method). * * @author Justin Basilico * @since 3.4.2 */ @PublicationReference( title="Bagging Predictors", author="Leo Breiman", year=1996, type=PublicationType.Journal, publication="Machine Learning", pages={123, 140}, url="http://www.springerlink.com/index/L4780124W2874025.pdf") public class RandomForestFactory extends AbstractCloneableSerializable { /** * Creates a random forest learner for categorization outputs. * * @param <CategoryType> * The type of categories. * @param ensembleSize * The size of the ensemble to learn. Must be non-negative. * @param baggingFraction * The percentage of the data to sample (with replacement) to train * each ensemble member. * @param dimensionsFraction * The percentage of the dimensions to sample at each node in each * tree when training in order to determine the best split point. * @param maxTreeDepth * The maximum allowed tree depth. Must be positive. * @param minLeafSize * The minimum allowed number of examples that are allowed to fall * into a leaf. * @param random * The random number generator to use. * @return * A new algorithm object for learning a random forest. */ public static <CategoryType> BaggingCategorizerLearner<Vector, CategoryType> createCategorizationLearner( final int ensembleSize, final double baggingFraction, final double dimensionsFraction, final int maxTreeDepth, final int minLeafSize, final Random random) { // The minimum size for a split has to be at least double the leaf // size. final int minSplitSize = 2 * minLeafSize; final CategorizationTreeLearner<Vector, CategoryType> treeLearner = new CategorizationTreeLearner<>( new RandomSubVectorThresholdLearner<>( new VectorThresholdInformationGainLearner<CategoryType>( minLeafSize), dimensionsFraction, random), minSplitSize, maxTreeDepth); return new BaggingCategorizerLearner<>(treeLearner, ensembleSize, baggingFraction, random); } /** * Creates a random forest learner for categorization outputs. * * @param <CategoryType> * The type of categories. * @param ensembleSize * The size of the ensemble to learn. Must be non-negative. * @param baggingFraction * The percentage of the data to sample (with replacement) to train * each ensemble member. * @param dimensionsFraction * The percentage of the dimensions to sample at each node in each * tree when training in order to determine the best split point. * @param maxTreeDepth * The maximum allowed tree depth. Must be positive. * @param minLeafSize * The minimum allowed number of examples that are allowed to fall * into a leaf. * @param random * The random number generator to use. * @return * A new algorithm object for learning a random forest. */ public static <CategoryType> BaggingRegressionLearner<Vector> createRegressionLearner( final int ensembleSize, final double baggingFraction, final double dimensionsFraction, final int maxTreeDepth, final int minLeafSize, final Random random) { final int minSplitSize = 2 * minLeafSize; final RegressionTreeLearner<Vector> treeLearner = new RegressionTreeLearner<>( new RandomSubVectorThresholdLearner<>( new VectorThresholdVarianceLearner(minLeafSize), dimensionsFraction, random), null, minSplitSize, maxTreeDepth); return new BaggingRegressionLearner<>(treeLearner, ensembleSize, baggingFraction, random); } }