/**
* Copyright (C) 2017 Jan Schäfer (jansch@users.sourceforge.net)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jskat.ai.nn.util;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.List;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.PersistBasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
import org.encog.neural.networks.training.propagation.resilient.RPROPType;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
/**
* Wraps the Encog network to fulfill the interface {@link INeuralNetwork}
*/
public class EncogNetworkWrapper implements INeuralNetwork {
private BasicNetwork network;
private final PersistBasicNetwork networkPersister;
/**
* Constructor
*
* @param topo
* Network topology
* @param useBias
* TRUE, if bias nodes should be used
*/
public EncogNetworkWrapper(NetworkTopology topo, boolean useBias) {
network = new BasicNetwork();
network.addLayer(new BasicLayer(new ActivationSigmoid(), useBias, topo.getInputNeuronCount()));
for (int i = 0; i < topo.getHiddenLayerCount(); i++) {
network.addLayer(new BasicLayer(new ActivationSigmoid(), useBias, topo.getHiddenNeuronCount(i)));
}
network.addLayer(new BasicLayer(new ActivationSigmoid(), useBias, 1));
network.getStructure().finalizeStructure();
network.reset();
networkPersister = new PersistBasicNetwork();
}
/**
* {@inheritDoc}
*/
@Override
public double getAvgDiff() {
return 0.0;
}
@Override
public synchronized double adjustWeights(final double[] inputValues, final double[] outputValues) {
List<MLDataPair> data = new ArrayList<MLDataPair>();
data.add(new BasicMLDataPair(new BasicMLData(inputValues), new BasicMLData(outputValues)));
MLDataSet trainingSet = new BasicMLDataSet(data);
final Backpropagation trainer = new Backpropagation(network, trainingSet, 0.07, 0.02);
trainer.setBatchSize(1);
trainer.iteration();
return trainer.getError();
}
@Override
public synchronized double adjustWeightsBatch(final double[][] inputValues, final double[][] outputValues) {
MLDataSet trainingSet = new BasicMLDataSet(inputValues, outputValues);
final ResilientPropagation train = new ResilientPropagation(network, trainingSet);
train.setRPROPType(RPROPType.iRPROPp);
train.setBatchSize(0);
train.iteration();
return train.getError();
}
/**
* {@inheritDoc}
*/
@Override
public synchronized void resetNetwork() {
network.reset();
}
/**
* {@inheritDoc}
*/
@Override
public synchronized double getPredictedOutcome(final double[] inputValues) {
MLData output = network.compute(new BasicMLData(inputValues));
return output.getData(0);
}
/**
* {@inheritDoc}
*/
@Override
public long getIterations() {
return 0;
}
/**
* {@inheritDoc}
*/
@Override
public synchronized boolean saveNetwork(final String fileName) {
try {
networkPersister.save(new FileOutputStream(fileName), network);
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
return true;
}
/**
* {@inheritDoc}
*/
@Override
public synchronized void loadNetwork(final String fileName, final int inputNeurons, final int hiddenNeurons,
final int outputNeurons) {
network = (BasicNetwork) networkPersister.read(getClass().getResourceAsStream(fileName));
}
}