package quickml.supervised.predictiveModelOptimizer; import quickml.data.instances.ClassifierInstance; import quickml.supervised.PredictiveModelBuilder; import quickml.supervised.crossValidation.attributeImportance.LossFunctionTracker; import quickml.supervised.crossValidation.data.TrainingDataCycler; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction; import quickml.supervised.classifier.Classifier; import java.util.List; import static quickml.supervised.Utils.calcResultPredictions; public class MultiLossModelTester { private TrainingDataCycler<ClassifierInstance> dataCycler; private final PredictiveModelBuilder<? extends Classifier, ClassifierInstance> modelBuilder; public MultiLossModelTester(PredictiveModelBuilder<? extends Classifier, ClassifierInstance> modelBuilder, TrainingDataCycler<ClassifierInstance> dataCycler) { this.dataCycler = dataCycler; this.modelBuilder = modelBuilder; } public LossFunctionTracker getMultilossForModel(List<ClassifierLossFunction> lossFunctions) { dataCycler.reset(); LossFunctionTracker lossFunctionTracker = new LossFunctionTracker(lossFunctions); do { List<ClassifierInstance> validationSet = dataCycler.getValidationSet(); Classifier predictiveModel = modelBuilder.buildPredictiveModel(dataCycler.getTrainingSet()); lossFunctionTracker.updateLosses(calcResultPredictions(predictiveModel, validationSet)); dataCycler.nextCycle(); } while (dataCycler.hasMore()); return lossFunctionTracker; } }