package gdsc.smlm.fitting; import java.util.Arrays; import org.apache.commons.math3.analysis.MultivariateFunction; import org.apache.commons.math3.analysis.MultivariateMatrixFunction; import org.apache.commons.math3.analysis.MultivariateVectorFunction; import org.apache.commons.math3.distribution.BinomialDistribution; import org.apache.commons.math3.exception.ConvergenceException; import org.apache.commons.math3.exception.TooManyEvaluationsException; import org.apache.commons.math3.exception.TooManyIterationsException; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum; import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem; import org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer; import org.apache.commons.math3.linear.DiagonalMatrix; import org.apache.commons.math3.optim.ConvergenceChecker; import org.apache.commons.math3.optim.InitialGuess; import org.apache.commons.math3.optim.MaxEval; import org.apache.commons.math3.optim.MaxIter; import org.apache.commons.math3.optim.OptimizationData; import org.apache.commons.math3.optim.PointValuePair; import org.apache.commons.math3.optim.SimpleBounds; import org.apache.commons.math3.optim.SimpleValueChecker; import org.apache.commons.math3.optim.nonlinear.scalar.GoalType; import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction; import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer; import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.Well19937c; import org.apache.commons.math3.util.CombinatoricsUtils; import org.apache.commons.math3.util.FastMath; /*----------------------------------------------------------------------------- * GDSC SMLM Software * * Copyright (C) 2013 Alex Herbert * Genome Damage and Stability Centre * University of Sussex, UK * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 3 of the License, or * (at your option) any later version. *---------------------------------------------------------------------------*/ import gdsc.core.ij.Utils; import gdsc.core.logging.Logger; import gdsc.core.utils.Maths; /** * Fit a binomial distribution to a histogram */ public class BinomialFitter { private Logger logger = null; private boolean maximumLikelihood = true; private int fitRestarts = 5; public BinomialFitter() { } /** * @param logger * Logging interface to report progress messages */ public BinomialFitter(Logger logger) { this.logger = logger; } /** * Create a histogram from n=0 to n=N as a normalised probability. * N = p.length - 1; * * @param data * @param cumulative * Build a cumulative histogram * @return The cumulative histogram (p) * @throws IllegalArgumentException * If any of the input data values are negative */ public static double[] getHistogram(int[] data, boolean cumulative) { double[] newData = new double[data.length]; for (int i = 0; i < data.length; i++) { if (data[i] < 0) throw new IllegalArgumentException("Input data must be positive"); newData[i] = data[i]; } return calculateHistogram(newData, cumulative); } /** * Create a histogram from n=0 to n=N as a normalised probability. * N = p.length - 1; * * @param data * @param cumulative * Build a cumulative histogram * @return The cumulative histogram (p) * @throws IllegalArgumentException * If any of the input data values are negative or non-integer */ public static double[] getHistogram(double[] data, boolean cumulative) { for (int i = 0; i < data.length; i++) { if (data[i] < 0) throw new IllegalArgumentException("Input data must be positive"); if ((int) data[i] != data[i]) throw new IllegalArgumentException("Input data must be integers"); } return calculateHistogram(data, cumulative); } /** * Create a histogram from n=0 to n=N as a normalised probability. * N = p.length - 1; * * @param data * @param cumulative * Build a cumulative histogram * @return The histogram (p) */ private static double[] calculateHistogram(double[] data, boolean cumulative) { double[][] histogram = Maths.cumulativeHistogram(data, true); if (histogram[0].length == 0) return new double[] { 1 }; // Pad to include all values double[] nValues = histogram[0]; double[] pValues = histogram[1]; int N = (int) nValues[nValues.length - 1]; double[] p = new double[N + 1]; // Pad the histogram out for any missing values between 0 and N for (int i = 1; i < nValues.length; i++) { int j = (int) nValues[i - 1]; int k = (int) nValues[i]; for (int ii = j; ii < k; ii++) p[ii] = pValues[i - 1]; } p[N] = pValues[pValues.length - 1]; // We need the original histogram, not the cumulative histogram if (!cumulative) { for (int i = p.length; i-- > 1;) { p[i] -= p[i - 1]; } } return p; } /** * Fit the binomial distribution (n,p) to the input data. Performs fitting assuming a fixed n value and attempts to * optimise p. All n from minN to maxN are evaluated. If maxN is zero then all possible n from minN are evaluated * until the fit is worse. * * @param data * The input data (all value must be positive) * @param minN * The minimum n to evaluate * @param maxN * The maximum n to evaluate. Set to zero to evaluate all possible values. * @param zeroTruncated * True if the model should ignore n=0 (zero-truncated binomial) * @return The best fit (n, p) * @throws IllegalArgumentException * If any of the input data values are negative */ public double[] fitBinomial(int[] data, int minN, int maxN, boolean zeroTruncated) { double[] histogram = getHistogram(data, false); final double initialSS = Double.POSITIVE_INFINITY; double bestSS = initialSS; double[] parameters = null; int worse = 0; int N = (int) histogram.length - 1; if (minN < 1) minN = 1; if (maxN > 0) { if (N > maxN) { // Limit the number fitted to maximum N = maxN; } else if (N < maxN) { // Expand the histogram to the maximum histogram = Arrays.copyOf(histogram, maxN + 1); N = maxN; } } if (minN > N) minN = N; final double mean = getMean(histogram); String name = (zeroTruncated) ? "Zero-truncated Binomial distribution" : "Binomial distribution"; log("Mean cluster size = %s", Utils.rounded(mean)); log("Fitting cumulative " + name); // Since varying the N should be done in integer steps do this // for n=1,2,3,... until the SS peaks then falls off (is worse than the best // score several times in succession) for (int n = minN; n <= N; n++) { PointValuePair solution = fitBinomial(histogram, mean, n, zeroTruncated); if (solution == null) continue; double p = solution.getPointRef()[0]; log("Fitted %s : N=%d, p=%s. SS=%g", name, n, Utils.rounded(p), solution.getValue()); if (bestSS > solution.getValue()) { bestSS = solution.getValue(); parameters = new double[] { n, p }; worse = 0; } else if (bestSS != initialSS) { if (++worse >= 3) break; } } return parameters; } /** * Fit the binomial distribution (n,p) to the cumulative histogram. Performs fitting assuming a fixed n value and * attempts to optimise p. * * @param histogram * The input histogram * @param n * The n to evaluate * @param zeroTruncated * True if the model should ignore n=0 (zero-truncated binomial) * @return The best fit (n, p) * @throws IllegalArgumentException * If any of the input data values are negative */ public PointValuePair fitBinomial(double[] histogram, int n, boolean zeroTruncated) { return fitBinomial(histogram, Double.NaN, n, zeroTruncated); } /** * Fit the binomial distribution (n,p) to the cumulative histogram. Performs fitting assuming a fixed n value and * attempts to optimise p. * * @param histogram * The input histogram * @param mean * The histogram mean (used to estimate p). Calculated if NaN. * @param n * The n to evaluate * @param zeroTruncated * True if the model should ignore n=0 (zero-truncated binomial) * @return The best fit (n, p) * @throws IllegalArgumentException * If any of the input data values are negative * @throws IllegalArgumentException * If any fitting a zero truncated binomial and there are no values above zero */ public PointValuePair fitBinomial(double[] histogram, double mean, int n, boolean zeroTruncated) { if (Double.isNaN(mean)) mean = getMean(histogram); if (zeroTruncated && histogram[0] > 0) { log("Fitting zero-truncated histogram but there are zero values - Renormalising to ignore zero"); double cumul = 0; for (int i = 1; i < histogram.length; i++) cumul += histogram[i]; if (cumul == 0) throw new IllegalArgumentException("Fitting zero-truncated histogram but there are no non-zero values"); histogram[0] = 0; for (int i = 1; i < histogram.length; i++) histogram[i] /= cumul; } int nFittedPoints = Math.min(histogram.length, n + 1) - ((zeroTruncated) ? 1 : 0); if (nFittedPoints < 1) { log("No points to fit (%d): Histogram.length = %d, n = %d, zero-truncated = %b", nFittedPoints, histogram.length, n, zeroTruncated); return null; } // The model is only fitting the probability p // For a binomial n*p = mean => p = mean/n double[] initialSolution = new double[] { FastMath.min(mean / n, 1) }; // Create the function BinomialModelFunction function = new BinomialModelFunction(histogram, n, zeroTruncated); double[] lB = new double[1]; double[] uB = new double[] { 1 }; SimpleBounds bounds = new SimpleBounds(lB, uB); // Fit // CMAESOptimizer or BOBYQAOptimizer support bounds // CMAESOptimiser based on Matlab code: // https://www.lri.fr/~hansen/cmaes.m // Take the defaults from the Matlab documentation int maxIterations = 2000; double stopFitness = 0; //Double.NEGATIVE_INFINITY; boolean isActiveCMA = true; int diagonalOnly = 0; int checkFeasableCount = 1; RandomGenerator random = new Well19937c(); boolean generateStatistics = false; ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(1e-6, 1e-10); // The sigma determines the search range for the variables. It should be 1/3 of the initial search region. OptimizationData sigma = new CMAESOptimizer.Sigma(new double[] { (uB[0] - lB[0]) / 3 }); OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(2)))); try { PointValuePair solution = null; boolean noRefit = maximumLikelihood; if (n == 1 && zeroTruncated) { // No need to fit solution = new PointValuePair(new double[] { 1 }, 0); noRefit = true; } else { GoalType goalType = (maximumLikelihood) ? GoalType.MAXIMIZE : GoalType.MINIMIZE; // Iteratively fit CMAESOptimizer opt = new CMAESOptimizer(maxIterations, stopFitness, isActiveCMA, diagonalOnly, checkFeasableCount, random, generateStatistics, checker); for (int iteration = 0; iteration <= fitRestarts; iteration++) { try { // Start from the initial solution PointValuePair result = opt.optimize(new InitialGuess(initialSolution), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2)); //System.out.printf("CMAES Iter %d initial = %g (%d)\n", iteration, result.getValue(), // opt.getEvaluations()); if (solution == null || result.getValue() < solution.getValue()) { solution = result; } } catch (TooManyEvaluationsException e) { } catch (TooManyIterationsException e) { } if (solution == null) continue; try { // Also restart from the current optimum PointValuePair result = opt.optimize(new InitialGuess(solution.getPointRef()), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2)); //System.out.printf("CMAES Iter %d restart = %g (%d)\n", iteration, result.getValue(), // opt.getEvaluations()); if (result.getValue() < solution.getValue()) { solution = result; } } catch (TooManyEvaluationsException e) { } catch (TooManyIterationsException e) { } } if (solution == null) return null; } if (noRefit) { // Although we fit the log-likelihood, return the sum-of-squares to allow // comparison across different n double p = solution.getPointRef()[0]; double ss = 0; double[] obs = function.p; double[] exp = function.getP(p); for (int i = 0; i < obs.length; i++) ss += (obs[i] - exp[i]) * (obs[i] - exp[i]); return new PointValuePair(solution.getPointRef(), ss); } // We can do a LVM refit if the number of fitted points is more than 1 else if (nFittedPoints > 1) { // Improve SS fit with a gradient based LVM optimizer LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(); try { final BinomialModelFunctionGradient gradientFunction = new BinomialModelFunctionGradient(histogram, n, zeroTruncated); //@formatter:off LeastSquaresProblem problem = new LeastSquaresBuilder() .maxEvaluations(Integer.MAX_VALUE) .maxIterations(3000) .start(solution.getPointRef()) .target(gradientFunction.p) .weight(new DiagonalMatrix(gradientFunction.getWeights())) .model(gradientFunction, new MultivariateMatrixFunction() { public double[][] value(double[] point) throws IllegalArgumentException { return gradientFunction.jacobian(point); }} ) //.checker (checker) .build(); //@formatter:on Optimum lvmSolution = optimizer.optimize(problem); // Check the pValue is valid since the LVM is not bounded. double p = lvmSolution.getPoint().getEntry(0); if (p <= 1 && p >= 0) { // True if the weights are 1 double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals()); //double ss = 0; //double[] obs = gradientFunction.p; //double[] exp = gradientFunction.value(lvmSolution.getPoint().toArray()); //for (int i = 0; i < obs.length; i++) // ss += (obs[i] - exp[i]) * (obs[i] - exp[i]); if (ss < solution.getValue()) { //log("Re-fitting improved the SS from %s to %s (-%s%%)", // Utils.rounded(solution.getValue(), 4), Utils.rounded(ss, 4), // Utils.rounded(100 * (solution.getValue() - ss) / solution.getValue(), 4)); return new PointValuePair(lvmSolution.getPoint().toArray(), ss); } } } catch (TooManyIterationsException e) { log("Failed to re-fit: Too many iterations: %s", e.getMessage()); } catch (ConvergenceException e) { log("Failed to re-fit: %s", e.getMessage()); } catch (Exception e) { // Ignore this ... } } return solution; } catch (Exception e) { log("Failed to fit Binomial distribution with N=%d : %s", n, e.getMessage()); } return null; } private double getMean(double[] histogram) { double sum = 0; double count = 0; for (int i = 0; i < histogram.length; i++) { sum += histogram[i] * i; count += histogram[i]; } double mean = sum / count; return mean; } /** * Evaluates the cumulative binomial probability distribution. Assumes the * input data is a cumulative histogram from 0 to N in integer increments. */ public class BinomialModel { int trials; double[] p; int startIndex; /** * Create a new Binomial model using the input p-values * * @param p * The observed p-value * @param trials * The number of trials * @param zeroTruncated * Set to true to ignore the x=0 datapoint */ public BinomialModel(double[] p, int trials, boolean zeroTruncated) { this.trials = trials; startIndex = (zeroTruncated) ? 1 : 0; this.p = p; } /** * Get the probability function for the input pValue * * @param pValue * @return */ public double[] getP(double pValue) { BinomialDistribution dist = new BinomialDistribution(trials, pValue); // Optionally ignore x=0 since we cannot see a zero size cluster. // This is done by re-normalising the cumulative probability excluding x=0 // to match the input curve. // // See Zero-truncated (zt) binomial distribution: // http://www.vosesoftware.com/ModelRiskHelp/index.htm#Distributions/Discrete_distributions/Zero-truncated_binomial_distribution.htm // pi = 1 / ( 1 - f(0) ) // Fzt(x) = pi . F(x) double[] p2 = new double[p.length]; for (int i = startIndex; i <= trials; i++) { p2[i] = dist.probability(i); } // Renormalise if necessary if (startIndex == 1) { final double pi = 1.0 / (1.0 - dist.probability(0)); for (int i = 1; i <= trials; i++) { p2[i] *= pi; } } return p2; } } /** * Allow optimisation using Apache Commons Math 3 MultivariateFunction optimisers */ public class BinomialModelFunction extends BinomialModel implements MultivariateFunction { public BinomialModelFunction(double[] p, int trials, boolean zeroTruncated) { super(p, trials, zeroTruncated); } /* * (non-Javadoc) * * @see org.apache.commons.math3.analysis.MultivariateFunction#value(double[]) */ public double value(double[] parameters) { double[] p2 = getP(parameters[0]); if (maximumLikelihood) { // Calculate the log-likelihood double ll = 0; // We cannot produce a likelihood for any n>N int limit = trials + 1; // p.length for (int i = startIndex; i < limit; i++) { // Sum for all observations the probability of the observation. // Use p[i] to indicate the frequency of this observation. ll += p[i] * Math.log(p2[i]); } //System.out.printf("%f => %f\n", parameters[0], ll); return ll; } else { // Calculate the sum of squares double ss = 0; for (int i = startIndex; i < p.length; i++) { final double dx = p[i] - p2[i]; ss += dx * dx; } return ss; } } } /** * Allow optimisation using Apache Commons Math 3 MultivariateFunction optimisers */ public class BinomialModelFunctionGradient extends BinomialModel implements MultivariateVectorFunction { long[] nC; public BinomialModelFunctionGradient(double[] histogram, int trials, boolean zeroTruncated) { super(histogram, trials, zeroTruncated); // We could ignore the first p value as it is always zero: //p = Arrays.copyOfRange(p, 1, p.length); // BUT then we would have to override the getP() method since this has // an offset of 1 and assumes the index of p is X. final int n = trials; nC = new long[n + 1]; for (int k = 0; k <= n; k++) { nC[k] = CombinatoricsUtils.binomialCoefficient(n, k); } } public double[] getWeights() { double[] w = new double[p.length]; Arrays.fill(w, 1); return w; } /* * (non-Javadoc) * * @see org.apache.commons.math3.analysis.MultivariateFunction#value(double[]) */ public double[] value(double[] point) throws IllegalArgumentException { return getP(point[0]); } // Set the delta using the desired fractional accuracy. // See Numerical Recipes, The Art of Scientific Computing (2nd edition) Chapter 5.7 // on numerical derivatives final double delta = Math.pow(1e-6, 1.0 / 3); double[][] jacobian(double[] variables) { // We could do analytical differentiation for the normal binomial: // pmf = nCk * p^k * (1-p)^(n-k) // pmf' = nCk * k*p^(k-1) * (1-p)^(n-k) + // nCk * p^k * (n-k) * (1-p)^(n-k-1) * -1 final double p = variables[0]; double[][] jacobian = new double[this.p.length][1]; // Compute the gradient using analytical differentiation final int n = trials; if (startIndex == 0) { for (int k = 0; k <= n; ++k) { //jacobian[k][0] = nC[k] * k * Math.pow(p, k - 1) * Math.pow(1 - p, n - k) + // nC[k] * Math.pow(p, k) * (n - k) * Math.pow(1 - p, n - k - 1) * -1; // Optimise jacobian[k][0] = nC[k] * (k * Math.pow(p, k - 1) * Math.pow(1 - p, n - k) - Math.pow(p, k) * (n - k) * Math.pow(1 - p, n - k - 1)); } } else { // Account for zero-truncated distribution jacobian[0][0] = 0; // In the zero-truncated Binomial all values are scaled by a factor // pi = 1.0 / (1.0 - dist.probability(0)); // We must apply the product rule with pi as f // (f.g)' = f'.g +f.g' // So far we have only computed g' for the original Binomial //double pi = dist.probability(0); final double p_n = Math.pow(1 - p, n); final double f = 1.0 / (1.0 - nC[0] * p_n); final double ff = -1 / Math.pow(1.0 - nC[0] * p_n, 2) + n * Math.pow(1 - p, n - 1); for (int k = 1; k <= n; ++k) { final double pk = Math.pow(p, k); final double p_n_k = Math.pow(1 - p, n - k); final double g = nC[k] * pk * p_n_k; // Differentiate as above final double gg = nC[k] * (k * Math.pow(p, k - 1) * p_n_k - pk * (n - k) * Math.pow(1 - p, n - k - 1)); jacobian[k][0] = ff * g + f * gg; } } // // Compute the gradients using numerical differentiation // // Set the step h for computing the function around the desired point // final double h = delta * p; // // // Ensure we stay within the 0-1 bounds // final double upperP = Math.min(1, p + h); // final double lowerP = Math.max(0, p - h); // final double diff = upperP - lowerP; // double[] upper = getP(upperP); // double[] lower = getP(lowerP); // // for (int i = startIndex; i <= trials; i++) // { // double g = (upper[i] - lower[i]) / diff; // if (trials > 1) // System.out.printf("(%d,%f)[%d] %f vs %f\n", trials, p, i, jacobian[i][0], g); // jacobian[i][0] = g; // } return jacobian; } } private void log(String format, Object... args) { if (logger != null) logger.info(format, args); } /** * @return True if use maximum likelihood fitting */ public boolean isMaximumLikelihood() { return maximumLikelihood; } /** * @param maximumLikelihood * True if use maximum likelihood fitting */ public void setMaximumLikelihood(boolean maximumLikelihood) { this.maximumLikelihood = maximumLikelihood; } /** * @return the number of restarts for fitting */ public int getFitRestarts() { return fitRestarts; } /** * Since fitting uses a bounded search seeded with random movements, restarting can improve the fit. Control the * number of restarts used fot fitting. * * @param fitRestarts * the number of restarts for fitting */ public void setFitRestarts(int fitRestarts) { this.fitRestarts = Math.max(0, fitRestarts); } }