/* * Copyright [2013-2015] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.dt; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import ml.shifu.guagua.io.Bytable; import ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy; /** * {@link TreeNode} is used to wrapper node and tree index. With tree id and node in {@link TreeNode}. * * <p> * {@link #features} is for sub-sampling of such node. For feature sub-sampling, {@link FeatureSubsetStrategy} includes * ALL, HALF and ONETHIRD. * * @author Zhang David (pengzhang@paypal.com) * * @see Node * @see FeatureSubsetStrategy */ public class TreeNode implements Bytable { /** * Tree id */ private int treeId; /** * Node to be wrappered */ private Node node; /** * nodeNum so far in the tree */ private int nodeNum; /** * Store weighted cnt of root node (id = 1) for further computing, it is no meaning full it current node is not ROOT * node */ private double rootWgtCnt = -1; /** * Sub-sampling features which is used in tree growth. */ private List<Integer> features; /** * Learning rate for current model. This is very useful for GBT, since may be first 100 trees learning rate is 0.04, * later it is changed to 0.03. */ private double learningRate = 1d; public TreeNode() { } public TreeNode(int treeId, Node node) { this.treeId = treeId; this.node = node; this.nodeNum = 1; this.learningRate = 1d; } public TreeNode(int treeId, Node node, double learningRate) { this.treeId = treeId; this.node = node; this.nodeNum = 1; this.learningRate = learningRate; } public TreeNode(int treeId, Node node, int nodeNum, double learningRate) { this.treeId = treeId; this.node = node; this.nodeNum = nodeNum; this.learningRate = learningRate; } public TreeNode(int treeId, Node node, List<Integer> features, double learningRate) { this.treeId = treeId; this.node = node; this.features = features; this.learningRate = learningRate; } /** * @return the treeId */ public int getTreeId() { return treeId; } /** * @return the node */ public Node getNode() { return node; } /** * @param treeId * the treeId to set */ public void setTreeId(int treeId) { this.treeId = treeId; } /** * @param node * the node to set */ public void setNode(Node node) { this.node = node; } /** * @return the features */ public List<Integer> getFeatures() { return features; } /** * @param features * the features to set */ public void setFeatures(List<Integer> features) { this.features = features; } /** * @return the nodeNum */ public int getNodeNum() { return nodeNum; } /** * Increase node number */ public void incrNodeNum() { nodeNum += 1; } /** * @param nodeNum * the nodeNum to set */ public void setNodeNum(int nodeNum) { this.nodeNum = nodeNum; } /** * @return the rootWgtCnt */ public double getRootWgtCnt() { return rootWgtCnt; } /** * @param rootWgtCnt * the rootWgtCnt to set */ public void setRootWgtCnt(double rootWgtCnt) { this.rootWgtCnt = rootWgtCnt; } /** * @return the learningRate */ public double getLearningRate() { return learningRate; } /** * @param learningRate * the learningRate to set */ public void setLearningRate(double learningRate) { this.learningRate = learningRate; } @Override public void write(DataOutput out) throws IOException { this.writeWithoutFeatures(out); if(features == null) { out.writeInt(0); } else { out.writeInt(features.size()); for(Integer index: features) { out.writeInt(index); } } } /** * This is serialization version to serialize TreeNode without sub-sampling features. * * @param out * output stream * @throws IOException * any io exception. */ public void writeWithoutFeatures(DataOutput out) throws IOException { out.writeInt(treeId); out.writeInt(nodeNum); this.node.write(out); out.writeDouble(this.learningRate); if(this.node.getId() == Node.ROOT_INDEX) { out.writeDouble(this.rootWgtCnt); } } @Override public void readFields(DataInput in) throws IOException { this.readFieldsWithoutFeatures(in); int len = in.readInt(); this.features = new ArrayList<Integer>(); for(int i = 0; i < len; i++) { this.features.add(in.readInt()); } } /** * This is serialization version to de-serialize TreeNode without sub-sampling features. * * @param in * input stream * @throws IOException * any io exception. */ public void readFieldsWithoutFeatures(DataInput in) throws IOException { this.treeId = in.readInt(); this.nodeNum = in.readInt(); this.node = new Node(); this.node.readFields(in); this.learningRate = in.readDouble(); if(this.node.getId() == Node.ROOT_INDEX) { this.rootWgtCnt = in.readDouble(); } } /** * Compute tree model feature importance. * * @return a map with (column_id, feature_importance.) */ public Map<Integer, Double> computeFeatureImportance() { Map<Integer, Double> importances = new HashMap<Integer, Double>(); preOrder(importances, node); return importances; } private void preOrder(Map<Integer, Double> importances, Node node) { if(node == null) { return; } computeImportance(importances, node); preOrder(importances, node.getLeft()); preOrder(importances, node.getRight()); } private void computeImportance(Map<Integer, Double> importances, Node node) { if(!node.isRealLeaf()) { int featureId = node.getSplit().getColumnNum(); if(!importances.containsKey(featureId)) { importances.put(featureId, node.getGain()); } else { importances.put(featureId, importances.get(featureId) + node.getGain()); } } } @Override public String toString() { return "TreeNode [treeId=" + treeId + ", node=" + node.getId() + ", features=" + features + "]"; } }