package quickml.supervised.classifier.downsampling; import com.google.common.collect.Iterables; import com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.collections.MapUtils; import quickml.data.instances.ClassifierInstance; import quickml.supervised.PredictiveModelBuilder; import quickml.data.instances.InstanceWithAttributesMap; import quickml.supervised.classifier.Classifier; import java.io.Serializable; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import static com.google.common.base.Preconditions.checkArgument; /** * Created by ian on 4/22/14. */ public class DownsamplingClassifierBuilder<T extends ClassifierInstance> implements PredictiveModelBuilder<Classifier, T> { public static final String MINORITY_INSTANCE_PROPORTION = "minorityInstanceProportion"; private static final Logger logger = LoggerFactory.getLogger(DownsamplingClassifierBuilder.class); private double targetMinorityProportion; private final PredictiveModelBuilder<? extends Classifier, T> predictiveModelBuilder; public DownsamplingClassifierBuilder(PredictiveModelBuilder<? extends Classifier, T> predictiveModelBuilder, double targetMinorityProportion) { checkArgument(targetMinorityProportion > 0 && targetMinorityProportion < 1, "targetMinorityProportion must be between 0 and 1 (was %s)", targetMinorityProportion); this.predictiveModelBuilder = predictiveModelBuilder; this.targetMinorityProportion = targetMinorityProportion; } @Override public DownsamplingClassifier buildPredictiveModel(Iterable<T> trainingData) { final Map<Serializable, Double> classificationProportions = getClassificationProportions(trainingData); if (classificationProportions.size() != 2) { printSampleInstancesForInspection(trainingData); } checkArgument(classificationProportions.size() == 2, "trainingData must contain only 2 classifications, but it had %s. mapOfClassificationsToOutcomes: %s", classificationProportions.size(), classificationProportions.get(1.0), classificationProportions.toString()); final Map.Entry<Serializable, Double> majorityEntry = MapUtils.getEntryWithHighestValue(classificationProportions).get(); final Map.Entry<Serializable, Double> minorityEntry = MapUtils.getEntryWithLowestValue(classificationProportions).get(); Serializable majorityClassification = majorityEntry.getKey(); final double majorityProportion = majorityEntry.getValue(); final double naturalMinorityProportion = 1.0 - majorityProportion; if (naturalMinorityProportion >= targetMinorityProportion) { final Classifier wrappedPredictiveModel = predictiveModelBuilder.buildPredictiveModel(trainingData); return new DownsamplingClassifier(wrappedPredictiveModel, majorityClassification, minorityEntry.getKey(), 0); } final double dropProbability = (naturalMinorityProportion > targetMinorityProportion) ? 0 : 1.0 - ((naturalMinorityProportion - targetMinorityProportion * naturalMinorityProportion) / (targetMinorityProportion - targetMinorityProportion * naturalMinorityProportion)); Iterable<T> downsampledTrainingData = Iterables.filter(trainingData, new RandomDroppingInstanceFilter(majorityClassification, dropProbability)); final Classifier wrappedPredictiveModel = predictiveModelBuilder.buildPredictiveModel(downsampledTrainingData); return new DownsamplingClassifier(wrappedPredictiveModel, majorityClassification, minorityEntry.getKey(), dropProbability); } @Override public void updateBuilderConfig(Map<String, Serializable> cfg) { predictiveModelBuilder.updateBuilderConfig(cfg); if (cfg.containsKey(MINORITY_INSTANCE_PROPORTION)) targetMinorityProportion((Double) cfg.get(MINORITY_INSTANCE_PROPORTION)); } public DownsamplingClassifierBuilder targetMinorityProportion(double targetMinorityProportion) { this.targetMinorityProportion = targetMinorityProportion; return this; } private void printSampleInstancesForInspection(Iterable<? extends InstanceWithAttributesMap> trainingData) { logger.info("length of training data" + Iterables.size(trainingData)); int counter = 0; for (InstanceWithAttributesMap instance : trainingData) { if (counter++ % 100 == 0) { if (instance.getLabel().equals(Double.valueOf(1.0))) { logger.info("instance " + counter); logger.info(instance.getAttributes().toString()); logger.info("label:" + instance.getLabel().toString()); logger.info("weight:" + instance.getWeight()); } } if (counter > 1000) break; } } private Map<Serializable, Double> getClassificationProportions(final Iterable<? extends InstanceWithAttributesMap> trainingData) { Map<Serializable, AtomicLong> classificationCounts = Maps.newHashMap(); long total = 0; for (InstanceWithAttributesMap instance : trainingData) { AtomicLong count = classificationCounts.get(instance.getLabel()); if (count == null) { count = new AtomicLong(0); classificationCounts.put(instance.getLabel(), count); } count.getAndIncrement(); total++; } Map<Serializable, Double> classificationProportions = Maps.newHashMap(); for (Map.Entry<Serializable, AtomicLong> classCount : classificationCounts.entrySet()) { classificationProportions.put(classCount.getKey(), classCount.getValue().doubleValue() / (double) total); } return classificationProportions; } }