/*
* Encog(tm) Core v2.5 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2010 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.engine.network.train.prop;
import java.util.HashMap;
import java.util.Map;
import org.encog.engine.EncogEngine;
import org.encog.engine.EncogEngineError;
import org.encog.engine.data.EngineDataSet;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.flat.ValidateForOpenCL;
import org.encog.engine.network.train.TrainFlatNetwork;
import org.encog.engine.opencl.kernels.KernelNetworkTrain;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ErrorCalculation;
import org.encog.engine.util.ErrorCalculationMode;
/**
* Train a flat network using OpenCL.
*/
public class TrainFlatNetworkOpenCL implements TrainFlatNetwork {
/**
* Learn RPROP.
*/
public static final int LEARN_RPROP = 0;
/**
* Learn backpropagation.
*/
public static final int LEARN_BPROP = 1;
/**
* Learn Manhattan update rule.
*/
public static final int LEARN_MANHATTAN = 2;
/**
* The error.
*/
private double error;
/**
* The network to train.
*/
private final FlatNetwork network;
/**
* The training data.
*/
private final EngineIndexableSet training;
/**
* Training type.
*/
private int learningType;
/**
* The learning rate.
*/
private double learningRate;
/**
* The momentum.
*/
private double momentum;
/**
* The initial update.
*/
private double initialUpdate;
/**
* The max step.
*/
private double maxStep;
/**
* The kernel in use.
*/
private KernelNetworkTrain kernel;
/**
* The iteration.
*/
private int iteration;
private final OpenCLTrainingProfile profile;
/**
* Train a flat network multithreaded.
*
* @param network
* The network to train.
* @param training
* The training data to use.
* @param profile
* The OpenCL training profile.
*/
public TrainFlatNetworkOpenCL(final FlatNetwork network,
final EngineDataSet training, final OpenCLTrainingProfile profile) {
(new ValidateForOpenCL()).validate(network);
if (!(training instanceof EngineIndexableSet)) {
throw new EncogEngineError(
"Training data must be Indexable for this training type.");
}
if (EncogEngine.getInstance().getCL() == null) {
throw new EncogEngineError(
"You must enable OpenCL before using this training type.");
}
this.profile = profile;
this.network = network;
this.training = (EngineIndexableSet) training;
}
/**
* Call the kernel.
*
* @param start
* The starting training element.
* @param size
* The number of training elements.
* @param learn
* Should we learn?
* @param iterations
* The number of iterations.
*/
private void callKernel(final int start, final int size,
final boolean learn, final int iterations) {
// System.out.println("Iteration: start=" + start + ",sizePer=" + size +
// ",total=" + (size*this.kernel.getGlobalWork()) );
this.kernel.calculate(start, size, learn, iterations);
double e = 0;
for (int i = 0; i < this.kernel.getGlobalWork(); i++) {
e += this.kernel.getErrors()[i];
}
this.error += e;
}
/**
* {@inheritDoc}
*/
public void finishTraining() {
if (this.kernel != null) {
this.kernel.release();
}
}
/**
* {@inheritDoc}
*/
@Override
public double getError() {
return this.error;
}
/**
* {@inheritDoc}
*/
public int getIteration() {
return this.iteration;
}
/**
* @return The last gradients.
*/
public double[] getLastGradient() {
final double[] result = new double[this.network.getWeights().length];
for (int i = 0; i < result.length; i++) {
result[i] = this.kernel.getTempDataArray()[i];
}
return result;
}
/**
* @return the learningRate
*/
public double getLearningRate() {
return this.learningRate;
}
/**
* @return the learningType
*/
public int getLearningType() {
return this.learningType;
}
/**
* @return the maxStep
*/
public double getMaxStep() {
return this.maxStep;
}
/**
* @return the momentum
*/
public double getMomentum() {
return this.momentum;
}
/**
* {@inheritDoc}
*/
@Override
public FlatNetwork getNetwork() {
return this.network;
}
/**
* {@inheritDoc}
*/
@Override
public int getNumThreads() {
return 0;
}
/**
* Get the learning properties.
*
* @param learningType
* The learning type.
* @return The options.
*/
private Map<String, String> getOptions(final String learningType) {
final Map<String, String> options = new HashMap<String, String>();
options.put("NEURON_COUNT", "" + this.network.getNeuronCount());
options.put("WEIGHT_COUNT", "" + this.network.getWeights().length);
options.put(learningType, null);
return options;
}
/**
* @return The training data to use.
*/
@Override
public EngineDataSet getTraining() {
// TODO Auto-generated method stub
return null;
}
/**
* @return The update values.
*/
public double[] getUpdateValues() {
final double[] result = new double[this.network.getWeights().length];
final int len = this.network.getWeights().length;
for (int i = 0; i < result.length; i++) {
result[i] = this.kernel.getTempDataArray()[len + i];
}
return result;
}
/**
* {@inheritDoc}
*/
public void iteration() {
iteration(1);
}
/**
* {@inheritDoc}
*/
@Override
public void iteration(final int iterations) {
if (this.learningType == -1) {
throw new EncogEngineError(
"Learning type has not been defined yet, you must first call one of the learnXXXX methods, such as learnRPROP.");
}
this.iteration += iterations;
int currentIndex = 0;
this.error = 0;
int count = this.profile.getKernelNumberOfCalls();
// If we are using an OpenCL ratio other than 1.0, which means that we
// are
// braining up a single training iteration, there is no reason to try
// and batch
// up multiple iterations.
if ((count > 0) && (iterations > 1)) {
throw new EncogEngineError(
"Must use an OpenCL ratio of 1.0 if you are going to use an iteration count > 1.");
}
this.kernel.setGlobalWork(this.profile.getKernelGlobalWorkgroup());
this.kernel.setLocalWork(this.profile.getKernelLocalWorkgroup());
// handle workloads
while (count > 0) {
callKernel(currentIndex, this.profile.getKernelWorkPerCall(),
false, 1);
count--;
currentIndex += this.profile.getKernelWorkPerCall()
* this.kernel.getGlobalWork();
}
// handle the final workload
this.kernel.setGlobalWork(this.profile.getKernelRemainderGlobal());
this.kernel.setLocalWork(this.profile.getKernelRemainderGlobal());
callKernel(currentIndex, this.profile.getKernelRemainderPer(), true,
iterations);
count = (int) this.training.getRecordCount();
this.error = this.error / (count * this.training.getIdealSize());
if (ErrorCalculation.getMode() == ErrorCalculationMode.RMS) {
this.error = Math.sqrt(this.error);
}
EngineArray.arrayCopy(this.kernel.getWeightOutArray(), this.network
.getWeights());
}
/**
* Learn using backpropagation.
*
* @param learningRate
* The learning rate.
* @param momentum
* The momentum.
*/
public void learnBPROP(final double learningRate, final double momentum) {
this.learningType = TrainFlatNetworkOpenCL.LEARN_BPROP;
this.momentum = momentum;
this.learningRate = learningRate;
this.learningType = TrainFlatNetworkOpenCL.LEARN_BPROP;
final Map<String, String> options = getOptions("LEARN_BPROP");
this.kernel = new KernelNetworkTrain(this.profile.getDevice(),
this.network, this.training,
this.network.getWeights().length + 2);
this.kernel.compile(options, this.profile, this.network);
this.kernel.getTempDataArray()[0] = (float) learningRate;
this.kernel.getTempDataArray()[1] = (float) momentum;
}
/**
* Learn using the Manhattan update rule.
*
* @param learningRate
* The learning rate.
*/
public void learnManhattan(final double learningRate) {
this.learningType = TrainFlatNetworkOpenCL.LEARN_MANHATTAN;
this.learningRate = learningRate;
final Map<String, String> options = getOptions("LEARN_MANHATTAN");
this.kernel = new KernelNetworkTrain(this.profile.getDevice(),
this.network, this.training, 1);
this.kernel.compile(options, this.profile, this.network);
this.kernel.getTempDataArray()[0] = (float) learningRate;
}
/**
* Learn using RPROP. Use default max step and initial update.
*/
public void learnRPROP() {
learnRPROP(RPROPConst.DEFAULT_INITIAL_UPDATE,
RPROPConst.DEFAULT_MAX_STEP);
}
/**
* Learn using RPROP with a custom initial update and max step.
*
* @param initialUpdate
* The initial update value.
* @param maxStep
* The max step.
*/
public void learnRPROP(final double initialUpdate, final double maxStep) {
this.learningType = TrainFlatNetworkOpenCL.LEARN_RPROP;
this.initialUpdate = initialUpdate;
this.maxStep = maxStep;
final Map<String, String> options = getOptions("LEARN_RPROP");
this.kernel = new KernelNetworkTrain(this.profile.getDevice(),
this.network, this.training,
this.network.getWeights().length * 2);
this.kernel.compile(options, this.profile, this.network);
final int weightLength = this.network.getWeights().length;
for (int i = 0; i < weightLength; i++) {
this.kernel.getTempDataArray()[i] = 0;
this.kernel.getTempDataArray()[i + weightLength] = (float) this.initialUpdate;
}
}
/**
* {@inheritDoc}
*/
public void setIteration(final int iteration) {
this.iteration = iteration;
}
/**
* {@inheritDoc}
*/
@Override
public void setNumThreads(final int numThreads) {
}
}