package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Ordering; import org.javatuples.Pair; import quickml.collections.ValueSummingMap; import quickml.data.instances.ClassifierInstance; import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; import java.io.Serializable; import java.util.*; import java.util.Map.Entry; import static quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.OldTreeBuilder.MISSING_VALUE; public class OldClassificationCounter implements Serializable { private static final long serialVersionUID = -6821237234748044623L; private final ValueSummingMap<Serializable> counts = new ValueSummingMap<Serializable>(); public OldClassificationCounter() { } public OldClassificationCounter(OldClassificationCounter classificationCounter) { this.counts.putAll(classificationCounter.counts); } public OldClassificationCounter(ClassificationCounter classificationCounter) { this.counts.putAll(classificationCounter.getCounts()); } private boolean hasSufficientData = true; public void setHasSufficientData(boolean hasSufficientData) { this.hasSufficientData = hasSufficientData; } public boolean hasSufficientData() { return hasSufficientData; } public static OldClassificationCounter merge(OldClassificationCounter a, OldClassificationCounter b) { OldClassificationCounter newCC = new OldClassificationCounter(); newCC.counts.putAll(a.counts); for (Entry<Serializable, Number> e : b.counts.entrySet()) { newCC.counts.addToValue(e.getKey(), e.getValue().doubleValue()); } return newCC; } public static Pair<OldClassificationCounter, Map<Serializable, OldClassificationCounter>> countAllByAttributeValues( final Iterable<? extends ClassifierInstance> instances, final String attribute) { final Map<Serializable, OldClassificationCounter> result = Maps.newHashMap(); final OldClassificationCounter totals = new OldClassificationCounter(); for (ClassifierInstance instance : instances) { final Serializable attrVal = instance.getAttributes().get(attribute); OldClassificationCounter cc; boolean acceptableMissingValue = attrVal == null; if (attrVal != null) cc = result.get(attrVal); else if (acceptableMissingValue) cc = result.get(MISSING_VALUE); else continue; if (cc == null) { cc = new OldClassificationCounter(); Serializable newKey = (attrVal != null) ? attrVal : MISSING_VALUE; result.put(newKey, cc); } cc.addClassification(instance.getLabel(), instance.getWeight()); totals.addClassification(instance.getLabel(), instance.getWeight()); } return Pair.with(totals, result); } public static Pair<OldClassificationCounter, List<OldAttributeValueWithClassificationCounter>> getSortedListOfAttributeValuesWithClassificationCounters( final Iterable<? extends ClassifierInstance> instances, final String attribute, final Serializable minorityClassification) { Pair<OldClassificationCounter, Map<Serializable, OldClassificationCounter>> totalsClassificationCounterPairedWithMapofClassificationCounters = countAllByAttributeValues(instances, attribute); final Map<Serializable, OldClassificationCounter> result = totalsClassificationCounterPairedWithMapofClassificationCounters.getValue1(); final OldClassificationCounter totals = totalsClassificationCounterPairedWithMapofClassificationCounters.getValue0(); List<OldAttributeValueWithClassificationCounter> attributesWithClassificationCounters = Lists.newArrayList(); for (Serializable key : result.keySet()) { attributesWithClassificationCounters.add(new OldAttributeValueWithClassificationCounter(key, result.get(key))); } Collections.sort(attributesWithClassificationCounters, new Comparator<OldAttributeValueWithClassificationCounter>() { @Override public int compare(OldAttributeValueWithClassificationCounter cc1, OldAttributeValueWithClassificationCounter cc2) { double probOfMinority1 = cc1.classificationCounter.getCount(minorityClassification) / cc1.classificationCounter.getTotal(); double probOfMinority2 = cc2.classificationCounter.getCount(minorityClassification) / cc2.classificationCounter.getTotal(); return Ordering.natural().reverse().compare(probOfMinority1, probOfMinority2); } }); return Pair.with(totals, attributesWithClassificationCounters); } public Map<Serializable, Double> getCounts() { Map<Serializable, Double> ret = Maps.newHashMap(); for (Entry<Serializable, Number> serializableNumberEntry : counts.entrySet()) { ret.put(serializableNumberEntry.getKey(), serializableNumberEntry.getValue().doubleValue()); } return ret; } public static OldClassificationCounter countAll(final Iterable<? extends ClassifierInstance> instances) { final OldClassificationCounter result = new OldClassificationCounter(); for (ClassifierInstance instance : instances) { result.addClassification(instance.getLabel(), instance.getWeight()); } return result; } public void addClassification(final Serializable classification, double weight) { counts.addToValue(classification, weight); } public double getCount(final Serializable classification) { Number count = counts.get(classification); if (count == null) { return 0; } else { return count.doubleValue(); } } public Set<Serializable> allClassifications() { return counts.keySet(); } public OldClassificationCounter add(final OldClassificationCounter other) { final OldClassificationCounter result = new OldClassificationCounter(); result.counts.putAll(counts); for (final Entry<Serializable, Number> e : other.counts.entrySet()) { result.counts.addToValue(e.getKey(), e.getValue().doubleValue()); } return result; } public OldClassificationCounter subtract(final OldClassificationCounter other) { final OldClassificationCounter result = new OldClassificationCounter(); result.counts.putAll(counts); for (final Entry<Serializable, Number> e : other.counts.entrySet()) { result.counts.addToValue(e.getKey(), -other.getCount(e.getKey())); } return result; } public double getTotal() { return counts.getSumOfValues(); } public Pair<Serializable, Double> mostPopular() { Entry<Serializable, Number> best = null; for (final Entry<Serializable, Number> e : counts.entrySet()) { if (best == null || e.getValue().doubleValue() > best.getValue().doubleValue()) { best = e; } } return Pair.with(best.getKey(), best.getValue().doubleValue()); } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; OldClassificationCounter that = (OldClassificationCounter) o; if (!counts.equals(that.counts)) return false; return true; } @Override public int hashCode() { return counts.hashCode(); } }