/*
* (c) Copyright Christian P. Fries, Germany. All rights reserved. Contact: email@christianfries.com.
*
* Created on 12.07.2014
*/
package net.finmath.tests.optimizer;
import java.util.ArrayList;
import org.junit.Assert;
import org.junit.Test;
import net.finmath.optimizer.LevenbergMarquardt;
import net.finmath.optimizer.OptimizerInterface;
import net.finmath.optimizer.SolverException;
/**
* @author Christian Fries
*
*/
public class LevenbergMarquardtTest {
@Test
public void testSmallLinearSystem() throws CloneNotSupportedException, SolverException {
LevenbergMarquardt optimizer = new LevenbergMarquardt() {
// Override your objective function here
@Override
public void setValues(double[] parameters, double[] values) {
values[0] = parameters[0] * 0.0 + parameters[1];
values[1] = parameters[0] * 2.0 + parameters[1];
}
};
// Set solver parameters
optimizer.setInitialParameters(new double[] { 0, 0 });
optimizer.setWeights(new double[] { 1, 1 });
optimizer.setMaxIteration(100);
optimizer.setTargetValues(new double[] { 5, 10 });
optimizer.run();
double[] bestParameters = optimizer.getBestFitParameters();
System.out.println("The solver for problem 1 required " + optimizer.getIterations() + " iterations. Accuracy is " + optimizer.getRootMeanSquaredError() + ". The best fit parameters are:");
for (int i = 0; i < bestParameters.length; i++) System.out.println("\tparameter[" + i + "]: " + bestParameters[i]);
System.out.println();
Assert.assertTrue(Math.abs(bestParameters[0] - 2.5) < 1E-12);
Assert.assertTrue(Math.abs(bestParameters[1] - 5.0) < 1E-12);
/*
* Creating a clone, continuing the search with new target values.
* Note that we do not re-define the setValues method.
*/
OptimizerInterface optimizer2 = optimizer.getCloneWithModifiedTargetValues(new double[] { 5.1, 10.2 }, new double[] { 1, 1 }, true);
optimizer2.run();
double[] bestParameters2 = optimizer2.getBestFitParameters();
System.out.println("The solver for problem 2 required " + optimizer2.getIterations() + " iterations. Accuracy is " + optimizer2.getRootMeanSquaredError() + ". The best fit parameters are:");
for (int i = 0; i < bestParameters2.length; i++) System.out.println("\tparameter[" + i + "]: " + bestParameters2[i]);
System.out.println();
Assert.assertTrue(Math.abs(bestParameters2[0] - 2.55) < 1E-12);
Assert.assertTrue(Math.abs(bestParameters2[1] - 5.10) < 1E-12);
}
@Test
public void testMultiThreaddedOptimizer() throws SolverException {
LevenbergMarquardt optimizer = new LevenbergMarquardt(
new double[] { 0, 0, 0 }, // Initial parameters
new double[] { 5, 10, 2 }, // Target values
100, // Max iterations
10 // Number of threads
) {
// Override your objective function here
@Override
public void setValues(double[] parameters, double[] values) {
values[0] = 1.0 * parameters[0] + 2.0 * parameters[1] + parameters[2] + parameters[0] * parameters[1];
values[1] = 2.0 * parameters[0] + 1.0 * parameters[1] + parameters[2] + parameters[1] * parameters[2];
values[2] = 3.0 * parameters[0] + 0.0 * parameters[1] + parameters[2];
}
};
optimizer.run();
double[] bestParameters = optimizer.getBestFitParameters();
System.out.println("The solver for problem 3 required " + optimizer.getIterations() + " iterations. Accuracy is " + optimizer.getRootMeanSquaredError() + ". The best fit parameters are:");
for (int i = 0; i < bestParameters.length; i++) System.out.println("\tparameter[" + i + "]: " + bestParameters[i]);
double[] values = new double[3];
optimizer.setValues(bestParameters, values);
for (int i = 0; i < bestParameters.length; i++) System.out.println("\tvalue[" + i + "]: " + values[i]);
System.out.println();
Assert.assertTrue(optimizer.getRootMeanSquaredError() < 1E-1);
}
@Test
public void testRosenbrockFunction() throws SolverException {
LevenbergMarquardt optimizer = new LevenbergMarquardt(
new double[] { 0.5, 0.5 }, // Initial parameters
new double[] { 0.0, 0.0 }, // Target values
100, // Max iterations
10 // Number of threads
) {
// Override your objective function here
@Override
public void setValues(double[] parameters, double[] values) {
values[0] = 10.0 * (parameters[1] - parameters[0]*parameters[0]);
values[1] = 1.0 - parameters[0];
}
};
optimizer.run();
double[] bestParameters = optimizer.getBestFitParameters();
System.out.println("The solver for problem 'Rosebrock' required " + optimizer.getIterations() + " iterations. Accuracy is " + optimizer.getRootMeanSquaredError() + ". The best fit parameters are:");
for (int i = 0; i < bestParameters.length; i++) System.out.println("\tparameter[" + i + "]: " + bestParameters[i]);
double[] values = new double[2];
optimizer.setValues(bestParameters, values);
for (int i = 0; i < values.length; i++) System.out.println("\tvalue[" + i + "]: " + values[i]);
System.out.println();
Assert.assertTrue(Math.abs(bestParameters[0] - 1.0) < 1E-10);
Assert.assertTrue(Math.abs(bestParameters[1] - 1.0) < 1E-10);
}
@Test
public void testRosenbrockFunctionWithList() throws SolverException {
ArrayList<Number> initialParams = new ArrayList<Number>();
initialParams.add(0.5);
initialParams.add(0.5);
ArrayList<Number> targetValues = new ArrayList<Number>();
targetValues.add(0.0);
targetValues.add(0.0);
LevenbergMarquardt optimizer = new LevenbergMarquardt(
initialParams, // Initial parameters
targetValues, // Target values
100, // Max iterations
10 // Number of threads
) {
// Override your objective function here
@Override
public void setValues(double[] parameters, double[] values) {
values[0] = 10.0 * (parameters[1] - parameters[0]*parameters[0]);
values[1] = 1.0 - parameters[0];
}
};
optimizer.run();
double[] bestParameters = optimizer.getBestFitParameters();
System.out.println("The solver for problem 'Rosebrock' required " + optimizer.getIterations() + " iterations. Accuracy is " + optimizer.getRootMeanSquaredError() + ". The best fit parameters are:");
for (int i = 0; i < bestParameters.length; i++) System.out.println("\tparameter[" + i + "]: " + bestParameters[i]);
double[] values = new double[2];
optimizer.setValues(bestParameters, values);
for (int i = 0; i < values.length; i++) System.out.println("\tvalue[" + i + "]: " + values[i]);
System.out.println();
Assert.assertTrue(Math.abs(bestParameters[0] - 1.0) < 1E-10);
Assert.assertTrue(Math.abs(bestParameters[1] - 1.0) < 1E-10);
}
}