package quickml.supervised.crossValidation;
import com.google.common.collect.Lists;
import quickml.data.instances.ClassifierInstance;
import quickml.supervised.Utils;
import quickml.supervised.classifier.Classifier;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction;
import java.util.List;
public class MultiTargetLossChecker<T extends ClassifierInstance> implements LossChecker<Classifier, T> {
private ClassifierLossFunction lossFunction;
private InstanceTargetSelector instanceTargetSelector;
public MultiTargetLossChecker(ClassifierLossFunction lossFunction, InstanceTargetSelector instanceTargets) {
this.lossFunction = lossFunction;
this.instanceTargetSelector = instanceTargets;
}
@Override
public double calculateLoss(Classifier predictiveModel, List<T> validationSet) {
List<ClassifierInstance> singleTargetValidationSet = Lists.newArrayList();
for(T instance : validationSet) {
singleTargetValidationSet.add(new ClassifierInstance(instance.getAttributes(), instanceTargetSelector.getSingleLabel(instance), instance.getWeight()));
}
return lossFunction.getLoss(Utils.calcResultPredictions(predictiveModel, singleTargetValidationSet));
}
}