package cz.cuni.lf1.lge.ThunderSTORM.calibration; import java.util.Arrays; import java.util.Comparator; import org.apache.commons.math3.analysis.ParametricUnivariateFunction; import org.apache.commons.math3.fitting.CurveFitter; import org.apache.commons.math3.fitting.WeightedObservedPoint; import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer; import org.apache.commons.math3.exception.TooManyEvaluationsException; import org.apache.commons.math3.optim.SimplePointChecker; public class IterativeFitting { private double inlierFraction; private int maxIterations; public IterativeFitting(int maxIterations, double inlierFraction) { this.maxIterations = maxIterations; this.inlierFraction = inlierFraction; } private double[] fit(double[] x, double[] y, ParametricUnivariateFunction function, double[] initialParams, int maxIter) { int numberOfInliers = (int) (inlierFraction * x.length); CurveFitter<ParametricUnivariateFunction> fitter = new CurveFitter<ParametricUnivariateFunction>(new LevenbergMarquardtOptimizer(new SimplePointChecker(10e-3, 10e-3, maxIter))); WeightedObservedPoint[] points = new WeightedObservedPoint[x.length]; //fit using all points for(int i = 0; i < x.length; i++) { points[i] = new WeightedObservedPoint(1, x[i], y[i]); fitter.addObservedPoint(points[i]); } double[] parameters = fitter.fit(maxIter, function, initialParams); double[] residuals = new double[x.length]; for(int it = 0; it < maxIterations; it++) { //fit again with only inlier points (points with smallest error) computeResiduals(parameters, function, x, y, residuals); int[] inliers = findIndicesOfSmallestN(residuals, numberOfInliers); // System.out.println("residuals : " + Arrays.toString(residuals)); // System.out.println("inliers : " + Arrays.toString(inliers)); fitter.clearObservations(); for(int i : inliers) { fitter.addObservedPoint(points[i]); } try { parameters = fitter.fit(maxIter, function, parameters); } catch(TooManyEvaluationsException ex) { if(it > 0) { return parameters; } else { throw ex; } } } return parameters; } public DefocusFunction fitParams(DefocusFunction defocusModel, double[] x, double[] y, int maxIter) { double min = y[0]; int minIndex = 0; for(int i = 0; i < x.length; i++) { if(y[i] < min) { min = y[i]; minIndex = i; } } return defocusModel.getNewInstance(defocusModel.transformParams(fit(x, y, defocusModel.getFittingFunction(), defocusModel.transformParamsInverse(defocusModel.getInitialParams(x[minIndex], min)), maxIter)), false); } private void computeResiduals(double[] parameters, ParametricUnivariateFunction function, double[] x, double[] y, double[] residualArray) { for(int i = 0; i < x.length; i++) { double dx = y[i] - function.value(x[i], parameters); residualArray[i] = dx * dx; } } protected static int[] findIndicesOfSmallestN(final double[] values, int n) { if (values.length < n) { throw new IllegalArgumentException("`values` must not have less than `n` elements!"); } Integer[] indices = new Integer[values.length]; for(int i = 0; i < values.length; i++) indices[i] = i; Arrays.sort(indices, new Comparator<Integer>() { @Override public int compare(Integer o1, Integer o2) { return Double.compare(values[o1], values[o2]); } }); int[] result = new int[n]; for(int i = 0; i < result.length; i++) result[i] = indices[i]; return result; } }