/*
* Encog(tm) Java Examples v3.4
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-examples
*
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.examples.neural.benchmark;
import org.encog.Encog;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.activation.ActivationReLU;
import org.encog.examples.proben.BenchmarkDefinition;
import org.encog.examples.proben.ProBenData;
import org.encog.examples.proben.ProBenResultAccumulator;
import org.encog.examples.proben.ProBenRunner;
import org.encog.mathutil.randomize.XaiverRandomizer;
import org.encog.ml.MLMethod;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.RequiredImprovementStrategy;
import org.encog.ml.train.strategy.end.EarlyStoppingStrategy;
import org.encog.ml.train.strategy.end.EndIterationsStrategy;
import org.encog.neural.error.ATanErrorFunction;
import org.encog.neural.error.CrossEntropyErrorFunction;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.error.LinearErrorFunction;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
public class BenchmarkErrorFunctions implements BenchmarkDefinition {
private String probenFolder;
private ErrorFunction errorFn;
BenchmarkErrorFunctions(String theProbenFolder, ErrorFunction theErrorFn) {
this.probenFolder = theProbenFolder;
this.errorFn = theErrorFn;
}
public MLMethod createMethod(ProBenData data) {
int hiddenCount = (int)((data.getInputCount()+data.getIdealCount())*0.5);
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(null,true,data.getInputCount()));
network.addLayer(new BasicLayer(new ActivationReLU(),true,hiddenCount));
network.addLayer(new BasicLayer(new ActivationLinear(),false,data.getIdealCount()));
network.getStructure().finalizeStructure();
network.reset();
(new XaiverRandomizer()).randomize(network);
return network;
}
public MLTrain createTrainer(MLMethod method, ProBenData data) {
final ResilientPropagation train = new ResilientPropagation(
(ContainsFlat)method, data.getTrainingDataSet());
train.setErrorFunction(this.errorFn);
train.addStrategy(new EarlyStoppingStrategy(data.getValidationDataSet()));
train.addStrategy(new RequiredImprovementStrategy(100));
train.addStrategy(new EndIterationsStrategy(2000));
return train;
}
public String getProBenFolder() {
return this.probenFolder;
}
public static ProBenResultAccumulator benchmarkLinear(String probenPath) {
System.out.println("Starting Linear...");
BenchmarkErrorFunctions def = new BenchmarkErrorFunctions(probenPath, new LinearErrorFunction());
ProBenRunner runner = new ProBenRunner(def);
return runner.run();
}
public static ProBenResultAccumulator benchmarkArctan(String probenPath) {
System.out.println("Starting Arctan...");
BenchmarkErrorFunctions def = new BenchmarkErrorFunctions(probenPath, new ATanErrorFunction());
ProBenRunner runner = new ProBenRunner(def);
return runner.run();
}
public static ProBenResultAccumulator benchmarkCrossEntropy(String probenPath) {
System.out.println("Starting CrossEntropy...");
BenchmarkErrorFunctions def = new BenchmarkErrorFunctions(probenPath, new CrossEntropyErrorFunction());
ProBenRunner runner = new ProBenRunner(def);
return runner.run();
}
public static void main(String[] args) {
String probenPath = ProBenData.obtainProbenPath(args);
System.out.println("Starting...");
ProBenResultAccumulator linear = benchmarkLinear(probenPath);
ProBenResultAccumulator arctan = benchmarkArctan(probenPath);
ProBenResultAccumulator crossEntropy = benchmarkCrossEntropy(probenPath);
System.out.println("Linear: " + linear.toString());
System.out.println("Arctan: " + arctan.toString());
System.out.println("Cross Entropy: " + crossEntropy.toString());
Encog.getInstance().shutdown();
}
@Override
public boolean shouldCenter() {
return true;
}
@Override
public double getInputCenter() {
return 0;
}
@Override
public double getOutputCenter() {
return 2;
}
}