package com.blazebit.ai.decisiontree; import com.blazebit.ai.decisiontree.impl.ID3AttributeSelector; import com.blazebit.ai.decisiontree.impl.SimpleAttributeValue; import com.blazebit.ai.decisiontree.impl.SimpleDecisionTree; import com.blazebit.ai.decisiontree.impl.SimpleDiscreteAttribute; import com.blazebit.ai.decisiontree.impl.SimpleExample; import com.blazebit.ai.decisiontree.impl.SimpleItem; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import org.junit.Ignore; import org.junit.Test; /** * * @author Christian Beikov */ public class ID3DecisionTreeTest { @Ignore @Test public void testCreate() { Map<RestaurantExample, Boolean> results = new LinkedHashMap<RestaurantExample, Boolean>(); Set<Attribute> attributes = new HashSet<Attribute>(); attributes.add(Bool.forName("alternate")); attributes.add(Bool.forName("bar")); attributes.add(Bool.forName("friday")); attributes.add(Bool.forName("hungry")); attributes.add(Bool.forName("raining")); attributes.add(Bool.forName("reservation")); attributes.add(Patron.ATTRIBUTE); attributes.add(Price.ATTRIBUTE); attributes.add(Type.ATTRIBUTE); attributes.add(WaitEstimate.ATTRIBUTE); results.put(new RestaurantExample(true, false, false, true, Patron.SOME, Price.EXPENSIVE, false, true, Type.FRENCH, WaitEstimate.ZERO_TO_TEN), true); results.put(new RestaurantExample(true, false, false, true, Patron.FULL, Price.CHEAP, false, false, Type.THAI, WaitEstimate.THIRTY_TO_SIXTY), false); results.put(new RestaurantExample(false, true, false, false, Patron.SOME, Price.CHEAP, false, false, Type.BURGER, WaitEstimate.ZERO_TO_TEN), true); results.put(new RestaurantExample(true, false, true, true, Patron.FULL, Price.CHEAP, false, false, Type.THAI, WaitEstimate.TEN_TO_THIRTY), true); results.put(new RestaurantExample(true, false, true, false, Patron.FULL, Price.EXPENSIVE, false, true, Type.FRENCH, WaitEstimate.GREATER_THAN_SIXTY), false); results.put(new RestaurantExample(false, true, false, true, Patron.SOME, Price.MEDIUM, true, true, Type.ITALIAN, WaitEstimate.ZERO_TO_TEN), true); results.put(new RestaurantExample(false, true, false, false, Patron.NONE, Price.CHEAP, true, false, Type.BURGER, WaitEstimate.ZERO_TO_TEN), false); results.put(new RestaurantExample(false, false, false, true, Patron.SOME, Price.MEDIUM, true, true, Type.THAI, WaitEstimate.ZERO_TO_TEN), true); results.put(new RestaurantExample(false, true, true, false, Patron.FULL, Price.CHEAP, true, false, Type.BURGER, WaitEstimate.GREATER_THAN_SIXTY), false); results.put(new RestaurantExample(true, true, true, true, Patron.FULL, Price.EXPENSIVE, false, true, Type.ITALIAN, WaitEstimate.TEN_TO_THIRTY), false); results.put(new RestaurantExample(false, false, false, false, Patron.NONE, Price.CHEAP, false, false, Type.THAI, WaitEstimate.ZERO_TO_TEN), false); results.put(new RestaurantExample(true, true, true, true, Patron.FULL, Price.CHEAP, false, false, Type.BURGER, WaitEstimate.THIRTY_TO_SIXTY), true); DecisionTree<Boolean> tree = new SimpleDecisionTree<Boolean>(attributes, examples(results), new ID3AttributeSelector()); assertEquals(1, tree.apply(item(new RestaurantExample(null, null, null, null, Patron.NONE, null, null, null, null, null))).size()); assertFalse(tree.apply(item(new RestaurantExample(null, null, null, null, Patron.NONE, null, null, null, null, null))).iterator().next()); assertEquals(1, tree.apply(item(new RestaurantExample(null, null, null, null, Patron.SOME, null, null, null, null, null))).size()); assertTrue(tree.apply(item(new RestaurantExample(null, null, null, null, Patron.SOME, null, null, null, null, null))).iterator().next()); assertEquals(2, tree.apply(item(new RestaurantExample(null, null, null, false, Patron.FULL, null, null, null, null, null))).size()); assertFalse(tree.apply(item(new RestaurantExample(null, null, null, false, Patron.FULL, null, null, null, null, null))).iterator().next()); /* Make clear we have both options in this general case */ assertEquals(2, tree.apply(item(new RestaurantExample(null, null, null, true, Patron.FULL, null, null, null, null, null))).size()); /* Make clear we can not make decisions in this case */ assertEquals(0, tree.apply(item(new RestaurantExample(null, null, null, true, Patron.FULL, null, null, null, Type.FRENCH, null))).size()); assertEquals(1, tree.apply(item(new RestaurantExample(null, null, null, true, Patron.FULL, null, null, null, Type.ITALIAN, null))).size()); assertFalse(tree.apply(item(new RestaurantExample(null, null, null, true, Patron.FULL, null, null, null, Type.ITALIAN, null))).iterator().next()); assertEquals(1, tree.apply(item(new RestaurantExample(null, null, null, true, Patron.FULL, null, null, null, Type.BURGER, null))).size()); assertTrue(tree.apply(item(new RestaurantExample(null, null, null, true, Patron.FULL, null, null, null, Type.BURGER, null))).iterator().next()); /* Hungry -> go to restaurant */ assertEquals(1, tree.apply(item(new RestaurantExample(null, null, true, true, Patron.FULL, null, null, null, Type.THAI, null))).size()); assertTrue(tree.apply(item(new RestaurantExample(null, null, true, true, Patron.FULL, null, null, null, Type.THAI, null))).iterator().next()); /* Not hungry -> do not go to restaurant */ assertEquals(1, tree.apply(item(new RestaurantExample(null, null, false, true, Patron.FULL, null, null, null, Type.THAI, null))).size()); assertFalse(tree.apply(item(new RestaurantExample(null, null, false, true, Patron.FULL, null, null, null, Type.THAI, null))).iterator().next()); /* We can add any additional attributes, since the attributes that are used for making the decision are fulfilled the rest does not affect anything */ assertEquals(1, tree.apply(item(new RestaurantExample(null, true, false, true, Patron.FULL, null, null, null, Type.THAI, null))).size()); assertFalse(tree.apply(item(new RestaurantExample(null, true, false, true, Patron.FULL, null, null, null, Type.THAI, null))).iterator().next()); } private static Item item(RestaurantExample v){ return new SimpleItem(valueMap(v)); } private static Set<Example<Boolean>> examples(Map<RestaurantExample, Boolean> restaurantExamples){ Set<Example<Boolean>> examples = new HashSet<Example<Boolean>>(restaurantExamples.size()); for(Map.Entry<RestaurantExample, Boolean> entry : restaurantExamples.entrySet()){ final RestaurantExample restaurantExample = entry.getKey(); examples.add(new SimpleExample<Boolean>(valueMap(restaurantExample), entry.getValue())); } return examples; } private static Map<Attribute, AttributeValue> valueMap(RestaurantExample v){ Map<Attribute, AttributeValue> values = new HashMap<Attribute, AttributeValue>(); if(v.alternate != null){ values.put(Bool.forName("alternate"), Bool.value(v.alternate)); } if(v.bar != null){ values.put(Bool.forName("bar"), Bool.value(v.bar)); } if(v.friday != null){ values.put(Bool.forName("friday"), Bool.value(v.friday)); } if(v.hungry != null){ values.put(Bool.forName("hungry"), Bool.value(v.hungry)); } if(v.raining != null){ values.put(Bool.forName("raining"), Bool.value(v.raining)); } if(v.reservation != null){ values.put(Bool.forName("reservation"), Bool.value(v.reservation)); } if(v.patron != null){ values.put(Patron.ATTRIBUTE, new SimpleAttributeValue(v.patron)); } if(v.price != null){ values.put(Price.ATTRIBUTE, new SimpleAttributeValue(v.price)); } if(v.type != null){ values.put(Type.ATTRIBUTE, new SimpleAttributeValue(v.type)); } if(v.estimate != null){ values.put(WaitEstimate.ATTRIBUTE, new SimpleAttributeValue(v.estimate)); } return values; } static class RestaurantExample{ Boolean alternate; Boolean bar; Boolean friday; Boolean hungry; Patron patron; Price price; Boolean raining; Boolean reservation; Type type; WaitEstimate estimate; public RestaurantExample(Boolean alternate, Boolean bar, Boolean friday, Boolean hungry, Patron patron, Price price, Boolean raining, Boolean reservation, Type type, WaitEstimate estimate) { this.alternate = alternate; this.bar = bar; this.friday = friday; this.hungry = hungry; this.patron = patron; this.price = price; this.raining = raining; this.reservation = reservation; this.type = type; this.estimate = estimate; } @Override public int hashCode() { int hash = 5; hash = 89 * hash + (this.alternate != null ? this.alternate.hashCode() : 0); hash = 89 * hash + (this.bar != null ? this.bar.hashCode() : 0); hash = 89 * hash + (this.friday != null ? this.friday.hashCode() : 0); hash = 89 * hash + (this.hungry != null ? this.hungry.hashCode() : 0); hash = 89 * hash + (this.raining != null ? this.raining.hashCode() : 0); hash = 89 * hash + (this.reservation != null ? this.reservation.hashCode() : 0); hash = 89 * hash + (this.patron != null ? this.patron.hashCode() : 0); hash = 89 * hash + (this.price != null ? this.price.hashCode() : 0); hash = 89 * hash + (this.type != null ? this.type.hashCode() : 0); hash = 89 * hash + (this.estimate != null ? this.estimate.hashCode() : 0); return hash; } @Override public boolean equals(Object obj) { if (obj == null) { return false; } if (getClass() != obj.getClass()) { return false; } final RestaurantExample other = (RestaurantExample) obj; if (this.alternate != other.alternate && (this.alternate == null || !this.alternate.equals(other.alternate))) { return false; } if (this.bar != other.bar && (this.bar == null || !this.bar.equals(other.bar))) { return false; } if (this.friday != other.friday && (this.friday == null || !this.friday.equals(other.friday))) { return false; } if (this.hungry != other.hungry && (this.hungry == null || !this.hungry.equals(other.hungry))) { return false; } if (this.raining != other.raining && (this.raining == null || !this.raining.equals(other.raining))) { return false; } if (this.reservation != other.reservation && (this.reservation == null || !this.reservation.equals(other.reservation))) { return false; } if (this.patron != other.patron) { return false; } if (this.price != other.price) { return false; } if (this.type != other.type) { return false; } if (this.estimate != other.estimate) { return false; } return true; } } static enum Bool{ YES, NO; public static AttributeValue value(Boolean value){ if(value == null){ return null; } return new SimpleAttributeValue(value ? YES : NO); } public static Attribute forName(String name){ return new SimpleDiscreteAttribute(name, attributeValues(values())); } } static enum Patron{ NONE, SOME, FULL; public static final Attribute ATTRIBUTE = new SimpleDiscreteAttribute("patron", attributeValues(values())); } static enum Price{ CHEAP, MEDIUM, EXPENSIVE; public static final Attribute ATTRIBUTE = new SimpleDiscreteAttribute("price", attributeValues(values())); } static enum Type{ FRENCH, THAI, BURGER, ITALIAN; public static final Attribute ATTRIBUTE = new SimpleDiscreteAttribute("type", attributeValues(values())); } static enum WaitEstimate{ ZERO_TO_TEN, TEN_TO_THIRTY, THIRTY_TO_SIXTY, GREATER_THAN_SIXTY; public static final Attribute ATTRIBUTE = new SimpleDiscreteAttribute("waitEstimate", attributeValues(values())); } public static Set<AttributeValue> attributeValues(Object[] objects){ Set<AttributeValue> values = new HashSet<AttributeValue>(objects.length); for(Object o : objects){ values.add(new SimpleAttributeValue(o)); } return values; } }