package quickml.supervised.crossValidation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.instances.Instance;
import quickml.supervised.PredictiveModel;
import quickml.supervised.PredictiveModelBuilder;
import quickml.supervised.crossValidation.data.TrainingDataCycler;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static quickml.supervised.Utils.getInstanceWeights;
public class SimpleCrossValidator<PM extends PredictiveModel, T extends Instance> implements CrossValidator {
private static final Logger logger = LoggerFactory.getLogger(SimpleCrossValidator.class);
private LossChecker<PM, T> lossChecker;
private TrainingDataCycler<T> dataCycler;
private final PredictiveModelBuilder<PM, T> modelBuilder;
public SimpleCrossValidator(PredictiveModelBuilder<PM, T> modelBuilder, LossChecker<PM, T> lossChecker, TrainingDataCycler<T> dataCycler) {
this.lossChecker = lossChecker;
this.dataCycler = dataCycler;
this.modelBuilder = modelBuilder;
}
/**
* Get the loss for a model without updating the model config
*/
public double getLossForModel() {
return getLossForModel(new HashMap<String, Serializable>());
}
public double getLossForModel(Map<String, Serializable> config) {
dataCycler.reset();
if (config.size()!=0) {
modelBuilder.updateBuilderConfig(config);
}
double loss = testModel();
logger.info("Loss {} for config {}", loss, config.toString());
return loss;
}
/**
* We keep cycling through the test data, updating the running losses for each run.
*/
private double testModel() {
double runningLoss = 0;
double runningWeightOfValidationSet = 0;
boolean gotNextCycle= false;
while (dataCycler.hasMore() || gotNextCycle){
List<T> validationSet = dataCycler.getValidationSet();
double validationSetWeight = getInstanceWeights(validationSet);
PM predictiveModel = modelBuilder.buildPredictiveModel(dataCycler.getTrainingSet());
runningLoss += lossChecker.calculateLoss(predictiveModel, validationSet) * validationSetWeight;
runningWeightOfValidationSet += validationSetWeight;
gotNextCycle = dataCycler.nextCycle();
}
return runningLoss / runningWeightOfValidationSet;
}
}