package func.nn.backprop;
import shared.DataSet;
import shared.GradientErrorMeasure;
import shared.Instance;
import func.nn.NetworkTrainer;
/**
* A standard batch back propagation trainer
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class BatchBackPropagationTrainer extends NetworkTrainer {
/**
* The weight update rule to use
*/
private WeightUpdateRule rule;
/**
* Make a new back propagation trainer
* @param patterns the patterns to train on
* @param network the network to train
* @param errorMeasure the error measure to use
*/
public BatchBackPropagationTrainer(DataSet patterns,
BackPropagationNetwork network,
GradientErrorMeasure errorMeasure,
WeightUpdateRule rule) {
super(patterns, network, errorMeasure);
this.rule = rule;
}
/**
* @see nn.Trainer#train()
*/
public double train() {
BackPropagationNetwork network =
(BackPropagationNetwork) getNetwork();
GradientErrorMeasure measure =
(GradientErrorMeasure) getErrorMeasure();
DataSet patterns = getDataSet();
double error = 0;
for (int i = 0; i < patterns.size(); i++) {
Instance pattern = patterns.get(i);
network.setInputValues(pattern.getData());
network.run();
Instance output = new Instance(network.getOutputValues());
double[] errors = measure.gradient(output, pattern);
error += measure.value(output, pattern);
network.setOutputErrors(errors);
network.backpropagate();
}
network.updateWeights(rule);
network.clearError();
return error / patterns.size();
}
}