package aima.test.core.unit.learning.learners; import java.util.ArrayList; import org.junit.Assert; import org.junit.Test; import aima.core.learning.framework.DataSet; import aima.core.learning.framework.DataSetFactory; import aima.core.learning.inductive.DLTest; import aima.core.learning.inductive.DLTestFactory; import aima.core.learning.learners.CurrentBestLearner; import aima.core.learning.learners.DecisionListLearner; import aima.core.learning.learners.DecisionTreeLearner; import aima.core.learning.learners.MajorityLearner; import aima.test.core.unit.learning.framework.MockDataSetSpecification; import aima.test.core.unit.learning.inductive.MockDLTestFactory; /** * @author Ravi Mohan * */ public class LearnerTest { @Test public void testMajorityLearner() throws Exception { MajorityLearner learner = new MajorityLearner(); DataSet ds = DataSetFactory.getRestaurantDataSet(); learner.train(ds); int[] result = learner.test(ds); Assert.assertEquals(6, result[0]); Assert.assertEquals(6, result[1]); } @Test public void testDefaultUsedWhenTrainingDataSetHasNoExamples() throws Exception { // tests RecursionBaseCase#1 DataSet ds = DataSetFactory.getRestaurantDataSet(); DecisionTreeLearner learner = new DecisionTreeLearner(); DataSet ds2 = ds.emptyDataSet(); Assert.assertEquals(0, ds2.size()); learner.train(ds2); Assert.assertEquals("Unable To Classify", learner.predict(ds.getExample(0))); } @Test public void testClassificationReturnedWhenAllExamplesHaveTheSameClassification() throws Exception { // tests RecursionBaseCase#2 DataSet ds = DataSetFactory.getRestaurantDataSet(); DecisionTreeLearner learner = new DecisionTreeLearner(); DataSet ds2 = ds.emptyDataSet(); // all 3 examples have the same classification (willWait = yes) ds2.add(ds.getExample(0)); ds2.add(ds.getExample(2)); ds2.add(ds.getExample(3)); learner.train(ds2); Assert.assertEquals("Yes", learner.predict(ds.getExample(0))); } @Test public void testMajorityReturnedWhenAttributesToExamineIsEmpty() throws Exception { // tests RecursionBaseCase#2 DataSet ds = DataSetFactory.getRestaurantDataSet(); DecisionTreeLearner learner = new DecisionTreeLearner(); DataSet ds2 = ds.emptyDataSet(); // 3 examples have classification = "yes" and one ,"no" ds2.add(ds.getExample(0)); ds2.add(ds.getExample(1));// "no" ds2.add(ds.getExample(2)); ds2.add(ds.getExample(3)); ds2.setSpecification(new MockDataSetSpecification("will_wait")); learner.train(ds2); Assert.assertEquals("Yes", learner.predict(ds.getExample(1))); } @Test public void testInducedTreeClassifiesDataSetCorrectly() throws Exception { DataSet ds = DataSetFactory.getRestaurantDataSet(); DecisionTreeLearner learner = new DecisionTreeLearner(); learner.train(ds); int[] result = learner.test(ds); Assert.assertEquals(12, result[0]); Assert.assertEquals(0, result[1]); } @Test public void testDecisionListLearnerReturnsNegativeDLWhenDataSetEmpty() throws Exception { // tests first base case of DL Learner DecisionListLearner learner = new DecisionListLearner("Yes", "No", new MockDLTestFactory(null)); DataSet ds = DataSetFactory.getRestaurantDataSet(); DataSet empty = ds.emptyDataSet(); learner.train(empty); Assert.assertEquals("No", learner.predict(ds.getExample(0))); Assert.assertEquals("No", learner.predict(ds.getExample(1))); Assert.assertEquals("No", learner.predict(ds.getExample(2))); } @Test public void testDecisionListLearnerReturnsFailureWhenTestsEmpty() throws Exception { // tests second base case of DL Learner DecisionListLearner learner = new DecisionListLearner("Yes", "No", new MockDLTestFactory(new ArrayList<DLTest>())); DataSet ds = DataSetFactory.getRestaurantDataSet(); learner.train(ds); Assert.assertEquals(DecisionListLearner.FAILURE, learner.predict(ds.getExample(0))); } @Test public void testDecisionListTestRunOnRestaurantDataSet() throws Exception { DataSet ds = DataSetFactory.getRestaurantDataSet(); DecisionListLearner learner = new DecisionListLearner("Yes", "No", new DLTestFactory()); learner.train(ds); int[] result = learner.test(ds); Assert.assertEquals(12, result[0]); Assert.assertEquals(0, result[1]); } @Test public void testCurrentBestLearnerOnRestaurantDataSet() throws Exception { DataSet ds = DataSetFactory.getRestaurantDataSet(); CurrentBestLearner learner = new CurrentBestLearner("Yes"); learner.train(ds); int[] result = learner.test(ds); Assert.assertEquals(12, result[0]); Assert.assertEquals(0, result[1]); } }