package quickml.supervised.tree.regressionTree; import com.google.common.collect.Maps; import org.javatuples.Pair; import org.joda.time.DateTime; import org.joda.time.Duration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.data.instances.RegressionInstance; import quickml.supervised.Utils; import quickml.supervised.crossValidation.RegressionLossChecker; import quickml.supervised.crossValidation.data.FoldedData; import quickml.supervised.crossValidation.data.TrainingDataCycler; import quickml.supervised.crossValidation.lossfunctions.regressionLossFunctions.RegressionRMSELossFunction; import quickml.supervised.crossValidation.utils.DateTimeExtractor; import quickml.supervised.ensembles.randomForest.randomRegressionForest.RandomRegressionForest; import quickml.supervised.ensembles.randomForest.randomRegressionForest.RandomRegressionForestBuilder; import quickml.supervised.predictiveModelOptimizer.FieldValueRecommender; import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizer; import quickml.supervised.predictiveModelOptimizer.SimplePredictiveModelOptimizerBuilder; import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender; import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability; import java.io.Serializable; import java.util.List; import java.util.Map; import static quickml.supervised.tree.constants.ForestOptions.*; /** * Created by alexanderhawk on 3/5/15. */ /* FIXME: This is unnecessarily specialized to out-of-time cross-validation, should be generalized * so that it can support alternate ways to separate training from test set (for example, * any Comparable class can be used to sort the training instances, not just DateTime). */ public class OptimizedRegressionForests { private static final Logger logger = LoggerFactory.getLogger(OptimizedRegressionForests.class); public static <T extends RegressionInstance> Pair<Map<String, Serializable>, RandomRegressionForest> getOptimizedRandomForest(List<T> trainingData, Map<String, FieldValueRecommender> config) { TrainingDataCycler<T> dataCycler = new FoldedData<>(trainingData, 6, 2); return getOptimizedRandomForest(trainingData, config, dataCycler); } public static <T extends RegressionInstance> Pair<Map<String, Serializable>, RandomRegressionForest> getOptimizedRandomForest(List<T> trainingData, Map<String, FieldValueRecommender> config, TrainingDataCycler<T> trainingDataCycler) { RegressionLossChecker<RandomRegressionForest, T> lossChecker = new RegressionLossChecker<>(new RegressionRMSELossFunction()); RandomRegressionForestBuilder<T> modelBuilder = new RandomRegressionForestBuilder<T>(); PredictiveModelOptimizer optimizer= new SimplePredictiveModelOptimizerBuilder<RandomRegressionForest, T>() .modelBuilder(modelBuilder) .dataCycler(trainingDataCycler) .lossChecker(lossChecker) .valuesToTest(config) .iterations(2).build(); Map<String, Serializable> optimalConfig = optimizer.determineOptimalConfig(); modelBuilder.updateBuilderConfig(optimalConfig); return Pair.with(optimalConfig, modelBuilder.buildPredictiveModel(trainingData)); } public static <T extends RegressionInstance> Pair<Map<String, Serializable>, RandomRegressionForest> getOptimizedRandomForest(List<T> trainingData) { Map<String, FieldValueRecommender> config = OptimizedRegressionForests.createConfig(); return getOptimizedRandomForest(trainingData, config, new FoldedData<>(trainingData, 6, 2)); } private static <I extends RegressionInstance> int getTimeSliceHours(List<I> trainingData, int rebuildsPerValidation, DateTimeExtractor<I> dateTimeExtractor) { Utils.sortTrainingInstancesByTime(trainingData, dateTimeExtractor); DateTime latestDateTime = dateTimeExtractor.extractDateTime(trainingData.get(trainingData.size()-1)); int indexOfEarliestValidationInstance = (int) (0.8 * trainingData.size()) - 1; DateTime earliestValidationTime = dateTimeExtractor.extractDateTime(trainingData.get(indexOfEarliestValidationInstance)); Duration duration = new Duration(earliestValidationTime, latestDateTime); int validationPeriodHours = (int)duration.getStandardHours(); return validationPeriodHours/rebuildsPerValidation; } // FIXME: Since most users of QuickML will be content with a default set of hyperparameters, we shouldn't force private static Map<String, FieldValueRecommender> createConfig() { Map<String, FieldValueRecommender> config = Maps.newHashMap(); config.put(MAX_DEPTH.name(), new FixedOrderRecommender(12));//Integer.MAX_VALUE, 2, 3, 5, 6, 9)); // config.put(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name(), new FixedOrderRecommender(7, 14)); config.put(MIN_LEAF_INSTANCES.name(), new FixedOrderRecommender(5));// 10)); config.put(ATTRIBUTE_IGNORING_STRATEGY.name(), new FixedOrderRecommender( // new IgnoreAttributesWithConstantProbability(0.65), new IgnoreAttributesWithConstantProbability(0.75), new IgnoreAttributesWithConstantProbability(0.85), new IgnoreAttributesWithConstantProbability(0.9) )); config.put(MIN_SLPIT_FRACTION.name(), new FixedOrderRecommender(0.0));//, 0.05, 0.2)); config.put(NUM_NUMERIC_BINS.name(), new FixedOrderRecommender(2));//, 5, 8)); config.put(NUM_SAMPLES_PER_NUMERIC_BIN.name(), new FixedOrderRecommender(25)); // config.put(DownsamplingClassifierBuilder.MINORITY_INSTANCE_PROPORTION, new FixedOrderRecommender(.1, .2)); config.put(DEGREE_OF_GAIN_RATIO_PENALTY.name(), new FixedOrderRecommender(1.0, 0.75)); config.put(NUM_TREES.name(), new FixedOrderRecommender(8)); return config; // .degreeOfGainRatioPenalty(1.0) // .attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.85)) // .maxDepth(12) // .minLeafInstances(5) // .minSplitFraction(0.1) // .numNumericBins(6) // .numSamplesPerNumericBin(50) } }