package quickml.supervised.crossValidation.lossfunctions;
import quickml.supervised.crossValidation.PredictionMapResult;
import quickml.supervised.crossValidation.PredictionMapResults;
import java.util.List;
public class LossFunctions {
public static double mseClassifierLoss(PredictionMapResults results) {
double totalLoss = 0;
for (PredictionMapResult result : results) {
final double error = (1.0 - result.getPredictionForLabel());
final double errorSquared = error * error * result.getWeight();
totalLoss += errorSquared;
}
return results.totalWeight() > 0 ? totalLoss / results.totalWeight() : 0;
}
public static double rmseClassifierLoss(PredictionMapResults results) {
return Math.sqrt(mseClassifierLoss(results));
}
public static double mseRegressionLoss(List<LabelPredictionWeight<Double, Double>> results) {
double totalLoss = 0;
double totalWeight = 0;
for (LabelPredictionWeight<Double, Double> result : results) {
final double error = (result.getLabel() - result.getPrediction());
final double errorSquared = error * error * result.getWeight();
totalLoss += errorSquared*result.getWeight();
totalWeight += result.getWeight();
}
return totalWeight > 0 ? totalLoss / totalWeight : 0;
}
public static double rmseRegressionLoss(List<LabelPredictionWeight<Double, Double>> results) {
return Math.sqrt(mseRegressionLoss(results));
}
}