package aima.test.core.unit.learning.learners; import java.util.ArrayList; import java.util.List; import org.junit.Assert; import org.junit.Test; import aima.core.learning.framework.DataSet; import aima.core.learning.framework.DataSetFactory; import aima.core.learning.inductive.DecisionTree; import aima.core.learning.learners.DecisionTreeLearner; import aima.core.util.Util; /** * @author Ravi Mohan * */ public class DecisionTreeTest { private static final String YES = "Yes"; private static final String NO = "No"; @Test public void testActualDecisionTreeClassifiesRestaurantDataSetCorrectly() throws Exception { DecisionTreeLearner learner = new DecisionTreeLearner( createActualRestaurantDecisionTree(), "Unable to clasify"); int[] results = learner.test(DataSetFactory.getRestaurantDataSet()); Assert.assertEquals(12, results[0]); Assert.assertEquals(0, results[1]); } @Test public void testInducedDecisionTreeClassifiesRestaurantDataSetCorrectly() throws Exception { DecisionTreeLearner learner = new DecisionTreeLearner( createInducedRestaurantDecisionTree(), "Unable to clasify"); int[] results = learner.test(DataSetFactory.getRestaurantDataSet()); Assert.assertEquals(12, results[0]); Assert.assertEquals(0, results[1]); } @Test public void testStumpCreationForSpecifiedAttributeValuePair() throws Exception { DataSet ds = DataSetFactory.getRestaurantDataSet(); List<String> unmatchedValues = new ArrayList<String>(); unmatchedValues.add(NO); DecisionTree dt = DecisionTree.getStumpFor(ds, "alternate", YES, YES, unmatchedValues, NO); Assert.assertNotNull(dt); } @Test public void testStumpCreationForDataSet() throws Exception { DataSet ds = DataSetFactory.getRestaurantDataSet(); List<DecisionTree> dt = DecisionTree.getStumpsFor(ds, YES, "Unable to classify"); Assert.assertEquals(26, dt.size()); } @Test public void testStumpPredictionForDataSet() throws Exception { DataSet ds = DataSetFactory.getRestaurantDataSet(); List<String> unmatchedValues = new ArrayList<String>(); unmatchedValues.add(NO); DecisionTree tree = DecisionTree.getStumpFor(ds, "hungry", YES, YES, unmatchedValues, "Unable to Classify"); DecisionTreeLearner learner = new DecisionTreeLearner(tree, "Unable to Classify"); int[] result = learner.test(ds); Assert.assertEquals(5, result[0]); Assert.assertEquals(7, result[1]); } // // PRIVATE METHODS // private static DecisionTree createInducedRestaurantDecisionTree() { // from AIMA 2nd ED // Fig 18.6 // friday saturday node DecisionTree frisat = new DecisionTree("fri/sat"); frisat.addLeaf(Util.YES, Util.YES); frisat.addLeaf(Util.NO, Util.NO); // type node DecisionTree type = new DecisionTree("type"); type.addLeaf("French", Util.YES); type.addLeaf("Italian", Util.NO); type.addNode("Thai", frisat); type.addLeaf("Burger", Util.YES); // hungry node DecisionTree hungry = new DecisionTree("hungry"); hungry.addLeaf(Util.NO, Util.NO); hungry.addNode(Util.YES, type); // patrons node DecisionTree patrons = new DecisionTree("patrons"); patrons.addLeaf("None", Util.NO); patrons.addLeaf("Some", Util.YES); patrons.addNode("Full", hungry); return patrons; } private static DecisionTree createActualRestaurantDecisionTree() { // from AIMA 2nd ED // Fig 18.2 // raining node DecisionTree raining = new DecisionTree("raining"); raining.addLeaf(Util.YES, Util.YES); raining.addLeaf(Util.NO, Util.NO); // bar node DecisionTree bar = new DecisionTree("bar"); bar.addLeaf(Util.YES, Util.YES); bar.addLeaf(Util.NO, Util.NO); // friday saturday node DecisionTree frisat = new DecisionTree("fri/sat"); frisat.addLeaf(Util.YES, Util.YES); frisat.addLeaf(Util.NO, Util.NO); // second alternate node to the right of the diagram below hungry DecisionTree alternate2 = new DecisionTree("alternate"); alternate2.addNode(Util.YES, raining); alternate2.addLeaf(Util.NO, Util.YES); // reservation node DecisionTree reservation = new DecisionTree("reservation"); frisat.addNode(Util.NO, bar); frisat.addLeaf(Util.YES, Util.YES); // first alternate node to the left of the diagram below waitestimate DecisionTree alternate1 = new DecisionTree("alternate"); alternate1.addNode(Util.NO, reservation); alternate1.addNode(Util.YES, frisat); // hungry node DecisionTree hungry = new DecisionTree("hungry"); hungry.addLeaf(Util.NO, Util.YES); hungry.addNode(Util.YES, alternate2); // wait estimate node DecisionTree waitEstimate = new DecisionTree("wait_estimate"); waitEstimate.addLeaf(">60", Util.NO); waitEstimate.addNode("30-60", alternate1); waitEstimate.addNode("10-30", hungry); waitEstimate.addLeaf("0-10", Util.YES); // patrons node DecisionTree patrons = new DecisionTree("patrons"); patrons.addLeaf("None", Util.NO); patrons.addLeaf("Some", Util.YES); patrons.addNode("Full", waitEstimate); return patrons; } }