package quickml.supervised.tree.decisionTree;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import quickml.data.AttributesMap;
import quickml.data.PredictionMap;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.classifier.AbstractClassifier;
import quickml.supervised.tree.Tree;
import quickml.supervised.tree.nodes.Branch;
import quickml.supervised.tree.nodes.Leaf;
import quickml.supervised.tree.nodes.LeafDepthStats;
import quickml.supervised.tree.nodes.Node;
import java.io.Serializable;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* Created with IntelliJ IDEA.
* User: janie
* Date: 6/26/13
* Time: 3:15 PM
* To change this template use File | Settings | File Templates.
*/
public class DecisionTree extends AbstractClassifier implements Tree<PredictionMap> {
static final long serialVersionUID = 56394564395635672L;
public final Node<ClassificationCounter> root;
private HashSet<Serializable> classifications = new HashSet<>();
public DecisionTree(Node<ClassificationCounter> root, Set<Serializable> classifications) {
this.root = root;
this.classifications = Sets.newHashSet(classifications);
}
public Set<Serializable> getClassifications() {
return classifications;
}
@Override
public double getProbability(AttributesMap attributes, Serializable classification) {
Leaf<ClassificationCounter> dtLeaf = root.getLeaf(attributes);
ClassificationCounter valueCounter = dtLeaf.getValueCounter();
return valueCounter.getCount(classification) / valueCounter.getTotal();
}
public double calcMeanDepth(){
LeafDepthStats leafDepthStats = new LeafDepthStats();
root.calcLeafDepthStats(leafDepthStats);
return (1.0*leafDepthStats.ttlDepth)/leafDepthStats.ttlSamples;
}
public double calcMedianDepth() {
LeafDepthStats leafDepthStats = new LeafDepthStats();
root.calcLeafDepthStats(leafDepthStats);
long counts = 0;
int depth = 0;
while (counts < leafDepthStats.ttlSamples/2) {
if (leafDepthStats.depthDistribution.containsKey(depth)) {
counts += leafDepthStats.depthDistribution.get(depth);
}
if (counts < leafDepthStats.ttlSamples/2) {
depth++;
}
}
return depth;
}
@Override
public double getProbabilityWithoutAttributes(AttributesMap attributes, Serializable classification, Set<String> attributesToIgnore) {
return getProbabilityWithoutAttributesHelper(root, attributes, classification, attributesToIgnore);
}
private double getProbabilityWithoutAttributesHelper(Node<ClassificationCounter> node, AttributesMap attributes, Serializable classification, Set<String> attributesToIgnore) {
//return getProbabilityOfPositiveClassification(attributes, classification);
if (node instanceof Branch) {
Branch branch = (Branch) node;
if (attributesToIgnore.contains(branch.attribute)) {
return branch.getProbabilityOfTrueChild() * getProbabilityWithoutAttributesHelper(branch.getTrueChild(), attributes, classification, attributesToIgnore) +
(1.0 - branch.getProbabilityOfTrueChild()) * getProbabilityWithoutAttributesHelper(branch.getFalseChild(), attributes, classification, attributesToIgnore);
} else {
if (branch.decide(attributes)) {
return getProbabilityWithoutAttributesHelper(branch.getTrueChild(), attributes, classification, attributesToIgnore);
} else {
return getProbabilityWithoutAttributesHelper(branch.getFalseChild(), attributes, classification, attributesToIgnore);
}
}
} else if (node instanceof Leaf) {
Leaf<ClassificationCounter> leaf = (Leaf<ClassificationCounter>) node;
ClassificationCounter classificationCounter = leaf.getValueCounter();
double prob = classificationCounter.getCount(classification) / classificationCounter.getTotal();
return prob;
}
else {
throw new RuntimeException("node not a branch or a leaf");
}
}
@Override
public PredictionMap predict(AttributesMap attributes) {
Leaf<ClassificationCounter> dtLeaf = root.getLeaf(attributes);
ClassificationCounter valueCounter = dtLeaf.getValueCounter();
Map<Serializable, Double> probsByClassification = Maps.newHashMap();
for (Serializable classification : valueCounter.allClassifications()) {
double probability = valueCounter.getCount(classification) / valueCounter.getTotal();
probsByClassification.put(classification, probability);
}
return new PredictionMap(probsByClassification);
}
@Override
public PredictionMap predictWithoutAttributes(AttributesMap attributes, Set<String> attributesToIgnore) {
Map<Serializable, Double> probsByClassification = Maps.newHashMap();
for (Serializable classification : classifications) {
probsByClassification.put(classification, getProbabilityWithoutAttributes(attributes, classification, attributesToIgnore));
}
return new PredictionMap(probsByClassification);
}
@Override
public Serializable getClassificationByMaxProb(AttributesMap attributes) {
Leaf<ClassificationCounter> leaf = root.getLeaf(attributes);
ClassificationCounter classificationCounter = leaf.getValueCounter();
return classificationCounter.mostPopular().getValue0();//returns best class.
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final DecisionTree decisionTree = (DecisionTree) o;
if (!root.equals(decisionTree.root)) return false;
return true;
}
@Override
public int hashCode() {
return root.hashCode();
}
protected transient volatile Map.Entry<Serializable, Double> bestClassificationEntry = null;
}