package hex.tree; import java.util.Random; import hex.genmodel.algos.tree.SharedTreeMojoModel; import water.*; import water.util.IcedBitSet; import water.util.SB; //--------------------------------------------------------------------------- // Note: this description seems to be out-of-date // // Highly compressed tree encoding: // tree: 1B nodeType, 2B colId, 4B splitVal, left-tree-size, left, right // nodeType: (from lsb): // 2 bits (1,2) skip-tree-size-size, // 2 bits (4,8) operator flag (0 --> <, 1 --> ==, 2 --> small (4B) group, 3 --> big (var size) group), // 1 bit ( 16) left leaf flag, // 1 bit ( 32) left leaf type flag (0: subtree, 1: small cat, 2: big cat, 3: float) // 1 bit ( 64) right leaf flag, // 1 bit (128) right leaf type flag (0: subtree, 1: small cat, 2: big cat, 3: float) // left, right: tree | prediction // prediction: 4 bytes of float (or 1 or 2 bytes of class prediction) // public class CompressedTree extends Keyed<CompressedTree> { final byte [] _bits; final int _nclass; // Number of classes being predicted (for an integer prediction tree) final long _seed; final String[][] _domains; public CompressedTree(byte[] bits, int nclass, long seed, int tid, int cls, String[][] domains) { super(makeTreeKey(tid, cls)); _bits = bits; _nclass = nclass; _seed = seed; _domains = domains; } public double score(final double row[]) { return SharedTreeMojoModel.scoreTree(_bits, row, _nclass, false, _domains); } public String getDecisionPath(final double row[]) { double d = SharedTreeMojoModel.scoreTree(_bits, row, _nclass, true, _domains); return SharedTreeMojoModel.getDecisionPath(d); } public Random rngForChunk(int cidx) { Random rand = new Random(_seed); for (int i = 0; i < cidx; i++) rand.nextLong(); long seed = rand.nextLong(); return new Random(seed); } public String toString(SharedTreeModel.SharedTreeOutput tm) { final String[] names = tm._names; final SB sb = new SB(); new TreeVisitor<RuntimeException>(this) { @Override protected void pre(int col, float fcmp, IcedBitSet gcmp, int equal, int naSplitDirInt) { if (naSplitDirInt == DhnasdNaVsRest) sb.p("!Double.isNaN(" + sb.i().p(names[col]).p(")")); else if (naSplitDirInt == DhnasdNaLeft) sb.p("Double.isNaN(" + sb.i().p(names[col]).p(") || ")); else if (equal==1) sb.p("!Double.isNaN(" + sb.i().p(names[col]).p(") && ")); if (naSplitDirInt != DhnasdNaVsRest) { sb.i().p(names[col]).p(' '); if (equal == 0) sb.p("< ").p(fcmp); else if (equal == 1) sb.p("!=").p(fcmp); else sb.p("in ").p(gcmp); } sb.ii(1).nl(); } @Override protected void post(int col, float fcmp, int equal) { sb.di(1); } @Override protected void leaf(float pred) { sb.i().p("return ").p(pred).nl(); } }.visit(); return sb.toString(); } public static Key<CompressedTree> makeTreeKey(int treeId, int clazz) { return Key.makeSystem("tree_" + treeId + "_" + clazz + "_" + Key.rand()); } @Override protected long checksum_impl() { throw new UnsupportedOperationException(); } }