/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.factories; import org.dmg.pmml.Array; import org.dmg.pmml.CompoundPredicate; import org.dmg.pmml.DataDictionary; import org.dmg.pmml.DataField; import org.dmg.pmml.DerivedField; import org.dmg.pmml.False; import org.dmg.pmml.MiningField; import org.dmg.pmml.Node; import org.dmg.pmml.Predicate; import org.dmg.pmml.SimplePredicate; import org.dmg.pmml.SimpleSetPredicate; import org.dmg.pmml.TransformationDictionary; import org.dmg.pmml.TreeModel; import org.dmg.pmml.True; import org.elasticsearch.ml.modelinput.VectorRange; import org.elasticsearch.ml.modelinput.VectorRangesToVectorPMML; import org.elasticsearch.ml.modelinput.PMMLVectorRange; import org.elasticsearch.ml.models.EsTreeModel; import org.elasticsearch.ml.modelinput.MapModelInput; import org.elasticsearch.ml.modelinput.ModelAndModelInputEvaluator; import org.elasticsearch.script.pmml.ProcessPMMLHelper; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; public class TreeModelFactory extends ModelFactory<MapModelInput, String, TreeModel> { public TreeModelFactory() { super(TreeModel.class); } @Override public ModelAndModelInputEvaluator<MapModelInput, String> buildFromPMML(TreeModel treeModel, DataDictionary dataDictionary, TransformationDictionary transformationDictionary) { if (treeModel.getFunctionName().value().equals("classification") && treeModel.getSplitCharacteristic().value().equals("binarySplit") && treeModel.getMissingValueStrategy().value().equals("defaultChild") && treeModel.getNoTrueChildStrategy().value().equals("returnLastPrediction")) { List<VectorRange> fields = getFieldValuesList(treeModel, dataDictionary, transformationDictionary); VectorRangesToVectorPMML.VectorRangesToVectorPMMLTreeModel fieldsToVector = new VectorRangesToVectorPMML.VectorRangesToVectorPMMLTreeModel(fields); Map<String, String> fieldToTypeMap = getFieldToTypeMap(fields); EsTreeModel esTreeModel = getEsTreeModel(treeModel, fieldToTypeMap); return new ModelAndModelInputEvaluator<>(fieldsToVector, esTreeModel); } else { throw new UnsupportedOperationException("TreeModel does not support the following parameters yet: " + " splitCharacteristic:" + treeModel.getSplitCharacteristic().value() + " missingValueStrategy:" + treeModel.getMissingValueStrategy().value() + " noTrueChildStrategy:" + treeModel.getNoTrueChildStrategy().value()); } } protected static List<VectorRange> getFieldValuesList(TreeModel treeModel, DataDictionary dataDictionary, TransformationDictionary transformationDictionary) { // walk the tree model and gather all the field name Set<String> fieldNames = new HashSet<>(); Node startNode = treeModel.getNode(); getFieldNamesFromNode(fieldNames, startNode); // create the actual VectorRange objects, copy paste much from GLMHelper List<VectorRange> fieldsToValues = new ArrayList<>(); List<DerivedField> allDerivedFields = ProcessPMMLHelper.getAllDerivedFields(treeModel, transformationDictionary); for(String fieldName : fieldNames) { List<DerivedField> derivedFields = new ArrayList<>(); String rawFieldName = ProcessPMMLHelper.getDerivedFields(fieldName, allDerivedFields, derivedFields); DataField rawField = ProcessPMMLHelper.getRawDataField(dataDictionary, rawFieldName); MiningField miningField = ProcessPMMLHelper.getMiningField(treeModel, rawFieldName); fieldsToValues.add(new PMMLVectorRange.FieldToValue(rawField, miningField, derivedFields.toArray(new DerivedField[derivedFields.size()]))); } return fieldsToValues; } protected static void getFieldNamesFromNode(Set<String> fieldNames, Node startNode) { Predicate predicate = startNode.getPredicate(); getFieldNamesFromPredicate(fieldNames, predicate); for (Node node : startNode.getNodes()) { getFieldNamesFromNode(fieldNames, node); } } protected static void getFieldNamesFromPredicate(Set<String> fieldNames, Predicate predicate) { if (predicate instanceof CompoundPredicate) { List<Predicate> predicates = ((CompoundPredicate) predicate).getPredicates(); for (Predicate predicate1 : predicates) { getFieldNamesFromPredicate(fieldNames, predicate1); } } else { if (predicate instanceof SimplePredicate) { fieldNames.add(((SimplePredicate) predicate).getField().getValue()); } else if (predicate instanceof SimpleSetPredicate) { fieldNames.add(((SimpleSetPredicate) predicate).getField().getValue()); } } } protected EsTreeModel getEsTreeModel(TreeModel treeModel, Map<String, String> fieldToTypeMap) { return new EsTreeModel(convertToEsTreeNode(treeModel.getNode(), fieldToTypeMap)); } public static Map<String,String> getFieldToTypeMap(java.util.List<VectorRange> vectorRangeList) { Map<String, String> fieldToTypeMap = new HashMap<>(); for (VectorRange vectorRange : vectorRangeList) { fieldToTypeMap.put(vectorRange.getLastDerivedFieldName(), vectorRange.getType()); } return fieldToTypeMap; } private EsTreeModel.EsTreeNode convertToEsTreeNode(Node node, Map<String, String> fieldTypeMap) { List<EsTreeModel.EsTreeNode> childNodes = new ArrayList<>(); EsTreeModel.EsPredicate predicate = createPredicate(node.getPredicate(), fieldTypeMap); for (Node childNode : node.getNodes()) { childNodes.add(convertToEsTreeNode(childNode, fieldTypeMap)); } return new EsTreeModel.EsTreeNode(Collections.unmodifiableList(childNodes), predicate, node.getScore()); } private static EsTreeModel.EsPredicate createPredicate(final Predicate predicate, Map<String, String> fieldTypeMap) { if (predicate instanceof SimplePredicate) { SimplePredicate simplePredicate = (SimplePredicate) predicate; String field = simplePredicate.getField().getValue(); String type = fieldTypeMap.get(field); String value = simplePredicate.getValue(); if (type == "string") { return getSimplePredicate(value, field, simplePredicate.getOperator().value()); } if (type == "double") { return getSimplePredicate(Double.parseDouble(value), field, simplePredicate.getOperator().value()); } if (type == "float") { return getSimplePredicate(Float.parseFloat(value), field, simplePredicate.getOperator().value()); } if (type == "int") { return getSimplePredicate(Integer.parseInt(value), field, simplePredicate.getOperator().value()); } if (type == "boolean") { return getSimplePredicate(Boolean.parseBoolean(value), field, simplePredicate.getOperator().value()); } throw new UnsupportedOperationException("Data type " + type + " for TreeModel not implemented yet."); } if (predicate instanceof True) { return new EsTreeModel.EsPredicate() { @Override public boolean match(Map<String, Object> vector) { return true; } @Override public boolean notEnoughValues(Map<String, Object> vector) { return false; } }; } if (predicate instanceof False) { return new EsTreeModel.EsPredicate() { @Override public boolean match(Map<String, Object> vector) { return false; } @Override public boolean notEnoughValues(Map<String, Object> vector) { return false; } }; } if (predicate instanceof CompoundPredicate) { CompoundPredicate compoundPredicate = (CompoundPredicate) predicate; List<EsTreeModel.EsPredicate> predicates = new ArrayList<>(); for (Predicate childPredicate : ((CompoundPredicate) predicate).getPredicates()) { predicates.add(createPredicate(childPredicate, fieldTypeMap)); } if (compoundPredicate.getBooleanOperator().value().equals("and")) { return new EsTreeModel.EsCompoundPredicate(predicates) { @Override protected boolean matchList(Map<String, Object> vector) { boolean result = true; for (EsTreeModel.EsPredicate childPredicate : predicates) { result = result && childPredicate.match(vector); } return result; } }; } if (compoundPredicate.getBooleanOperator().value().equals("or")) { return new EsTreeModel.EsCompoundPredicate(predicates) { @Override protected boolean matchList(Map<String, Object> vector) { boolean result = false; for (EsTreeModel.EsPredicate childPredicate : predicates) { result = result || childPredicate.match(vector); } return result; } }; } if (compoundPredicate.getBooleanOperator().value().equals("xor")) { return new EsTreeModel.EsCompoundPredicate(predicates) { @Override protected boolean matchList(Map<String, Object> vector) { boolean result = false; for (EsTreeModel.EsPredicate childPredicate : predicates) { if (result == false) { result = result || childPredicate.match(vector); } else { if (childPredicate.match(vector)) { // we had true already, xor must return false return false; } } } return result; } }; } if (compoundPredicate.getBooleanOperator().value().equals("surrogate")) { return new EsTreeModel.EsCompoundPredicate(predicates) { @Override protected boolean matchList(Map<String, Object> vector) { for (EsTreeModel.EsPredicate childPredicate : predicates) { if (childPredicate.notEnoughValues(vector) == false) { return childPredicate.match(vector); } } return false; } @Override public boolean notEnoughValues(Map<String, Object> vector) { boolean notEnoughValues = true; for (EsTreeModel.EsPredicate predicate : predicates) { // only one needs to have enough values and then the predicate is defined notEnoughValues = predicate.notEnoughValues(vector) && notEnoughValues; } return notEnoughValues; } }; } } if (predicate instanceof SimpleSetPredicate) { SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) predicate; Array setArray = simpleSetPredicate.getArray(); String field = simpleSetPredicate.getField().getValue(); if (setArray.getType().equals(Array.Type.STRING)) { HashSet<String> valuesSet = new HashSet<>(); String[] values = setArray.getValue().split("\" \""); // trimm beginning and end quotes values[0] = values[0].substring(1, values[0].length()); values[values.length - 1] = values[values.length - 1].substring(0, values[values.length - 1].length() - 1); if (values.length != setArray.getN()) { throw new UnsupportedOperationException("Could not infer values from array value " + setArray.getValue()); } for (String value : values) { valuesSet.add(value); } return new EsTreeModel.EsSimpleSetPredicate<>(valuesSet, field); } if (setArray.getType().equals(Array.Type.STRING)) { HashSet<Double> valuesSet = new HashSet<>(); String[] values = setArray.getValue().split(" "); if (values.length != setArray.getN()) { throw new UnsupportedOperationException("Could not infer values from array value " + setArray.getValue()); } for (String value : values) { valuesSet.add(Double.parseDouble(value)); } return new EsTreeModel.EsSimpleSetPredicate<>(valuesSet, field); } if (setArray.getType().equals(Array.Type.INT)) { HashSet<Integer> valuesSet = new HashSet<>(); String[] values = setArray.getValue().split(" "); if (values.length != setArray.getN()) { throw new UnsupportedOperationException("Could not infer values from array value " + setArray.getValue()); } for (String value : values) { valuesSet.add(Integer.parseInt(value)); } return new EsTreeModel.EsSimpleSetPredicate<>(valuesSet, field); } } throw new UnsupportedOperationException("Predicate Type " + predicate.getClass().getName() + " for TreeModel not implemented yet."); } protected static <T extends Comparable<T>> EsTreeModel.EsSimplePredicate<T> getSimplePredicate(T value, String field, String operator) { if (operator.equals("equal")) { return new EsTreeModel.EsSimplePredicate<T>(value, field) { @Override public boolean match(T fieldValue) { return value.equals(fieldValue); } }; } if (operator.equals("notEqual")) { return new EsTreeModel.EsSimplePredicate<T>(value, field) { @Override public boolean match(T fieldValue) { return value.equals(fieldValue) == false; } }; } if (operator.equals("lessThan")) { return new EsTreeModel.EsSimplePredicate<T>(value, field) { @Override public boolean match(T fieldValue) { return fieldValue.compareTo(value) < 0; } }; } if (operator.equals("lessOrEqual")) { return new EsTreeModel.EsSimplePredicate<T>(value, field) { @Override public boolean match(T fieldValue) { return fieldValue.compareTo(value) <= 0; } }; } if (operator.equals("greaterThan")) { return new EsTreeModel.EsSimplePredicate<T>(value, field) { @Override public boolean match(T fieldValue) { return fieldValue.compareTo(value) > 0; } }; } if (operator.equals("greaterOrEqual")) { return new EsTreeModel.EsSimplePredicate<T>(value, field) { @Override public boolean match(T fieldValue) { return fieldValue.compareTo(value) >= 0; } }; } if (operator.equals("isMissing")) { return new EsTreeModel.EsSimplePredicate<T>(value, field) { @Override public boolean match(T fieldValue) { throw new UnsupportedOperationException("We should never get here!"); } @Override public boolean match(Map<String, Object> vector) { Object fieldValue = vector.get(field); if (fieldValue == null) { return true; } return false; } }; } if (operator.equals("isNotMissing")) { return new EsTreeModel.EsSimplePredicate<T>(value, field) { @Override public boolean match(T fieldValue) { throw new UnsupportedOperationException("We should never get here!"); } @Override public boolean match(Map<String, Object> vector) { Object fieldValue = vector.get(field); if (fieldValue == null) { return false; } return true; } }; } throw new UnsupportedOperationException("OOperator " + operator + " not supported for Predicate in TreeModel."); } }