package quickml.supervised.tree.decisionTree.nodes; import com.google.common.base.Preconditions; import com.google.common.collect.Iterables; import quickml.data.AttributesMap; import quickml.data.instances.ClassifierInstance; import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; import quickml.supervised.tree.nodes.Branch; import quickml.supervised.tree.nodes.Leaf; import quickml.supervised.tree.nodes.LeafDepthStats; import java.io.Serializable; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; public class DTLeaf implements Leaf<ClassificationCounter>, Serializable { private static final long serialVersionUID = -5617660873196498754L; private static final AtomicLong guidCounter = new AtomicLong(0); public final long guid; public ClassificationCounter getValueCounter() { return classificationCounts; } /** * How deep in the oldTree is this label? A lower number typically indicates a * more confident getBestClassification. */ public final int depth; public final Branch<ClassificationCounter> parent; /** * How many training examples matched this leaf? A higher number indicates a * more confident getBestClassification. */ public double exampleCount; /** * The actual getBestClassification counts */ private final ClassificationCounter classificationCounts; public DTLeaf(Branch<ClassificationCounter> parent, final Iterable<? extends ClassifierInstance> instances, final int depth) { this(parent, ClassificationCounter.countAll(instances), depth); Preconditions.checkArgument(!Iterables.isEmpty(instances), "Can't create leaf with no instances"); } public DTLeaf(Branch<ClassificationCounter> parent, final ClassificationCounter classificationCounts, final int depth) { guid = guidCounter.incrementAndGet(); this.classificationCounts = classificationCounts; Preconditions.checkState(classificationCounts.getTotal() > 0, "Classifications must be > 0"); exampleCount = classificationCounts.getTotal(); this.depth = depth; this.parent = parent; } @Override public int getDepth() { return depth; } @Override public Branch<ClassificationCounter> getParent() { return parent; } /** * @return The most likely classification */ @Override public DTLeaf getLeaf(final AttributesMap attributes) { return this; } @Override public int getSize() { return 1; } //TODO: move this up when Java 8 is migrated too @Override public void calcLeafDepthStats(final LeafDepthStats stats) { stats.ttlDepth += depth * exampleCount; stats.ttlSamples += exampleCount; Map<Integer, Long> dist = stats.depthDistribution; if (dist.containsKey(depth)) { dist.put(depth, dist.get(depth) + (long)exampleCount); } else { dist.put(depth, (long)exampleCount); } } @Override public String toString() { StringBuilder builder = new StringBuilder(); for (Serializable key : classificationCounts.getCounts().keySet()) { builder.append(key + "=" + this.classificationCounts.getCounts().get(key)/classificationCounts.getTotal() + " "); } return builder.toString(); } @Override public boolean equals(final Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; final DTLeaf DTLeaf = (DTLeaf) o; if (depth != DTLeaf.depth) return false; if (Double.compare(DTLeaf.exampleCount, exampleCount) != 0) return false; if (!classificationCounts.equals(DTLeaf.classificationCounts)) return false; return true; } @Override public int hashCode() { int result; long temp; result = depth; temp = Double.doubleToLongBits(exampleCount); result = 31 * result + (int) (temp ^ (temp >>> 32)); result = 31 * result + classificationCounts.hashCode(); return result; } }