package quickml.supervised.classifier.logisticRegression; /** * Created by alexanderhawk on 10/12/15. */ import com.google.common.collect.Lists; import quickml.data.instances.ClassifierInstance; import quickml.supervised.EnhancedPredictiveModelBuilder; import quickml.supervised.Utils; import quickml.supervised.classifier.Classifier; import quickml.supervised.dataProcessing.instanceTranformer.ProductFeatureAppender; import quickml.supervised.regressionModel.IsotonicRegression.PoolAdjacentViolatorsModel; import java.io.Serializable; import java.util.List; import java.util.Map; import java.util.Set; /** * Created by alexanderhawk on 10/9/15. */ public class LogisticRegressionBuilder<D extends LogisticRegressionDTO<D>> implements EnhancedPredictiveModelBuilder<LogisticRegression, ClassifierInstance, SparseClassifierInstance, D> { public boolean calibrateWithPoolAdjacentViolators = false; public static final String MIN_OBSERVATIONS_OF_ATTRIBUTE= "minObservationsOfAttribute"; public static final String PRODUCT_FEATURE_APPENDER = "productFeatureAppender"; public static final String CALIBRATE_WITH_POOL_ADJACENT_VIOLATORS = "calibrateWithPoolAdjacentViolators"; public static final String POOL_ADJACENT_VIOLATORS_MIN_WEIGHT = "poolAdjacentViolatorsMinWeight"; public StandardDataTransformer<D> logisticRegressionDataTransformer; private ProductFeatureAppender<ClassifierInstance> productFeatureAppender; GradientDescent<SparseClassifierInstance> gradientDescent = new SparseSGD(); private int minWeightForPavBuckets =2; public LogisticRegressionBuilder(StandardDataTransformer<D> dataTransformer) { this.logisticRegressionDataTransformer = dataTransformer; } public LogisticRegressionBuilder<D> productFeatureAppender(ProductFeatureAppender<ClassifierInstance> productFeatureAppender) { logisticRegressionDataTransformer.productFeatureAppender(productFeatureAppender); return this; } public LogisticRegressionBuilder<D> minObservationsOfAttribute(int minObservationsOfAttribute) { logisticRegressionDataTransformer.minObservationsOfAttribute(minObservationsOfAttribute); return this; } public LogisticRegressionBuilder<D> gradientDescent(GradientDescent gradientDescent) { this.gradientDescent = gradientDescent; return this; } public LogisticRegressionBuilder<D> calibrateWithPoolAdjacentViolators(boolean calibrateWithPoolAdjacentViolators) { this.calibrateWithPoolAdjacentViolators = calibrateWithPoolAdjacentViolators; return this; } public LogisticRegressionBuilder<D> poolAdjacentViolatorsMinWeight(int minWeightForPavBuckets) { this.minWeightForPavBuckets = minWeightForPavBuckets; return this; } @Override public D transformData(List<ClassifierInstance> rawInstances){ return logisticRegressionDataTransformer.transformData(rawInstances); } @Override public LogisticRegression buildPredictiveModel(D logisticRegressionDTO) { List<SparseClassifierInstance> sparseClassifierInstances =logisticRegressionDTO.getTransformedInstances(); double[] weights = gradientDescent.minimize(sparseClassifierInstances, logisticRegressionDTO.getNameToIndexMap().size()); LogisticRegression uncalibrated = getUncalibratedModel(logisticRegressionDTO, weights); if (calibrateWithPoolAdjacentViolators) { PoolAdjacentViolatorsModel poolAdjacentViolatorsModel = new PoolAdjacentViolatorsModel(LogisticRegressionBuilder.<SparseClassifierInstance>getPavPredictions(logisticRegressionDTO.getTransformedInstances(), uncalibrated), minWeightForPavBuckets); return new LogisticRegression(uncalibrated, poolAdjacentViolatorsModel); } return uncalibrated; } private LogisticRegressionDTO getLogisticRegressionDTO(Iterable<? extends ClassifierInstance> trainingData) { List<ClassifierInstance> trainingDataList = Utils.iterableToListOfClassifierInstances(trainingData); return logisticRegressionDataTransformer.transformData(trainingDataList); } // Could have a model factory that has no generics on D store all the information that the DTO stores...have an object specific setter, and a "getModelMethod. This factory it would consume the // model builder...and then finush off the build. Would it have to be generic? private LogisticRegression getUncalibratedModel(D logisticRegressionDTO, double[] weights) { LogisticRegression uncalibrated; if (logisticRegressionDTO.getNumericClassLabels() == null) { Set<Double> classifications = InstanceTransformerUtils.getClassifications(logisticRegressionDTO.getTransformedInstances()); uncalibrated = new LogisticRegression(weights, logisticRegressionDTO.getNameToIndexMap(), classifications); } else { uncalibrated = new LogisticRegression(weights,logisticRegressionDTO.getNameToIndexMap(),logisticRegressionDTO.getNumericClassLabels()); } return uncalibrated; } public static <I extends ClassifierInstance> List<PoolAdjacentViolatorsModel.Observation> getPavPredictions(List<I> trainingData, Classifier classifier) { List<PoolAdjacentViolatorsModel.Observation> observations = Lists.newArrayList(); for (I instance : trainingData) { double uncalibratedProbability = classifier.getProbability(instance.getAttributes(), 1.0); PoolAdjacentViolatorsModel.Observation ob = new PoolAdjacentViolatorsModel.Observation(uncalibratedProbability, (Double) instance.getLabel(), instance.getWeight()); observations.add(ob); } return observations; } @Override public void updateBuilderConfig(final Map<String, Serializable> config) { gradientDescent.updateBuilderConfig(config); if (config.containsKey(MIN_OBSERVATIONS_OF_ATTRIBUTE)) { minObservationsOfAttribute((Integer) config.get(MIN_OBSERVATIONS_OF_ATTRIBUTE)); } if (config.containsKey(PRODUCT_FEATURE_APPENDER)) { productFeatureAppender((ProductFeatureAppender<ClassifierInstance>) config.get(PRODUCT_FEATURE_APPENDER)); } if (config.containsKey(CALIBRATE_WITH_POOL_ADJACENT_VIOLATORS)) { calibrateWithPoolAdjacentViolators((Boolean) config.get(CALIBRATE_WITH_POOL_ADJACENT_VIOLATORS)); } if (config.containsKey(POOL_ADJACENT_VIOLATORS_MIN_WEIGHT)) { poolAdjacentViolatorsMinWeight((Integer) config.get(POOL_ADJACENT_VIOLATORS_MIN_WEIGHT)); } } }