package aima.test.core.unit.learning.learners; import java.util.ArrayList; import java.util.List; import org.junit.Assert; import org.junit.Test; import aima.core.learning.framework.DataSet; import aima.core.learning.framework.DataSetFactory; import aima.core.learning.framework.Learner; import aima.core.learning.inductive.DecisionTree; import aima.core.learning.learners.AdaBoostLearner; import aima.core.learning.learners.StumpLearner; /** * @author Ravi Mohan * */ public class EnsembleLearningTest { private static final String YES = "Yes"; @Test public void testAdaBoostEnablesCollectionOfStumpsToClassifyDataSetAccurately() throws Exception { DataSet ds = DataSetFactory.getRestaurantDataSet(); List<DecisionTree> stumps = DecisionTree.getStumpsFor(ds, YES, "No"); List<Learner> learners = new ArrayList<Learner>(); for (Object stump : stumps) { DecisionTree sl = (DecisionTree) stump; StumpLearner stumpLearner = new StumpLearner(sl, "No"); learners.add(stumpLearner); } AdaBoostLearner learner = new AdaBoostLearner(learners, ds); learner.train(ds); int[] result = learner.test(ds); Assert.assertEquals(12, result[0]); Assert.assertEquals(0, result[1]); } }