package quickml.supervised.classifier.logRegression; import com.google.common.collect.Maps; import org.junit.Ignore; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.BenchmarkTest; import quickml.InstanceLoader; import quickml.data.OnespotDateTimeExtractor; import quickml.data.instances.ClassifierInstance; import quickml.supervised.classifier.logisticRegression.*; import quickml.supervised.crossValidation.EnhancedCrossValidator; import quickml.supervised.crossValidation.data.FoldedDataFactory; import quickml.supervised.crossValidation.data.OutOfTimeDataFactory; import quickml.supervised.crossValidation.ClassifierLossChecker; import quickml.supervised.crossValidation.SimpleCrossValidator; import quickml.supervised.crossValidation.data.FoldedData; import quickml.supervised.crossValidation.data.OutOfTimeData; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierRMSELossFunction; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.WeightedAUCCrossValLossFunction; import quickml.supervised.dataProcessing.instanceTranformer.CommonCoocurrenceProductFeatureAppender; import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder; import quickml.supervised.predictiveModelOptimizer.FieldValueRecommender; import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizer; import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender; import quickml.supervised.tree.decisionTree.DecisionTreeBuilder; import java.util.List; import java.util.Map; import static quickml.supervised.classifier.logisticRegression.LogisticRegressionBuilder.MIN_OBSERVATIONS_OF_ATTRIBUTE; import static quickml.supervised.classifier.logisticRegression.SparseSGD.*; /** * Created by alexanderhawk on 10/13/15. */ public class RidgeRegressionBuilderTest { public static final Logger logger = LoggerFactory.getLogger(RidgeRegressionBuilderTest.class); @Ignore //test takes too long, but is illustrative of how to build a model @Test public void testAdInstances() { List<ClassifierInstance> instances = InstanceLoader.getAdvertisingInstances(); logger.info("got instances"); CommonCoocurrenceProductFeatureAppender productFeatureAppender = new CommonCoocurrenceProductFeatureAppender<>() .setMinObservationsOfRawAttribute(35) .setAllowCategoricalProductFeatures(true) .setAllowNumericProductFeatures(false) .setApproximateOverlap(true) .setMinOverlap(20) .setIgnoreAttributesCommonToAllInsances(true); DatedAndMeanNormalizedLogisticRegressionDataTransformer lrdt = new DatedAndMeanNormalizedLogisticRegressionDataTransformer() .minObservationsOfAttribute(35) .usingProductFeatures(true) .productFeatureAppender(productFeatureAppender); LogisticRegressionBuilder<MeanNormalizedAndDatedLogisticRegressionDTO> logisticRegressionBuilder = new LogisticRegressionBuilder<MeanNormalizedAndDatedLogisticRegressionDTO>(lrdt) .calibrateWithPoolAdjacentViolators(false) .gradientDescent(new SparseSGD() .ridgeRegularizationConstant(0.1) .learningRate(.0025) .minibatchSize(1000) .minEpochs(1000) .maxEpochs(1000) .minPredictedProbablity(1E-3) .sparseParallelization(true) ); double start = System.nanoTime(); EnhancedCrossValidator<LogisticRegression, ClassifierInstance, SparseClassifierInstance, MeanNormalizedAndDatedLogisticRegressionDTO> enhancedCrossValidator = new EnhancedCrossValidator<>(logisticRegressionBuilder, new ClassifierLossChecker(new WeightedAUCCrossValLossFunction(1.0)), new OutOfTimeDataFactory(0.25, 48), instances); double lossForSGD = enhancedCrossValidator.getLossForModel(); double stop = System.nanoTime(); logger.info("LR out of time loss: {}, in {} nanoseconds", lossForSGD, stop-start); RandomDecisionForestBuilder<ClassifierInstance> randomDecisionForestBuilder = new RandomDecisionForestBuilder<>(new DecisionTreeBuilder<>().minAttributeValueOccurences(2).maxDepth(12).minLeafInstances(0).minSplitFraction(.005).ignoreAttributeProbability(0.5)).numTrees(64); SimpleCrossValidator<LogisticRegression, ClassifierInstance> simpleCrossValidator = new SimpleCrossValidator(randomDecisionForestBuilder, new ClassifierLossChecker(new WeightedAUCCrossValLossFunction(1.0)), new OutOfTimeData<ClassifierInstance>(instances, 0.25, 48, new OnespotDateTimeExtractor())); logger.info("RF out of time loss: {}", simpleCrossValidator.getLossForModel()); } @Test public void testDiabetesInstances() { //need a builder List<ClassifierInstance> instances = BenchmarkTest.loadDiabetesDataset(); CommonCoocurrenceProductFeatureAppender productFeatureAppender = new CommonCoocurrenceProductFeatureAppender<>().setMinObservationsOfRawAttribute(1) .setAllowNumericProductFeatures(true) .setApproximateOverlap(true) .setMinOverlap(0); DatedAndMeanNormalizedLogisticRegressionDataTransformer lrdt = new DatedAndMeanNormalizedLogisticRegressionDataTransformer() .minObservationsOfAttribute(1) .usingProductFeatures(true) .productFeatureAppender(productFeatureAppender); LogisticRegressionBuilder<MeanNormalizedAndDatedLogisticRegressionDTO> logisticRegressionBuilder = new LogisticRegressionBuilder<MeanNormalizedAndDatedLogisticRegressionDTO>(lrdt); logisticRegressionBuilder.gradientDescent(new SparseSGD() .executorThreadCount(3) .sparseParallelization(false) .ridgeRegularizationConstant(.1) .learningRate(.001) .minibatchSize(600) .minEpochs(16000) .maxEpochs(16000) .useBoldDriver(false) .learningRateReductionFactor(0.01)); ClassifierLossFunction lossFunction = new ClassifierRMSELossFunction();//);//new ClassifierRMSELossFunction();//new WeightedAUCCrossValLossFunction(1.0);//new ClassifierRMSELossFunction();//new ClassifierLogCVLossFunction(1E-5);//new WeightedAUCCrossValLossFunction(1.0); EnhancedCrossValidator<LogisticRegression, ClassifierInstance, SparseClassifierInstance, MeanNormalizedAndDatedLogisticRegressionDTO> enhancedCrossValidator = new EnhancedCrossValidator<>(logisticRegressionBuilder, new ClassifierLossChecker(lossFunction), new FoldedDataFactory(4, 4), instances); logger.info("LR out of time loss: {}", enhancedCrossValidator.getLossForModel()); RandomDecisionForestBuilder<ClassifierInstance> randomDecisionForestBuilder = new RandomDecisionForestBuilder<>(new DecisionTreeBuilder<>().minAttributeValueOccurences(2).maxDepth(5).minLeafInstances(20).minSplitFraction(.005).ignoreAttributeProbability(0.5)).numTrees(64); SimpleCrossValidator<LogisticRegression, ClassifierInstance> simpleCrossValidator = new SimpleCrossValidator(randomDecisionForestBuilder, new ClassifierLossChecker(lossFunction), new FoldedData(instances, 4, 4)); logger.info("RF out of time loss: {}", simpleCrossValidator.getLossForModel()); } @Ignore @Test public void optimizerTest(){ List<ClassifierInstance> instances = InstanceLoader.getAdvertisingInstances().subList(0,1000); CommonCoocurrenceProductFeatureAppender productFeatureAppender = new CommonCoocurrenceProductFeatureAppender<>() .setMinObservationsOfRawAttribute(35) .setAllowCategoricalProductFeatures(false) .setAllowNumericProductFeatures(false) .setApproximateOverlap(true) .setMinOverlap(20) .setIgnoreAttributesCommonToAllInsances(true); DatedAndMeanNormalizedLogisticRegressionDataTransformer lrdt = new DatedAndMeanNormalizedLogisticRegressionDataTransformer() .minObservationsOfAttribute(35) .usingProductFeatures(false) .productFeatureAppender(productFeatureAppender); LogisticRegressionBuilder<MeanNormalizedAndDatedLogisticRegressionDTO> logisticRegressionBuilder = new LogisticRegressionBuilder<MeanNormalizedAndDatedLogisticRegressionDTO>(lrdt) .calibrateWithPoolAdjacentViolators(false) .gradientDescent(new SparseSGD() .ridgeRegularizationConstant(0.1) .learningRate(.0025) .minibatchSize(1000) .minEpochs(500) .maxEpochs(500) .minPredictedProbablity(1E-3) .sparseParallelization(true) ); double start = System.nanoTime(); EnhancedCrossValidator<LogisticRegression, ClassifierInstance, SparseClassifierInstance, MeanNormalizedAndDatedLogisticRegressionDTO> enhancedCrossValidator = new EnhancedCrossValidator<>(logisticRegressionBuilder, new ClassifierLossChecker(new WeightedAUCCrossValLossFunction(1.0)), new OutOfTimeDataFactory(0.25, 48), instances); Map<String, FieldValueRecommender> sgdParams = Maps.newHashMap(); sgdParams.put(RIDGE, new FixedOrderRecommender(.0001));//;, .001, .01, .1, 1));//MonotonicConvergenceRecommender(numTreesList, 0.01)); sgdParams.put(MIN_EPOCHS, new FixedOrderRecommender(8000));// 16000)); sgdParams.put(MAX_EPOCHS, new FixedOrderRecommender(16000));//, 3200)); sgdParams.put(LEARNING_RATE, new FixedOrderRecommender(.0025));//, .001, .005));//11, 14, 16 //Pbest 12 sgdParams.put(MIN_OBSERVATIONS_OF_ATTRIBUTE, new FixedOrderRecommender(20, 50));// 16000)); PredictiveModelOptimizer modelOptimizer = new PredictiveModelOptimizer(sgdParams, enhancedCrossValidator, 2); logger.info("Optimal sgd parameters: {}", modelOptimizer.determineOptimalConfig()); } }