/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.common; import hivemall.utils.math.MathUtils; /** * @link https://github.com/JohnLangford/vowpal_wabbit/wiki/Loss-functions */ public final class LossFunctions { public enum LossType { SquaredLoss, LogLoss, HingeLoss, SquaredHingeLoss, QuantileLoss, EpsilonInsensitiveLoss } public static LossFunction getLossFunction(String type) { if ("SquaredLoss".equalsIgnoreCase(type)) { return new SquaredLoss(); } else if ("LogLoss".equalsIgnoreCase(type)) { return new LogLoss(); } else if ("HingeLoss".equalsIgnoreCase(type)) { return new HingeLoss(); } else if ("SquaredHingeLoss".equalsIgnoreCase(type)) { return new SquaredHingeLoss(); } else if ("QuantileLoss".equalsIgnoreCase(type)) { return new QuantileLoss(); } else if ("EpsilonInsensitiveLoss".equalsIgnoreCase(type)) { return new EpsilonInsensitiveLoss(); } throw new IllegalArgumentException("Unsupported type: " + type); } public static LossFunction getLossFunction(LossType type) { switch (type) { case SquaredLoss: return new SquaredLoss(); case LogLoss: return new LogLoss(); case HingeLoss: return new HingeLoss(); case SquaredHingeLoss: return new SquaredHingeLoss(); case QuantileLoss: return new QuantileLoss(); case EpsilonInsensitiveLoss: return new EpsilonInsensitiveLoss(); default: throw new IllegalArgumentException("Unsupported type: " + type); } } public interface LossFunction { /** * Evaluate the loss function. * * @param p The prediction, p = w^T x * @param y The true value (aka target) * @return The loss evaluated at `p` and `y`. */ public float loss(float p, float y); public double loss(double p, double y); /** * Evaluate the derivative of the loss function with respect to the prediction `p`. * * @param p The prediction, p = w^T x * @param y The true value (aka target) * @return The derivative of the loss function w.r.t. `p`. */ public float dloss(float p, float y); public boolean forBinaryClassification(); public boolean forRegression(); } public static abstract class BinaryLoss implements LossFunction { protected static void checkTarget(float y) { if (!(y == 1.f || y == -1.f)) { throw new IllegalArgumentException("target must be [+1,-1]: " + y); } } protected static void checkTarget(double y) { if (!(y == 1.d || y == -1.d)) { throw new IllegalArgumentException("target must be [+1,-1]: " + y); } } @Override public boolean forBinaryClassification() { return true; } @Override public boolean forRegression() { return false; } } public static abstract class RegressionLoss implements LossFunction { @Override public boolean forBinaryClassification() { return false; } @Override public boolean forRegression() { return true; } } /** * Squared loss for regression problems. * * If you're trying to minimize the mean error, use squared-loss. */ public static final class SquaredLoss extends RegressionLoss { @Override public float loss(float p, float y) { final float z = p - y; return z * z * 0.5f; } @Override public double loss(double p, double y) { final double z = p - y; return z * z * 0.5d; } @Override public float dloss(float p, float y) { return p - y; // 2 (p - y) / 2 } } /** * Logistic regression loss for binary classification with y in {-1, 1}. */ public static final class LogLoss extends BinaryLoss { /** * <code>logloss(p,y) = log(1+exp(-p*y))</code> */ @Override public float loss(float p, float y) { checkTarget(y); final float z = y * p; if (z > 18.f) { return (float) Math.exp(-z); } if (z < -18.f) { return -z; } return (float) Math.log(1.d + Math.exp(-z)); } @Override public double loss(double p, double y) { checkTarget(y); final double z = y * p; if (z > 18.d) { return Math.exp(-z); } if (z < -18.d) { return -z; } return Math.log(1.d + Math.exp(-z)); } @Override public float dloss(float p, float y) { checkTarget(y); float z = y * p; if (z > 18.f) { return (float) Math.exp(-z) * -y; } if (z < -18.f) { return -y; } return -y / ((float) Math.exp(z) + 1.f); } } /** * Hinge loss for binary classification tasks with y in {-1,1}. */ public static final class HingeLoss extends BinaryLoss { private float threshold; public HingeLoss() { this(1.f); } /** * @param threshold Margin threshold. When threshold=1.0, one gets the loss used by SVM. * When threshold=0.0, one gets the loss used by the Perceptron. */ public HingeLoss(float threshold) { this.threshold = threshold; } public void setThreshold(float threshold) { this.threshold = threshold; } @Override public float loss(float p, float y) { float loss = hingeLoss(p, y, threshold); return (loss > 0.f) ? loss : 0.f; } @Override public double loss(double p, double y) { double loss = hingeLoss(p, y, threshold); return (loss > 0.d) ? loss : 0.d; } @Override public float dloss(float p, float y) { float loss = hingeLoss(p, y, threshold); return (loss > 0.f) ? -y : 0.f; } } /** * Squared Hinge loss for binary classification tasks with y in {-1,1}. */ public static final class SquaredHingeLoss extends BinaryLoss { @Override public float loss(float p, float y) { return squaredHingeLoss(p, y); } @Override public double loss(double p, double y) { return squaredHingeLoss(p, y); } @Override public float dloss(float p, float y) { checkTarget(y); float d = 1 - (y * p); return (d > 0.f) ? -2.f * d * y : 0.f; } } /** * Quantile loss is useful to predict rank/order and you do not mind the mean error to increase * as long as you get the relative order correct. * * @link http://en.wikipedia.org/wiki/Quantile_regression */ public static final class QuantileLoss extends RegressionLoss { private float tau; public QuantileLoss() { this.tau = 0.5f; } public QuantileLoss(float tau) { setTau(tau); } public void setTau(float tau) { if (tau <= 0 || tau >= 1.0) { throw new IllegalArgumentException("tau must be in range (0, 1): " + tau); } this.tau = tau; } @Override public float loss(float p, float y) { float e = y - p; if (e > 0.f) { return tau * e; } else { return -(1.f - tau) * e; } } @Override public double loss(double p, double y) { double e = y - p; if (e > 0.d) { return tau * e; } else { return -(1.d - tau) * e; } } @Override public float dloss(float p, float y) { float e = y - p; if (e == 0.f) { return 0.f; } return (e > 0.f) ? -tau : (1.f - tau); } } /** * Epsilon-Insensitive loss used by Support Vector Regression (SVR). * <code>loss = max(0, |y - p| - epsilon)</code> */ public static final class EpsilonInsensitiveLoss extends RegressionLoss { private float epsilon; public EpsilonInsensitiveLoss() { this(0.1f); } public EpsilonInsensitiveLoss(float epsilon) { this.epsilon = epsilon; } public void setEpsilon(float epsilon) { this.epsilon = epsilon; } @Override public float loss(float p, float y) { float loss = Math.abs(y - p) - epsilon; return (loss > 0.f) ? loss : 0.f; } @Override public double loss(double p, double y) { double loss = Math.abs(y - p) - epsilon; return (loss > 0.d) ? loss : 0.d; } @Override public float dloss(float p, float y) { if ((y - p) > epsilon) {// real value > predicted value - epsilon return -1.f; } if ((p - y) > epsilon) {// real value < predicted value - epsilon return 1.f; } return 0.f; } } public static float logisticLoss(final float target, final float predicted) { if (predicted > -100.d) { return target - (float) MathUtils.sigmoid(predicted); } else { return target; } } public static float logLoss(final float p, final float y) { BinaryLoss.checkTarget(y); final float z = y * p; if (z > 18.f) { return (float) Math.exp(-z); } if (z < -18.f) { return -z; } return (float) Math.log(1.d + Math.exp(-z)); } public static double logLoss(final double p, final double y) { BinaryLoss.checkTarget(y); final double z = y * p; if (z > 18.d) { return Math.exp(-z); } if (z < -18.d) { return -z; } return Math.log(1.d + Math.exp(-z)); } public static float squaredLoss(float p, float y) { final float z = p - y; return z * z * 0.5f; } public static double squaredLoss(double p, double y) { final double z = p - y; return z * z * 0.5d; } public static float hingeLoss(final float p, final float y, final float threshold) { BinaryLoss.checkTarget(y); float z = y * p; return threshold - z; } public static double hingeLoss(final double p, final double y, final double threshold) { BinaryLoss.checkTarget(y); double z = y * p; return threshold - z; } public static float hingeLoss(float p, float y) { return hingeLoss(p, y, 1.f); } public static double hingeLoss(double p, double y) { return hingeLoss(p, y, 1.d); } public static float squaredHingeLoss(final float p, final float y) { BinaryLoss.checkTarget(y); float z = y * p; float d = 1.f - z; return (d > 0.f) ? (d * d) : 0.f; } public static double squaredHingeLoss(final double p, final double y) { BinaryLoss.checkTarget(y); double z = y * p; double d = 1.d - z; return (d > 0.d) ? d * d : 0.d; } /** * Math.abs(target - predicted) - epsilon */ public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) { return Math.abs(target - predicted) - epsilon; } }