/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.pmml;
import java.io.File;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.shifu.shifu.ShifuCLI;
import ml.shifu.shifu.container.obj.ModelTrainConf;
import ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator;
import ml.shifu.shifu.core.pmml.builder.impl.NNSpecifCreator;
import ml.shifu.shifu.core.processor.ExportModelProcessor;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.ClassificationMap;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.NeuralNetworkEvaluator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Created by zhanhu on 7/15/16.
*/
public class PMMLVerifySuit {
private static Logger logger = LoggerFactory.getLogger(PMMLVerifySuit.class);
private String modelName;
private String modelConfigPath;
private String columnConfigpath;
private String modelsPath;
private ModelTrainConf.ALGORITHM algorithm = ModelTrainConf.ALGORITHM.NN;
private int modelCnt;
private String evalSetName;
private String evalDataPath;
private String delimiter;
private double scoreDiff;
private boolean isConcisePmml;
public PMMLVerifySuit(String modelName,
String modelConfigPath,
String columnConfigpath,
String modelsPath,
int modelCnt,
String evalSetName,
String evalDataPath,
String delimiter,
double scoreDiff,
boolean isConcisePmml) {
this.modelName = modelName;
this.modelConfigPath = modelConfigPath;
this.columnConfigpath = columnConfigpath;
this.modelsPath = modelsPath;
this.modelCnt = modelCnt;
this.evalSetName = evalSetName;
this.evalDataPath = evalDataPath;
this.delimiter = delimiter;
this.scoreDiff = scoreDiff;
this.isConcisePmml = isConcisePmml;
}
public PMMLVerifySuit(String modelName,
String modelConfigPath,
String columnConfigpath,
String modelsPath,
ModelTrainConf.ALGORITHM algorithm,
int modelCnt,
String evalSetName,
String evalDataPath,
String delimiter,
double scoreDiff,
boolean isConcisePmml) {
this(modelName, modelConfigPath, columnConfigpath, modelsPath, modelCnt,
evalSetName, evalDataPath, delimiter, scoreDiff, isConcisePmml);
this.algorithm = algorithm;
}
public boolean doVerification() throws Exception {
// Step 1. Eval the scores using SHIFU
File originModel = new File(this.modelConfigPath);
File tmpModel = new File("ModelConfig.json");
File originColumn = new File(this.columnConfigpath);
File tmpColumn = new File("ColumnConfig.json");
File modelsDir = new File(this.modelsPath);
File tmpModelsDir = new File("models");
FileUtils.copyFile(originModel, tmpModel);
FileUtils.copyFile(originColumn, tmpColumn);
FileUtils.copyDirectory(modelsDir, tmpModelsDir);
// run evaluation set
ShifuCLI.runEvalScore(this.evalSetName);
File evalScore = new File("evals" + File.separator + this.evalSetName + File.separator + "EvalScore");
Map<String, Object> params = new HashMap<String, Object>();
params.put(ExportModelProcessor.IS_CONCISE, this.isConcisePmml);
ShifuCLI.exportModel(null, params);
// Step 2. Eval the scores using PMML and compare it with SHIFU output
String DataPath = this.evalDataPath;
String OutPath = "./pmml_out.dat";
for (int index = 0; index < modelCnt; index++) {
String num = Integer.toString(index);
String pmmlPath = "pmmls" + File.separator + this.modelName + num + ".pmml";
if ( ModelTrainConf.ALGORITHM.NN.equals(algorithm) ) {
evalNNPmml(pmmlPath, DataPath, OutPath, this.delimiter, "model" + num);
} else if ( ModelTrainConf.ALGORITHM.LR.equals(algorithm) ) {
evalLRPmml(pmmlPath, DataPath, OutPath, this.delimiter, "model" + num);
} else {
logger.error("The algorithm - {} is not supported yet.", algorithm);
return false;
}
boolean status = compareScore(evalScore, new File(OutPath), "model" + num, "\\|", this.scoreDiff);
if ( ! status ) {
return status;
}
FileUtils.deleteQuietly(new File(OutPath));
}
FileUtils.deleteQuietly(tmpModel);
FileUtils.deleteQuietly(tmpColumn);
FileUtils.deleteDirectory(tmpModelsDir);
FileUtils.deleteQuietly(new File("./pmmls"));
FileUtils.deleteQuietly(new File("evals"));
return true;
}
private boolean compareScore(File test, File control, String scoreName, String sep, Double errorRange) throws Exception {
List<String> testData = FileUtils.readLines(test);
List<String> controlData = FileUtils.readLines(control);
String[] testSchema = testData.get(0).trim().split(sep);
String[] controlSchema = controlData.get(0).trim().split(sep);
for (int row = 1; row < controlData.size(); row++) {
Map<String, Object> ctx = new HashMap<String, Object>();
Map<String, Object> controlCtx = new HashMap<String, Object>();
String[] testRowValue = testData.get(row).split(sep, testSchema.length);
for (int index = 0; index < testSchema.length; index++) {
ctx.put(testSchema[index], testRowValue[index]);
}
String[] controlRowValue = controlData.get(row).split(sep, controlSchema.length);
for (int index = 0; index < controlSchema.length; index++) {
controlCtx.put(controlSchema[index], controlRowValue[index]);
}
Double controlScore = Double.valueOf((String) controlCtx.get(scoreName));
Double testScore = Double.valueOf((String) ctx.get(scoreName));
if ( Math.abs(controlScore - testScore) > errorRange ) {
logger.error("The score doens't match {} vs {}.", controlScore, testScore);
return false;
}
}
return true;
}
@SuppressWarnings("unchecked")
private void evalNNPmml(String pmmlPath, String DataPath, String OutPath, String sep, String scoreName) throws Exception {
PMML pmml = PMMLUtils.loadPMML(pmmlPath);
NeuralNetworkEvaluator evaluator = new NeuralNetworkEvaluator(pmml);
PrintWriter writer = new PrintWriter(OutPath, "UTF-8");
writer.println(scoreName);
List<Map<FieldName, FieldValue>> input = CsvUtil.load(evaluator, DataPath, sep);
for (Map<FieldName, FieldValue> maps : input) {
switch (evaluator.getModel().getFunctionName()) {
case REGRESSION:
Map<FieldName, Double> regressionTerm = (Map<FieldName, Double>) evaluator.evaluate(maps);
writer.println(regressionTerm.get(new FieldName(AbstractSpecifCreator.FINAL_RESULT)).intValue());
break;
case CLASSIFICATION:
Map<FieldName, ClassificationMap<String>> classificationTerm = (Map<FieldName, ClassificationMap<String>>) evaluator.evaluate(maps);
for (ClassificationMap<String> cMap : classificationTerm.values())
for (Map.Entry<String, Double> entry : cMap.entrySet())
System.out.println(entry.getValue() * 1000);
default:
break;
}
}
IOUtils.closeQuietly(writer);
}
@SuppressWarnings("unchecked")
private void evalLRPmml(String pmmlPath, String DataPath, String OutPath, String sep, String scoreName)
throws Exception {
PMML pmml = PMMLUtils.loadPMML(pmmlPath);
Model m =pmml.getModels().get(0);
ModelEvaluator<?> evaluator = ModelEvaluatorFactory.getInstance().getModelManager(pmml, m);
PrintWriter writer = new PrintWriter(OutPath, "UTF-8");
writer.println(scoreName);
List<Map<FieldName, FieldValue>> input = CsvUtil.load(evaluator, DataPath, sep);
for(Map<FieldName, FieldValue> maps: input) {
Map<FieldName, Double> regressionTerm = (Map<FieldName, Double>) evaluator.evaluate(maps);
writer.println(regressionTerm.get(new FieldName(NNSpecifCreator.FINAL_RESULT)).intValue());
}
IOUtils.closeQuietly(writer);
}
}