/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import ml.shifu.shifu.container.ScoreObject;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ColumnConfig.ColumnType;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM;
import ml.shifu.shifu.core.alg.NNTrainer;
import ml.shifu.shifu.core.alg.SVMTrainer;
import ml.shifu.shifu.util.Constants;
import org.apache.commons.io.FileUtils;
import org.encog.ml.BasicML;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
public class ScorerTest {
private ModelConfig modelConfig;
List<BasicML> models = new ArrayList<BasicML>();
MLDataSet set = new BasicMLDataSet();
@BeforeClass
public void setup() throws IOException {
modelConfig = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, ".");
modelConfig.getTrain().getParams().put("Propagation", "B");
modelConfig.getTrain().getParams().put("NumHiddenLayers", 2);
modelConfig.getTrain().getParams().put("LearningRate", 0.5);
List<Integer> nodes = new ArrayList<Integer>();
nodes.add(3);
nodes.add(4);
List<String> func = new ArrayList<String>();
func.add("linear");
func.add("tanh");
modelConfig.getTrain().getParams().put("NumHiddenNodes", nodes);
modelConfig.getTrain().getParams().put("ActivationFunc", func);
NNTrainer trainer = new NNTrainer(modelConfig, 0, false);
double[] input = { 0., 0., };
double[] ideal = { 1. };
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
set.add(pair);
input = new double[] { 0., 1., };
ideal = new double[] { 0. };
pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
set.add(pair);
input = new double[] { 1., 0., };
ideal = new double[] { 0. };
pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
set.add(pair);
input = new double[] { 1., 1., };
ideal = new double[] { 1. };
pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
set.add(pair);
trainer.setTrainSet(set);
trainer.setValidSet(set);
trainer.train();
modelConfig.getTrain().setAlgorithm("SVM");
modelConfig.getTrain().getParams().put("Kernel", "rbf");
modelConfig.getTrain().getParams().put("Const", 0.1);
modelConfig.getTrain().getParams().put("Gamma", 1.0);
modelConfig.getVarSelect().setFilterNum(2);
SVMTrainer svm = new SVMTrainer(modelConfig, 1, false);
svm.setTrainSet(set);
svm.setValidSet(set);
svm.train();
models.add(trainer.getNetwork());
models.add(svm.getSVM());
}
// @Test
public void scoreTest() {
List<ColumnConfig> list = new ArrayList<ColumnConfig>();
ColumnConfig col = new ColumnConfig();
col.setColumnType(ColumnType.N);
col.setColumnName("A");
col.setColumnNum(0);
col.setFinalSelect(true);
list.add(col);
col = new ColumnConfig();
col.setColumnType(ColumnType.N);
col.setColumnName("B");
col.setColumnNum(1);
col.setFinalSelect(true);
list.add(col);
Scorer s = new Scorer(models, list, "NN", modelConfig);
double[] input = { 0., 0., };
double[] ideal = { 1. };
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
ScoreObject o = s.score(pair, null);
List<Double> scores = o.getScores();
Assert.assertTrue(scores.get(0) > 400);
Assert.assertTrue(scores.get(1) == 1000);
}
// @Test
public void scoreNull() {
Scorer s = new Scorer(models, null, "NN", modelConfig);
Assert.assertNull(s.score(null, null));
}
// @Test
public void scoreModelsException() {
List<ColumnConfig> list = new ArrayList<ColumnConfig>();
ColumnConfig col = new ColumnConfig();
col.setColumnType(ColumnType.N);
col.setColumnName("A");
col.setColumnNum(0);
col.setFinalSelect(true);
list.add(col);
col = new ColumnConfig();
col.setColumnType(ColumnType.N);
col.setColumnName("B");
col.setColumnNum(1);
col.setFinalSelect(true);
list.add(col);
Scorer s = new Scorer(models, list, "NN", modelConfig);
double[] input = { 0., 0., 3. };
double[] ideal = { 1. };
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
Assert.assertEquals(s.score(pair, null).getScores().size(), 0);
}
@AfterClass
public void delete() throws IOException {
FileUtils.deleteDirectory(new File("tmp"));
FileUtils.deleteDirectory(new File("models"));
FileUtils.deleteDirectory(new File("test-output"));
FileUtils.deleteDirectory(new File(Constants.COLUMN_META_FOLDER_NAME));
}
}