package quickml.supervised.tree.regressionTree;
import quickml.data.AttributesMap;
import quickml.supervised.tree.Tree;
import quickml.supervised.tree.nodes.Branch;
import quickml.supervised.tree.nodes.Leaf;
import quickml.supervised.tree.nodes.LeafDepthStats;
import quickml.supervised.tree.nodes.Node;
import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter;
import java.io.Serializable;
import java.util.Map;
import java.util.Set;
/**
* Created with IntelliJ IDEA.
* User: janie
* Date: 6/26/13
* Time: 3:15 PM
* To change this template use File | Settings | File Templates.
*/
public class RegressionTree implements Tree<Double> {
static final long serialVersionUID = 56394564395635672L;
public final Node<MeanValueCounter> root;
public RegressionTree(Node<MeanValueCounter> root) {
this.root = root;
}
@Override
public Double predict(AttributesMap attributes) {
Leaf<MeanValueCounter> dtLeaf = root.getLeaf(attributes);
MeanValueCounter valueCounter = dtLeaf.getValueCounter();
return valueCounter.getAccumulatedValue() / valueCounter.getTotal();
}
public double calcMeanDepth(){
LeafDepthStats leafDepthStats = new LeafDepthStats();
root.calcLeafDepthStats(leafDepthStats);
return (1.0*leafDepthStats.ttlDepth)/leafDepthStats.ttlSamples;
}
public double calcMedianDepth() {
LeafDepthStats leafDepthStats = new LeafDepthStats();
root.calcLeafDepthStats(leafDepthStats);
long counts = 0;
int depth = 0;
while (counts < leafDepthStats.ttlSamples/2) {
if (leafDepthStats.depthDistribution.containsKey(depth)) {
counts += leafDepthStats.depthDistribution.get(depth);
}
if (counts < leafDepthStats.ttlSamples/2) {
depth++;
}
}
return depth;
}
@Override
public Double predictWithoutAttributes(AttributesMap attributes, Set<String> attributesToIgnore) {
return getPredictionWithoutAttributesHelper(root, attributes, attributesToIgnore);
}
private double getPredictionWithoutAttributesHelper(Node<MeanValueCounter> node, AttributesMap attributes, Set<String> attributesToIgnore) {
//return getProbabilityOfPositiveClassification(attributes, classification);
if (node instanceof Branch) {
Branch branch = (Branch) node;
if (attributesToIgnore.contains(branch.attribute)) {
return branch.getProbabilityOfTrueChild() * getPredictionWithoutAttributesHelper(branch.getTrueChild(), attributes, attributesToIgnore) +
(1.0 - branch.getProbabilityOfTrueChild()) * getPredictionWithoutAttributesHelper(branch.getFalseChild(), attributes, attributesToIgnore);
} else {
if (branch.decide(attributes)) {
return getPredictionWithoutAttributesHelper(branch.getTrueChild(), attributes, attributesToIgnore);
} else {
return getPredictionWithoutAttributesHelper(branch.getFalseChild(), attributes, attributesToIgnore);
}
}
} else if (node instanceof Leaf) {
Leaf<MeanValueCounter> leaf = (Leaf<MeanValueCounter>) node;
MeanValueCounter meanValueCounter = leaf.getValueCounter();
double expectedValue = meanValueCounter.getAccumulatedValue()/ meanValueCounter.getTotal();
return expectedValue;
}
else {
throw new RuntimeException("node not a branch or a leaf");
}
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final quickml.supervised.tree.decisionTree.DecisionTree decisionTree = (quickml.supervised.tree.decisionTree.DecisionTree) o;
if (!root.equals(decisionTree.root)) return false;
return true;
}
@Override
public int hashCode() {
return root.hashCode();
}
protected transient volatile Map.Entry<Serializable, Double> bestClassificationEntry = null;
}