package quickml.supervised.classifier;
import com.google.common.collect.Maps;
import org.javatuples.Pair;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.instances.ClassifierInstance;
import quickml.supervised.Utils;
import quickml.supervised.classifier.downsampling.DownsamplingClassifier;
import quickml.supervised.classifier.downsampling.DownsamplingClassifierBuilder;
import quickml.supervised.crossValidation.ClassifierLossChecker;
import quickml.supervised.crossValidation.data.FoldedData;
import quickml.supervised.crossValidation.data.OutOfTimeData;
import quickml.supervised.crossValidation.data.TrainingDataCycler;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierRMSELossFunction;
import quickml.supervised.crossValidation.utils.DateTimeExtractor;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder;
import quickml.supervised.predictiveModelOptimizer.FieldValueRecommender;
import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizer;
import quickml.supervised.predictiveModelOptimizer.SimplePredictiveModelOptimizerBuilder;
import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender;
import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability;
import quickml.supervised.tree.decisionTree.DecisionTreeBuilder;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import static quickml.supervised.tree.constants.ForestOptions.*;
/**
* Created by alexanderhawk on 3/5/15.
*/
/* FIXME: This is unnecessarily specialized to out-of-time cross-validation, should be generalized
* so that it can support alternate ways to separate training from test set (for example,
* any Comparable class can be used to sort the training instances, not just DateTime).
*/
public class Classifiers {
private static final Logger logger = LoggerFactory.getLogger(Classifiers.class);
public static <T extends ClassifierInstance> Pair<Map<String, Serializable>, RandomDecisionForest> getOptimizedRandomForest(List<T> trainingData) {
FoldedData<T> foldedData = new FoldedData<>(trainingData, 10, 2);
ClassifierLossChecker<T, RandomDecisionForest> classifierInstanceClassifierLossChecker = new ClassifierLossChecker<>(new ClassifierRMSELossFunction());
RandomDecisionForestBuilder<T> modelBuilder = new RandomDecisionForestBuilder<T>();
PredictiveModelOptimizer optimizer= new SimplePredictiveModelOptimizerBuilder<RandomDecisionForest, T>()
.modelBuilder(modelBuilder)
.dataCycler(foldedData)
.lossChecker(classifierInstanceClassifierLossChecker)
.valuesToTest(Classifiers.createConfig())
.iterations(3).build();
Map<String, Serializable> optimalConfig = optimizer.determineOptimalConfig();
modelBuilder.updateBuilderConfig(optimalConfig);
return Pair.with(optimalConfig, modelBuilder.buildPredictiveModel(trainingData));
}
public static <T extends ClassifierInstance> Pair<Map<String, Serializable>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<T> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, DownsamplingClassifierBuilder<T> modelBuilder, Map<String, FieldValueRecommender> config) {
/**
* @param rebuildsPerValidation is the number of times the model will be rebuilt with a new training set while estimating the loss of a model
* with a prarticular set of hyperparameters
* @param fractionOfDataForValidation is the fraction of the training data that out of time validation is performed on during parameter optimization.
* Note, the final model returned by the method uses all data.
*/
int timeSliceHours = getTimeSliceHours(trainingData, rebuildsPerValidation, dateTimeExtractor);
double crossValidationFraction = 0.2;
TrainingDataCycler<T> outOfTimeData = new OutOfTimeData<T>(trainingData, crossValidationFraction, timeSliceHours, dateTimeExtractor);
ClassifierLossChecker<T, DownsamplingClassifier> classifierInstanceClassifierLossChecker = new ClassifierLossChecker<>(lossFunction);
PredictiveModelOptimizer optimizer= new SimplePredictiveModelOptimizerBuilder<DownsamplingClassifier, T>()
.modelBuilder(modelBuilder)
.dataCycler(outOfTimeData)
.lossChecker(classifierInstanceClassifierLossChecker)
.valuesToTest(config)
.iterations(3).build();
Map<String, Serializable> bestParams = optimizer.determineOptimalConfig();
RandomDecisionForestBuilder<T> randomForestBuilder = new RandomDecisionForestBuilder<T>(new DecisionTreeBuilder<T>().attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.7))).numTrees(24);
DownsamplingClassifierBuilder<T> downsamplingClassifierBuilder = new DownsamplingClassifierBuilder<>(randomForestBuilder,0.1);
downsamplingClassifierBuilder.updateBuilderConfig(bestParams);
DownsamplingClassifier downsamplingClassifier = downsamplingClassifierBuilder.buildPredictiveModel(trainingData);
return new Pair<Map<String, Serializable>, DownsamplingClassifier>(bestParams, downsamplingClassifier);
}
public static <T extends ClassifierInstance> Pair<Map<String, Serializable>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<T> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, DownsamplingClassifierBuilder<T> modelBuilder) {
Map<String, FieldValueRecommender> config = createConfig();
return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, modelBuilder, config);
}
public static <T extends ClassifierInstance> Pair<Map<String, Serializable>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<T> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, Map<String, FieldValueRecommender> config) {
DownsamplingClassifierBuilder<T> modelBuilder = new DownsamplingClassifierBuilder<T>(new RandomDecisionForestBuilder<T>(), .1);
return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, modelBuilder, config);
}
public static Pair<Map<String, Serializable>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<? extends ClassifierInstance> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor) {
Map<String, FieldValueRecommender> config = createConfig();
return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, config);
}
private static <I extends ClassifierInstance> int getTimeSliceHours(List<I> trainingData, int rebuildsPerValidation, DateTimeExtractor<I> dateTimeExtractor) {
Utils.sortTrainingInstancesByTime(trainingData, dateTimeExtractor);
DateTime latestDateTime = dateTimeExtractor.extractDateTime(trainingData.get(trainingData.size()-1));
int indexOfEarliestValidationInstance = (int) (0.8 * trainingData.size()) - 1;
DateTime earliestValidationTime = dateTimeExtractor.extractDateTime(trainingData.get(indexOfEarliestValidationInstance));
Duration duration = new Duration(earliestValidationTime, latestDateTime);
int validationPeriodHours = (int)duration.getStandardHours();
return validationPeriodHours/rebuildsPerValidation;
}
// FIXME: Since most users of QuickML will be content with a default set of hyperparameters, we shouldn't force
private static Map<String, FieldValueRecommender> createConfig() {
Map<String, FieldValueRecommender> config = Maps.newHashMap();
config.put(MAX_DEPTH.name(), new FixedOrderRecommender(4, 8, 16));//Integer.MAX_VALUE, 2, 3, 5, 6, 9));
config.put(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name(), new FixedOrderRecommender(7, 14));
config.put(MIN_LEAF_INSTANCES.name(), new FixedOrderRecommender(0, 15));
config.put(DownsamplingClassifierBuilder.MINORITY_INSTANCE_PROPORTION, new FixedOrderRecommender(.1, .2));
config.put(DEGREE_OF_GAIN_RATIO_PENALTY.name(), new FixedOrderRecommender(1.0, 0.75));
return config;
}
}