/*
* Copyright (c) 2012 Diamond Light Source Ltd.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*/
package uk.ac.diamond.scisoft.analysis.optimize;
/**
* GradientDescent Class
*/
/**
* Basic implementation of a gradient Descent optimiser
* only slight change from normal is that this is damped
* so that it deals with sloppy surfaces more effectively.
*/
public class GradientDescent extends AbstractOptimizer {
/**
* Setup the logging facilities
*/
// private static final Logger logger = LoggerFactory.getLogger(GradientDescent.class);
double qualityFactor = 0.1;
public GradientDescent() {
}
/**
* @param quality
*/
public GradientDescent(double quality) {
setAccuracy(quality);
}
public void setAccuracy(double quality) {
qualityFactor = quality;
}
@Override
void internalOptimize() {
double[] bestParams = optimise(getParameterValues(), qualityFactor);
function.setParameterValues(bestParams);
}
private double diffsize = 0.001;
private double stepsize = 0.1;
/**
* The main optimisation method
*
* @param parameters
* @param finishCriteria
* @return a double array of the parameters for the optimisation
*/
public double[] optimise(double[] parameters,
double finishCriteria) {
double[] solution = parameters;
double[] stepweight = solution.clone();
for(int i =0; i < stepweight.length; i++) {
stepweight[i] = 1.0;
}
double[] oldd = deriv(solution);
while (stepsize > finishCriteria*0.1) {
double[] d = deriv(solution);
// now adjust the stepweights
for(int i = 0; i < stepweight.length; i++) {
if(d[i]*oldd[i] < 0) {
stepweight[i] *= 0.5;
} else {
stepweight[i] *= 1.1;
}
if(stepweight[i] > 1.0) {
stepweight[i] = 1.0;
}
}
oldd = d;
double value = calculateResidual(solution);
double[] test = solution.clone();
for(int i = 0; i < d.length; i++) {
test[i] -= d[i]*stepsize*stepweight[i];
}
double testValue = calculateResidual(test);
if(testValue < value) {
solution = test;
stepsize *= 1.1;
} else {
stepsize *= 0.75;
}
}
return solution;
}
/**
* @param position
* @return the normalised derivative
*/
private double[] deriv(double[] position) {
double[] deriv = position.clone();
double length = 0.0;
boolean stepSizeOk = false;
while (!stepSizeOk) {
stepSizeOk = true;
for(int i = 0; i < deriv.length; i++) {
double[] pos = position.clone();
double[] neg = position.clone();
pos[i] += diffsize;
neg[i] -= diffsize;
double posvalue = calculateResidual(pos);
double negvalue = calculateResidual(neg);
deriv[i] = (posvalue - negvalue)*(2.0*diffsize);
if (deriv[i] == 0) {
stepSizeOk=false;
diffsize *= 1.5;
break;
}
length += deriv[i]*deriv[i];
}
}
length = Math.sqrt(length);
for(int i = 0; i < deriv.length; i++) {
deriv[i] = deriv[i] / length;
}
return deriv;
}
}