package ml.shifu.shifu.core.pmml; import java.util.List; import java.util.Map; import ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator; import org.dmg.pmml.FieldName; import org.dmg.pmml.PMML; import org.jpmml.evaluator.ClassificationMap; import org.jpmml.evaluator.FieldValue; import org.jpmml.evaluator.NeuralNetworkEvaluator; import org.testng.annotations.Test; /** * Created by zhanhu on 2/8/17. */ public class PMMLScoreGenTest { @Test public void verifyPmml() throws Exception { String dataPath = "src/test/resources/example/pmml-test/test-data.line100"; String delimiter = ","; PMML pmml = PMMLUtils.loadPMML("src/test/resources/example/pmml-test/ATOM17_SEG3_35.pmml"); NeuralNetworkEvaluator evaluator = new NeuralNetworkEvaluator(pmml); List<Map<FieldName, FieldValue>> input = CsvUtil.load(evaluator, dataPath, delimiter); for (Map<FieldName, FieldValue> rawLine : input) { int score = runPmmlModel(evaluator, rawLine); System.out.println(score); } } @SuppressWarnings("unchecked") private int runPmmlModel(NeuralNetworkEvaluator evaluator, Map<FieldName, FieldValue> rawInput) { switch (evaluator.getModel().getFunctionName()) { case REGRESSION: Map<FieldName, Double> regressionTerm = (Map<FieldName, Double>) evaluator.evaluate(rawInput); return regressionTerm.get(new FieldName(AbstractSpecifCreator.FINAL_RESULT)).intValue(); case CLASSIFICATION: Map<FieldName, ClassificationMap<String>> classificationTerm = (Map<FieldName, ClassificationMap<String>>) evaluator.evaluate(rawInput); for (ClassificationMap<String> cMap : classificationTerm.values()) { for (Map.Entry<String, Double> entry : cMap.entrySet()) { return (int) (entry.getValue() * 1000); } } default: return -1; } } }