package quickml.supervised.crossValidation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.data.instances.Instance; import quickml.data.instances.RegressionInstance; import quickml.supervised.PredictiveModel; import quickml.supervised.PredictiveModelBuilder; import quickml.supervised.crossValidation.data.TrainingDataCycler; import java.io.BufferedWriter; import java.io.FileWriter; import java.io.IOException; import java.io.Serializable; import java.util.HashMap; import java.util.List; import java.util.Map; import static quickml.supervised.Utils.getInstanceWeights; public class SimpleCrossValidatorWithWriter<PM extends PredictiveModel, T extends Instance> implements CrossValidator { private static final Logger logger = LoggerFactory.getLogger(SimpleCrossValidatorWithWriter.class); private LossChecker<PM, T> lossChecker; private TrainingDataCycler<T> dataCycler; private final PredictiveModelBuilder<PM, T> modelBuilder; public SimpleCrossValidatorWithWriter(PredictiveModelBuilder<PM, T> modelBuilder, LossChecker<PM, T> lossChecker, TrainingDataCycler<T> dataCycler) { this.lossChecker = lossChecker; this.dataCycler = dataCycler; this.modelBuilder = modelBuilder; } /** * Get the loss for a model without updating the model config */ public double getLossForModel() { return getLossForModel(new HashMap<String, Serializable>()); } public double getLossForModel(Map<String, Serializable> config) { dataCycler.reset(); if (config.size()!=0) { modelBuilder.updateBuilderConfig(config); } double loss = testModel(); logger.info("Loss {} for config {}", loss, config.toString()); return loss; } /** * We keep cycling through the test data, updating the running losses for each run. */ private double testModel() { try { double runningLoss = 0; double runningWeightOfValidationSet = 0; boolean gotNextCycle = false; int cycle = 0; while (dataCycler.hasMore() || gotNextCycle) { BufferedWriter trainingWriter = new BufferedWriter(new FileWriter("training" + cycle)); trainingWriter.write("FileNo"); for (T instance : dataCycler.getValidationSet()) { trainingWriter.write("" + ((RegressionInstance) instance).id + "\n"); } BufferedWriter testWriter = new BufferedWriter(new FileWriter("test" + cycle)); testWriter.write("FileNo,actual,predicted\n"); testWriter = new BufferedWriter(new FileWriter("validation" + cycle)); List<T> validationSet = dataCycler.getValidationSet(); double validationSetWeight = getInstanceWeights(validationSet); PM predictiveModel = modelBuilder.buildPredictiveModel(dataCycler.getTrainingSet()); runningLoss += ((RegressionLossChecker) lossChecker).calculateLoss(predictiveModel, validationSet, testWriter) * validationSetWeight; runningWeightOfValidationSet += validationSetWeight; gotNextCycle = dataCycler.nextCycle(); cycle++; testWriter.flush(); trainingWriter.flush(); testWriter.close(); trainingWriter.close(); } return runningLoss / runningWeightOfValidationSet; } catch (IOException e) { logger.error("couldn't write"); } return 0.0; } }