/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM;
import ml.shifu.shifu.core.alg.NNTrainer;
import ml.shifu.shifu.util.Constants;
import org.apache.commons.io.FileUtils;
import org.encog.Encog;
import org.encog.engine.network.activation.*;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.quick.QuickPropagation;
import org.encog.persist.EncogDirectoryPersistence;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class NNTrainerTest {
private MLDataSet trainSet;
private BasicNetwork network;
private final static MLDataSet xor_Trainset = new BasicMLDataSet();
//private final static Integer numberXorSet = 4 * 3;
private final static MLDataSet xor_Validset = new BasicMLDataSet();
static {
double[] input = {0., 0.,};
double[] ideal = {0.};
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input),
new BasicMLData(ideal));
xor_Trainset.add(pair);
xor_Validset.add(pair);
input = new double[]{0., 1.,};
ideal = new double[]{1.};
pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(
ideal));
xor_Trainset.add(pair);
xor_Validset.add(pair);
input = new double[]{1., 0.,};
ideal = new double[]{1.};
pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(
ideal));
xor_Trainset.add(pair);
xor_Validset.add(pair);
input = new double[]{1., 1.,};
ideal = new double[]{0.};
pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(
ideal));
xor_Trainset.add(pair);
xor_Validset.add(pair);
}
@BeforeClass
public void setUp() {
trainSet = new BasicMLDataSet();
network = new BasicNetwork();
network.addLayer(new BasicLayer(new ActivationLinear(), true, 2));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 4));
network.addLayer(new BasicLayer(new ActivationLOG(), true, 3));
network.addLayer(new BasicLayer(new ActivationSIN(), true, 3));
network.addLayer(new BasicLayer(new ActivationTANH(), false, 1));
network.getStructure().finalizeStructure();
network.reset();
}
// @Test
public void testXorOperation() throws IOException {
ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, ".");
config.getTrain().setBaggingSampleRate(1.0);
config.getTrain().setValidSetRate(0.1);
config.getTrain().getParams().put("Propagation", "Q");
config.getTrain().getParams().put("NumHiddenLayers", 1);
config.getTrain().getParams().put("LearningRate", 1);
List<Integer> nodes = new ArrayList<Integer>();
nodes.add(5);
List<String> func = new ArrayList<String>();
func.add("tanh");
config.getTrain().getParams().put("NumHiddenNodes", nodes);
config.getTrain().getParams().put("ActivationFunc", func);
config.getTrain().setNumTrainEpochs(100);
NNTrainer trainer = new NNTrainer(config, 0, false);
trainer.setTrainSet(xor_Trainset);
trainer.setValidSet(xor_Validset);
trainer.train();
BasicNetwork bn = trainer.getNetwork();
boolean[] cases = {true, false, false, true};
int i = 0;
for (MLDataPair data : xor_Validset) {
double[] score = bn.compute(data.getInput()).getData();
Assert.assertEquals(score[0] * 1000 < 500, cases[i]);
i++;
}
Assert.assertEquals(bn.getLayerCount(), (Integer) (config.getTrain().getParams().get("NumHiddenLayers")) + 2 /*add input output*/);
}
@Test(expectedExceptions = RuntimeException.class)
public void testExceptionWhileSetupModel() throws IOException {
ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, ".");
config.getTrain().getParams().put("Propagation", "Q");
config.getTrain().getParams().put("NumHiddenLayers", 2);
config.getTrain().getParams().put("LearningRate", 0.1);
List<Integer> nodes = new ArrayList<Integer>();
nodes.add(3);
nodes.add(3);
nodes.add(3);
List<String> func = new ArrayList<String>();
func.add("tanh");
config.getTrain().getParams().put("NumHiddenNodes", nodes);
config.getTrain().getParams().put("ActivationFunc", func);
config.getTrain().setNumTrainEpochs(50);
NNTrainer trainer = new NNTrainer(config, 0, false);
try {
trainer.setDataSet(xor_Trainset);
} catch (IOException e) {
}
trainer.buildNetwork();
}
@Test
public void testAndOperation() throws IOException {
MLDataPair dataPair0 = BasicMLDataPair.createPair(2, 1);
dataPair0.setInputArray(new double[]{0.0, 0.0});
dataPair0.setIdealArray(new double[]{0.0});
trainSet.add(dataPair0);
MLDataPair dataPair1 = BasicMLDataPair.createPair(2, 1);
dataPair1.setInputArray(new double[]{0.0, 1.0});
dataPair1.setIdealArray(new double[]{0.0});
trainSet.add(dataPair1);
MLDataPair dataPair2 = BasicMLDataPair.createPair(2, 1);
dataPair2.setInputArray(new double[]{1.0, 0.0});
dataPair2.setIdealArray(new double[]{0.0});
trainSet.add(dataPair2);
MLDataPair dataPair3 = BasicMLDataPair.createPair(2, 1);
dataPair3.setInputArray(new double[]{1.0, 1.0});
dataPair3.setIdealArray(new double[]{1.0});
trainSet.add(dataPair3);
Propagation propagation = new QuickPropagation(network, trainSet, 0.1);
double error = 0.0;
double lastError = Double.MAX_VALUE;
int iterCnt = 0;
do {
propagation.iteration();
lastError = error;
error = propagation.getError();
System.out.println("The #" + (++iterCnt)
+ " error is " + error);
} while (Math.abs(lastError - error) > 0.001);
propagation.finishTraining();
File tmp = new File("model_folder");
if (!tmp.exists()) {
FileUtils.forceMkdir(tmp);
}
File modelFile = new File(
"model_folder/model6.nn");
EncogDirectoryPersistence.saveObject(modelFile, network);
Assert.assertTrue(modelFile.exists());
FileUtils.deleteQuietly(modelFile);
}
@Test
public void testExistingModels() throws IOException {
MLDataPair dataPair0 = BasicMLDataPair.createPair(2, 1);
dataPair0.setInputArray(new double[]{-0.866025, -0.866025});
dataPair0.setIdealArray(new double[]{0.0});
trainSet.add(dataPair0);
MLDataPair dataPair1 = BasicMLDataPair.createPair(2, 1);
dataPair1.setInputArray(new double[]{-0.866025, 0.866025});
dataPair1.setIdealArray(new double[]{0.0});
trainSet.add(dataPair1);
MLDataPair dataPair2 = BasicMLDataPair.createPair(2, 1);
dataPair2.setInputArray(new double[]{0.866025, -0.866025});
dataPair2.setIdealArray(new double[]{0.0});
trainSet.add(dataPair2);
MLDataPair dataPair3 = BasicMLDataPair.createPair(2, 1);
dataPair3.setInputArray(new double[]{0.866025, 0.866025});
dataPair3.setIdealArray(new double[]{1.0});
trainSet.add(dataPair3);
File modelDir = new File("model_folder");
if (modelDir.isDirectory()) {
File[] files = modelDir.listFiles();
if (files != null) {
for (File modelFile : files) {
System.out.println("result of " + modelFile.getName() + ":");
computeScore(modelFile, dataPair0, dataPair1, dataPair2, dataPair3);
}
} else {
throw new IOException(String.format("Failed to list files in %s", modelDir.getAbsolutePath()));
}
} else {
System.err.println("No ./model_folder exist!");
}
}
private void computeScore(File modelFile, MLDataPair dataPair0,
MLDataPair dataPair1, MLDataPair dataPair2, MLDataPair dataPair3) {
BasicNetwork model = (BasicNetwork) EncogDirectoryPersistence
.loadObject(modelFile);
System.out.println((int) (model.compute(dataPair0.getInput())
.getData(0) * 1000));
System.out.println((int) (model.compute(dataPair1.getInput())
.getData(0) * 1000));
System.out.println((int) (model.compute(dataPair2.getInput())
.getData(0) * 1000));
System.out.println((int) (model.compute(dataPair3.getInput())
.getData(0) * 1000));
}
@AfterClass
public void shutDown() throws IOException {
FileUtils.deleteDirectory(new File("./models/"));
FileUtils.deleteDirectory(new File("./modelsTmp/"));
FileUtils.deleteDirectory(new File("model_folder"));
FileUtils.deleteDirectory(new File("tmp"));
FileUtils.deleteQuietly(new File("init0.json"));
FileUtils.deleteDirectory(new File(Constants.COLUMN_META_FOLDER_NAME));
Encog.getInstance().shutdown();
}
}