package hex.genmodel.algos.tree; import hex.genmodel.MojoModel; import hex.genmodel.algos.drf.DrfMojoModel; import hex.genmodel.algos.gbm.GbmMojoModel; import hex.genmodel.utils.ByteBufferWrapper; import hex.genmodel.utils.GenmodelBitSet; import java.util.Arrays; import java.util.HashMap; /** * Common ancestor for {@link DrfMojoModel} and {@link GbmMojoModel}. * See also: `hex.tree.SharedTreeModel` and `hex.tree.TreeVisitor` classes. */ public abstract class SharedTreeMojoModel extends MojoModel { private static final int NsdNaVsRest = NaSplitDir.NAvsREST.value(); private static final int NsdNaLeft = NaSplitDir.NALeft.value(); private static final int NsdLeft = NaSplitDir.Left.value(); protected Number _mojo_version; /** * {@code _ntree_groups} is the number of trees requested by the user. For * binomial case or regression this is also the total number of trees * trained; however in multinomial case each requested "tree" is actually * represented as a group of trees, with {@code _ntrees_per_group} trees * in each group. Each of these individual trees assesses the likelihood * that a given observation belongs to class A, B, C, etc. of a * multiclass response. */ protected int _ntree_groups; protected int _ntrees_per_group; /** * Array of binary tree data, each tree being a {@code byte[]} array. The * trees are logically grouped into a rectangular grid of dimensions * {@link #_ntree_groups} x {@link #_ntrees_per_group}, however physically * they are stored as 1-dimensional list, and an {@code [i, j]} logical * tree is mapped to the index {@link #treeIndex(int, int)}. */ protected byte[][] _compressed_trees; /** * Array of auxiliary binary tree data, each being a {@code byte[]} array. */ protected byte[][] _compressed_trees_aux; /** * GLM's beta used for calibrating output probabilities using Platt Scaling. */ protected double[] _calib_glm_beta; /** * Highly efficient (critical path) tree scoring * * Given a tree (in the form of a byte array) and the row of input data, compute either this tree's * predicted value when `computeLeafAssignment` is false, or the the decision path within the tree (but no more * than 64 levels) when `computeLeafAssignment` is true. * * Note: this function is also used from the `hex.tree.CompressedTree` class in `h2o-algos` project. */ @SuppressWarnings("ConstantConditions") // Complains that the code is too complex. Well duh! public static double scoreTree(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment, String[][] domains) { ByteBufferWrapper ab = new ByteBufferWrapper(tree); GenmodelBitSet bs = null; long bitsRight = 0; int level = 0; while (true) { int nodeType = ab.get1U(); int colId = ab.get2(); if (colId == 65535) return ab.get4f(); int naSplitDir = ab.get1U(); boolean naVsRest = naSplitDir == NsdNaVsRest; boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft; int lmask = (nodeType & 51); int equal = (nodeType & 12); // Can be one of 0, 8, 12 assert equal != 4; // no longer supported float splitVal = -1; if (!naVsRest) { // Extract value or group to split on if (equal == 0) { // Standard float-compare test (either < or ==) splitVal = ab.get4f(); // Get the float to compare } else { // Bitset test if (bs == null) bs = new GenmodelBitSet(0); if (equal == 8) bs.fill2(tree, ab); else bs.fill3(tree, ab); } } // This logic: // // double d = row[colId]; // if (Double.isNaN(d) || ( equal != 0 && bs != null && !bs.isInRange((int)d) ) || (domains != null && domains[colId] != null && domains[colId].length <= (int)d) // ? !leftward : !naVsRest && (equal == 0? d >= splitVal : bs.contains((int)d))) { // Really does this: // // if (value is NaN or value is not in the range of the bitset or is outside the domain map length (but an integer) ) { // if (leftward) { // go left // } // else { // go right // } // } // else { // if (naVsRest) { // go left // } // else { // if (numeric) { // if (value < split value) { // go left // } // else { // go right // } // } // else { // if (value not in bitset) { // go left // } // else { // go right // } // } // } // } double d = row[colId]; if (Double.isNaN(d) || ( equal != 0 && bs != null && !bs.isInRange((int)d) ) || (domains != null && domains[colId] != null && domains[colId].length <= (int)d) ? !leftward : !naVsRest && (equal == 0? d >= splitVal : bs.contains((int)d))) { // go RIGHT switch (lmask) { case 0: ab.skip(ab.get1U()); break; case 1: ab.skip(ab.get2()); break; case 2: ab.skip(ab.get3()); break; case 3: ab.skip(ab.get4()); break; case 16: ab.skip(nclasses < 256? 1 : 2); break; // Small leaf case 48: ab.skip(4); break; // skip the prediction default: assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree); } if (computeLeafAssignment && level < 64) bitsRight |= 1 << level; lmask = (nodeType & 0xC0) >> 2; // Replace leftmask with the rightmask } else { // go LEFT if (lmask <= 3) ab.skip(lmask + 1); } level++; if ((lmask & 16) != 0) { if (computeLeafAssignment) { bitsRight |= 1 << level; // mark the end of the tree return Double.longBitsToDouble(bitsRight); } else { return ab.get4f(); } } } } public static String getDecisionPath(double leafAssignment) { long l = Double.doubleToRawLongBits(leafAssignment); StringBuilder sb = new StringBuilder(); int pos = 0; for (int i = 0; i < 64; ++i) { boolean right = ((l>>i) & 0x1L) == 1; sb.append(right? "R" : "L"); if (right) pos = i; } return sb.substring(0, pos); } //------------------------------------------------------------------------------------------------------------------ // Computing a Tree Graph //------------------------------------------------------------------------------------------------------------------ private void computeTreeGraph(SharedTreeSubgraph sg, SharedTreeNode node, byte[] tree, ByteBufferWrapper ab, HashMap<Integer, AuxInfo> auxMap, int nclasses) { int nodeType = ab.get1U(); int colId = ab.get2(); if (colId == 65535) { float leafValue = ab.get4f(); node.setPredValue(leafValue); return; } String colName = getNames()[colId]; node.setCol(colId, colName); int naSplitDir = ab.get1U(); boolean naVsRest = naSplitDir == NsdNaVsRest; boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft; node.setLeftward(leftward); node.setNaVsRest(naVsRest); int lmask = (nodeType & 51); int equal = (nodeType & 12); // Can be one of 0, 8, 12 assert equal != 4; // no longer supported if (!naVsRest) { // Extract value or group to split on if (equal == 0) { // Standard float-compare test (either < or ==) float splitVal = ab.get4f(); // Get the float to compare node.setSplitValue(splitVal); } else { // Bitset test GenmodelBitSet bs = new GenmodelBitSet(0); if (equal == 8) bs.fill2(tree, ab); else bs.fill3(tree, ab); node.setBitset(getDomainValues(colId), bs); } } AuxInfo auxInfo = auxMap.get(node.getNodeNumber()); // go RIGHT { ByteBufferWrapper ab2 = new ByteBufferWrapper(tree); ab2.skip(ab.position()); switch (lmask) { case 0: ab2.skip(ab2.get1U()); break; case 1: ab2.skip(ab2.get2()); break; case 2: ab2.skip(ab2.get3()); break; case 3: ab2.skip(ab2.get4()); break; case 16: ab2.skip(nclasses < 256 ? 1 : 2); break; // Small leaf case 48: ab2.skip(4); break; // skip the prediction default: assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree); } int lmask2 = (nodeType & 0xC0) >> 2; // Replace leftmask with the rightmask SharedTreeNode newNode = sg.makeRightChildNode(node); newNode.setWeight(auxInfo.weightR); newNode.setNodeNumber(auxInfo.nidR); newNode.setPredValue(auxInfo.predR); newNode.setSquaredError(auxInfo.sqErrR); if ((lmask2 & 16) != 0) { float leafValue = ab2.get4f(); newNode.setPredValue(leafValue); auxInfo.predR = leafValue; } else { computeTreeGraph(sg, newNode, tree, ab2, auxMap, nclasses); } } // go LEFT { ByteBufferWrapper ab2 = new ByteBufferWrapper(tree); ab2.skip(ab.position()); if (lmask <= 3) ab2.skip(lmask + 1); SharedTreeNode newNode = sg.makeLeftChildNode(node); newNode.setWeight(auxInfo.weightL); newNode.setNodeNumber(auxInfo.nidL); newNode.setPredValue(auxInfo.predL); newNode.setSquaredError(auxInfo.sqErrL); if ((lmask & 16) != 0) { float leafValue = ab2.get4f(); newNode.setPredValue(leafValue); auxInfo.predL = leafValue; } else { computeTreeGraph(sg, newNode, tree, ab2, auxMap, nclasses); } } if (node.getNodeNumber() == 0) { float p = (float)(((double)auxInfo.predL*(double)auxInfo.weightL + (double)auxInfo.predR*(double)auxInfo.weightR)/((double)auxInfo.weightL + (double)auxInfo.weightR)); if (Math.abs(p) < 1e-7) p = 0; node.setPredValue(p); node.setSquaredError(auxInfo.sqErrR + auxInfo.sqErrL); node.setWeight(auxInfo.weightL + auxInfo.weightR); } checkConsistency(auxInfo, node); } /** * Compute a graph of the forest. * * @return A graph of the forest. */ public SharedTreeGraph _computeGraph(int treeToPrint) { SharedTreeGraph g = new SharedTreeGraph(); if (treeToPrint >= _ntree_groups) { throw new IllegalArgumentException("Tree " + treeToPrint + " does not exist (max " + _ntree_groups + ")"); } int j; if (treeToPrint >= 0) { j = treeToPrint; } else { j = 0; } for (; j < _ntree_groups; j++) { for (int i = 0; i < _ntrees_per_group; i++) { String className = ""; { String[] domainValues = getDomainValues(getResponseIdx()); if (domainValues != null) { className = ", Class " + domainValues[i]; } } int itree = treeIndex(j, i); SharedTreeSubgraph sg = g.makeSubgraph("Tree " + j + className); SharedTreeNode node = sg.makeRootNode(); node.setSquaredError(Float.NaN); node.setPredValue(Float.NaN); byte[] tree = _compressed_trees[itree]; ByteBufferWrapper ab = new ByteBufferWrapper(tree); ByteBufferWrapper abAux = new ByteBufferWrapper(_compressed_trees_aux[itree]); HashMap<Integer, AuxInfo> auxMap = new HashMap<>(); while (abAux.hasRemaining()) { AuxInfo auxInfo = new AuxInfo(abAux); auxMap.put(auxInfo.nid, auxInfo); } computeTreeGraph(sg, node, tree, ab, auxMap, _nclasses); } if (treeToPrint >= 0) { break; } } return g; } static class AuxInfo { AuxInfo(ByteBufferWrapper abAux) { // node ID nid = abAux.get4(); // parent node ID pid = abAux.get4(); //sum of observation weights (typically, that's just the count of observations) weightL = abAux.get4f(); weightR = abAux.get4f(); //predicted values predL = abAux.get4f(); predR = abAux.get4f(); //squared error sqErrL = abAux.get4f(); sqErrR = abAux.get4f(); //node IDs (consistent with tree construction) nidL = abAux.get4(); nidR = abAux.get4(); } @Override public String toString() { return "nid: " + nid + "\n" + "pid: " + pid + "\n" + "nidL: " + nidL + "\n" + "nidR: " + nidR + "\n" + "weightL: " + weightL + "\n" + "weightR: " + weightR + "\n" + "predL: " + predL + "\n" + "predR: " + predR + "\n" + "sqErrL: " + sqErrL + "\n" + "sqErrR: " + sqErrR + "\n"; } public int nid, pid, nidL, nidR; public float weightL, weightR, predL, predR, sqErrL, sqErrR; } void checkConsistency(AuxInfo auxInfo, SharedTreeNode node) { boolean ok = true; ok &= (auxInfo.nid == node.getNodeNumber()); double sum = 0; if (node.leftChild!=null) { ok &= (auxInfo.nidL == node.leftChild.getNodeNumber()); ok &= (auxInfo.weightL == node.leftChild.getWeight()); ok &= (auxInfo.predL == node.leftChild.predValue); ok &= (auxInfo.sqErrL == node.leftChild.squaredError); sum += node.leftChild.getWeight(); } if (node.rightChild!=null) { ok &= (auxInfo.nidR == node.rightChild.getNodeNumber()); ok &= (auxInfo.weightR == node.rightChild.getWeight()); ok &= (auxInfo.predR == node.rightChild.predValue); ok &= (auxInfo.sqErrR == node.rightChild.squaredError); sum += node.rightChild.getWeight(); } if (node.parent!=null) { ok &= (auxInfo.pid == node.parent.getNodeNumber()); ok &= (Math.abs(node.getWeight() - sum) < 1e-5 * (node.getWeight() + sum)); } if (!ok) { System.out.println("\nTree inconsistency found:"); node.print(); node.leftChild.print(); node.rightChild.print(); System.out.println(auxInfo.toString()); } } //------------------------------------------------------------------------------------------------------------------ // Private //------------------------------------------------------------------------------------------------------------------ protected SharedTreeMojoModel(String[] columns, String[][] domains) { super(columns, domains); } /** * Score all trees and fill in the `preds` array. */ protected void scoreAllTrees(double[] row, double[] preds) { java.util.Arrays.fill(preds, 0); for (int i = 0; i < _ntrees_per_group; i++) { int k = _nclasses == 1? 0 : i + 1; for (int j = 0; j < _ntree_groups; j++) { int itree = treeIndex(j, i); // Skip all empty trees if (_compressed_trees[itree] == null) continue; if (_mojo_version.equals(1.0)) { //First version preds[k] += scoreTree0(_compressed_trees[itree], row, _nclasses, false); } else if (_mojo_version.equals(1.1)) { //Second version preds[k] += scoreTree1(_compressed_trees[itree], row, _nclasses, false); } else if (_mojo_version.equals(1.2)) { //CURRENT VERSION preds[k] += scoreTree(_compressed_trees[itree], row, _nclasses, false, _domains); } } } } protected int treeIndex(int groupIndex, int classIndex) { return classIndex * _ntree_groups + groupIndex; } // DO NOT CHANGE THE CODE BELOW THIS LINE // DO NOT CHANGE THE CODE BELOW THIS LINE // DO NOT CHANGE THE CODE BELOW THIS LINE // DO NOT CHANGE THE CODE BELOW THIS LINE // DO NOT CHANGE THE CODE BELOW THIS LINE // DO NOT CHANGE THE CODE BELOW THIS LINE // DO NOT CHANGE THE CODE BELOW THIS LINE ///////////////////////////////////////////////////// /** * SET IN STONE FOR MOJO VERSION "1.00" - DO NOT CHANGE * @param tree * @param row * @param nclasses * @param computeLeafAssignment * @return */ @SuppressWarnings("ConstantConditions") // Complains that the code is too complex. Well duh! public static double scoreTree0(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment) { ByteBufferWrapper ab = new ByteBufferWrapper(tree); GenmodelBitSet bs = null; // Lazily set on hitting first group test long bitsRight = 0; int level = 0; while (true) { int nodeType = ab.get1U(); int colId = ab.get2(); if (colId == 65535) return ab.get4f(); int naSplitDir = ab.get1U(); boolean naVsRest = naSplitDir == NsdNaVsRest; boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft; int lmask = (nodeType & 51); int equal = (nodeType & 12); // Can be one of 0, 8, 12 assert equal != 4; // no longer supported float splitVal = -1; if (!naVsRest) { // Extract value or group to split on if (equal == 0) { // Standard float-compare test (either < or ==) splitVal = ab.get4f(); // Get the float to compare } else { // Bitset test if (bs == null) bs = new GenmodelBitSet(0); if (equal == 8) bs.fill2(tree, ab); else bs.fill3_1(tree, ab); } } double d = row[colId]; if (Double.isNaN(d)? !leftward : !naVsRest && (equal == 0? d >= splitVal : bs.contains0((int)d))) { // go RIGHT switch (lmask) { case 0: ab.skip(ab.get1U()); break; case 1: ab.skip(ab.get2()); break; case 2: ab.skip(ab.get3()); break; case 3: ab.skip(ab.get4()); break; case 16: ab.skip(nclasses < 256? 1 : 2); break; // Small leaf case 48: ab.skip(4); break; // skip the prediction default: assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree); } if (computeLeafAssignment && level < 64) bitsRight |= 1 << level; lmask = (nodeType & 0xC0) >> 2; // Replace leftmask with the rightmask } else { // go LEFT if (lmask <= 3) ab.skip(lmask + 1); } level++; if ((lmask & 16) != 0) { if (computeLeafAssignment) { bitsRight |= 1 << level; // mark the end of the tree return Double.longBitsToDouble(bitsRight); } else { return ab.get4f(); } } } } /** * SET IN STONE FOR MOJO VERSION "1.10" - DO NOT CHANGE * @param tree * @param row * @param nclasses * @param computeLeafAssignment * @return */ @SuppressWarnings("ConstantConditions") // Complains that the code is too complex. Well duh! public static double scoreTree1(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment) { ByteBufferWrapper ab = new ByteBufferWrapper(tree); GenmodelBitSet bs = null; long bitsRight = 0; int level = 0; while (true) { int nodeType = ab.get1U(); int colId = ab.get2(); if (colId == 65535) return ab.get4f(); int naSplitDir = ab.get1U(); boolean naVsRest = naSplitDir == NsdNaVsRest; boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft; int lmask = (nodeType & 51); int equal = (nodeType & 12); // Can be one of 0, 8, 12 assert equal != 4; // no longer supported float splitVal = -1; if (!naVsRest) { // Extract value or group to split on if (equal == 0) { // Standard float-compare test (either < or ==) splitVal = ab.get4f(); // Get the float to compare } else { // Bitset test if (bs == null) bs = new GenmodelBitSet(0); if (equal == 8) bs.fill2(tree, ab); else bs.fill3_1(tree, ab); } } double d = row[colId]; if (Double.isNaN(d) || ( equal != 0 && bs != null && !bs.isInRange((int)d) ) ? !leftward : !naVsRest && (equal == 0? d >= splitVal : bs.contains((int)d))) { // go RIGHT switch (lmask) { case 0: ab.skip(ab.get1U()); break; case 1: ab.skip(ab.get2()); break; case 2: ab.skip(ab.get3()); break; case 3: ab.skip(ab.get4()); break; case 16: ab.skip(nclasses < 256? 1 : 2); break; // Small leaf case 48: ab.skip(4); break; // skip the prediction default: assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree); } if (computeLeafAssignment && level < 64) bitsRight |= 1 << level; lmask = (nodeType & 0xC0) >> 2; // Replace leftmask with the rightmask } else { // go LEFT if (lmask <= 3) ab.skip(lmask + 1); } level++; if ((lmask & 16) != 0) { if (computeLeafAssignment) { bitsRight |= 1 << level; // mark the end of the tree return Double.longBitsToDouble(bitsRight); } else { return ab.get4f(); } } } } @Override public boolean calibrateClassProbabilities(double[] preds) { if (_calib_glm_beta == null) return false; assert _nclasses == 2; // only supported for binomial classification assert preds.length == _nclasses + 1; double p = GLM_logitInv((preds[1] * _calib_glm_beta[0]) + _calib_glm_beta[1]); preds[1] = 1 - p; preds[2] = p; return true; } }