/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.neural.networks.training.nm;
import org.encog.Encog;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.structure.NetworkCODEC;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;
/**
* The Nelder-Mead method is a commonly used parameter optimization method that
* can be used for neural network training. It typically provides a good error
* rate and is relatively fast.
*
* Nelder-Mead must build a simplex, which is an n*(n+1) matrix of weights. If
* you have a large number of weights, this matrix can quickly overflow memory.
*
* The biggest enhancement that is needed for this trainer is to make use of
* multi-threaded code to evaluate the speed evaluations when training on a
* multi-core.
*
* This implementation is based on the source code provided by John Burkardt
* (http://people.sc.fsu.edu/~jburkardt/)
*
* http://people.sc.fsu.edu/~jburkardt/c_src/asa047/asa047.c
*/
public class NelderMeadTraining extends BasicTraining {
/**
* The network to be trained.
*/
private final BasicNetwork network;
/**
* The best error rate.
*/
private double ynewlo;
/**
* True if the network has converged, and no further training is needed.
*/
private boolean converged = false;
/**
* Used to calculate the centroid.
*/
private final double ccoeff = 0.5;
private double del;
private final double ecoeff = 2.0;
private final double eps = 0.001;
private int ihi;
private int ilo;
private int jcount;
private int l;
private final int nn;
private final double[] p;
private final double[] p2star;
private final double[] pbar;
private final double[] pstar;
private final double rcoeff = 1.0;
private final double rq;
private final double[] y;
private double y2star;
private double ylo;
private double ystar;
private double z;
private final double[] start;
private final double[] trainedWeights;
private final double[] step;
private int konvge;
/**
* Construct a Nelder Mead trainer with a step size of 100.
*
* @param network
* The network to train.
* @param training
* The training set to use.
*/
public NelderMeadTraining(final BasicNetwork network,
final MLDataSet training) {
this(network, training, 100);
}
/**
* Construct a Nelder Mead trainer with a definable step.
*
* @param network
* The network to train.
* @param training
* The training data to use.
* @param stepValue
* The step value. This value defines, to some degree the range
* of different weights that will be tried.
*/
public NelderMeadTraining(final BasicNetwork network,
final MLDataSet training, final double stepValue) {
super(TrainingImplementationType.OnePass);
this.network = network;
setTraining(training);
this.start = NetworkCODEC.networkToArray(network);
this.trainedWeights = NetworkCODEC.networkToArray(network);
final int n = this.start.length;
this.p = new double[n * (n + 1)];
this.pstar = new double[n];
this.p2star = new double[n];
this.pbar = new double[n];
this.y = new double[n + 1];
this.nn = n + 1;
this.del = 1.0;
this.rq = Encog.DEFAULT_DOUBLE_EQUAL * n;
this.step = new double[NetworkCODEC.networkSize(network)];
this.jcount = this.konvge = 500;
EngineArray.fill(this.step, stepValue);
}
/**
* {@inheritDoc}
*/
@Override
public boolean canContinue() {
return false;
}
/**
* Calculate the error for the neural network with a given set of weights.
*
* @param weights
* The weights to use.
* @return The current error.
*/
public double fn(final double[] weights) {
NetworkCODEC.arrayToNetwork(weights, this.network);
return this.network.calculateError(getTraining());
}
/**
* {@inheritDoc}
*/
@Override
public MLMethod getMethod() {
return this.network;
}
/**
* {@inheritDoc}
*/
@Override
public boolean isTrainingDone() {
if (this.converged) {
return true;
} else {
return super.isTrainingDone();
}
}
/**
* {@inheritDoc}
*/
@Override
public void iteration() {
if (this.converged) {
return;
}
final int n = this.start.length;
for (int i = 0; i < n; i++) {
this.p[i + n * n] = this.start[i];
}
this.y[n] = fn(this.start);
for (int j = 0; j < n; j++) {
final double x = this.start[j];
this.start[j] = this.start[j] + this.step[j] * this.del;
for (int i = 0; i < n; i++) {
this.p[i + j * n] = this.start[i];
}
this.y[j] = fn(this.start);
this.start[j] = x;
}
/*
* The simplex construction is complete.
*
* Find highest and lowest Y values. YNEWLO = Y(IHI) indicates the
* vertex of the simplex to be replaced.
*/
this.ylo = this.y[0];
this.ilo = 0;
for (int i = 1; i < this.nn; i++) {
if (this.y[i] < this.ylo) {
this.ylo = this.y[i];
this.ilo = i;
}
}
/*
* Inner loop.
*/
for (;;) {
/*
* if (kcount <= icount) { break; }
*/
this.ynewlo = this.y[0];
this.ihi = 0;
for (int i = 1; i < this.nn; i++) {
if (this.ynewlo < this.y[i]) {
this.ynewlo = this.y[i];
this.ihi = i;
}
}
/*
* Calculate PBAR, the centroid of the simplex vertices excepting
* the vertex with Y value YNEWLO.
*/
for (int i = 0; i < n; i++) {
this.z = 0.0;
for (int j = 0; j < this.nn; j++) {
this.z = this.z + this.p[i + j * n];
}
this.z = this.z - this.p[i + this.ihi * n];
this.pbar[i] = this.z / n;
}
/*
* Reflection through the centroid.
*/
for (int i = 0; i < n; i++) {
this.pstar[i] = this.pbar[i] + this.rcoeff
* (this.pbar[i] - this.p[i + this.ihi * n]);
}
this.ystar = fn(this.pstar);
/*
* Successful reflection, so extension.
*/
if (this.ystar < this.ylo) {
for (int i = 0; i < n; i++) {
this.p2star[i] = this.pbar[i] + this.ecoeff
* (this.pstar[i] - this.pbar[i]);
}
this.y2star = fn(this.p2star);
/*
* Check extension.
*/
if (this.ystar < this.y2star) {
for (int i = 0; i < n; i++) {
this.p[i + this.ihi * n] = this.pstar[i];
}
this.y[this.ihi] = this.ystar;
}
/*
* Retain extension or contraction.
*/
else {
for (int i = 0; i < n; i++) {
this.p[i + this.ihi * n] = this.p2star[i];
}
this.y[this.ihi] = this.y2star;
}
}
/*
* No extension.
*/
else {
this.l = 0;
for (int i = 0; i < this.nn; i++) {
if (this.ystar < this.y[i]) {
this.l = this.l + 1;
}
}
if (1 < this.l) {
for (int i = 0; i < n; i++) {
this.p[i + this.ihi * n] = this.pstar[i];
}
this.y[this.ihi] = this.ystar;
}
/*
* Contraction on the Y(IHI) side of the centroid.
*/
else if (this.l == 0) {
for (int i = 0; i < n; i++) {
this.p2star[i] = this.pbar[i] + this.ccoeff
* (this.p[i + this.ihi * n] - this.pbar[i]);
}
this.y2star = fn(this.p2star);
/*
* Contract the whole simplex.
*/
if (this.y[this.ihi] < this.y2star) {
for (int j = 0; j < this.nn; j++) {
for (int i = 0; i < n; i++) {
this.p[i + j * n] = (this.p[i + j * n] + this.p[i
+ this.ilo * n]) * 0.5;
this.trainedWeights[i] = this.p[i + j * n];
}
this.y[j] = fn(this.trainedWeights);
}
this.ylo = this.y[0];
this.ilo = 0;
for (int i = 1; i < this.nn; i++) {
if (this.y[i] < this.ylo) {
this.ylo = this.y[i];
this.ilo = i;
}
}
continue;
}
/*
* Retain contraction.
*/
else {
for (int i = 0; i < n; i++) {
this.p[i + this.ihi * n] = this.p2star[i];
}
this.y[this.ihi] = this.y2star;
}
}
/*
* Contraction on the reflection side of the centroid.
*/
else if (this.l == 1) {
for (int i = 0; i < n; i++) {
this.p2star[i] = this.pbar[i] + this.ccoeff
* (this.pstar[i] - this.pbar[i]);
}
this.y2star = fn(this.p2star);
/*
* Retain reflection?
*/
if (this.y2star <= this.ystar) {
for (int i = 0; i < n; i++) {
this.p[i + this.ihi * n] = this.p2star[i];
}
this.y[this.ihi] = this.y2star;
} else {
for (int i = 0; i < n; i++) {
this.p[i + this.ihi * n] = this.pstar[i];
}
this.y[this.ihi] = this.ystar;
}
}
}
/*
* Check if YLO improved.
*/
if (this.y[this.ihi] < this.ylo) {
this.ylo = this.y[this.ihi];
this.ilo = this.ihi;
}
this.jcount = this.jcount - 1;
if (0 < this.jcount) {
continue;
}
/*
* Check to see if minimum reached.
*/
// if (icount <= kcount)
{
this.jcount = this.konvge;
this.z = 0.0;
for (int i = 0; i < this.nn; i++) {
this.z = this.z + this.y[i];
}
final double x = this.z / this.nn;
this.z = 0.0;
for (int i = 0; i < this.nn; i++) {
this.z = this.z + Math.pow(this.y[i] - x, 2);
}
if (this.z <= this.rq) {
break;
}
}
}
/*
* Factorial tests to check that YNEWLO is a local minimum.
*/
for (int i = 0; i < n; i++) {
this.trainedWeights[i] = this.p[i + this.ilo * n];
}
this.ynewlo = this.y[this.ilo];
boolean fault = false;
for (int i = 0; i < n; i++) {
this.del = this.step[i] * this.eps;
this.trainedWeights[i] += this.del;
this.z = fn(this.trainedWeights);
if (this.z < this.ynewlo) {
fault = true;
break;
}
this.trainedWeights[i] = this.trainedWeights[i] - this.del
- this.del;
this.z = fn(this.trainedWeights);
if (this.z < this.ynewlo) {
fault = true;
break;
}
this.trainedWeights[i] += this.del;
}
if (!fault) {
this.converged = true;
} else {
/*
* Restart the procedure.
*/
for (int i = 0; i < n; i++) {
this.start[i] = this.trainedWeights[i];
}
this.del = this.eps;
}
setError(this.ynewlo);
NetworkCODEC.arrayToNetwork(this.trainedWeights, this.network);
}
/**
* {@inheritDoc}
*/
@Override
public TrainingContinuation pause() {
return null;
}
/**
* {@inheritDoc}
*/
@Override
public void resume(final TrainingContinuation state) {
}
}