package org.geogebra.common.kernel.statistics;
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer;
import org.geogebra.common.kernel.Construction;
import org.geogebra.common.kernel.algos.AlgoElement;
import org.geogebra.common.kernel.arithmetic.MyDouble;
import org.geogebra.common.kernel.commands.Commands;
import org.geogebra.common.kernel.geos.GeoElement;
import org.geogebra.common.kernel.geos.GeoFunction;
import org.geogebra.common.kernel.geos.GeoList;
import org.geogebra.common.kernel.geos.GeoPoint;
import org.geogebra.common.kernel.optimization.FitRealFunction;
import org.geogebra.common.util.debug.Log;
/*
GeoGebra - Dynamic Mathematics for Everyone
http://www.geogebra.org
This file is part of GeoGebra.
This program is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by
the Free Software Foundation.
*/
import org.apache.commons.math3.fitting.CurveFitter;
/**
* <pre>
* AlgoFitNL: (NL=NonLinear Curvefit)
* A general curvefit:
* Fit[<List of Points>,<Function>]
* Example:
* f(x)=a+b/(x-c)
* L={A,B,...}
* g(x)=Fit[L,f]
* will give a function
* g(x)=p1+p2/(x-p3)
* where p1, p2 and p3 are calculated to give the least sum of squared errors.
*
* The nonlinear curve-fitting is done with an iteration algortithm, which is not
* guaranteed to work.
* The values of a, b and c are taken as starting points for the iteration algorithm.
* If the iteration does not converge or the number of iterations is getting to large,
* the result is undefined, a signal for the user to try to adjust the starting
* point with the gliders a, b and c.
*
* Uses Levenberg-Marquardt algorithm in org.apache.commons library
*
* ToDo: The gradient in FitRealFunction could be more sophisticated, but the Apache lib is quite robust :-)
* Some tuning of numerical precision both here and in the setup of LM-optimizer
* </pre>
*
* @author Hans-Petter Ulven
* @version 2011-03-15
*/
@SuppressWarnings("deprecation")
public class AlgoFitNL extends AlgoElement implements FitAlgo {
private GeoList pointlist; // input
private GeoFunction inputfunction; // input
private GeoFunction outputfunction; // output
// variables:
private int datasize = 0; // rows in M and Y
private double[] xdata = null;
private double[] ydata = null;
private FitRealFunction prfunction = null; // function for Apache lib
private LevenbergMarquardtOptimizer LMO = new LevenbergMarquardtOptimizer();
private CurveFitter<ParametricUnivariateFunction> curvefitter = new CurveFitter<ParametricUnivariateFunction>(
LMO);
/**
* @param cons
* construction
* @param pointlist
* points
* @param inputfunction
* function with parameters
*/
public AlgoFitNL(Construction cons, GeoList pointlist,
GeoFunction inputfunction) {
super(cons);
this.pointlist = pointlist;
this.inputfunction = inputfunction;
outputfunction = new GeoFunction(cons);
setInputOutput();
compute();
}
@Override
public Commands getClassName() {
return Commands.Fit;
}
@Override
protected void setInputOutput() {
input = new GeoElement[2];
input[0] = pointlist;
input[1] = inputfunction;
setOnlyOutput(outputfunction);
setDependencies();
}
/**
* @return output function
*/
public GeoFunction getFitNL() {
return outputfunction;
}
@Override
public final void compute() {
GeoElement geo1 = null;
GeoElement geo2 = null;
this.datasize = pointlist.size(); // Points in dataset
if (!pointlist.isDefined() || !inputfunction.isDefined()
|| (datasize < 1) // Perhaps a max restriction of functions and
// data?
) // Even if noone would try 500 datapoints and 100 functions...
{
outputfunction.setUndefined();
return;
}
// We are in business...
// Best to also check:
geo1 = pointlist.get(0);
geo2 = inputfunction;
if (!geo2.isGeoFunction() || !geo1.isGeoPoint()) {
outputfunction.setUndefined();
return;
} // if wrong contents in lists
try {
// Get points as x[] and y[] from lists
if (!makeDataArrays()) {
outputfunction.setUndefined();
return;
}
// / --- Solve :-) --- ///
// prfunction makes itself a copy of inputfunction with
// parameters instead of GeoNumerics
prfunction = new FitRealFunction(inputfunction.getFunction());
if (!prfunction.isParametersOK()) {
outputfunction.setUndefined();
return;
}
// very important:
curvefitter.clearObservations();
for (int i = 0; i < datasize; i++) {
curvefitter.addObservedPoint(1.0, xdata[i], ydata[i]);
} // for all datapoints
curvefitter.fit(prfunction, prfunction.getStartValues());
// DEBUG - to be removed:
// int iter = LMO.getIterations();
// if (iter > 200) {
// Log.debug("More than 200 iterations...");
// }
outputfunction.setFunction(prfunction.getFunction());
outputfunction.setDefined(true);
} catch (Throwable t) {
outputfunction.setUndefined();
Log.debug(t.getMessage());
}
}
// Get info from lists into matrixes and functionarray
private final boolean makeDataArrays() {
GeoElement geo = null;
GeoPoint point = null;
datasize = pointlist.size();
xdata = new double[datasize];
ydata = new double[datasize];
// Make array of datapoints
for (int i = 0; i < datasize; i++) {
geo = pointlist.get(i);
if (!geo.isGeoPoint()) {
// throw (new Exception("Not points in function list..."));
return false;
} // if not point
point = (GeoPoint) geo;
xdata[i] = point.getX();
ydata[i] = point.getY();
}
return true;
}
@Override
public double[] getCoeffs() {
MyDouble[] coeffs = prfunction.getCoeffs();
double[] ret = new double[coeffs.length];
for (int i = 0; i < coeffs.length; i++) {
ret[i] = coeffs[i].getDouble();
}
return ret;
}
}