package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree; import com.google.common.base.Preconditions; import com.google.common.collect.Iterables; import quickml.data.AttributesMap; import quickml.data.instances.ClassifierInstance; import java.io.IOException; import java.io.Serializable; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicLong; public class OldLeaf extends OldNode { private static final long serialVersionUID = -5617660873196498754L; private static final AtomicLong guidCounter = new AtomicLong(0); public final long guid; /** * How deep in the oldTree is this label? A lower number typically indicates a * more confident getBestClassification. */ public final int depth; /** * How many training examples matched this leaf? A higher number indicates a * more confident getBestClassification. */ public double exampleCount; /** * The actual getBestClassification counts */ public final OldClassificationCounter classificationCounts; protected transient volatile Map.Entry<Serializable, Double> bestClassificationEntry = null; public OldLeaf(OldNode parent, final Iterable<? extends ClassifierInstance> instances, final int depth) { this(parent, OldClassificationCounter.countAll(instances), depth); Preconditions.checkArgument(!Iterables.isEmpty(instances), "Can't create leaf with no instances"); } public OldLeaf(OldNode parent, final OldClassificationCounter classificationCounts, final int depth) { super(parent); guid = guidCounter.incrementAndGet(); this.classificationCounts = classificationCounts; Preconditions.checkState(classificationCounts.getTotal() > 0, "Classifications must be > 0"); exampleCount = classificationCounts.getTotal(); this.depth = depth; } /** * @return The most likely classification */ public Serializable getBestClassification() { return getBestClassificationEntry().getKey(); } protected synchronized Map.Entry<Serializable, Double> getBestClassificationEntry() { if (bestClassificationEntry != null) return bestClassificationEntry; for (Map.Entry<Serializable, Double> e : classificationCounts.getCounts().entrySet()) { if (bestClassificationEntry == null || e.getValue() > bestClassificationEntry.getValue()) { bestClassificationEntry = e; } } return bestClassificationEntry; } @Override public void dump(final int indent, final Appendable ap) { try { for (int x = 0; x < indent; x++) { ap.append(' '); } ap.append(this + "\n"); } catch (IOException e) { e.printStackTrace(); throw new RuntimeException(); } } @Override public OldLeaf getLeaf(final AttributesMap attributes) { return this; } @Override public int size() { return 1; } @Override protected void calcLeafDepthStats(final LeafDepthStats stats) { stats.ttlDepth += depth * exampleCount; stats.ttlSamples += exampleCount; Map<Integer, Long> dist = stats.depthDistribution; if (dist.containsKey(depth)) { dist.put(depth, dist.get(depth) + (long) exampleCount); } else { dist.put(depth, (long) exampleCount); } } @Override public String toString() { StringBuilder builder = new StringBuilder(); for (Serializable key : getClassifications()) { builder.append(key + "=" + this.getProbability(key) + " "); } return builder.toString(); } public double getProbability(Serializable classification) { final double totalCount = classificationCounts.getTotal(); if (totalCount == 0) { throw new IllegalStateException("Trying to get a probability from a Leaf with no examples"); } final double probability = classificationCounts.getCount(classification) / totalCount; return probability; } public double getProbabilityWithoutAttributes(AttributesMap attributes, Serializable classification, Set<String> attribute) { return getProbability(classification); } public Set<Serializable> getClassifications() { return classificationCounts.getCounts().keySet(); } @Override public boolean equals(final Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; final OldLeaf oldLeaf = (OldLeaf) o; if (depth != oldLeaf.depth) return false; if (Double.compare(oldLeaf.exampleCount, exampleCount) != 0) return false; if (!classificationCounts.equals(oldLeaf.classificationCounts)) return false; return true; } @Override public int hashCode() { int result; long temp; result = depth; temp = Double.doubleToLongBits(exampleCount); result = 31 * result + (int) (temp ^ (temp >>> 32)); result = 31 * result + classificationCounts.hashCode(); return result; } }