package quickml.supervised.predictiveModelOptimizer; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.supervised.crossValidation.CrossValidator; import java.io.Serializable; import java.util.*; public class PredictiveModelOptimizer { private static final Logger logger = LoggerFactory.getLogger(PredictiveModelOptimizer.class); private Map<String, ? extends FieldValueRecommender> fieldsToOptimize; //should pass in private final CrossValidator crossValidator; private HashMap<String, Serializable> localBestConfig; //should be a param: not be stateful private HashMap<String, Serializable> bestConfig; //should be a param: not be stateful private final int iterations;//should be param private int iteration; //should not be a field private List<ConfigWithLoss> configsWithLosses = Lists.newArrayList(); /** * Do not call directly - Use PredictiveModelOptimizerBuilder to an instance * @param fieldsToOptimize - key is the field - e.g. maxDepth, FixedOrderRecommender is a set of values for maxDepth to try * @param crossValidator - Model tester takes a configuration and returns the loss */ public PredictiveModelOptimizer(Map<String, ? extends FieldValueRecommender> fieldsToOptimize, CrossValidator crossValidator, int iterations) { this.fieldsToOptimize = fieldsToOptimize; this.crossValidator = crossValidator; this.iterations = iterations; this.localBestConfig = setBestConfigToFirstValues(fieldsToOptimize); } /** * We find the value for each field that results in the lowest loss * Then repeat the process starting with the optimized configuration * Keep going until we are no longer improving or we have reached max_iterations */ public Map<String, Serializable> determineOptimalConfig() { for (iteration = 0; iteration < iterations; iteration++) { logger.info("Starting iteration - {}", iteration); HashMap<String, Serializable> previousConfig = copyOf(localBestConfig); updateBestConfig(); if (localBestConfig.equals(previousConfig)) break; } sortConfigsWithLosses(); logger.info("best loss: is {} for: \n{}", configsWithLosses.get(0).loss, configsWithLosses.get(0).config); return configsWithLosses.get(0).config; } public List<ConfigWithLoss> exploreConfigs() { configsWithLosses = Lists.newArrayList(); for (iteration = 0; iteration < iterations; iteration++) { logger.info("Starting iteration - {}", iteration); HashMap<String, Serializable> previousConfig = copyOf(localBestConfig); updateBestConfig(); if (localBestConfig.equals(previousConfig)) break; } return configsWithLosses; } private void sortConfigsWithLosses() { Collections.sort(configsWithLosses, new Comparator<ConfigWithLoss>() { @Override public int compare(ConfigWithLoss o1, ConfigWithLoss o2) { return Double.compare(o1.loss, o2.loss); } }); } private void updateBestConfig() { for (String field : fieldsToOptimize.keySet()) { logger.info("optimizing {}", field); findBestValueForField(field); } } private void findBestValueForField(String field) { FieldLosses losses = new FieldLosses(); FieldValueRecommender fieldValueRecommender = fieldsToOptimize.get(field); if (fieldValueRecommender.getValues().size() == 1) { return; } //localBestConfig is not actually localBestConfig inth for loop logger.info("values to try: {} ", fieldValueRecommender.getValues().toString()); for (Serializable value : fieldValueRecommender.getValues()) { //TODO: make so it does not repeat a conf already seen in present iteration (e.g. keep a set of configs) if (localBestConfig.get(field).equals(value) && iteration > 0) { logger.info("skipping field value {} bc value {} already tried ", field, value); continue; //safe to continue bc everything else about the config is the same. } localBestConfig.put(field, value); double lossForModel = crossValidator.getLossForModel(localBestConfig); logger.info("loss: {}, for field {}, config {}", lossForModel, field, localBestConfig); losses.addFieldLoss(value, lossForModel); if (configsWithLosses != null) { configsWithLosses.add(new ConfigWithLoss(lossForModel, copyOf(localBestConfig))); } if (!fieldValueRecommender.shouldContinue(losses.getLosses())) break; } if (!losses.getLosses().isEmpty()) { localBestConfig.put(field, losses.valueWithLowestLoss()); } } private HashMap<String, Serializable> setBestConfigToFirstValues(Map<String, ? extends FieldValueRecommender> config) { HashMap<String, Serializable> map = new HashMap<>(); for (Map.Entry<String, ? extends FieldValueRecommender> entry : config.entrySet()) { map.put(entry.getKey(), entry.getValue().first()); } logger.info("Initial Configuration - {}", map); return map; } private HashMap<String, Serializable> copyOf(final HashMap<String, Serializable> map) { return Maps.newHashMap(map); } /** * Convience classes to sort and return the value with the lowest loss */ public static class FieldLosses { private List<FieldLoss> losses = new ArrayList<>(); public void add(FieldLoss fieldLoss) { losses.add(fieldLoss); } public Serializable valueWithLowestLoss() { Collections.sort(losses); return losses.get(0).fieldValue; } public void addFieldLoss(Serializable fieldValue, double loss) { add(new FieldLoss(fieldValue, loss)); } public List<Double> getLosses() { List<Double> rawLosses = Lists.newArrayList(); for (FieldLoss loss : losses) { rawLosses.add(loss.loss); } return rawLosses; } } private static class FieldLoss implements Comparable<FieldLoss> { private final Serializable fieldValue; private final double loss; public FieldLoss(Serializable fieldValue, double loss) { this.fieldValue = fieldValue; this.loss = loss; } @Override public int compareTo(FieldLoss o) { return Double.compare(this.loss, o.loss); } } }