package quickml.supervised.crossValidation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.data.instances.Instance; import quickml.supervised.PredictiveModel; import quickml.supervised.PredictiveModelBuilder; import quickml.supervised.crossValidation.data.TrainingDataCycler; import java.io.Serializable; import java.util.HashMap; import java.util.List; import java.util.Map; import static quickml.supervised.Utils.getInstanceWeights; public class SimpleCrossValidator<PM extends PredictiveModel, T extends Instance> implements CrossValidator { private static final Logger logger = LoggerFactory.getLogger(SimpleCrossValidator.class); private LossChecker<PM, T> lossChecker; private TrainingDataCycler<T> dataCycler; private final PredictiveModelBuilder<PM, T> modelBuilder; public SimpleCrossValidator(PredictiveModelBuilder<PM, T> modelBuilder, LossChecker<PM, T> lossChecker, TrainingDataCycler<T> dataCycler) { this.lossChecker = lossChecker; this.dataCycler = dataCycler; this.modelBuilder = modelBuilder; } /** * Get the loss for a model without updating the model config */ public double getLossForModel() { return getLossForModel(new HashMap<String, Serializable>()); } public double getLossForModel(Map<String, Serializable> config) { dataCycler.reset(); if (config.size()!=0) { modelBuilder.updateBuilderConfig(config); } double loss = testModel(); logger.info("Loss {} for config {}", loss, config.toString()); return loss; } /** * We keep cycling through the test data, updating the running losses for each run. */ private double testModel() { double runningLoss = 0; double runningWeightOfValidationSet = 0; boolean gotNextCycle= false; while (dataCycler.hasMore() || gotNextCycle){ List<T> validationSet = dataCycler.getValidationSet(); double validationSetWeight = getInstanceWeights(validationSet); PM predictiveModel = modelBuilder.buildPredictiveModel(dataCycler.getTrainingSet()); runningLoss += lossChecker.calculateLoss(predictiveModel, validationSet) * validationSetWeight; runningWeightOfValidationSet += validationSetWeight; gotNextCycle = dataCycler.nextCycle(); } return runningLoss / runningWeightOfValidationSet; } }