package quickml.supervised.dataProcessing.instanceTranformer; import com.google.common.collect.Lists; import it.unimi.dsi.fastutil.objects.Object2LongArrayMap; import quickml.data.AttributesMap; import quickml.data.instances.InstanceFactory; import quickml.data.instances.InstanceWithAttributesMap; import quickml.supervised.dataProcessing.AttributeCharacteristics; import java.io.Serializable; import java.util.List; import java.util.Map; import static quickml.supervised.classifier.logisticRegression.InstanceTransformerUtils.oneHotEncode; /** * Created by alexanderhawk on 10/14/15. */ public class OneHotEncoder<L extends Serializable, I extends InstanceWithAttributesMap<L>, R extends InstanceWithAttributesMap<Serializable>> implements InstanceTransformer<I, R> { public static final String INSUFFICIENT_CAT_ATTR = "insufficientCatAttr"; final Map<String, AttributeCharacteristics> attributeCharacteristics; final InstanceFactory<R, AttributesMap, L> instanceFactory; final int minObservationsOfAnAtribute; public OneHotEncoder(Map<String, AttributeCharacteristics> attributeCharacteristics, final InstanceFactory<R, AttributesMap, L> instanceFactory, int minObservationsOfAnAtribute) { this.attributeCharacteristics = attributeCharacteristics; this.instanceFactory = instanceFactory; this.minObservationsOfAnAtribute = minObservationsOfAnAtribute; } public OneHotEncoder(Map<String, AttributeCharacteristics> attributeCharacteristics, final InstanceFactory<R, AttributesMap, L> instanceFactory) { this(attributeCharacteristics, instanceFactory, 1); } public List<R> transformAll(List<I> instances) { Object2LongArrayMap<String> expandedCatAttributeToCounts = new Object2LongArrayMap<>(); Object2LongArrayMap<String> attributeToCounts = new Object2LongArrayMap<>(); for (I instance : instances) { for (Map.Entry<String, Serializable> entry : instance.getAttributes().entrySet()) { String attribute = entry.getKey(); updateCounts(attributeToCounts, attribute); if (!attributeCharacteristics.get(attribute).isNumber) { String expandedAttribute = oneHotEncode(attribute, entry.getValue()); updateCounts(expandedCatAttributeToCounts, expandedAttribute); } } } List<R> transformed = Lists.newArrayList(); for (I instance : instances) { transformed.add(transformInstance(instance, expandedCatAttributeToCounts, attributeToCounts, minObservationsOfAnAtribute)); } return transformed; } private void updateCounts(Object2LongArrayMap<String> counts, String key) { if (!counts.containsKey(key)) { counts.put(key, 0L); } counts.put(key, counts.getLong(key) + 1L); } public R transformInstance(I instance, Object2LongArrayMap<String> expandedCatAttributeToCounts, Object2LongArrayMap<String> attributeToCounts, int minObservationsOfAnAtribute) { /**attributes with insufficient data are ignored altogether...less arbitrary than counting number of attributes with insufficient-data **/ AttributesMap attributesMap = AttributesMap.newHashMap(); AttributesMap rawAttributes = instance.getAttributes(); for (Map.Entry<String, Serializable> entry : rawAttributes.entrySet()) { String attribute = entry.getKey(); if (!attributeCharacteristics.get(attribute).isNumber) { String expandedAttribute = oneHotEncode(attribute, entry.getValue()); if (expandedCatAttributeToCounts.get(expandedAttribute) >= (long) minObservationsOfAnAtribute) { attributesMap.put(expandedAttribute, 1.0); } else if (attributeToCounts.get(attribute) >= (long) minObservationsOfAnAtribute) { String insufficientDataAttribute = attribute + "--" + INSUFFICIENT_CAT_ATTR; attributesMap.put(insufficientDataAttribute, 1.0); } } else { if (attributeToCounts.get(attribute) >= (long) minObservationsOfAnAtribute) { attributesMap.put(attribute, ((Number) entry.getValue()).doubleValue()); } } } return instanceFactory.createInstance(attributesMap, instance.getLabel(), instance.getWeight()); } @Override public R transformInstance(I instance) { AttributesMap attributesMap = AttributesMap.newHashMap(); AttributesMap rawAttributes = instance.getAttributes(); for (String key : rawAttributes.keySet()) { if (!attributeCharacteristics.get(key).isNumber) { attributesMap.put(oneHotEncode(key, rawAttributes.get(key)), 1.0); } else { attributesMap.put(key, ((Number) rawAttributes.get(key)).doubleValue()); } } return instanceFactory.createInstance(attributesMap, instance.getLabel(), instance.getWeight()); } }