/*
* Copyright [2013-2016] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.pmml;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import ml.shifu.shifu.core.dtrain.dt.IndependentTreeModel;
import org.apache.commons.io.IOUtils;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.MiningModelEvaluator;
import org.testng.annotations.Test;
public class TreeModelPmmlTest {
@SuppressWarnings("unchecked")
@Test
public void testTreeModel() throws Exception {
InputStream is = null;
try {
is = new FileInputStream("src/test/resources/dttest/model/model-5.gbt");
IndependentTreeModel model = IndependentTreeModel.loadFromStream(is);
PMML pmml = PMMLUtils.loadPMML("src/test/resources/dttest/model/model-5.pmml");
MiningModelEvaluator evaluator = new MiningModelEvaluator(pmml);
List<Map<FieldName, FieldValue>> input = CsvUtil.load(evaluator,
"src/test/resources/dttest/data/tmdata.csv", "\\|");
for(Map<FieldName, FieldValue> map: input) {
Map<String, Object> newMap = new HashMap<String, Object>();
Map<FieldName, Double> regressionTerm = (Map<FieldName, Double>) evaluator.evaluate(map);
double pmmlScore = 0d;
for(Map.Entry<FieldName, Double> entry: regressionTerm.entrySet()) {
pmmlScore = entry.getValue() * 1000;
}
for(Entry<FieldName, FieldValue> entry: map.entrySet()) {
FieldName key = entry.getKey();
FieldValue value = entry.getValue();
switch(value.getOpType()) {
case CONTINUOUS:
newMap.put(key.getValue(), Double.parseDouble(value.getValue().toString()));
break;
case CATEGORICAL:
newMap.put(key.getValue(), value.getValue().toString());
break;
}
}
double[] results = model.compute(newMap);
double ownScore = results[0] * 1000;
org.testng.Assert.assertTrue(Math.abs(pmmlScore - ownScore) <= 1);
}
} finally {
IOUtils.closeQuietly(is);
}
}
}