package ij.measure; import ij.*; import ij.gui.*; import ij.macro.*; import ij.util.Tools; /** Curve fitting class based on the Simplex method described * in the article "Fitting Curves to Data" in the May 1984 * issue of Byte magazine, pages 340-362. * * 2001/02/14: Modified to handle a gamma variate curve. * Uses altered Simplex method based on method in "Numerical Recipes in C". * This method tends to converge closer in less iterations. * Has the option to restart the simplex at the initial best solution in * case it is "stuck" in a local minimum (by default, restarted twice). Also includes * settings dialog option for user control over simplex parameters and functions to * evaluate the goodness-of-fit. The results can be easily reported with the * getResultString() method. * Kieran Holland (holki659 at student.otago.ac.nz) * * 2008/01/21: Modified to do Gaussian fitting by Stefan Woerz (s.woerz at dkfz.de). * */ public class CurveFitter { public static final int STRAIGHT_LINE=0,POLY2=1,POLY3=2,POLY4=3, EXPONENTIAL=4,POWER=5,LOG=6,RODBARD=7,GAMMA_VARIATE=8, LOG2=9, RODBARD2=10, EXP_WITH_OFFSET=11, GAUSSIAN=12, EXP_RECOVERY=13; private static final int CUSTOM = 20; public static final int IterFactor = 500; public static final String[] fitList = {"Straight Line","2nd Degree Polynomial", "3rd Degree Polynomial", "4th Degree Polynomial","Exponential","Power", "Log","Rodbard", "Gamma Variate", "y = a+b*ln(x-c)","Rodbard (NIH Image)", "Exponential with Offset","Gaussian", "Exponential Recovery"}; // fList and doFit() must also be updated public static final String[] fList = {"y = a+bx","y = a+bx+cx^2", "y = a+bx+cx^2+dx^3", "y = a+bx+cx^2+dx^3+ex^4","y = a*exp(bx)","y = ax^b", "y = a*ln(bx)", "y = d+(a-d)/(1+(x/c)^b)", "y = a*(x-b)^c*exp(-(x-b)/d)", "y = a+b*ln(x-c)", "y = d+(a-d)/(1+(x/c)^b)", "y = a*exp(-bx) + c", "y = a + (b-a)*exp(-(x-c)*(x-c)/(2*d*d))", "y=a*(1-exp(-b*x)) + c"}; private static final double alpha = -1.0; // reflection coefficient private static final double beta = 0.5; // contraction coefficient private static final double gamma = 2.0; // expansion coefficient private static final double root2 = 1.414214; // square root of 2 private int fit; // Number of curve type to fit private double[] xData, yData; // x,y data to fit private int numPoints; // number of data points private int numParams; // number of parametres private int numVertices; // numParams+1 (includes sumLocalResiduaalsSqrd) private int worst; // worst current parametre estimates private int nextWorst; // 2nd worst current parametre estimates private int best; // best current parametre estimates private double[][] simp; // the simplex (the last element of the array at each vertice is the sum of the square of the residuals) private double[] next; // new vertex to be tested private int numIter; // number of iterations so far private int maxIter; // maximum number of iterations per restart private int restarts; // number of times to restart simplex after first soln. private static int defaultRestarts = 2; // default number of restarts private int nRestarts; // the number of restarts that occurred private static double maxError = 1e-10; // maximum error tolerance private double[] initialParams; // user specified initial parameters private long time; //elapsed time in ms private String customFormula; private int customParamCount; private Interpreter macro; private double[] initialValues; /** Construct a new CurveFitter. */ public CurveFitter (double[] xData, double[] yData) { this.xData = xData; this.yData = yData; numPoints = xData.length; } /** Perform curve fitting with the simplex method * doFit(fitType) just does the fit * doFit(fitType, true) pops up a dialog allowing control over simplex parameters * alpha is reflection coefficient (-1) * beta is contraction coefficient (0.5) * gamma is expansion coefficient (2) */ public void doFit(int fitType) { doFit(fitType, false); } public void doFit(int fitType, boolean showSettings) { if (fitType<STRAIGHT_LINE || (fitType>EXP_RECOVERY&&fitType!=CUSTOM)) throw new IllegalArgumentException("Invalid fit type"); int saveFitType = fitType; if (fitType==RODBARD2) { double[] temp; temp = xData; xData = yData; yData = temp; fitType = RODBARD; } fit = fitType; initialize(); if (initialParams!=null) { for (int i=0; i<numParams; i++) simp[0][i] = initialParams[i]; initialParams = null; } if (showSettings) settingsDialog(); long startTime = System.currentTimeMillis(); restart(0); numIter = 0; boolean done = false; double[] center = new double[numParams]; // mean of simplex vertices while (!done) { numIter++; for (int i = 0; i < numParams; i++) center[i] = 0.0; // get mean "center" of vertices, excluding worst for (int i = 0; i < numVertices; i++) if (i != worst) for (int j = 0; j < numParams; j++) center[j] += simp[i][j]; // Reflect worst vertex through centre for (int i = 0; i < numParams; i++) { center[i] /= numParams; next[i] = center[i] + alpha*(simp[worst][i] - center[i]); } sumResiduals(next); // if it's better than the best... if (next[numParams] <= simp[best][numParams]) { newVertex(); // try expanding it for (int i = 0; i < numParams; i++) next[i] = center[i] + gamma * (simp[worst][i] - center[i]); sumResiduals(next); // if this is even better, keep it if (next[numParams] <= simp[worst][numParams]) newVertex(); } // else if better than the 2nd worst keep it... else if (next[numParams] <= simp[nextWorst][numParams]) { newVertex(); } // else try to make positive contraction of the worst else { for (int i = 0; i < numParams; i++) next[i] = center[i] + beta*(simp[worst][i] - center[i]); sumResiduals(next); // if this is better than the second worst, keep it. if (next[numParams] <= simp[nextWorst][numParams]) { newVertex(); } // if all else fails, contract simplex in on best else { for (int i = 0; i < numVertices; i++) { if (i != best) { for (int j = 0; j < numVertices; j++) simp[i][j] = beta*(simp[i][j]+simp[best][j]); sumResiduals(simp[i]); } } } } order(); double rtol = 2 * Math.abs(simp[best][numParams] - simp[worst][numParams]) / (Math.abs(simp[best][numParams]) + Math.abs(simp[worst][numParams]) + 0.0000000001); if (numIter >= maxIter) done = true; else if (rtol < maxError) { restarts--; if (restarts < 0) done = true; else restart(best); } } fit = saveFitType; time = System.currentTimeMillis()-startTime; } public int doCustomFit(String equation, double[] initialValues, boolean showSettings) { customFormula = null; customParamCount = 0; Program pgm = (new Tokenizer()).tokenize(equation); if (!pgm.hasWord("y")) return 0; if (!pgm.hasWord("x")) return 0; String[] params = {"a","b","c","d","e","f"}; for (int i=0; i<params.length; i++) { if (pgm.hasWord(params[i])) customParamCount++; } if (customParamCount==0) return 0; customFormula = equation; String code = "var x, a, b, c, d, e, f;\n"+ "function dummy() {}\n"+ equation+";\n"; // starts at program counter location 21 macro = new Interpreter(); macro.run(code, null); if (macro.wasError()) return 0; this.initialValues = initialValues; doFit(CUSTOM, showSettings); return customParamCount; } /** Pop up a dialog allowing control over simplex starting parameters */ private void settingsDialog() { GenericDialog gd = new GenericDialog("Simplex Fitting Options"); gd.addMessage("Function name: " + getName() + "\n" + "Formula: " + getFormula()); char pChar = 'a'; for (int i = 0; i < numParams; i++) { gd.addNumericField("Initial "+(new Character(pChar)).toString()+":", simp[0][i], 2); pChar++; } gd.addNumericField("Maximum iterations:", maxIter, 0); gd.addNumericField("Number of restarts:", defaultRestarts, 0); gd.addNumericField("Error tolerance [1*10^(-x)]:", -(Math.log(maxError)/Math.log(10)), 0); gd.showDialog(); if (gd.wasCanceled() || gd.invalidNumber()) { IJ.error("Parameter setting canceled.\nUsing default parameters."); } // Parametres: for (int i = 0; i < numParams; i++) { simp[0][i] = gd.getNextNumber(); } maxIter = (int) gd.getNextNumber(); defaultRestarts = restarts = (int) gd.getNextNumber(); maxError = Math.pow(10.0, -gd.getNextNumber()); } /** Initialise the simplex */ void initialize() { // Calculate some things that might be useful for predicting parametres numParams = getNumParams(); numVertices = numParams + 1; // need 1 more vertice than parametres, simp = new double[numVertices][numVertices]; next = new double[numVertices]; double firstx = xData[0]; double firsty = yData[0]; double lastx = xData[numPoints-1]; double lasty = yData[numPoints-1]; double xmean = (firstx+lastx)/2.0; double ymean = (firsty+lasty)/2.0; double miny=firsty, maxy=firsty; if (fit==GAUSSIAN) { for (int i=1; i<numPoints; i++) { if (yData[i]>maxy) maxy = yData[i]; if (yData[i]<miny) miny = yData[i]; } } double slope; if ((lastx - firstx) != 0.0) slope = (lasty - firsty)/(lastx - firstx); else slope = 1.0; double yintercept = firsty - slope * firstx; if (maxIter==0) maxIter = IterFactor * numParams * numParams; // Where does this estimate come from? restarts = defaultRestarts; nRestarts = 0; switch (fit) { case STRAIGHT_LINE: simp[0][0] = yintercept; simp[0][1] = slope; break; case POLY2: simp[0][0] = yintercept; simp[0][1] = slope; simp[0][2] = 0.0; break; case POLY3: simp[0][0] = yintercept; simp[0][1] = slope; simp[0][2] = 0.0; simp[0][3] = 0.0; break; case POLY4: simp[0][0] = yintercept; simp[0][1] = slope; simp[0][2] = 0.0; simp[0][3] = 0.0; simp[0][4] = 0.0; break; case EXPONENTIAL: simp[0][0] = 0.1; simp[0][1] = 0.01; break; case EXP_WITH_OFFSET: simp[0][0] = 0.1; simp[0][1] = 0.01; simp[0][2] = 0.1; break; case EXP_RECOVERY: simp[0][0] = 0.1; simp[0][1] = 0.01; simp[0][2] = 0.1; break; case GAUSSIAN: simp[0][0] = miny; // a0 simp[0][1] = maxy; // a1 simp[0][2] = xmean; // x0 simp[0][3] = 3.0; // sigma break; case POWER: simp[0][0] = 0.0; simp[0][1] = 1.0; break; case LOG: simp[0][0] = 1.0; simp[0][1] = 1.0; break; case RODBARD: case RODBARD2: simp[0][0] = firsty; simp[0][1] = 1.0; simp[0][2] = xmean; simp[0][3] = lasty; break; case GAMMA_VARIATE: // First guesses based on following observations: // t0 [b] = time of first rise in gamma curve - so use the user specified first limit // tm = t0 + a*B [c*d] where tm is the time of the peak of the curve // therefore an estimate for a and B is sqrt(tm-t0) // K [a] can now be calculated from these estimates simp[0][0] = firstx; double ab = xData[getMax(yData)] - firstx; simp[0][2] = Math.sqrt(ab); simp[0][3] = Math.sqrt(ab); simp[0][1] = yData[getMax(yData)] / (Math.pow(ab, simp[0][2]) * Math.exp(-ab/simp[0][3])); break; case LOG2: simp[0][0] = 0.5; simp[0][1] = 0.05; simp[0][2] = 0.0; break; case CUSTOM: if (macro==null) throw new IllegalArgumentException("No custom formula!"); if (initialValues!=null && initialValues.length>=numParams) { for (int i=0; i<numParams; i++) simp[0][i] = initialValues[i]; } else { for (int i=0; i<numParams; i++) simp[0][i] = 1.0; } break; } } /** Restart the simplex at the nth vertex */ void restart(int n) { // Copy nth vertice of simplex to first vertice for (int i = 0; i < numParams; i++) { simp[0][i] = simp[n][i]; } sumResiduals(simp[0]); // Get sum of residuals^2 for first vertex double[] step = new double[numParams]; for (int i = 0; i < numParams; i++) { step[i] = simp[0][i] / 2.0; // Step half the parametre value if (step[i] == 0.0) // We can't have them all the same or we're going nowhere step[i] = 0.01; } // Some kind of factor for generating new vertices double[] p = new double[numParams]; double[] q = new double[numParams]; for (int i = 0; i < numParams; i++) { p[i] = step[i] * (Math.sqrt(numVertices) + numParams - 1.0)/(numParams * root2); q[i] = step[i] * (Math.sqrt(numVertices) - 1.0)/(numParams * root2); } // Create the other simplex vertices by modifing previous one. for (int i = 1; i < numVertices; i++) { for (int j = 0; j < numParams; j++) { simp[i][j] = simp[i-1][j] + q[j]; } simp[i][i-1] = simp[i][i-1] + p[i-1]; sumResiduals(simp[i]); } // Initialise current lowest/highest parametre estimates to simplex 1 best = 0; worst = 0; nextWorst = 0; order(); nRestarts++; } // Display simplex [Iteration: s0(p1, p2....), s1(),....] in Log window void showSimplex(int iter) { ij.IJ.log("" + iter); for (int i = 0; i < numVertices; i++) { String s = ""; for (int j=0; j < numVertices; j++) s += " "+ ij.IJ.d2s(simp[i][j], 6); ij.IJ.log(s); } } /** Get number of parameters for current fit formula */ public int getNumParams() { switch (fit) { case STRAIGHT_LINE: return 2; case POLY2: return 3; case POLY3: return 4; case POLY4: return 5; case EXPONENTIAL: return 2; case POWER: return 2; case LOG: return 2; case RODBARD: case RODBARD2: return 4; case GAMMA_VARIATE: return 4; case LOG2: return 3; case EXP_WITH_OFFSET: return 3; case GAUSSIAN: return 4; case EXP_RECOVERY: return 3; case CUSTOM: return customParamCount; } return 0; } /** Returns formula value for parameters 'p' at 'x' */ public double f(double[] p, double x) { if (fit==CUSTOM) { macro.setVariable("x", x); macro.setVariable("a", p[0]); if (customParamCount>1) macro.setVariable("b", p[1]); if (customParamCount>2) macro.setVariable("c", p[2]); if (customParamCount>3) macro.setVariable("d", p[3]); if (customParamCount>4) macro.setVariable("e", p[4]); if (customParamCount>5) macro.setVariable("f", p[5]); macro.run(21); return macro.getVariable("y"); } else return f(fit, p, x); } /** Returns 'fit' formula value for parameters "p" at "x" */ public static double f(int fit, double[] p, double x) { double y; switch (fit) { case STRAIGHT_LINE: return p[0] + p[1]*x; case POLY2: return p[0] + p[1]*x + p[2]* x*x; case POLY3: return p[0] + p[1]*x + p[2]*x*x + p[3]*x*x*x; case POLY4: return p[0] + p[1]*x + p[2]*x*x + p[3]*x*x*x + p[4]*x*x*x*x; case EXPONENTIAL: return p[0]*Math.exp(p[1]*x); case EXP_WITH_OFFSET: return p[0]*Math.exp(p[1]*x*-1)+p[2]; case EXP_RECOVERY: return p[0]*(1-Math.exp(-p[1]*x))+p[2]; case GAUSSIAN: return p[0]+(p[1]-p[0])*Math.exp(-(x-p[2])*(x-p[2])/(2.0*p[3]*p[3])); case POWER: if (x == 0.0) return 0.0; else return p[0]*Math.exp(p[1]*Math.log(x)); //y=ax^b case LOG: if (x == 0.0) x = 0.5; return p[0]*Math.log(p[1]*x); case RODBARD: double ex; if (x == 0.0) ex = 0.0; else ex = Math.exp(Math.log(x/p[2])*p[1]); y = p[0]-p[3]; y = y/(1.0+ex); return y+p[3]; case GAMMA_VARIATE: if (p[0] >= x) return 0.0; if (p[1] <= 0) return -100000.0; if (p[2] <= 0) return -100000.0; if (p[3] <= 0) return -100000.0; double pw = Math.pow((x - p[0]), p[2]); double e = Math.exp((-(x - p[0]))/p[3]); return p[1]*pw*e; case LOG2: double tmp = x-p[2]; if (tmp<0.001) tmp = 0.001; return p[0]+p[1]*Math.log(tmp); case RODBARD2: if (x<=p[0]) y = 0.0; else { y = (p[0]-x)/(x-p[3]); y = Math.exp(Math.log(y)*(1.0/p[1])); //y=y**(1/b) y = y*p[2]; } return y; default: return 0.0; } } /** Get the set of parameter values from the best corner of the simplex */ public double[] getParams() { order(); return simp[best]; } /** Returns residuals array ie. differences between data and curve. */ public double[] getResiduals() { int saveFit = fit; if (fit==RODBARD2) fit=RODBARD; double[] params = getParams(); double[] residuals = new double[numPoints]; if (fit==CUSTOM) { for (int i=0; i<numPoints; i++) residuals[i] = yData[i] - f(params, xData[i]); } else { for (int i=0; i<numPoints; i++) residuals[i] = yData[i] - f(fit, params, xData[i]); } fit = saveFit; return residuals; } /* Last "parametre" at each vertex of simplex is sum of residuals * for the curve described by that vertex */ public double getSumResidualsSqr() { double sumResidualsSqr = (getParams())[getNumParams()]; return sumResidualsSqr; } /** Returns the standard deviation of the residuals. */ public double getSD() { double[] residuals = getResiduals(); int n = residuals.length; double sum=0.0, sum2=0.0; for (int i=0; i<n; i++) { sum += residuals[i]; sum2 += residuals[i]*residuals[i]; } double stdDev = (n*sum2-sum*sum)/n; return Math.sqrt(stdDev/(n-1.0)); } /** Returns R^2, where 1.0 is best. <pre> r^2 = 1 - SSE/SSD where: SSE = sum of the squares of the errors SSD = sum of the squares of the deviations about the mean. </pre> */ public double getRSquared() { double sumY = 0.0; for (int i=0; i<numPoints; i++) sumY += yData[i]; double mean = sumY/numPoints; double sumMeanDiffSqr = 0.0; for (int i=0; i<numPoints; i++) sumMeanDiffSqr += sqr(yData[i]-mean); double rSquared = 0.0; if (sumMeanDiffSqr>0.0) rSquared = 1.0 - getSumResidualsSqr()/sumMeanDiffSqr; return rSquared; } /** Get a measure of "goodness of fit" where 1.0 is best. */ public double getFitGoodness() { double sumY = 0.0; for (int i = 0; i < numPoints; i++) sumY += yData[i]; double mean = sumY / numPoints; double sumMeanDiffSqr = 0.0; int degreesOfFreedom = numPoints - getNumParams(); double fitGoodness = 0.0; for (int i = 0; i < numPoints; i++) { sumMeanDiffSqr += sqr(yData[i] - mean); } if (sumMeanDiffSqr > 0.0 && degreesOfFreedom != 0) fitGoodness = 1.0 - (getSumResidualsSqr() / degreesOfFreedom) * ((numPoints) / sumMeanDiffSqr); return fitGoodness; } /** Get a string description of the curve fitting results * for easy output. */ public String getResultString() { String results = "\nFormula: " + getFormula() + "\nTime: "+time+"ms" + "\nNumber of iterations: " + getIterations() + " (" + getMaxIterations() + ")" + "\nNumber of restarts: " + (nRestarts-1) + " (" + defaultRestarts + ")" + "\nSum of residuals squared: " + IJ.d2s(getSumResidualsSqr(),4) + "\nStandard deviation: " + IJ.d2s(getSD(),4) + "\nR^2: " + IJ.d2s(getRSquared(),4) + "\nParameters:"; char pChar = 'a'; double[] pVal = getParams(); for (int i = 0; i < numParams; i++) { results += ("\n " + pChar + " = " + IJ.d2s(pVal[i],4)); pChar++; } return results; } double sqr(double d) { return d * d; } /** Adds sum of square of residuals to end of array of parameters */ void sumResiduals (double[] x) { x[numParams] = 0.0; if (fit==CUSTOM) { for (int i=0; i<numPoints; i++) x[numParams] = x[numParams] + sqr(f(x,xData[i])-yData[i]); } else { for (int i=0; i<numPoints; i++) x[numParams] = x[numParams] + sqr(f(fit,x,xData[i])-yData[i]); } } /** Keep the "next" vertex */ void newVertex() { for (int i = 0; i < numVertices; i++) simp[worst][i] = next[i]; } /** Find the worst, nextWorst and best current set of parameter estimates */ void order() { for (int i = 0; i < numVertices; i++) { if (simp[i][numParams] < simp[best][numParams]) best = i; if (simp[i][numParams] > simp[worst][numParams]) worst = i; } nextWorst = best; for (int i = 0; i < numVertices; i++) { if (i != worst) { if (simp[i][numParams] > simp[nextWorst][numParams]) nextWorst = i; } } // IJ.log("B: " + simp[best][numParams] + " 2ndW: " + simp[nextWorst][numParams] + " W: " + simp[worst][numParams]); } /** Get number of iterations performed */ public int getIterations() { return numIter; } /** Get maximum number of iterations allowed */ public int getMaxIterations() { return maxIter; } /** Set maximum number of iterations allowed */ public void setMaxIterations(int x) { maxIter = x; } /** Get number of simplex restarts to do */ public int getRestarts() { return defaultRestarts; } /** Set number of simplex restarts to do */ public void setRestarts(int n) { defaultRestarts = n; } /** Sets the initial parameters, which override the default initial parameters. */ public void setInitialParameters(double[] params) { initialParams = params; } /** * Gets index of highest value in an array. * * @param Double array. * @return Index of highest value. */ public static int getMax(double[] array) { double max = array[0]; int index = 0; for(int i = 1; i < array.length; i++) { if(max < array[i]) { max = array[i]; index = i; } } return index; } public double[] getXPoints() { return xData; } public double[] getYPoints() { return yData; } public int getFit() { return fit; } public String getName() { if (fit==CUSTOM) return "User-defined"; else return fitList[fit]; } public String getFormula() { if (fit==CUSTOM) return customFormula; else return fList[fit]; } }