package com.blazebit.ai.decisiontree.impl; import com.blazebit.ai.decisiontree.Attribute; import com.blazebit.ai.decisiontree.AttributeSelector; import com.blazebit.ai.decisiontree.AttributeValue; import com.blazebit.ai.decisiontree.DiscreteAttribute; import com.blazebit.ai.decisiontree.Example; import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.Set; /** * * @author Christian Beikov */ public class ID3AttributeSelector implements AttributeSelector<Boolean> { @Override public Attribute select(final Set<Example<Boolean>> examples, final Set<Attribute> availableAttributes, final Set<Attribute> usedAttributes) { Attribute attribute = null; float attributeRem = Float.MAX_VALUE; int attributeValueCount = Integer.MAX_VALUE; float positives = 0; float negatives = 0; final Map<Attribute, Map<AttributeValue, Pair>> attributeUsage = new HashMap<Attribute, Map<AttributeValue, Pair>>(); /* Make array for performance */ final Example<Boolean>[] exampleArray = examples.toArray(new Example[0]); final int examplesSize = exampleArray.length; for (final Attribute attr : availableAttributes) { if (usedAttributes.contains(attr)) { continue; } final Map<AttributeValue, Pair> valueUsage = new HashMap<AttributeValue, Pair>(); attributeUsage.put(attr, valueUsage); for (int i = 0; i < examplesSize; i++) { final AttributeValue value = exampleArray[i].getValues().get(attr); Pair valueUsageExamples = valueUsage.get(value); if (valueUsageExamples == null) { valueUsageExamples = new Pair(); valueUsage.put(value, valueUsageExamples); } if (exampleArray[i].getResult()) { ++valueUsageExamples.positive; ++positives; } else { ++valueUsageExamples.negative; ++negatives; } } } if (positives > 0 && negatives > 0) { for (final Map.Entry<Attribute, Map<AttributeValue, Pair>> entry : attributeUsage.entrySet()) { final Attribute attr = entry.getKey(); final float rem = Pair.rem(entry.getValue().values(), positives, negatives); if (rem < attributeRem) { attribute = attr; attributeRem = rem; if (attr instanceof DiscreteAttribute) { attributeValueCount = ((DiscreteAttribute) attr).getValues().size(); } } else if (attr instanceof DiscreteAttribute && (rem == attributeRem) && ((DiscreteAttribute) attr).getValues().size() < attributeValueCount) { attribute = attr; attributeRem = rem; attributeValueCount = ((DiscreteAttribute) attr).getValues().size(); } } } return attribute; } private static class Pair { static final float log2 = (float) Math.log(2); float positive = 0; float negative = 0; double entropy() { final float localPositive = positive; final float localNegative = negative; final float localLog2 = log2; if (localPositive == 0 || localNegative == 0) { return 0; } final float p = localPositive / (localPositive + localNegative); return -p * (Math.log(p) / localLog2) - (1 - p) * (Math.log(1 - p) / localLog2); } static float rem(final Collection<Pair> pairs, final float positives, final float negatives) { float rem = 0; for (final Pair p : pairs) { rem += ((p.positive + p.negative) / (positives + negatives)) * p.entropy(); } return rem; } } }