package quickml.supervised.classifier.splitOnAttribute; import com.google.common.collect.Lists; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import quickml.InstanceLoader; import quickml.supervised.crossValidation.attributeImportance.LossFunctionTracker; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLogCVLossFunction; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierRMSELossFunction; import quickml.data.instances.ClassifierInstance; import quickml.supervised.predictiveModelOptimizer.MultiLossModelTester; import quickml.data.OnespotDateTimeExtractor; import quickml.supervised.crossValidation.data.OutOfTimeData; import quickml.supervised.classifier.downsampling.DownsamplingClassifierBuilder; import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder; import quickml.supervised.crossValidation.lossfunctions.LossFunctionCorrectedForDownsampling; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.WeightedAUCCrossValLossFunction; import quickml.supervised.tree.decisionTree.scorers.GRPenalizedGiniImpurityScorerFactory; import java.io.Serializable; import java.util.*; import static com.google.common.collect.Sets.newHashSet; import static org.junit.Assert.assertTrue; public class SplitOnAttributeClassifierBuilderTest { private List<ClassifierInstance> instances; @Before public void setUp() throws Exception { instances = InstanceLoader.getAdvertisingInstances(); } @Ignore @Test public void advertisingDataTest() { List<SplitOnAttributeClassifierBuilder.SplitModelGroup> splitModelGroupCollection = new ArrayList<>(); splitModelGroupCollection.add(createSplitModelGroup(0, newHashSet("_830", "_833"), 0.1, 2000)); splitModelGroupCollection.add(createSplitModelGroup(1, newHashSet("_792"), 0.4, 100)); int defaultGroup = 0; RandomDecisionForestBuilder randomDecisionForestBuilder = new RandomDecisionForestBuilder(); DownsamplingClassifierBuilder downsamplingBuilder = new DownsamplingClassifierBuilder(randomDecisionForestBuilder, 0.30D); SplitOnAttributeClassifierBuilder splitOnAttributeClassifierBuilder = new SplitOnAttributeClassifierBuilder("campaignId", splitModelGroupCollection, defaultGroup, downsamplingBuilder); splitOnAttributeClassifierBuilder.updateBuilderConfig(createModelConfig()); List<ClassifierLossFunction> lossFunctions = Lists.newArrayList(); lossFunctions.add(new LossFunctionCorrectedForDownsampling(new WeightedAUCCrossValLossFunction(1.0), 0.99, Double.valueOf(0.0))); lossFunctions.add(new LossFunctionCorrectedForDownsampling(new ClassifierLogCVLossFunction(0.000001), 0.99, Double.valueOf(0.0))); lossFunctions.add(new LossFunctionCorrectedForDownsampling(new ClassifierRMSELossFunction(), 0.99, Double.valueOf(0.0))); lossFunctions.add(new WeightedAUCCrossValLossFunction(1.0)); lossFunctions.add(new ClassifierLogCVLossFunction(0.00001)); lossFunctions.add(new ClassifierRMSELossFunction()); // Get the losses for a split model MultiLossModelTester splitModelTester = new MultiLossModelTester(splitOnAttributeClassifierBuilder, new OutOfTimeData<>(instances, 0.15, 24, new OnespotDateTimeExtractor())); LossFunctionTracker splitLosses = splitModelTester.getMultilossForModel(lossFunctions); // Get the losses for a non split model MultiLossModelTester singleModelTester = new MultiLossModelTester(downsamplingBuilder, new OutOfTimeData<>(instances, 0.15, 24, new OnespotDateTimeExtractor())); LossFunctionTracker singleLosses = singleModelTester.getMultilossForModel(lossFunctions); // Log losses splitLosses.logLosses(); singleLosses.logLosses(); // TODO: determine why split model so much worse for the downsampled log loss for (String function : splitLosses.lossFunctionNames()) { double singleModelLoss = singleLosses.getLossForFunction(function); double splitModelLoss = splitLosses.getLossForFunction(function); assertTrue("single Model Loss: " + singleModelLoss + "splitModelLoss: "+ splitModelLoss ,val1NotWorseThanVal2(0.2, splitModelLoss, singleModelLoss)); } } private SplitOnAttributeClassifierBuilder.SplitModelGroup createSplitModelGroup(int id, Set<String> group0Campaigns, double percentageOfCrossData, int minTotalSamples) { HashMap<Integer, Double> relativeImportance = new HashMap<>(); relativeImportance.put(1, 1.0); return new SplitOnAttributeClassifierBuilder.SplitModelGroup(id, group0Campaigns, minTotalSamples, percentageOfCrossData, relativeImportance); } private Map<String, Serializable> createModelConfig() { Map<String, Serializable> predictiveModelParameters = new HashMap<>(); predictiveModelParameters.put("numTrees", Integer.valueOf(16)); predictiveModelParameters.put("bagSize", Integer.valueOf(0));//need to clean the builders to not use this since baggingnot used predictiveModelParameters.put("ignoreAttrProb", Double.valueOf(0.7)); predictiveModelParameters.put("minScore", Double.valueOf(0.000001)); predictiveModelParameters.put("maxDepth", Integer.valueOf(16)); predictiveModelParameters.put("minCatAttrOcc", Integer.valueOf(29)); predictiveModelParameters.put("minLeafInstances", Integer.valueOf(0)); predictiveModelParameters.put("scorerFactory", new GRPenalizedGiniImpurityScorerFactory()); predictiveModelParameters.put("rebuildThreshold", Integer.valueOf(1)); predictiveModelParameters.put("splitNodeThreshold", Integer.valueOf(1)); predictiveModelParameters.put("minorityInstanceProportion", Double.valueOf(0.30)); return predictiveModelParameters; } private boolean val1NotWorseThanVal2(double tolerance, double val1, double val2) { return (val1>val2 && Math.abs((val2 - val1) / val1) < tolerance) || val1 < val2; } }