/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.dtrain.dt;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import ml.shifu.guagua.io.Bytable;
/**
* A binary tree node.
*
* <p>
* {@link #left} and {@link #right} are children for node. Other attributes are attached as fields. {@link #predict} are
* node predict info with predict value and classification value. To predict a node is to find a lef node and get its
* predict.
*
* <p>
* A tree can be set as a only root node and started with id 1, 2, 3 ...
*
* @author Zhang David (pengzhang@paypal.com)
*/
public class Node implements Bytable {
/**
* Node id, start from 1, 2, 3 ...
*/
private int id;
/**
* Feature split for such node, if leaf node means no split. Node is split by numeric feature or categorical
* feature. Please check {@link Split} for details.
*/
private Split split;
/**
* Left child, if leaf, left is null.
*/
private Node left;
/**
* Right child, if leaf, right is null.
*/
private Node right;
/**
* Predict value and probability for such node which is collected from workers.
*/
private Predict predict;
/**
* Ratio of # of weighted instances in such node over # of all weighted instances
*/
private double wgtCntRatio;
/**
* Gain for such node, such value can be computed from different {@link Impurity} like {@link Entropy},
* {@link Variance}. Gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity.
*/
private double gain;
/**
* Impurity value for such node, such value can be computed from different {@link Impurity} like {@link Entropy},
* {@link Variance}.
*/
private double impurity;
/**
* Predict value for left child, null if leaf.
*/
private Predict leftPredict;
/**
* Left impurity value, 0 if leaf.
*/
private double leftImpurity;
/**
* Predict value for right child, null if leaf.
*/
private Predict rightPredict;
/**
* Impurity for right node.
*/
private double rightImpurity;
/**
* 'isLeaf' is used to set a flag not to extend this tree.
*/
private boolean isLeaf;
/**
* Default root index is 1. Others are 2, 3, 4, 5 ...
*/
public static final int ROOT_INDEX = 1;
/**
* If node with such index, means such node is invalid.
*/
public static final int INVALID_INDEX = -1;
public Node() {
this(ROOT_INDEX);
}
public Node(int id) {
this.id = id;
}
public Node(int id, Node left, Node right) {
this.id = id;
this.left = left;
this.right = right;
}
public Node(int id, Predict predict, double impurity, boolean isLeaf) {
this.id = id;
this.predict = predict;
this.impurity = impurity;
this.isLeaf = isLeaf;
}
public Node(int id, Split split, Node left, Node right, Predict predict, double gain, double impurity) {
this.id = id;
this.split = split;
this.left = left;
this.right = right;
this.predict = predict;
this.gain = gain;
this.impurity = impurity;
}
/**
* @return the id
*/
public int getId() {
return id;
}
/**
* @return the left
*/
public Node getLeft() {
return left;
}
/**
* @return the right
*/
public Node getRight() {
return right;
}
/**
* @return the predict
*/
public Predict getPredict() {
return predict;
}
/**
* @return the gain
*/
public double getGain() {
return gain;
}
/**
* @return the impurity
*/
public double getImpurity() {
return impurity;
}
/**
* @return the leftPredict
*/
public Predict getLeftPredict() {
return leftPredict;
}
/**
* @return the leftImpurity
*/
public double getLeftImpurity() {
return leftImpurity;
}
/**
* @return the rightPredict
*/
public Predict getRightPredict() {
return rightPredict;
}
/**
* @return the rightImpurity
*/
public double getRightImpurity() {
return rightImpurity;
}
/**
* @param id
* the id to set
*/
public void setId(int id) {
this.id = id;
}
/**
* @param left
* the left to set
*/
public void setLeft(Node left) {
this.left = left;
}
/**
* @param right
* the right to set
*/
public void setRight(Node right) {
this.right = right;
}
/**
* @param predict
* the predict to set
*/
public void setPredict(Predict predict) {
this.predict = predict;
}
/**
* @param gain
* the gain to set
*/
public void setGain(double gain) {
this.gain = gain;
}
/**
* @param impurity
* the impurity to set
*/
public void setImpurity(double impurity) {
this.impurity = impurity;
}
/**
* @param leftPredict
* the leftPredict to set
*/
public void setLeftPredict(Predict leftPredict) {
this.leftPredict = leftPredict;
}
/**
* @param leftImpurity
* the leftImpurity to set
*/
public void setLeftImpurity(double leftImpurity) {
this.leftImpurity = leftImpurity;
}
/**
* @param rightPredict
* the rightPredict to set
*/
public void setRightPredict(Predict rightPredict) {
this.rightPredict = rightPredict;
}
/**
* @param rightImpurity
* the rightImpurity to set
*/
public void setRightImpurity(double rightImpurity) {
this.rightImpurity = rightImpurity;
}
/**
* @return the split
*/
public Split getSplit() {
return split;
}
/**
* @param split
* the split to set
*/
public void setSplit(Split split) {
this.split = split;
}
public static int indexToLevel(int nodeIndex) {
return Integer.numberOfTrailingZeros(Integer.highestOneBit(nodeIndex)) + 1;
}
/**
* @param isLeaf
* the isLeaf to set
*/
public void setLeaf(boolean isLeaf) {
this.isLeaf = isLeaf;
}
boolean isLeaf() {
return this.isLeaf;
}
/**
* Check if node is real for leaf. No matter the leaf flag, this will check whether left and right exist.
*
* @return if it is real leaf node
*/
public boolean isRealLeaf() {
return this.left == null && this.right == null;
}
/**
* According to node index and topNode, find the exact node.
*
* @param topNode
* the top node of the tree
* @param index
* the index to be searched
* @return the node with such index, or null if not found
*/
public static Node getNode(Node topNode, int index) {
assert index > 0 && topNode != null && topNode.id == 1;
if(index == 1) {
return topNode;
}
int currIndex = index;
List<Integer> walkIndexes = new ArrayList<Integer>(16);
while(currIndex > 1) {
walkIndexes.add(currIndex);
currIndex /= 2;
}
// reverse walk through
Node result = topNode;
for(int i = 0; i < walkIndexes.size(); i++) {
int searchIndex = walkIndexes.get(walkIndexes.size() - 1 - i);
if(searchIndex % 2 == 0) {
result = result.getLeft();
} else {
result = result.getRight();
}
if(searchIndex == index) {
return result;
}
}
return null;
}
/**
* Left index according to current id.
*
* @param id
* current id
* @return left index
*/
public static int leftIndex(int id) {
return id << 1;
}
/**
* @return the wgtCnt
*/
public double getWgtCntRatio() {
return wgtCntRatio;
}
/**
* @param wgtCntRatio
* the wgtCntRatio to set
*/
public void setWgtCntRatio(double wgtCntRatio) {
this.wgtCntRatio = wgtCntRatio;
}
/**
* Right index according to current id.
*
* @param id
* current id
* @return right index
*/
public static int rightIndex(int id) {
return (id << 1) + 1;
}
/**
* Parent index according to current id.
*
* @param id
* current id
* @return parent index
*/
public static int parentIndex(int id) {
return id >>> 1;
}
/**
* If current node is ROOT or not
*
* @param node
* node to be checked
* @return true if ROOT, false if not ROOT or null
*/
public static boolean isRootNode(Node node) {
return node != null && node.getId() == ROOT_INDEX;
}
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(id);
// cast to float to save space
out.writeFloat((float) gain);
out.writeFloat((float) wgtCntRatio);
if(split == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
split.write(out);
}
// only store needed predict info
boolean isRealLeaf = isRealLeaf();
out.writeBoolean(isRealLeaf);
if(isRealLeaf) {
if(predict == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
predict.write(out);
}
}
if(left == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
left.write(out);
}
if(right == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
right.write(out);
}
}
@Override
public void readFields(DataInput in) throws IOException {
this.id = in.readInt();
this.gain = in.readFloat();
this.wgtCntRatio = in.readFloat();
if(in.readBoolean()) {
this.split = new Split();
this.split.readFields(in);
}
boolean isRealLeaf = in.readBoolean();
if(isRealLeaf) {
if(in.readBoolean()) {
this.predict = new Predict();
this.predict.readFields(in);
}
}
if(in.readBoolean()) {
this.left = new Node();
this.left.readFields(in);
}
if(in.readBoolean()) {
this.right = new Node();
this.right.readFields(in);
}
}
@Override
public String toString() {
return "Node [id=" + id + ", split=" + split + ", left=" + left + ", right=" + right + ", predict=" + predict
+ ", gain=" + gain + ", impurity=" + impurity + ", leftPredict=" + leftPredict + ", leftImpurity="
+ leftImpurity + ", rightPredict=" + rightPredict + ", rightImpurity=" + rightImpurity + "]";
}
public String toTree() {
String str = "[id=" + id + ", split=" + split + ", predict=" + predict + "]\n";
if(this.left != null) {
str += this.left.toTree();
}
if(this.right != null) {
str += this.right.toTree();
}
return str;
}
}