package quickml.supervised.ensembles.randomForest.randomDecisionForest;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.AtomicDouble;
import quickml.data.AttributesMap;
import quickml.data.PredictionMap;
import quickml.supervised.classifier.AbstractClassifier;
import quickml.supervised.ensembles.randomForest.RandomForest;
import quickml.supervised.tree.decisionTree.DecisionTree;
import java.io.Serializable;
import java.util.*;
/**
* Created with IntelliJ IDEA.
* User: ian
* Date: 4/18/13
* Time: 4:17 PM
* To change this template use File | Settings | File Templates.
*/
public class RandomDecisionForest extends AbstractClassifier implements RandomForest<PredictionMap, DecisionTree> {
static final long serialVersionUID = 56394564395638954L;
public final List<DecisionTree> decisionTrees;
private Set<Serializable> classifications = new HashSet<>();
private boolean binaryClassification = true;
protected RandomDecisionForest(List<DecisionTree> decisionTrees, Set<Serializable> classifications) {
Preconditions.checkArgument(decisionTrees.size() > 0, "We must have at least one oldTree");
this.decisionTrees = decisionTrees;
this.classifications = classifications;
if (classifications.size() > 2) {
binaryClassification = false;
} else if (classifications.size() < 1) {
throw new RuntimeException("no classes listed in classifications");
}
}
@Override
public double getProbability(AttributesMap attributes, Serializable classification) {
double total = 0;
for (DecisionTree decisionTree : decisionTrees) {
final double probability = decisionTree.getProbability(attributes, classification);
if (Double.isInfinite(probability) || Double.isNaN(probability)) {
throw new RuntimeException("Probability must be a normal number, not "+probability);
}
total += probability;
}
return total / decisionTrees.size();
}
public double getProbabilityWithoutAttributes(AttributesMap attributes, Serializable classification, Set<String> attributesToIgnore) {
double total = 0;
for (DecisionTree decisionTree : decisionTrees) {
final double probability = decisionTree.getProbabilityWithoutAttributes(attributes, classification, attributesToIgnore);
if (Double.isInfinite(probability) || Double.isNaN(probability)) {
throw new RuntimeException("Probability must be a normal number, not "+probability);
}
total += probability;
}
return total / decisionTrees.size();
}
@Override
public PredictionMap predict(final AttributesMap attributes) {
if (binaryClassification) {
return getPredictionForTwoClasses(attributes);
}
else {
return getPredictionForNClasses(attributes);
}
}
private PredictionMap getPredictionForNClasses(AttributesMap attributes) {
PredictionMap sumsByClassification = new PredictionMap(new HashMap<Serializable, Double>());
for (DecisionTree decisionTree : decisionTrees) {
final PredictionMap treeProbs = decisionTree.predict(attributes);
for (Map.Entry<Serializable, Double> tpe : treeProbs.entrySet()) {
Double sum = sumsByClassification.get(tpe.getKey());
if (sum == null) sum = 0.0;
sum += tpe.getValue();
sumsByClassification.put(tpe.getKey(), sum);
}
}
PredictionMap probsByClassification = new PredictionMap(new HashMap<Serializable, Double>());
for (Map.Entry<Serializable, Double> sumEntry : sumsByClassification.entrySet()) {
probsByClassification.put(sumEntry.getKey(), sumEntry.getValue() / decisionTrees.size());
}
return probsByClassification;
}
@Override
public PredictionMap predictWithoutAttributes(AttributesMap attributes, Set<String> attributesToIgnore) {
PredictionMap sumsByClassification = new PredictionMap(new HashMap<Serializable, Double>());
for (DecisionTree decisionTree : decisionTrees) {
final PredictionMap treeProbs = decisionTree.predictWithoutAttributes(attributes, attributesToIgnore);
for (Map.Entry<Serializable, Double> tpe : treeProbs.entrySet()) {
Double sum = sumsByClassification.get(tpe.getKey());
if (sum == null) sum = 0.0;
sum += tpe.getValue();
sumsByClassification.put(tpe.getKey(), sum);
}
}
PredictionMap probsByClassification = new PredictionMap(new HashMap<Serializable, Double>());
for (Map.Entry<Serializable, Double> sumEntry : sumsByClassification.entrySet()) {
probsByClassification.put(sumEntry.getKey(), sumEntry.getValue() / decisionTrees.size());
}
return probsByClassification;
}
private PredictionMap getPredictionForTwoClasses(AttributesMap attributes) {
PredictionMap probsByClassification = PredictionMap.newMap();
Iterator<Serializable> classIterator = classifications.iterator();
if (!classIterator.hasNext()) {
throw new RuntimeException("no class labels present in classification set");
}
Serializable firstClassification = classIterator.next();
double firstProbability = getProbability(attributes, firstClassification);
probsByClassification.put(firstClassification, firstProbability);
if (classIterator.hasNext()) {
Serializable secondClassification = classIterator.next();
probsByClassification.put(secondClassification, 1.0 - firstProbability);
}
return probsByClassification;
}
@Override
public Serializable getClassificationByMaxProb(AttributesMap attributes) {
Map<Serializable, AtomicDouble> probTotals = Maps.newHashMap();
for (DecisionTree decisionTree : decisionTrees) {
PredictionMap predictionMap = decisionTree.predict(attributes);
for (Serializable key : predictionMap.keySet()) {
if (probTotals.containsKey(key)) {
probTotals.put(key, new AtomicDouble(probTotals.get(key).getAndAdd(predictionMap.get(key))));
} else {
probTotals.put(key, new AtomicDouble(predictionMap.get(key)));
}
}
}
Serializable bestClassification = null;
double bestClassificationTtlProb = 0;
for (Map.Entry<Serializable, AtomicDouble> classificationProb : probTotals.entrySet()) {
if (bestClassification == null || classificationProb.getValue().get() > bestClassificationTtlProb) {
bestClassification = classificationProb.getKey();
bestClassificationTtlProb = classificationProb.getValue().get();
}
}
return bestClassification;
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final RandomDecisionForest that = (RandomDecisionForest) o;
if (!decisionTrees.equals(that.decisionTrees)) return false;
return true;
}
@Override
public int hashCode() {
return decisionTrees.hashCode();
}
}