package quickml.supervised.tree.decisionTree.valueCounters; import com.google.common.collect.Maps; import org.javatuples.Pair; import quickml.collections.ValueSummingMap; import quickml.data.instances.ClassifierInstance; import quickml.supervised.tree.summaryStatistics.ValueCounter; import java.io.Serializable; import java.util.*; import java.util.Map.Entry; public class ClassificationCounter extends ValueCounter<ClassificationCounter> implements Serializable { private static final long serialVersionUID = -6821237234748044623L; private final ValueSummingMap<Serializable> counts = new ValueSummingMap<Serializable>(); //TODO: remove hasSuffientData stuff after debugging private boolean hasSufficientData = true; public void setHasSufficientData(boolean hasSufficientData) { this.hasSufficientData = hasSufficientData; } public boolean hasSufficientData() { return hasSufficientData; } public ClassificationCounter() {} public ClassificationCounter(Serializable attrVal) { super(attrVal); } public boolean isEmpty() { return counts.isEmpty(); } public ClassificationCounter(ClassificationCounter classificationCounter) { super(classificationCounter.attrVal); this.counts.putAll(classificationCounter.counts); } public ClassificationCounter(HashMap<Serializable, ? extends Number> mapOfCounts) { for (Serializable classification: mapOfCounts.keySet()) { counts.addToValue(classification, mapOfCounts.get(classification).doubleValue()); } } public static ClassificationCounter merge(ClassificationCounter a, ClassificationCounter b) { ClassificationCounter newCC = new ClassificationCounter(); newCC.counts.putAll(a.counts); for (Entry<Serializable, Number> e : b.counts.entrySet()) { newCC.counts.addToValue(e.getKey(), e.getValue().doubleValue()); } return newCC; } public static Serializable getLeastPopularClass(ClassificationCounter classificationCounter) { Serializable minClass = null; double minCounts = Double.MAX_VALUE; for (Serializable classification : classificationCounter.allClassifications()) { if (classificationCounter.getCount(classification) < minCounts) { minCounts = classificationCounter.getCount(classification); minClass = classification; } } return minClass; } public static Serializable getMostPopularClass(ClassificationCounter classificationCounter) { Serializable maxClass = null; double maxCounts = 0; Serializable leastPopular = getLeastPopularClass(classificationCounter); //want to ensure don't have the same leastPopular as mostPopular when class ballance is 50/50 for (Serializable classification : classificationCounter.allClassifications()) { if (classificationCounter.getCount(classification) > maxCounts || !classification.equals(leastPopular)) { maxCounts = classificationCounter.getCount(classification); maxClass = classification; } } return maxClass; } //should be abstracted. Data should be in an inner class public Map<Serializable, Double> getCounts() { Map<Serializable, Double> ret = Maps.newHashMap(); for (Entry<Serializable, Number> serializableNumberEntry : counts.entrySet()) { ret.put(serializableNumberEntry.getKey(), serializableNumberEntry.getValue().doubleValue()); } return ret; } public static ClassificationCounter countAll(final Iterable<? extends ClassifierInstance> instances) { final ClassificationCounter result = new ClassificationCounter(); for (ClassifierInstance instance : instances) { result.addClassification(instance.getLabel(), instance.getWeight()); } return result; } public void addClassification(final Serializable classification, double weight) { counts.addToValue(classification, weight); } public double getCount(final Serializable classification) { Number count = counts.get(classification); if (count == null) { return 0; } else { return count.doubleValue(); } } public Set<Serializable> allClassifications() { return counts.keySet(); } public ClassificationCounter add(final ClassificationCounter other) { final ClassificationCounter result = new ClassificationCounter(); result.counts.putAll(counts); for (final Entry<Serializable, Number> e : other.counts.entrySet()) { result.counts.addToValue(e.getKey(), e.getValue().doubleValue()); } return result; } public ClassificationCounter subtract(final ClassificationCounter other) { final ClassificationCounter result = new ClassificationCounter(); result.counts.putAll(counts); for (final Entry<Serializable, Number> e : other.counts.entrySet()) { result.counts.addToValue(e.getKey(), -other.getCount(e.getKey())); } return result; } @Override public double getTotal() { return counts.getSumOfValues(); } public Pair<Serializable, Double> mostPopular() { Entry<Serializable, Number> best = null; for (final Entry<Serializable, Number> e : counts.entrySet()) { if (best == null || e.getValue().doubleValue() > best.getValue().doubleValue()) { best = e; } } return Pair.with(best.getKey(), best.getValue().doubleValue()); } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ClassificationCounter that = (ClassificationCounter) o; if (!counts.equals(that.counts)) return false; return true; } @Override public int hashCode() { return counts.hashCode(); } @Override public String toString() { return getCounts().toString(); } }