package quickml.supervised.crossValidation; import org.junit.Before; import org.junit.Test; import quickml.InstanceLoader; import quickml.data.instances.ClassifierInstance; import quickml.data.OnespotDateTimeExtractor; import quickml.supervised.crossValidation.data.OutOfTimeData; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLogCVLossFunction; import quickml.supervised.tree.decisionTree.DecisionTree; import quickml.supervised.tree.decisionTree.DecisionTreeBuilder; import quickml.supervised.tree.decisionTree.scorers.GRPenalizedGiniImpurityScorerFactory; import java.util.List; /** * Created by alexanderhawk on 7/8/15. */ public class SimpleCrossValidatorIntegrationTest { private List<ClassifierInstance> instances; @Before public void setUp() throws Exception { instances = InstanceLoader.getAdvertisingInstances().subList(0,1000); } @Test public void testCrossValidation() throws Exception { System.out.println("\n \n \n new attrImportanceTest"); DecisionTreeBuilder<ClassifierInstance> modelBuilder = new DecisionTreeBuilder<ClassifierInstance>().scorerFactory(new GRPenalizedGiniImpurityScorerFactory()).maxDepth(16).minLeafInstances(0).minAttributeValueOccurences(11).attributeIgnoringStrategy(new quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability(0.7)); SimpleCrossValidator<DecisionTree, ClassifierInstance> cv = new SimpleCrossValidator<>(modelBuilder, new ClassifierLossChecker<ClassifierInstance, DecisionTree>(new ClassifierLogCVLossFunction(.000001)), new OutOfTimeData<>(instances, .25, 12, new OnespotDateTimeExtractor() ) ); for (int i =0; i<3; i++) { System.out.println("Loss: " + cv.getLossForModel()); } } }