package quickml.supervised.featureEngineering1.enrichStrategies.probabilityInjector; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import quickml.data.instances.InstanceWithAttributesMap; import quickml.supervised.featureEngineering1.AttributesEnrichStrategy; import quickml.supervised.featureEngineering1.AttributesEnricher; import java.io.Serializable; import java.util.Map; import java.util.Set; /** * This strategy will inject new attributes for a particular set of existing attributes corresponding to * the probability of a specified classification given the value associated with that attribute. So, for example, * if we are predicting a person's likelihood to have an illness based on a variety of factors including gender, * and a generic male's overall probability of having the illness is 0.2 based on our training data, then it will * enrich with an attribute like "male-PROB"=0.2. */ public class ProbabilityEnrichStrategy implements AttributesEnrichStrategy { private static final int DEFAULT_MAX_VALUE_COUNT = 20000; private final Set<String> attributeKeysToInject; private final Serializable classification; private final int maxValueCount; /** * * @param attributeKeysToInject The attributes to enrich with probabilities * @param classification The classification whose probability we should use. If there are only two * classifications then it doesn't particularly matter which one we use. If there * are more than two you might wish to create multiple enrich strategies, each * looking at a different classification. */ public ProbabilityEnrichStrategy(Set<String> attributeKeysToInject, Serializable classification) { this(attributeKeysToInject, classification, DEFAULT_MAX_VALUE_COUNT); } /** * @param attributeKeysToInject The attributes to enrich with probabilities * @param classification The classification whose probability we should use. If there are only two * classifications then it doesn't particularly matter which one we use. If there * are more than two you might wish to create multiple enrich strategies, each * looking at a different classification. * @param maxValueCount This is the maximum number of values an attribute can have before it will be * ignored by ProbabilityEnrichStrategy. If unspecified the default is 20,000. */ public ProbabilityEnrichStrategy(Set<String> attributeKeysToInject, Serializable classification, final int maxValueCount) { this.attributeKeysToInject = attributeKeysToInject; this.classification = classification; this.maxValueCount = maxValueCount; } @Override public AttributesEnricher build(final Iterable<InstanceWithAttributesMap<?>> trainingData) { Map<String, Map<Serializable, ProbCounter>> valueProbCountersByAttribute = Maps.newHashMap(); Set<String> attributesWithTooManyValues = Sets.newHashSet(); for (InstanceWithAttributesMap instance : trainingData) { int classificationMatch = instance.getLabel().equals(classification) ? 1 : 0; for (String attributeKey : attributeKeysToInject) { if (attributesWithTooManyValues.contains(attributeKey)) { continue; } Map<Serializable, ProbCounter> attributeValueProbabilities = valueProbCountersByAttribute.get(attributeKey); if (attributeValueProbabilities == null) { attributeValueProbabilities = Maps.newHashMap(); valueProbCountersByAttribute.put(attributeKey, attributeValueProbabilities); } if (attributeValueProbabilities.size() > maxValueCount) { attributesWithTooManyValues.add(attributeKey); valueProbCountersByAttribute.remove(attributeKey); continue; } Serializable value = instance.getAttributes().get(attributeKey); if (value == null) { value = Integer.MIN_VALUE; } ProbCounter probCounter = attributeValueProbabilities.get(value); if (probCounter == null) { probCounter = new ProbCounter(); attributeValueProbabilities.put(value, probCounter); } probCounter.add(classificationMatch, instance.getWeight()); } } Map<String, Map<Serializable, Double>> attributeValueProbabilitiesByAttribute = Maps.newHashMap(); for (Map.Entry<String, Map<Serializable, ProbCounter>> attributeValueProbEntry : valueProbCountersByAttribute.entrySet()) { Map<Serializable, Double> probabilitiesByValue = Maps.newHashMap(); for (Map.Entry<Serializable, ProbCounter> valueProbEntry : attributeValueProbEntry.getValue().entrySet()) { probabilitiesByValue.put(valueProbEntry.getKey(), valueProbEntry.getValue().getProb()); } attributeValueProbabilitiesByAttribute.put(attributeValueProbEntry.getKey(), probabilitiesByValue); } return new ProbabilityInjectingEnricher(attributeValueProbabilitiesByAttribute); } /** * Keeps a running average of the classificationMatch value, weighted accordingly */ private static class ProbCounter { private double sum = 0; private double total = 0; public void add(int classificationMatch, double weight) { sum += classificationMatch * weight; total += weight; } public double getProb() { return sum / total; } } }