/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.utils.math; import hivemall.utils.lang.Preconditions; import javax.annotation.Nonnull; import org.apache.commons.math3.distribution.ChiSquaredDistribution; import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.NotPositiveException; import org.apache.commons.math3.linear.DecompositionSolver; import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.linear.SingularValueDecomposition; import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.MathArrays; import java.util.AbstractMap; import java.util.Map; public final class StatsUtils { private StatsUtils() {} /** * probit(p)=sqrt(2)erf^-1(2p-1) * * <pre> * probit(1)=INF, probit(0)=-INF, probit(0.5)=0 * </pre> * * @param p must be in [0,1] * @link http://en.wikipedia.org/wiki/Probit */ public static double probit(double p) { if (p < 0 || p > 1) { throw new IllegalArgumentException("p must be in [0,1]"); } return Math.sqrt(2.d) * MathUtils.inverseErf(2.d * p - 1.d); } public static double probit(double p, double range) { if (range <= 0) { throw new IllegalArgumentException("range must be > 0: " + range); } if (p == 0) { return -range; } if (p == 1) { return range; } double v = probit(p); if (v < 0) { return Math.max(v, -range); } else { return Math.min(v, range); } } /** * @return value of probabilistic density function */ public static double pdf(final double x, final double x_hat, final double sigma) { if (sigma == 0.d) { return 0.d; } double diff = x - x_hat; double numerator = Math.exp(-0.5d * diff * diff / sigma); double denominator = Math.sqrt(2.d * Math.PI) * Math.sqrt(sigma); return numerator / denominator; } /** * pdf(x, x_hat) = exp(-0.5 * (x-x_hat) * inv(Σ) * (x-x_hat)T) / ( 2π^0.5d * det(Σ)^0.5) * * @return value of probabilistic density function * @link https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Density_function */ public static double pdf(@Nonnull final RealVector x, @Nonnull final RealVector x_hat, @Nonnull final RealMatrix sigma) { final int dim = x.getDimension(); Preconditions.checkArgument(x_hat.getDimension() == dim, "|x| != |x_hat|, |x|=" + dim + ", |x_hat|=" + x_hat.getDimension()); Preconditions.checkArgument(sigma.getRowDimension() == dim, "|x| != |sigma|, |x|=" + dim + ", |sigma|=" + sigma.getRowDimension()); Preconditions.checkArgument(sigma.isSquare(), "Sigma is not square matrix"); LUDecomposition LU = new LUDecomposition(sigma); final double detSigma = LU.getDeterminant(); double denominator = Math.pow(2.d * Math.PI, 0.5d * dim) * Math.pow(detSigma, 0.5d); if (denominator == 0.d) { // avoid divide by zero return 0.d; } final RealMatrix invSigma; DecompositionSolver solver = LU.getSolver(); if (solver.isNonSingular() == false) { SingularValueDecomposition svd = new SingularValueDecomposition(sigma); invSigma = svd.getSolver().getInverse(); // least square solution } else { invSigma = solver.getInverse(); } //EigenDecomposition eigen = new EigenDecomposition(sigma); //double detSigma = eigen.getDeterminant(); //RealMatrix invSigma = eigen.getSolver().getInverse(); RealVector diff = x.subtract(x_hat); RealVector premultiplied = invSigma.preMultiply(diff); double sum = premultiplied.dotProduct(diff); double numerator = Math.exp(-0.5d * sum); return numerator / denominator; } public static double logLoss(final double actual, final double predicted, final double sigma) { double p = pdf(actual, predicted, sigma); if (p == 0.d) { return 0.d; } return -Math.log(p); } public static double logLoss(@Nonnull final RealVector actual, @Nonnull final RealVector predicted, @Nonnull final RealMatrix sigma) { double p = pdf(actual, predicted, sigma); if (p == 0.d) { return 0.d; } return -Math.log(p); } /** * @param mu1 mean of the first normal distribution * @param sigma1 variance of the first normal distribution * @param mu2 mean of the second normal distribution * @param sigma2 variance of the second normal distribution * @return the Hellinger distance between two normal distributions * @link https://en.wikipedia.org/wiki/Hellinger_distance#Examples */ public static double hellingerDistance(@Nonnull final double mu1, @Nonnull final double sigma1, @Nonnull final double mu2, @Nonnull final double sigma2) { double sigmaSum = sigma1 + sigma2; if (sigmaSum == 0.d) { return 0.d; } double numerator = Math.pow(sigma1, 0.25d) * Math.pow(sigma2, 0.25d) * Math.exp(-0.25d * Math.pow(mu1 - mu2, 2d) / sigmaSum); double denominator = Math.sqrt(sigmaSum / 2d); if (denominator == 0.d) { return 1.d; } return 1.d - numerator / denominator; } /** * @param mu1 mean vector of the first normal distribution * @param sigma1 covariance matrix of the first normal distribution * @param mu2 mean vector of the second normal distribution * @param sigma2 covariance matrix of the second normal distribution * @return the Hellinger distance between two multivariate normal distributions * @link https://en.wikipedia.org/wiki/Hellinger_distance#Examples */ public static double hellingerDistance(@Nonnull final RealVector mu1, @Nonnull final RealMatrix sigma1, @Nonnull final RealVector mu2, @Nonnull final RealMatrix sigma2) { RealVector muSub = mu1.subtract(mu2); RealMatrix sigmaMean = sigma1.add(sigma2).scalarMultiply(0.5d); LUDecomposition LUsigmaMean = new LUDecomposition(sigmaMean); double denominator = Math.sqrt(LUsigmaMean.getDeterminant()); if (denominator == 0.d) { return 1.d; // avoid divide by zero } RealMatrix sigmaMeanInv = LUsigmaMean.getSolver().getInverse(); // has inverse iff det != 0 double sigma1Det = MatrixUtils.det(sigma1); double sigma2Det = MatrixUtils.det(sigma2); double numerator = Math.pow(sigma1Det, 0.25d) * Math.pow(sigma2Det, 0.25d) * Math.exp(-0.125d * sigmaMeanInv.preMultiply(muSub).dotProduct(muSub)); return 1.d - numerator / denominator; } /** * @param observed means non-negative vector * @param expected means positive vector * @return chi2 value */ public static double chiSquare(@Nonnull final double[] observed, @Nonnull final double[] expected) { if (observed.length < 2) { throw new DimensionMismatchException(observed.length, 2); } if (expected.length != observed.length) { throw new DimensionMismatchException(observed.length, expected.length); } MathArrays.checkPositive(expected); for (double d : observed) { if (d < 0.d) { throw new NotPositiveException(d); } } double sumObserved = 0.d; double sumExpected = 0.d; for (int i = 0; i < observed.length; i++) { sumObserved += observed[i]; sumExpected += expected[i]; } double ratio = 1.d; boolean rescale = false; if (FastMath.abs(sumObserved - sumExpected) > 10e-6) { ratio = sumObserved / sumExpected; rescale = true; } double sumSq = 0.d; for (int i = 0; i < observed.length; i++) { if (rescale) { final double dev = observed[i] - ratio * expected[i]; sumSq += dev * dev / (ratio * expected[i]); } else { final double dev = observed[i] - expected[i]; sumSq += dev * dev / expected[i]; } } return sumSq; } /** * @param observed means non-negative vector * @param expected means positive vector * @return p value */ public static double chiSquareTest(@Nonnull final double[] observed, @Nonnull final double[] expected) { final ChiSquaredDistribution distribution = new ChiSquaredDistribution( expected.length - 1.d); return 1.d - distribution.cumulativeProbability(chiSquare(observed, expected)); } /** * This method offers effective calculation for multiple entries rather than calculation * individually * * @param observeds means non-negative matrix * @param expecteds means positive matrix * @return (chi2 value[], p value[]) */ public static Map.Entry<double[], double[]> chiSquare(@Nonnull final double[][] observeds, @Nonnull final double[][] expecteds) { Preconditions.checkArgument(observeds.length == expecteds.length); final int len = expecteds.length; final int lenOfEach = expecteds[0].length; final ChiSquaredDistribution distribution = new ChiSquaredDistribution(lenOfEach - 1.d); final double[] chi2s = new double[len]; final double[] ps = new double[len]; for (int i = 0; i < len; i++) { chi2s[i] = chiSquare(observeds[i], expecteds[i]); ps[i] = 1.d - distribution.cumulativeProbability(chi2s[i]); } return new AbstractMap.SimpleEntry<double[], double[]>(chi2s, ps); } }