/* * Copyright [2013-2016] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.dt; import java.io.BufferedInputStream; import java.io.DataInputStream; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.zip.GZIPInputStream; import ml.shifu.shifu.core.dtrain.CommonConstants; /** * {@link IndependentTreeModel} depends no other classes which is easy to deploy model in production. * * <p> * {@link #loadFromStream(InputStream)} should be the only interface to load a tree model object. * * <p> * To predict data for tree model, call {@link #compute(Map)} or {@link #compute(double[])} * * @author Zhang David (pengzhang@paypal.com) */ public class IndependentTreeModel { private static final char MERGE_CATEGORY_DELIMITER = '^'; /** * Mapping for (ColumnNum, ColumnName) */ private Map<Integer, String> numNameMapping; /** * Mapping for (ColumnName, ColumnNum) */ private Map<String, Integer> nameNumMapping; /** * Mapping for (ColumnNum, Category List) for categorical feature */ private Map<Integer, List<String>> categoricalColumnNameNames; /** * Mapping for (ColumnNum, Map(Category, CategoryIndex) for categorical feature */ private Map<Integer, Map<String, Integer>> columnCategoryIndexMapping; /** * Mapping for (ColumnNum, index in double[] array) */ private Map<Integer, Integer> columnNumIndexMapping; /** * A list of tree models, can be RF or GBT */ private List<TreeNode> trees; /** * Weights per each tree in {@link #trees} */ private List<Double> weights; /** * If it is for GBT */ private boolean isGBDT = false; /** * If model is for classification */ private boolean isClassification = false; /** * GBT model results is not in [0, 1], set {@link #isConvertToProb} to true will normalize model score to [0, 1] */ private boolean isConvertToProb = false; /** * {@link #lossStr} is used to validate, if continuous model training but different loss type, should be failed. * TODO add validation */ private String lossStr; /** * RF or GBT */ private String algorithm; /** * # of input node */ private int inputNode; /** * Model version */ private int version; /** * For numerical columns, mean value is used for null replacement */ private Map<Integer, Double> numericalMeanMapping; public IndependentTreeModel(Map<Integer, Double> numericalMeanMapping, Map<Integer, String> numNameMapping, Map<Integer, List<String>> categoricalColumnNameNames, Map<Integer, Map<String, Integer>> columnCategoryIndexMapping, Map<Integer, Integer> columnNumIndexMapping, List<TreeNode> trees, List<Double> weights, boolean isGBDT, boolean isClassification, boolean isConvertToProb, String lossStr, String algorithm, int inputNode, int version) { this.numericalMeanMapping = numericalMeanMapping; this.numNameMapping = numNameMapping; this.nameNumMapping = new HashMap<String, Integer>(); for(Entry<Integer, String> entry: this.numNameMapping.entrySet()) { this.nameNumMapping.put(entry.getValue(), entry.getKey()); } this.categoricalColumnNameNames = categoricalColumnNameNames; this.columnCategoryIndexMapping = columnCategoryIndexMapping; this.columnNumIndexMapping = columnNumIndexMapping; this.trees = trees; this.weights = weights; this.isGBDT = isGBDT; this.isClassification = isClassification; this.isConvertToProb = isConvertToProb; this.lossStr = lossStr; this.algorithm = algorithm; this.inputNode = inputNode; this.version = version; } /** * Given double array data, compute score values of tree model. * * @param data * data array includes only effective column data, numeric value is real value, categorical feature value * is index of binCategoryList. * @return if classification mode, return array of all scores of trees * if regression of RF, return array with only one element which is avg score of all tree model scores * if regression of GBT, return array with only one element which is score of the GBT model */ public double[] compute(double[] data) { double predictSum = 0d; double weightSum = 0d; double[] scores = new double[this.trees.size()]; for(int i = 0; i < this.trees.size(); i++) { TreeNode treeNode = this.trees.get(i); Double weight = this.weights.get(i); weightSum += weight; double score = predictNode(treeNode.getNode(), data); scores[i] = score; predictSum += score * weight; } if(this.isClassification) { return scores; } else { double finalPredict; if(this.isGBDT) { if(this.isConvertToProb) { finalPredict = convertToProb(predictSum); } else { finalPredict = predictSum; } } else { finalPredict = predictSum / weightSum; } return new double[] { finalPredict }; } } /** * Given {@code dataMap} with format (columnName, value), compute score values of tree model. * * <p> * No any alert or exception if your {@code dataMap} doesn't contain features included in the model, such case will * be treated as missing value case. Please make sure feature names in keys of {@code dataMap} are consistent with * names in model. * * <p> * In {@code dataMap}, numerical value can be (String, Double) format or (String, String) format, they will all be * parsed to Double; categorical value are all converted to (String, String). If value not in our categorical list, * it will also be treated missing value. * * @param dataMap * {@code dataMap} for (columnName, value), numeric value can be double/String, categorical feature can * be int(index) or category value. if not set or set to null, such feature will be treated as missing * value. For numerical value, if it cannot be parsed successfully, it will also be treated as missing. * @return if classification mode, return array of all scores of trees * if regression of RF, return array with only one element which is average score of all tree model scores * if regression of GBT, return array with only one element which is score of the GBT model */ public final double[] compute(Map<String, Object> dataMap) { double predictSum = 0d; double weightSum = 0d; double[] scores = new double[this.trees.size()]; double[] data = convertDataMapToDoubleArray(dataMap); for(int i = 0; i < this.trees.size(); i++) { TreeNode treeNode = this.trees.get(i); Double weight = this.weights.get(i); weightSum += weight; double score = predictNode(treeNode.getNode(), data); scores[i] = score; predictSum += score * weight; } if(this.isClassification) { return scores; } else { double finalPredict; if(this.isGBDT) { if(this.isConvertToProb) { finalPredict = convertToProb(predictSum); } else { finalPredict = predictSum; } } else { finalPredict = predictSum / weightSum; } return new double[] { finalPredict }; } } /** * Covert score to probability value which are in [0, 1], for GBT regression, scores can not be [0, 1]. Round score * to 1.0E19 to avoid NaN in final return result. * * @param score * the raw score * @return score after sigmoid transform. */ public double convertToProb(double score) { // sigmoid function to covert to [0, 1], TODO, how to make it configuable for users return 1 / (1 + Math.min(1.0E19, Math.exp(-score))); } private double predictNode(Node topNode, double[] data) { Node currNode = topNode; // go until leaf while(currNode.getSplit() != null && !currNode.isRealLeaf()) { Split split = currNode.getSplit(); double value = data[this.columnNumIndexMapping.get(split.getColumnNum())]; if(split.getFeatureType().isNumerical()) { // value is real numeric value and no need to transform to binLowestValue if(value < split.getThreshold()) { currNode = currNode.getLeft(); } else { currNode = currNode.getRight(); } } else if(split.getFeatureType().isCategorical()) { short indexValue = -1; int categoricalSize = categoricalColumnNameNames.get(split.getColumnNum()).size(); if(Double.compare(value, 0d) < 0 || Double.compare(value, categoricalSize) >= 0) { indexValue = (short) categoricalSize; } else { // value is category index + 0.1d is to avoid 0.9999999 converted to 0, is there? indexValue = (short) (value + 0.1d); } Set<Short> childCategories = split.getLeftOrRightCategories(); if(split.isLeft()) { if(childCategories.contains(indexValue)) { currNode = currNode.getLeft(); } else { currNode = currNode.getRight(); } } else { if(childCategories.contains(indexValue)) { currNode = currNode.getRight(); } else { currNode = currNode.getLeft(); } } } } if(this.isClassification) { return currNode.getPredict().getClassValue(); } else { return currNode.getPredict().getPredict(); } } private double[] convertDataMapToDoubleArray(Map<String, Object> dataMap) { double[] data = new double[this.columnNumIndexMapping.size()]; for(Entry<Integer, Integer> entry: this.columnNumIndexMapping.entrySet()) { double value = 0d; Integer columnNum = entry.getKey(); String columnName = this.numNameMapping.get(columnNum); Object obj = dataMap.get(columnName); if(this.categoricalColumnNameNames.containsKey(columnNum)) { // categorical column double indexValue = -1d; int categoricalSize = categoricalColumnNameNames.get(columnNum).size(); if(obj == null) { // no matter set it to null or not set it in dataMap, it will be treated as missing value, last one // is missing value category indexValue = categoricalSize; } else { Integer intIndex = columnCategoryIndexMapping.get(columnNum).get(obj.toString()); if(intIndex == null || intIndex < 0 || intIndex >= categoricalSize) { // last one is for invalid category intIndex = categoricalSize; } indexValue = intIndex; } value = indexValue; } else { // numerical column if(obj == null || ((obj instanceof String) && ((String) obj).length() == 0)) { // no matter set it to null or not set it in dataMap, it will be treated as missing value, last one // is missing value category value = this.numericalMeanMapping.get(columnNum) == null ? 0d : this.numericalMeanMapping .get(columnNum); } else { if(obj instanceof Number) { value = ((Number) obj).doubleValue(); } else { try { value = Double.parseDouble(obj.toString()); } catch (NumberFormatException e) { // not valid double value for numerical feature, using default value value = this.numericalMeanMapping.get(columnNum) == null ? 0d : this.numericalMeanMapping .get(columnNum); } } } if(Double.isNaN(value)) { value = this.numericalMeanMapping.get(columnNum) == null ? 0d : this.numericalMeanMapping .get(columnNum); } } Integer index = entry.getValue(); if(index != null && index < data.length) { data[index] = value; } } return data; } /** * @return the lossStr */ public String getLossStr() { return lossStr; } /** * @param lossStr * the lossStr to set */ public void setLossStr(String lossStr) { this.lossStr = lossStr; } /** * @return the numNameMapping */ public Map<Integer, String> getNumNameMapping() { return numNameMapping; } /** * @return the categoricalColumnNameNames */ public Map<Integer, List<String>> getCategoricalColumnNameNames() { return categoricalColumnNameNames; } /** * @return the columnCategoryIndexMapping */ public Map<Integer, Map<String, Integer>> getColumnCategoryIndexMapping() { return columnCategoryIndexMapping; } /** * @return the columnNumIndexMapping */ public Map<Integer, Integer> getColumnNumIndexMapping() { return columnNumIndexMapping; } /** * @return the trees */ public List<TreeNode> getTrees() { return trees; } /** * @return the weights */ public List<Double> getWeights() { return weights; } /** * @return the isGBDT */ public boolean isGBDT() { return isGBDT; } /** * @return the isClassification */ public boolean isClassification() { return isClassification; } /** * @return the isConvertToProb */ public boolean isConvertToProb() { return isConvertToProb; } /** * @param numNameMapping * the numNameMapping to set */ public void setNumNameMapping(Map<Integer, String> numNameMapping) { this.numNameMapping = numNameMapping; } /** * @param categoricalColumnNameNames * the categoricalColumnNameNames to set */ public void setCategoricalColumnNameNames(Map<Integer, List<String>> categoricalColumnNameNames) { this.categoricalColumnNameNames = categoricalColumnNameNames; } /** * @param columnCategoryIndexMapping * the columnCategoryIndexMapping to set */ public void setColumnCategoryIndexMapping(Map<Integer, Map<String, Integer>> columnCategoryIndexMapping) { this.columnCategoryIndexMapping = columnCategoryIndexMapping; } /** * @param columnNumIndexMapping * the columnNumIndexMapping to set */ public void setColumnNumIndexMapping(Map<Integer, Integer> columnNumIndexMapping) { this.columnNumIndexMapping = columnNumIndexMapping; } /** * @param trees * the trees to set */ public void setTrees(List<TreeNode> trees) { this.trees = trees; } /** * @param weights * the weights to set */ public void setWeights(List<Double> weights) { this.weights = weights; } /** * @param isGBDT * the isGBDT to set */ public void setGBDT(boolean isGBDT) { this.isGBDT = isGBDT; } /** * @param isClassification * the isClassification to set */ public void setClassification(boolean isClassification) { this.isClassification = isClassification; } /** * @param isConvertToProb * the isConvertToProb to set */ public void setConvertToProb(boolean isConvertToProb) { this.isConvertToProb = isConvertToProb; } /** * @return the algorithm */ public String getAlgorithm() { return algorithm; } /** * @param algorithm * the algorithm to set */ public void setAlgorithm(String algorithm) { this.algorithm = algorithm; } /** * @return the inputNode */ public int getInputNode() { return inputNode; } /** * @param inputNode * the inputNode to set */ public void setInputNode(int inputNode) { this.inputNode = inputNode; } /** * Load model instance from stream like model0.gbt or model0.rf, by default not to convert gbt score to [0, 1] * * @param input * the input stream * @return the tree model instance * @throws IOException * any exception in load input stream */ public static IndependentTreeModel loadFromStream(InputStream input) throws IOException { return loadFromStream(input, false); } /** * Load model instance from stream like model0.gbt or model0.rf. User can specify to use raw score or score after * sigmoid transfrom by isConvertToProb. * * @param input * the input stream * @param isConvertToProb * if convert score to probability (if to transfrom raw score by sigmoid) * @return the tree model instance * @throws IOException * any exception in load input stream */ public static IndependentTreeModel loadFromStream(InputStream input, boolean isConvertToProb) throws IOException { DataInputStream dis = null; // check if gzip or not try { byte[] header = new byte[2]; BufferedInputStream bis = new BufferedInputStream(input); bis.mark(2); int result = bis.read(header); bis.reset(); int ss = (header[0] & 0xff) | ((header[1] & 0xff) << 8); if(result != -1 && ss == GZIPInputStream.GZIP_MAGIC) { dis = new DataInputStream(new GZIPInputStream(bis)); } else { dis = new DataInputStream(bis); } } catch (java.io.IOException e) { dis = new DataInputStream(input); } int version = dis.readInt(); String algorithm = dis.readUTF(); String lossStr = dis.readUTF(); boolean isClassification = dis.readBoolean(); boolean isOneVsAll = dis.readBoolean(); int inputNode = dis.readInt(); Map<Integer, Double> numericalMeanMapping = new HashMap<Integer, Double>(); Map<Integer, String> columnIndexNameMapping = new HashMap<Integer, String>(); int size = dis.readInt(); for(int i = 0; i < size; i++) { int columnIndex = dis.readInt(); double mean = dis.readDouble(); numericalMeanMapping.put(columnIndex, mean); } size = dis.readInt(); for(int i = 0; i < size; i++) { int columnIndex = dis.readInt(); String columnName = dis.readUTF(); columnIndexNameMapping.put(columnIndex, columnName); } Map<Integer, List<String>> categoricalColumnNameNames = new HashMap<Integer, List<String>>(); Map<Integer, Map<String, Integer>> columnCategoryIndexMapping = new HashMap<Integer, Map<String, Integer>>(); size = dis.readInt(); for(int i = 0; i < size; i++) { int columnIndex = dis.readInt(); int categoryListSize = dis.readInt(); Map<String, Integer> categoryIndexMapping = new HashMap<String, Integer>(); List<String> categories = new ArrayList<String>(); for(int j = 0; j < categoryListSize; j++) { String category = dis.readUTF(); // categories is merged category list categories.add(category); if(category.contains("" + MERGE_CATEGORY_DELIMITER)) { // merged category should be flatten, use split function this class to avoid depending on guava jar String[] splits = split(category, MERGE_CATEGORY_DELIMITER); for(String str: splits) { categoryIndexMapping.put(str, j); } } else { categoryIndexMapping.put(category, j); } } categoricalColumnNameNames.put(columnIndex, categories); columnCategoryIndexMapping.put(columnIndex, categoryIndexMapping); } Map<Integer, Integer> columnMapping = new HashMap<Integer, Integer>(); int columnMappingSize = dis.readInt(); for(int i = 0; i < columnMappingSize; i++) { columnMapping.put(dis.readInt(), dis.readInt()); } int treeNum = dis.readInt(); List<TreeNode> trees = new ArrayList<TreeNode>(treeNum); List<Double> weights = new ArrayList<Double>(treeNum); for(int i = 0; i < treeNum; i++) { TreeNode treeNode = new TreeNode(); treeNode.readFields(dis); trees.add(treeNode); weights.add(treeNode.getLearningRate()); } // if one vs all, even multiple classification, treated as regression return new IndependentTreeModel(numericalMeanMapping, columnIndexNameMapping, categoricalColumnNameNames, columnCategoryIndexMapping, columnMapping, trees, weights, CommonConstants.GBT_ALG_NAME.equalsIgnoreCase(algorithm), isClassification && !isOneVsAll, isConvertToProb, lossStr, algorithm, inputNode, version); } /** * Manual split function to avoid depending on guava. * * <p> * Some examples: "^"=>[, ]; ""=>[]; "a"=>[a]; "abc"=>[abc]; "a^"=>[a, ]; "^b"=>[, b]; "^^b"=>[, , b] * * @param str * the string to be split * @param delimiter * the delimiter * @return split string array */ private static String[] split(String str, char delimiter) { if(str == null || str.length() == 0) { return new String[] { "" }; } List<String> categories = new ArrayList<String>(); int begin = 0; for(int i = 0; i < str.length(); i++) { if(str.charAt(i) == delimiter) { categories.add(str.substring(begin, i)); begin = i + 1; } if(i == str.length() - 1) { categories.add(str.substring(begin, str.length())); } } return categories.toArray(new String[0]); } /** * @return the numericalMeanMapping */ public Map<Integer, Double> getNumericalMeanMapping() { return numericalMeanMapping; } /** * @param numericalMeanMapping * the numericalMeanMapping to set */ public void setNumericalMeanMapping(Map<Integer, Double> numericalMeanMapping) { this.numericalMeanMapping = numericalMeanMapping; } /** * @return the version */ public int getVersion() { return version; } /** * @param version * the version to set */ public void setVersion(int version) { this.version = version; } }