package quickml.supervised.classifier.temporallyWeightClassifier;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.joda.time.DateTime;
import org.joda.time.DateTimeConstants;
import org.joda.time.Hours;
import quickml.data.instances.ClassifierInstance;
import quickml.supervised.PredictiveModelBuilder;
import quickml.supervised.classifier.Classifier;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.crossValidation.utils.DateTimeExtractor;
import java.io.Serializable;
import java.util.*;
/**
* Created by ian on 5/29/14.
*/
public class TemporallyReweightedClassifierBuilder implements PredictiveModelBuilder<TemporallyReweightedClassifier, ClassifierInstance> {
public static final String HALF_LIFE_OF_NEGATIVE = "halfLifeOfNegative";
public static final String HALF_LIFE_OF_POSITIVE = "halfLifeOfPositive";
private static final double DEFAULT_DECAY_CONSTANT = 173; //approximately 5 days
private double decayConstantOfPositive = DEFAULT_DECAY_CONSTANT;
private double decayConstantOfNegative = DEFAULT_DECAY_CONSTANT;
private final PredictiveModelBuilder<Classifier, ClassifierInstance> wrappedBuilder;
private final Serializable positiveClassification;
private final DateTimeExtractor dateTimeExtractor;
public TemporallyReweightedClassifierBuilder(PredictiveModelBuilder<Classifier, ClassifierInstance> wrappedBuilder, Serializable positiveClassification, DateTimeExtractor dateTimeExtractor) {
this.wrappedBuilder = wrappedBuilder;
this.positiveClassification = positiveClassification;
this.dateTimeExtractor = dateTimeExtractor;
}
@Override
public void updateBuilderConfig(Map<String, Serializable> config) {
wrappedBuilder.updateBuilderConfig(config);
if (config.containsKey(HALF_LIFE_OF_NEGATIVE))
halfLifeOfNegative((Double) config.get(HALF_LIFE_OF_NEGATIVE));
if (config.containsKey(HALF_LIFE_OF_POSITIVE))
halfLifeOfPositive((Double) config.get(HALF_LIFE_OF_POSITIVE));
}
public TemporallyReweightedClassifierBuilder halfLifeOfPositive(double halfLifeOfPositiveInDays) {
this.decayConstantOfPositive = halfLifeOfPositiveInDays * DateTimeConstants.HOURS_PER_DAY / Math.log(2);
return this;
}
public TemporallyReweightedClassifierBuilder halfLifeOfNegative(double halfLifeOfNegativeInDays) {
this.decayConstantOfNegative = halfLifeOfNegativeInDays * DateTimeConstants.HOURS_PER_DAY / Math.log(2);
return this;
}
public TemporallyReweightedClassifierBuilder DateTimeExtractor(double halfLifeOfNegativeInDays) {
this.decayConstantOfNegative = halfLifeOfNegativeInDays * DateTimeConstants.HOURS_PER_DAY / Math.log(2);
return this;
}
@Override
public TemporallyReweightedClassifier buildPredictiveModel(Iterable<ClassifierInstance> trainingData) {
validateData(trainingData);
DateTime mostRecent = getMostRecentInstance(trainingData);
List<ClassifierInstance> trainingDataList = sortAndReweightTrainingData(trainingData, mostRecent);
final Classifier predictiveModel = wrappedBuilder.buildPredictiveModel(trainingDataList);
return new TemporallyReweightedClassifier(predictiveModel);
}
private List<ClassifierInstance> sortAndReweightTrainingData(Iterable<ClassifierInstance> trainingData, DateTime mostRecentInstance) {
ArrayList<ClassifierInstance> sortedData = Lists.newArrayList();
for (ClassifierInstance inst : trainingData) {
sortedData.add(inst);
}
Collections.sort(sortedData, new Comparator<ClassifierInstance>() {
@Override
public int compare(ClassifierInstance o1, ClassifierInstance o2) {
DateTime d1 = dateTimeExtractor.extractDateTime(o1);
DateTime d2 = dateTimeExtractor.extractDateTime(o2);
return -d1.compareTo(d2); //later times shoudl be sorted ahead of earlier times
}
});
ArrayList<ClassifierInstance> trainingDataList = Lists.newArrayList();
for (ClassifierInstance instance : sortedData) {
double decayConstant = (instance.getLabel().equals(positiveClassification)) ? decayConstantOfPositive : decayConstantOfNegative;
double hoursBack = Hours.hoursBetween(mostRecentInstance, dateTimeExtractor.extractDateTime(instance)).getHours();
double newWeight = Math.exp(-1.0 * hoursBack / decayConstant);
//TODO[mk] Reweight needs to be moved / removed
trainingDataList.add(new ClassifierInstance(instance.getAttributes(), instance.getLabel(), newWeight));
}
return trainingDataList;
}
private void validateData(Iterable<ClassifierInstance> trainingData) {
ClassificationCounter classificationCounter = ClassificationCounter.countAll(trainingData);
Preconditions.checkArgument(classificationCounter.getCounts().keySet().size() <= 2, "trainingData must contain only 2 classifications, but it had %s", classificationCounter.getCounts().keySet().size());
}
private DateTime getMostRecentInstance(Iterable<ClassifierInstance> newData) {
DateTime mostRecent = null;
for (ClassifierInstance instance : newData) {
if (mostRecent == null || dateTimeExtractor.extractDateTime(instance).isAfter(mostRecent)) {
mostRecent = dateTimeExtractor.extractDateTime(instance);
}
}
return mostRecent;
}
}