package quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions; import quickml.supervised.crossValidation.PredictionMapResult; import quickml.supervised.crossValidation.PredictionMapResults; public class ClassifierLogCVLossFunction extends ClassifierLossFunction { private static final double DEFAULT_MIN_PROBABILITY = 10E-7; public static final String NAME = "LOG_CV"; public double minProbability; public double maxError; public ClassifierLogCVLossFunction(double minProbability) { this.minProbability = minProbability; this.maxError = -Math.log(minProbability); } private double lossForInstance(double correctProbability, double weight) { return (correctProbability > minProbability) ? -weight * Math.log(correctProbability) : weight * maxError; } @Override public Double getLoss(PredictionMapResults results) { double totalLoss = 0; for (PredictionMapResult result : results) { totalLoss += lossForInstance(result.getPredictionForLabel(), result.getWeight()); } return results.totalWeight() > 0 ? totalLoss / results.totalWeight() : 0; } @Override public String getName() { return NAME; } }