package hex.genmodel.algos.tree; import hex.genmodel.utils.GenmodelBitSet; import java.io.PrintStream; import java.util.ArrayList; import java.util.BitSet; /** * Node in a tree. * A node (optionally) contains left and right edges to the left and right child nodes. */ class SharedTreeNode { final SharedTreeNode parent; final int subgraphNumber; int nodeNumber; float weight; final int depth; int colId; String colName; boolean leftward; boolean naVsRest; float splitValue = Float.NaN; String[] domainValues; GenmodelBitSet bs; float predValue = Float.NaN; float squaredError = Float.NaN; SharedTreeNode leftChild; SharedTreeNode rightChild; // Whether NA for this colId is reachable to this node. private boolean inclusiveNa; // When a column is categorical, levels that are reachable to this node. // This in particular includes any earlier splits of the same colId. private BitSet inclusiveLevels; /** * Create a new node. * @param p Parent node * @param sn Tree number * @param d Node depth within the tree */ SharedTreeNode(SharedTreeNode p, int sn, int d) { parent = p; subgraphNumber = sn; depth = d; } public int getDepth() { return depth; } public int getNodeNumber() { return nodeNumber; } float getWeight() { return weight; } void setNodeNumber(int id) { nodeNumber = id; } void setWeight(float w) { weight = w; } void setCol(int v1, String v2) { colId = v1; colName = v2; } private int getColId() { return colId; } void setLeftward(boolean v) { leftward = v; } void setNaVsRest(boolean v) { naVsRest = v; } void setSplitValue(float v) { splitValue = v; } void setBitset(String[] v1, GenmodelBitSet v2) { assert (v1 != null); domainValues = v1; bs = v2; } void setPredValue(float v) { predValue = v; } void setSquaredError(float v) { squaredError = v; } /** * Calculate whether the NA value for a particular colId can reach this node. * @param colIdToFind Column id to find * @return true if NA of colId reaches this node, false otherwise */ private boolean findInclusiveNa(int colIdToFind) { if (parent == null) { return true; } else if (parent.getColId() == colIdToFind) { return inclusiveNa; } return parent.findInclusiveNa(colIdToFind); } private boolean calculateChildInclusiveNa(boolean includeThisSplitEdge) { return findInclusiveNa(colId) && includeThisSplitEdge; } /** * Find the set of levels for a particular categorical column that can reach this node. * A null return value implies the full set (i.e. every level). * @param colIdToFind Column id to find * @return Set of levels */ private BitSet findInclusiveLevels(int colIdToFind) { if (parent == null) { return null; } if (parent.getColId() == colIdToFind) { return inclusiveLevels; } return parent.findInclusiveLevels(colIdToFind); } private boolean calculateIncludeThisLevel(BitSet inheritedInclusiveLevels, int i) { if (inheritedInclusiveLevels == null) { // If there is no prior split history for this column, then treat the // inherited set as complete. return true; } else if (inheritedInclusiveLevels.get(i)) { // Allow levels that flowed into this node. return true; } // Filter out levels that were already discarded from a previous split. return false; } /** * Calculate the set of levels that flow through to a child. * @param includeAllLevels naVsRest dictates include all (inherited) levels * @param discardAllLevels naVsRest dictates discard all levels * @param nodeBitsetDoesContain true if the GenmodelBitset from the compressed_tree * @return Calculated set of levels */ private BitSet calculateChildInclusiveLevels(boolean includeAllLevels, boolean discardAllLevels, boolean nodeBitsetDoesContain) { BitSet inheritedInclusiveLevels = findInclusiveLevels(colId); BitSet childInclusiveLevels = new BitSet(); for (int i = 0; i < domainValues.length; i++) { // Calculate whether this level should flow into this child node. boolean includeThisLevel = false; { if (discardAllLevels) { includeThisLevel = false; } else if (includeAllLevels) { includeThisLevel = calculateIncludeThisLevel(inheritedInclusiveLevels, i); } else if (bs.isInRange(i) && bs.contains(i) == nodeBitsetDoesContain) { includeThisLevel = calculateIncludeThisLevel(inheritedInclusiveLevels, i); } } if (includeThisLevel) { childInclusiveLevels.set(i); } } return childInclusiveLevels; } void setLeftChild(SharedTreeNode v) { leftChild = v; boolean childInclusiveNa = calculateChildInclusiveNa(leftward); v.setInclusiveNa(childInclusiveNa); if (! isBitset()) { return; } BitSet childInclusiveLevels = calculateChildInclusiveLevels(naVsRest, false, false); v.setInclusiveLevels(childInclusiveLevels); } void setRightChild(SharedTreeNode v) { rightChild = v; boolean childInclusiveNa = calculateChildInclusiveNa(!leftward); v.setInclusiveNa(childInclusiveNa); if (! isBitset()) { return; } BitSet childInclusiveLevels = calculateChildInclusiveLevels(false, naVsRest, true); v.setInclusiveLevels(childInclusiveLevels); } void setInclusiveNa(boolean v) { inclusiveNa = v; } private boolean getInclusiveNa() { return inclusiveNa; } private void setInclusiveLevels(BitSet v) { inclusiveLevels = v; } private BitSet getInclusiveLevels() { return inclusiveLevels; } public String getName() { return "Node " + nodeNumber; } public void print() { System.out.println(" Node " + nodeNumber); System.out.println(" weight: " + weight); System.out.println(" depth: " + depth); System.out.println(" colId: " + colId); System.out.println(" colName: " + ((colName != null) ? colName : "")); System.out.println(" leftward: " + leftward); System.out.println(" naVsRest: " + naVsRest); System.out.println(" splitVal: " + splitValue); System.out.println(" isBitset: " + isBitset()); System.out.println(" predValue: " + predValue); System.out.println(" squaredErr: " + squaredError); System.out.println(" leftChild: " + ((leftChild != null) ? leftChild.getName() : "")); System.out.println(" rightChild: " + ((rightChild != null) ? rightChild.getName() : "")); } void printEdges() { if (leftChild != null) { System.out.println(" " + getName() + " ---left---> " + leftChild.getName()); leftChild.printEdges(); } if (rightChild != null) { System.out.println(" " + getName() + " ---right--> " + rightChild.getName()); rightChild.printEdges(); } } private String getDotName() { return "SG_" + subgraphNumber + "_Node_" + nodeNumber; } private boolean isBitset() { return (domainValues != null); } public static String escapeQuotes(String s) { return s.replace("\"", "\\\""); } private void printDotNode(PrintStream os, boolean detail) { os.print("\"" + getDotName() + "\""); os.print(" ["); if (leftChild==null && rightChild==null) { os.print("label=\""); os.print(predValue); } else if (isBitset()) { os.print("shape=box,label=\""); os.print(escapeQuotes(colName)); } else { assert(! Float.isNaN(splitValue)); os.print("shape=box,label=\""); os.print(escapeQuotes(colName) + " < " + splitValue); } if (detail) { os.print("\\n\\nN" + getNodeNumber() + "\\n"); if (leftChild != null || rightChild != null) { if (!Float.isNaN(predValue)) { os.print("\\nPred: " + predValue); } } if (!Float.isNaN(squaredError)) { os.print("\\nSE: " + squaredError); } os.print("\\nW: " + getWeight()); if (naVsRest) { os.print("\\n" + "nasVsRest"); } if (leftChild != null) { os.print("\\n" + "L: N" + leftChild.getNodeNumber()); } if (rightChild != null) { os.print("\\n" + "R: N" + rightChild.getNodeNumber()); } } os.print("\"]"); os.println(""); } /** * Recursively print nodes at a particular depth level in the tree. Useful to group them so they render properly. * @param os output stream * @param levelToPrint level number * @param detail include addtional node detail information */ void printDotNodesAtLevel(PrintStream os, int levelToPrint, boolean detail) { if (getDepth() == levelToPrint) { printDotNode(os, detail); return; } assert (getDepth() < levelToPrint); if (leftChild != null) { leftChild.printDotNodesAtLevel(os, levelToPrint, detail); } if (rightChild != null) { rightChild.printDotNodesAtLevel(os, levelToPrint, detail); } } private void printDotEdgesCommon(PrintStream os, int maxLevelsToPrintPerEdge, ArrayList<String> arr, SharedTreeNode child) { if (isBitset()) { BitSet childInclusiveLevels = child.getInclusiveLevels(); int total = childInclusiveLevels.cardinality(); if ((total > 0) && (total <= maxLevelsToPrintPerEdge)) { for (int i = childInclusiveLevels.nextSetBit(0); i >= 0; i = childInclusiveLevels.nextSetBit(i+1)) { arr.add(domainValues[i]); } } else { arr.add(total + " levels"); } } os.print("label=\""); for (String s : arr) { os.print(escapeQuotes(s) + "\\n"); } os.print("\""); os.println("]"); } /** * Recursively print all edges in the tree. * @param os output stream * @param maxLevelsToPrintPerEdge Limit the number of individual categorical level names printed per edge */ void printDotEdges(PrintStream os, int maxLevelsToPrintPerEdge) { assert (leftChild == null) == (rightChild == null); if (leftChild != null) { os.print("\"" + getDotName() + "\"" + " -> " + "\"" + leftChild.getDotName() + "\"" + " ["); ArrayList<String> arr = new ArrayList<>(); if (leftChild.getInclusiveNa()) { arr.add("[NA]"); } if (naVsRest) { arr.add("[Not NA]"); } else { if (! isBitset()) { arr.add("<"); } } printDotEdgesCommon(os, maxLevelsToPrintPerEdge, arr, leftChild); } if (rightChild != null) { os.print("\"" + getDotName() + "\"" + " -> " + "\"" + rightChild.getDotName() + "\"" + " ["); ArrayList<String> arr = new ArrayList<>(); if (rightChild.getInclusiveNa()) { arr.add("[NA]"); } if (! naVsRest) { if (! isBitset()) { arr.add(">="); } } printDotEdgesCommon(os, maxLevelsToPrintPerEdge, arr, rightChild); } } }