/* * Copyright [2013-2016] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.pmml.builder.impl; import java.util.List; import java.util.Set; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.core.dtrain.dt.IndependentTreeModel; import ml.shifu.shifu.core.dtrain.dt.Node; import ml.shifu.shifu.core.dtrain.dt.Split; import ml.shifu.shifu.core.pmml.builder.creator.AbstractPmmlElementCreator; import org.dmg.pmml.Array; import org.dmg.pmml.FieldName; import org.dmg.pmml.Predicate; import org.dmg.pmml.SimplePredicate; import org.dmg.pmml.SimpleSetPredicate; import org.dmg.pmml.True; import org.encog.ml.BasicML; public class TreeNodePmmlElementCreator extends AbstractPmmlElementCreator<org.dmg.pmml.Node> { public TreeNodePmmlElementCreator(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) { super(modelConfig, columnConfigList); } private IndependentTreeModel treeModel = null; public TreeNodePmmlElementCreator(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, IndependentTreeModel treeModel) { super(modelConfig, columnConfigList); this.treeModel = treeModel; } public void setTreeMode(IndependentTreeModel treeModel) { this.treeModel = treeModel; } public org.dmg.pmml.Node build(BasicML basicML) { return null; } public org.dmg.pmml.Node convert(Node node) { org.dmg.pmml.Node pmmlNode = new org.dmg.pmml.Node(); pmmlNode.setId(String.valueOf(node.getId())); pmmlNode.setDefaultChild(null); pmmlNode.setPredicate(new True()); pmmlNode.setEmbeddedModel(null); List<org.dmg.pmml.Node> childList = pmmlNode.getNodes(); org.dmg.pmml.Node left = convert(node.getLeft(), true, node.getSplit()); childList.add(left); org.dmg.pmml.Node right = convert(node.getRight(), false, node.getSplit()); childList.add(right); return pmmlNode; } public org.dmg.pmml.Node convert(Node node, boolean isLeft, Split split) { org.dmg.pmml.Node pmmlNode = new org.dmg.pmml.Node(); pmmlNode.setId(String.valueOf(node.getId())); if(node.getPredict() != null) { pmmlNode.setScore(String.valueOf(treeModel.isClassification() ? node.getPredict().getClassValue() : node .getPredict().getPredict())); } pmmlNode.setDefaultChild(null); Predicate predicate = null; ColumnConfig columnConfig = this.columnConfigList.get(split.getColumnNum()); if(columnConfig.isNumerical()) { SimplePredicate p = new SimplePredicate(); p.setValue(String.valueOf(split.getThreshold())); p.setField(new FieldName(columnConfig.getColumnName())); if(isLeft) { p.setOperator(SimplePredicate.Operator.fromValue("lessThan")); } else { p.setOperator(SimplePredicate.Operator.fromValue("greaterOrEqual")); } predicate = p; } else if(columnConfig.isCategorical()) { SimpleSetPredicate p = new SimpleSetPredicate(); Set<Short> childCategories = split.getLeftOrRightCategories(); p.setField(new FieldName(columnConfig.getColumnName())); StringBuilder arrayStr = new StringBuilder(); List<String> valueList = treeModel.getCategoricalColumnNameNames().get(columnConfig.getColumnNum()); for(Short sh: childCategories) { if(sh >= valueList.size()) { arrayStr.append(" \"\""); continue; } String s = valueList.get(sh); arrayStr.append(" "); if(s.contains("\"")) { String tmp = s.replaceAll("\"", "\\\\\\\""); if(s.contains(" ")) { arrayStr.append("\""); arrayStr.append(tmp); arrayStr.append("\""); } else { arrayStr.append(tmp); } } else { if(s.contains(" ")) { arrayStr.append("\""); arrayStr.append(s); arrayStr.append("\""); } else { arrayStr.append(s); } } } Array array = new Array(arrayStr.toString().trim(), Array.Type.fromValue("string")); p.setArray(array); if(isLeft) { if(split.isLeft()) { p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn")); } else { p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn")); } } else { if(split.isLeft()) { p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn")); } else { p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn")); } } predicate = p; } pmmlNode.setPredicate(predicate); if(node.getSplit() == null || node.isRealLeaf()) { return pmmlNode; } List<org.dmg.pmml.Node> childList = pmmlNode.getNodes(); org.dmg.pmml.Node left = convert(node.getLeft(), true, node.getSplit()); org.dmg.pmml.Node right = convert(node.getRight(), false, node.getSplit()); childList.add(left); childList.add(right); return pmmlNode; } }