package quickml.supervised.predictiveModelOptimizer; import com.beust.jcommander.internal.Sets; import com.google.common.collect.Maps; import org.junit.Before; import org.junit.Test; import quickml.data.instances.ClassifierInstance; import quickml.data.OnespotDateTimeExtractor; import static quickml.supervised.tree.constants.ForestOptions.*; import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest; import quickml.supervised.tree.attributeIgnoringStrategies.CompositeAttributeIgnoringStrategy; import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesInSet; import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability; import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder; import quickml.supervised.crossValidation.ClassifierLossChecker; import quickml.supervised.crossValidation.data.OutOfTimeData; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.WeightedAUCCrossValLossFunction; import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender; import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.MonotonicConvergenceRecommender; import quickml.supervised.tree.decisionTree.scorers.GRPenalizedGiniImpurityScorerFactory; import quickml.supervised.tree.decisionTree.scorers.PenalizedInformationGainScorerFactory; import java.util.*; import static java.util.Arrays.asList; import static quickml.InstanceLoader.getAdvertisingInstances; public class PredictiveModelOptimizerIntegrationTest { private PredictiveModelOptimizer optimizer; @Before public void setUp() throws Exception { List<ClassifierInstance> advertisingInstances = getAdvertisingInstances(); advertisingInstances = advertisingInstances.subList(0, 3000); optimizer = new SimplePredictiveModelOptimizerBuilder<RandomDecisionForest, ClassifierInstance>() .modelBuilder(new RandomDecisionForestBuilder<>()) .dataCycler(new OutOfTimeData<>(advertisingInstances, 0.2, 12, new OnespotDateTimeExtractor())) .lossChecker(new ClassifierLossChecker<ClassifierInstance, RandomDecisionForest>(new WeightedAUCCrossValLossFunction(1.0))) .valuesToTest(createConfig()) .iterations(2) .build(); } @Test public void testOptimizer() throws Exception { System.out.println("optimalConfig = " + optimizer.determineOptimalConfig()); } private Map<String, FieldValueRecommender> createConfig() { Map<String, FieldValueRecommender> config = Maps.newHashMap(); Set<String> attributesToIgnore = Sets.newHashSet(); attributesToIgnore.addAll(Arrays.asList("browser", "eap", "destinationId", "seenPixel", "internalCreativeId")); double probabilityOfDiscardingFromAttributesToIgnore = 0.3; CompositeAttributeIgnoringStrategy compositeAttributeIgnoringStrategy = new CompositeAttributeIgnoringStrategy(Arrays.asList( new IgnoreAttributesWithConstantProbability(0.7), new IgnoreAttributesInSet(attributesToIgnore, probabilityOfDiscardingFromAttributesToIgnore) )); config.put(ATTRIBUTE_IGNORING_STRATEGY.name(), new FixedOrderRecommender(new IgnoreAttributesWithConstantProbability(0.7), compositeAttributeIgnoringStrategy )); config.put(NUM_TREES.name(), new MonotonicConvergenceRecommender(asList(20))); config.put(MAX_DEPTH.name(), new FixedOrderRecommender( 4, 8, 16));//Integer.MAX_VALUE, 2, 3, 5, 6, 9)); config.put(MIN_SCORE.name(), new FixedOrderRecommender(0.00000000000001));//, Double.MIN_VALUE, 0.0, 0.000001, 0.0001, 0.001, 0.01, 0.1)); config.put(ATTRIBUTE_VALUE_THRESHOLD_OBSERVATIONS.name(), new FixedOrderRecommender(2, 11, 16, 30 )); config.put(MIN_LEAF_INSTANCES.name(), new FixedOrderRecommender(0, 20, 40)); config.put(SCORER_FACTORY.name(), new FixedOrderRecommender(new PenalizedInformationGainScorerFactory(), new GRPenalizedGiniImpurityScorerFactory())); config.put(DEGREE_OF_GAIN_RATIO_PENALTY.name(), new FixedOrderRecommender(1.0, 0.75, .5 )); return config; } }