package quickml.supervised;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.joda.time.DateTime;
import quickml.data.AttributesMap;
import quickml.data.instances.ClassifierInstance;
import quickml.data.instances.Instance;
import quickml.data.instances.InstanceWithAttributesMap;
import quickml.data.PredictionMap;
import quickml.data.instances.RegressionInstance;
import quickml.supervised.classifier.Classifier;
import quickml.supervised.classifier.logisticRegression.SparseClassifierInstance;
import quickml.supervised.crossValidation.PredictionMapResult;
import quickml.supervised.crossValidation.PredictionMapResults;
import quickml.supervised.crossValidation.lossfunctions.LabelPredictionWeight;
import quickml.supervised.crossValidation.utils.DateTimeExtractor;
import quickml.supervised.dataProcessing.AttributeCharacteristics;
import quickml.supervised.dataProcessing.BinaryAttributeCharacteristics;
import quickml.supervised.tree.nodes.Branch;
import quickml.supervised.tree.nodes.LeafDepthStats;
import quickml.supervised.tree.nodes.Node;
import quickml.supervised.tree.summaryStatistics.ValueCounter;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.Serializable;
import java.math.BigDecimal;
import java.nio.DoubleBuffer;
import java.util.*;
/**
* Created by alexanderhawk on 7/31/14.
*/
public class Utils {
public static <R, L, P> List<LabelPredictionWeight<L, P>> createLabelPredictionWeights(List<? extends Instance> instances, PredictiveModel<R, P> predictiveModel) {
List<LabelPredictionWeight<L, P>> labelPredictionWeights = Lists.newArrayList();
for (Instance<R, L> instance : instances) {
LabelPredictionWeight<L, P> labelPredictionWeight = new LabelPredictionWeight<>(instance.getLabel(), predictiveModel.predict(instance.getAttributes()), instance.getWeight());
labelPredictionWeights.add(labelPredictionWeight);
}
return labelPredictionWeights;
}
public static <R, L, P> List<LabelPredictionWeight<L, P>> createLabelPredictionWeightsWithoutAttributes(List<? extends Instance<R, L>> instances, PredictiveModel<R, P> predictiveModel, Set<String> attributesToIgnore) {
List<LabelPredictionWeight<L, P>> labelPredictionWeights = Lists.newArrayList();
for (Instance<R, L> instance : instances) {
LabelPredictionWeight<L, P> labelPredictionWeight = new LabelPredictionWeight<>(instance.getLabel(),
predictiveModel.predictWithoutAttributes(instance.getAttributes(), attributesToIgnore), instance.getWeight());
labelPredictionWeights.add(labelPredictionWeight);
}
return labelPredictionWeights;
}
public static double getInstanceWeights(List<? extends Instance> instances) {
double weight = 0;
for (Instance instance : instances) {
weight += instance.getWeight();
}
return weight;
}
public static List<LabelPredictionWeight<Double, Double>> getRegLabelsPredictionsWeights(PredictiveModel<AttributesMap, Double> predictiveModel, List<? extends Instance<AttributesMap, Double>> validationSet) {
ArrayList<LabelPredictionWeight<Double, Double>> results = new ArrayList<>();
for (Instance<AttributesMap, Double> instance : validationSet) {
results.add(new LabelPredictionWeight<Double, Double>(instance.getLabel(), predictiveModel.predict(instance.getAttributes()), instance.getWeight()));
}
return results;
}
public static List<LabelPredictionWeight<Double, Double>> getRegLabelsPredictionsWeights(PredictiveModel<AttributesMap, Double> predictiveModel, List<? extends Instance<AttributesMap, Double>> validationSet, BufferedWriter bw) {
ArrayList<LabelPredictionWeight<Double, Double>> results = new ArrayList<>();
for (Instance<AttributesMap, Double> instance : validationSet) {
Double prediction = predictiveModel.predict(instance.getAttributes());
Long id = ((RegressionInstance)instance).id;
results.add(new LabelPredictionWeight<Double, Double>(instance.getLabel(), prediction, instance.getWeight()));
try {
bw.write(""+id + "," + instance.getLabel() + "," + prediction + "\n");
} catch (IOException e) {
e.printStackTrace();
}
}
return results;
}
public static PredictionMapResults calcResultPredictions(Classifier predictiveModel, List<? extends InstanceWithAttributesMap<?>> validationSet) {
ArrayList<PredictionMapResult> results = new ArrayList<>();
for (InstanceWithAttributesMap<?> instance : validationSet) {
results.add(new PredictionMapResult(predictiveModel.predict(instance.getAttributes()), instance.getLabel(), instance.getWeight()));
}
return new PredictionMapResults(results);
}
public static PredictionMapResults calcResultpredictionsWithoutAttrs(Classifier predictiveModel, List<? extends InstanceWithAttributesMap<?>> validationSet, Set<String> attributesToIgnore) {
ArrayList<PredictionMapResult> results = new ArrayList<>();
for (InstanceWithAttributesMap<?> instance : validationSet) {
PredictionMap prediction = predictiveModel.predictWithoutAttributes(instance.getAttributes(), attributesToIgnore);
results.add(new PredictionMapResult(prediction, instance.getLabel(), instance.getWeight()));
}
return new PredictionMapResults(results);
}
public static List<LabelPredictionWeight<Double, Double>> calcLabelPredictionsWeightsWithoutAttrs(PredictiveModel<AttributesMap, Double> predictiveModel, List<? extends RegressionInstance> validationSet, Set<String> attributesToIgnore) {
ArrayList<LabelPredictionWeight<Double, Double>> results = new ArrayList<>();
for (RegressionInstance instance : validationSet) {
Double prediction = predictiveModel.predictWithoutAttributes(instance.getAttributes(), attributesToIgnore);
results.add(new LabelPredictionWeight<Double, Double>(prediction, instance.getLabel(), instance.getWeight()));
}
return results;
}
public static <T extends InstanceWithAttributesMap<?>> void sortTrainingInstancesByTime(List<T> trainingData, final DateTimeExtractor<T> dateTimeExtractor) {
Collections.sort(trainingData, new Comparator<T>() {
@Override
public int compare(T o1, T o2) {
DateTime dateTime1 = dateTimeExtractor.extractDateTime(o1);
DateTime dateTime2 = dateTimeExtractor.extractDateTime(o2);
return dateTime1.compareTo(dateTime2);
}
});
}
public static <T> List<T> iterableToList(Iterable<T> trainingData) {
if (trainingData instanceof List) {
return (List<T>) trainingData;
}
return Lists.newArrayList(trainingData);
}
public static <T extends ClassifierInstance> List<ClassifierInstance> iterableToListOfClassifierInstances(Iterable<T> trainingData) {
List<ClassifierInstance> returnList = Lists.newArrayListWithCapacity(Iterables.size(trainingData));
for (T instance : trainingData) {
returnList.add(instance);
}
return returnList;
}
public static <T extends SparseClassifierInstance> List<SparseClassifierInstance> iterableToListOfSparseClassifierInstances(Iterable<T> trainingData) {
List<SparseClassifierInstance> returnList = Lists.newArrayListWithCapacity(Iterables.size(trainingData));
for (T instance : trainingData) {
returnList.add(instance);
}
return returnList;
}
public static <T extends InstanceWithAttributesMap<?>> TrueFalsePair<T> setTrueAndFalseTrainingSets(List<T> trainingData, Branch bestNode) {
/**fly weight pattern */
int firstIndexOfFalseSet = trainingData.size();
int trialFirstIndexOfFalseSet = firstIndexOfFalseSet - 1;
firstIndexOfFalseSet = repartitionTrainingData(trainingData, bestNode, firstIndexOfFalseSet, trialFirstIndexOfFalseSet);
List<T> trueTrainingSet = trainingData.subList(0, firstIndexOfFalseSet);
List<T> falseTrainingSet = trainingData.subList(firstIndexOfFalseSet, trainingData.size());
return new TrueFalsePair(trueTrainingSet, falseTrainingSet);
}
public static <T extends InstanceWithAttributesMap<?>> int repartitionTrainingData(List<T> trainingData, Branch bestNode, int firstIndexOfFalseSet, int trialFirstIndexOfFalseSet) {
for (int i = 0; i < trainingData.size() && firstIndexOfFalseSet > i; i++) {
T instance = trainingData.get(i);
if (bestNode.decide(instance.getAttributes())) {
continue; //the above condition ensures the instance at position i is in the trueSet
} else {
//Since we now know the instance is not in true set, we swap with whatever instance sits just before the the firstIndexOfTheFalseSet. If the new instance at i is in the trueSet,
//we return to the loop over i. If not, we decrement firstnIndexOfTheFalseSet, and try swapping again. We repeat until we either get
// a trueInstance at position i or we find that the firstIndexOfFalseSet is actually i.
while (!bestNode.decide(trainingData.get(i).getAttributes()) && (trialFirstIndexOfFalseSet >= i)) {
if (i == trialFirstIndexOfFalseSet) { //edge case
firstIndexOfFalseSet = trialFirstIndexOfFalseSet; //we have verified the instance is in the false set by virtue of being in the else block
break;
}
//swap
swap(i, trialFirstIndexOfFalseSet, trainingData);
firstIndexOfFalseSet = trialFirstIndexOfFalseSet; //the instance we moved into the position indexed by trialFirstIndexOfFalseSet is known to be in the falseSet
trialFirstIndexOfFalseSet--;
}
}
}
return firstIndexOfFalseSet;
}
private static <T extends InstanceWithAttributesMap<?>> void swap(int i, int trialFirstIndexOfFalseSet, List<T> trainingData) {
T temp = trainingData.get(trialFirstIndexOfFalseSet);
trainingData.set(trialFirstIndexOfFalseSet, trainingData.get(i));
trainingData.set(i, temp);
}
public static class TrueFalsePair<T extends InstanceWithAttributesMap<?>> {
public List<T> trueTrainingSet;
public List<T> falseTrainingSet;
public TrueFalsePair(List<T> trueTrainingSet, List<T> falseTrainingSet) {
this.trueTrainingSet = trueTrainingSet;
this.falseTrainingSet = falseTrainingSet;
}
}
public static <VC extends ValueCounter<VC>> double meanDepth(Node<VC> node) {
final LeafDepthStats stats = new LeafDepthStats();
node.calcLeafDepthStats(stats);
return (double) stats.ttlDepth / stats.ttlSamples;
}
public static <I extends InstanceWithAttributesMap<?>> Map<String, MeanStdMaxMin> getMeanStdMaxMins(Map<String, AttributeCharacteristics> attributeCharacteristics,
List<I> instances) {
Map<String, MeanStdMaxMin> meansAndStds = Maps.newHashMap();
for (I instance : instances) {
AttributesMap attributes = instance.getAttributes();
for (String key : attributes.keySet()) {
if (attributeCharacteristics.get(key).isNumber) {
if (!meansAndStds.containsKey(key)) {
meansAndStds.put(key, new MeanStdMaxMin());
}
MeanStdMaxMin meanStdMaxMin = meansAndStds.get(key);
meanStdMaxMin.update(((Number) attributes.get(key)).doubleValue());
}
}
}
return meansAndStds;
}
public static <I extends InstanceWithAttributesMap<?>> Map<String, MeanStdMaxMin> getMeanStdMaxMins(List<I> instances) {
Map<String, MeanStdMaxMin> meansAndStds = Maps.newHashMap();
for (I instance : instances) {
AttributesMap attributes = instance.getAttributes();
for (String key : attributes.keySet()) {
if (!meansAndStds.containsKey(key)) {
meansAndStds.put(key, new MeanStdMaxMin());
}
MeanStdMaxMin meanStdMaxMin = meansAndStds.get(key);
meanStdMaxMin.update(((Number) attributes.get(key)).doubleValue());
}
}
return meansAndStds;
}
public static class MeanStdMaxMin {
BigDecimal runningSum = new BigDecimal(0);
BigDecimal runningSumOfSquares = new BigDecimal(0);
double totalWeight = 0;
double mean = 0;
double max = 0;
double min = 0;
double std = 0;
public MeanStdMaxMin() {
}
public void update(double val) {
this.update(val, 1.0);
}
public void update(double val, double weight) {
BigDecimal bigVal = new BigDecimal(val);
runningSum = runningSum.add(bigVal);
BigDecimal augendSquared = bigVal.multiply(bigVal);
runningSumOfSquares = runningSumOfSquares.add(augendSquared);
totalWeight += weight;
// mean = runningSum / totalWeight;
if (max < val) {
max= val;
}
if (min > val) {
min = val;
}
}
public double getMean() {
return runningSum.divide(new BigDecimal(totalWeight), 3, BigDecimal.ROUND_HALF_UP).doubleValue();
}
public double getNonZeroStd() {
if (totalWeight ==0 ) {
return (getMaxMinMinusMin() == 0) ? 1.0 : getMaxMinMinusMin();
} else {
BigDecimal bigTotalWeight = new BigDecimal(totalWeight);
BigDecimal secondMoment = (runningSumOfSquares.divide(bigTotalWeight, 3, BigDecimal.ROUND_HALF_UP));
BigDecimal firstMoment = (runningSum.divide(bigTotalWeight, 3, BigDecimal.ROUND_HALF_UP));
BigDecimal firstMomentSquared = firstMoment.multiply(firstMoment);
if (firstMomentSquared.equals(secondMoment)) {
return getMaxMinMinusMin();
}
BigDecimal stdSquared = secondMoment.subtract(firstMomentSquared);
return (Math.sqrt(stdSquared.doubleValue()) == 0 ) ? 1.0 : Math.sqrt(stdSquared.doubleValue()); }
}
public double getMaxMinMinusMin(){
return max-min;
}
}
public static <T extends InstanceWithAttributesMap<?>> Map<String, BinaryAttributeCharacteristics> getMapOfAttributesToBinaryAttributeCharacteristics(List<T> trainingData) {
Map<String, BinaryAttributeCharacteristics> attributeCharacteristics = Maps.newHashMap();
for (T instance : trainingData) {
for (Map.Entry<String, Serializable> e : instance.getAttributes().entrySet()) {
BinaryAttributeCharacteristics attributeCharacteristic = attributeCharacteristics.get(e.getKey());
if (attributeCharacteristic == null) {
attributeCharacteristic = new BinaryAttributeCharacteristics();
attributeCharacteristics.put(e.getKey(), attributeCharacteristic);
}
attributeCharacteristic.updateBinaryStatus((Double) e.getValue());
}
}
return attributeCharacteristics;
}
}