/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.alg;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.shifu.shifu.container.ModelInitInputObject;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.shifu.core.AbstractTrainer;
import ml.shifu.shifu.core.ConvergeJudger;
import ml.shifu.shifu.core.MSEWorker;
import ml.shifu.shifu.core.dtrain.CommonConstants;
import ml.shifu.shifu.fs.ShifuFileUtils;
import ml.shifu.shifu.util.JSONUtils;
import org.apache.commons.io.FileUtils;
import org.encog.engine.network.activation.ActivationLOG;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.activation.ActivationSIN;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.engine.network.activation.ActivationTANH;
import org.encog.mathutil.IntRange;
import org.encog.ml.data.MLDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
import org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation;
import org.encog.neural.networks.training.propagation.quick.QuickPropagation;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.neural.networks.training.propagation.scg.ScaledConjugateGradient;
import org.encog.persist.EncogDirectoryPersistence;
import org.encog.util.concurrency.DetermineWorkload;
import org.encog.util.concurrency.EngineConcurrency;
import org.encog.util.concurrency.TaskGroup;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Neural network trainer
*/
public class NNTrainer extends AbstractTrainer {
private static final Logger LOG = LoggerFactory.getLogger(NNTrainer.class);
private final static double Epsilon = 1.0; // set the weight range in [-INIT_EPSILON INIT_EPSILON];
public static final Map<String, Double> defaultLearningRate;
public static final Map<String, String> learningAlgMap;
private BasicNetwork network;
private volatile boolean toPersistentModel = true;
private volatile boolean toLoggingProcess = true;
/**
* Convergence judger instance for convergence criteria checking.
*/
private ConvergeJudger judger = new ConvergeJudger();
static {
// TODO use UnmodifiableMap or use other immutable Collections such as guava's
Map<String, Double> tmpLearningRate = new HashMap<String, Double>();
tmpLearningRate.put("S", 0.1);
tmpLearningRate.put("R", 0.1);
tmpLearningRate.put("Q", 2.0);
tmpLearningRate.put("B", 0.01);
tmpLearningRate.put("M", 0.00001);
defaultLearningRate = Collections.unmodifiableMap(tmpLearningRate);
Map<String, String> tmpLearningAlgMap = new HashMap<String, String>();
tmpLearningAlgMap.put("S", "Scaled Conjugate Gradient");
tmpLearningAlgMap.put("R", "Resilient Propagation");
tmpLearningAlgMap.put("M", "Manhattan Propagation");
tmpLearningAlgMap.put("B", "Back Propagation");
tmpLearningAlgMap.put("Q", "Quick Propagation");
learningAlgMap = Collections.unmodifiableMap(tmpLearningAlgMap);
}
public NNTrainer(ModelConfig modelConfig, int trainerID, Boolean dryRun) {
super(modelConfig, trainerID, dryRun);
}
@SuppressWarnings("unchecked")
public void buildNetwork() {
network = new BasicNetwork();
network.addLayer(new BasicLayer(new ActivationLinear(), true, trainSet.getInputSize()));
int numLayers = (Integer) modelConfig.getParams().get(CommonConstants.NUM_HIDDEN_LAYERS);
List<String> actFunc = (List<String>) modelConfig.getParams().get(CommonConstants.ACTIVATION_FUNC);
List<Integer> hiddenNodeList = (List<Integer>) modelConfig.getParams().get(CommonConstants.NUM_HIDDEN_NODES);
if(numLayers != 0 && (numLayers != actFunc.size() || numLayers != hiddenNodeList.size())) {
throw new RuntimeException(
"the number of layer do not equal to the number of activation function or the function list and node list empty");
}
if(toLoggingProcess)
LOG.info(" - total " + numLayers + " layers, each layers are: "
+ Arrays.toString(hiddenNodeList.toArray()) + " the activation function are: "
+ Arrays.toString(actFunc.toArray()));
for(int i = 0; i < numLayers; i++) {
String func = actFunc.get(i);
Integer numHiddenNode = hiddenNodeList.get(i);
// java 6
if("linear".equalsIgnoreCase(func)) {
network.addLayer(new BasicLayer(new ActivationLinear(), true, numHiddenNode));
} else if(func.equalsIgnoreCase("sigmoid")) {
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, numHiddenNode));
} else if(func.equalsIgnoreCase("tanh")) {
network.addLayer(new BasicLayer(new ActivationTANH(), true, numHiddenNode));
} else if(func.equalsIgnoreCase("log")) {
network.addLayer(new BasicLayer(new ActivationLOG(), true, numHiddenNode));
} else if(func.equalsIgnoreCase("sin")) {
network.addLayer(new BasicLayer(new ActivationSIN(), true, numHiddenNode));
} else {
LOG.info("Unsupported activation function: " + func
+ " !! Set this layer activation function to be Sigmoid ");
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, numHiddenNode));
}
}
network.addLayer(new BasicLayer(new ActivationSigmoid(), false, trainSet.getIdealSize()));
network.getStructure().finalizeStructure();
if(!modelConfig.isFixInitialInput()) {
network.reset();
} else {
int numWeight = 0;
for(int i = 0; i < network.getLayerCount() - 1; i++) {
numWeight = numWeight + network.getLayerTotalNeuronCount(i) * network.getLayerNeuronCount(i + 1);
}
LOG.info(" - You have " + numWeight + " weights to be initialize");
loadWeightsInput(numWeight);
}
}
@Override
public double train() throws IOException {
if(toLoggingProcess)
LOG.info("Using neural network algorithm...");
if(toLoggingProcess) {
if(this.dryRun) {
LOG.info("Start Training(Dry Run)... Model #" + this.trainerID);
} else {
LOG.info("Start Training... Model #" + this.trainerID);
}
LOG.info(" - Input Size: " + trainSet.getInputSize());
LOG.info(" - Ideal Size: " + trainSet.getIdealSize());
LOG.info(" - Training Records Count: " + trainSet.getRecordCount());
LOG.info(" - Validation Records Count: " + validSet.getRecordCount());
}
// set up the model
buildNetwork();
Propagation mlTrain = getMLTrain();
mlTrain.setThreadCount(0);
if(this.dryRun) {
return 0.0;
}
int epochs = this.modelConfig.getNumTrainEpochs();
int factor = Math.max(epochs / 50, 10);
// Get convergence threshold from modelConfig.
double threshold = modelConfig.getTrain().getConvergenceThreshold() == null ? 0.0 : modelConfig.getTrain()
.getConvergenceThreshold().doubleValue();
String formatedThreshold = df.format(threshold);
setBaseMSE(Double.MAX_VALUE);
for(int i = 0; i < epochs; i++) {
mlTrain.iteration();
if(i % factor == 0) {
this.saveTmpNN(i);
}
double validMSE = (this.validSet.getRecordCount() > 0) ? getValidSetError() : mlTrain.getError();
String extra = "";
if(validMSE < getBaseMSE()) {
setBaseMSE(validMSE);
saveNN();
extra = " <-- NN saved: ./models/model" + this.trainerID + ".nn";
}
if(toLoggingProcess)
LOG.info(" Trainer-" + trainerID + "> Epoch #" + (i + 1) + " Train Error: "
+ df.format(mlTrain.getError()) + " Validation Error: "
+ ((this.validSet.getRecordCount() > 0) ? df.format(validMSE) : "N/A") + " " + extra);
// Convergence judging.
double avgErr = (mlTrain.getError() + validMSE) / 2;
if(judger.judge(avgErr, threshold)) {
LOG.info("Trainer-{}> Epoch #{} converged! Average Error: {}, Threshold: {}", trainerID, (i + 1),
df.format(avgErr), formatedThreshold);
break;
} else {
if(toLoggingProcess) {
LOG.info("Trainer-{}> Epoch #{} Average Error: {}, Threshold: {}", trainerID, (i + 1),
df.format(avgErr), formatedThreshold);
}
}
}
mlTrain.finishTraining();
if(toLoggingProcess)
LOG.info("Trainer #" + this.trainerID + " is Finished!");
return getBaseMSE();
}
public BasicNetwork getNetwork() {
return network;
}
public void enableModelPersistence() {
this.toPersistentModel = true;
}
public void disableModelPersistence() {
this.toPersistentModel = false;
}
public void enableLogging() {
this.toLoggingProcess = true;
}
public void disableLogging() {
this.toLoggingProcess = false;
}
/**
* @param network
* the network to set
*/
public void setNetwork(BasicNetwork network) {
this.network = network;
}
private Propagation getMLTrain() {
// String alg = this.modelConfig.getLearningAlgorithm();
String alg = (String) modelConfig.getParams().get(CommonConstants.PROPAGATION);
if(!(defaultLearningRate.containsKey(alg))) {
throw new RuntimeException("Learning algorithm is invalid: " + alg);
}
// Double rate = this.modelConfig.getLearningRate();
double rate = defaultLearningRate.get(alg);
Object rateObj = modelConfig.getParams().get(CommonConstants.LEARNING_RATE);
if(rateObj instanceof Double) {
rate = (Double) rateObj;
} else if(rateObj instanceof Integer) {
// change like this, because user may set it as integer
rate = ((Integer) rateObj).doubleValue();
} else if(rateObj instanceof Float) {
rate = ((Float) rateObj).doubleValue();
}
if(toLoggingProcess)
LOG.info(" - Learning Algorithm: " + learningAlgMap.get(alg));
if(alg.equals("Q") || alg.equals("B") || alg.equals("M")) {
if(toLoggingProcess)
LOG.info(" - Learning Rate: " + rate);
}
if(alg.equals("B")) {
return new Backpropagation(network, trainSet, rate, 0);
} else if(alg.equals("Q")) {
return new QuickPropagation(network, trainSet, rate);
} else if(alg.equals("M")) {
return new ManhattanPropagation(network, trainSet, rate);
} else if(alg.equals("R")) {
return new ResilientPropagation(network, trainSet);
} else if(alg.equals("S")) {
return new ScaledConjugateGradient(network, trainSet);
} else {
return null;
}
}
private double getValidSetError() {
// return calculateMSE(this.network, this.validSet);
return calculateMSEParallel(this.network, this.validSet);
}
public double calculateMSEParallel(BasicNetwork network, MLDataSet dataSet) {
int numRecords = (int) dataSet.getRecordCount();
assert numRecords > 0;
// setup workers
final DetermineWorkload determine = new DetermineWorkload(0, numRecords);
// nice little workaround
MSEWorker[] workers = new MSEWorker[determine.getThreadCount()];
int index = 0;
TaskGroup group = EngineConcurrency.getInstance().createTaskGroup();
for(final IntRange r: determine.calculateWorkers()) {
workers[index++] = new MSEWorker((BasicNetwork) network.clone(), dataSet.openAdditional(), r.getLow(),
r.getHigh());
}
for(final MSEWorker worker: workers) {
EngineConcurrency.getInstance().processTask(worker, group);
}
group.waitForComplete();
double totalError = 0;
for(final MSEWorker worker: workers) {
totalError += worker.getTotalError();
}
return totalError / numRecords;
}
private void saveNN() throws IOException {
if(!toPersistentModel) {
return;
}
File folder = new File(pathFinder.getModelsPath(SourceType.LOCAL));
if(!folder.exists()) {
FileUtils.forceMkdir(folder);
}
EncogDirectoryPersistence.saveObject(new File(folder, "model" + this.trainerID + ".nn"), network);
}
private void saveTmpNN(int epoch) throws IOException {
if(!toPersistentModel) {
return;
}
File tmpFolder = new File(pathFinder.getTmpModelsPath(SourceType.LOCAL));
if(!tmpFolder.exists()) {
FileUtils.forceMkdir(tmpFolder);
}
EncogDirectoryPersistence.saveObject(new File(tmpFolder, "model" + trainerID + "-" + epoch + ".nn"), network);
}
public MLDataSet getValidSet() {
return validSet;
}
public Double getBaseMSE() {
if(baseMSE == null) {
LOG.error("baseMSE is not available. Run train() First!");
return null;
}
return baseMSE;
}
private void loadWeightsInput(int numWeights) {
try {
File file = new File("./init" + this.trainerID + ".json");
if(!file.exists()) {
ModelInitInputObject io = new ModelInitInputObject();
io.setWeights(randomSetWeights(numWeights));
io.setNumWeights(numWeights);
setWeights(io.getWeights());
JSONUtils.writeValue(file, io);
} else {
BufferedReader reader = ShifuFileUtils.getReader("./init" + this.trainerID + ".json", SourceType.LOCAL);
ModelInitInputObject io = JSONUtils.readValue(reader, ModelInitInputObject.class);
if(io == null) {
io = new ModelInitInputObject();
}
if(io.getNumWeights() != numWeights) {
io.setNumWeights(numWeights);
io.setWeights(randomSetWeights(numWeights));
JSONUtils.writeValue(file, io);
}
setWeights(io.getWeights());
reader.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
private void setWeights(List<Double> weights) {
if(network == null)
return;
int i = 0;
for(int numLayer = 0; numLayer < network.getLayerCount() - 1; numLayer++) {
int fromCount = network.getLayerTotalNeuronCount(numLayer);
int toCount = network.getLayerNeuronCount(numLayer + 1);
for(int fromNeuron = 0; fromNeuron < fromCount; fromNeuron++) {
for(int toNeuron = 0; toNeuron < toCount; toNeuron++) {
network.setWeight(numLayer, fromNeuron, toNeuron, weights.get(i++));
}
}
}
}
private List<Double> randomSetWeights(int numWeights) {
List<Double> weights = new ArrayList<Double>();
for(int i = 0; i < numWeights; i++) {
weights.add(this.random.nextDouble() * 2 * Epsilon - Epsilon);
}
return weights;
}
}