package hex.gbm; import static hex.gbm.SharedTreeModelBuilder.createRNG; import hex.ConfusionMatrix; import hex.VarImp; import hex.gbm.DTree.TreeModel.CompressedTree; import hex.gbm.DTree.TreeModel.TreeVisitor; import water.*; import water.api.*; import water.api.Request.API; import water.fvec.Chunk; import water.license.LicenseManager; import water.util.*; import water.util.Utils.IcedBitSet; import java.util.*; /** A Decision Tree, laid over a Frame of Vecs, and built distributed. This class defines an explicit Tree structure, as a collection of {@code DTree} {@code Node}s. The Nodes are numbered with a unique {@code _nid}. Users need to maintain their own mapping from their data to a {@code _nid}, where the obvious technique is to have a Vec of {@code _nid}s (ints), one per each element of the data Vecs. Each {@code Node} has a {@code DHistogram}, describing summary data about the rows. The DHistogram requires a pass over the data to be filled in, and we expect to fill in all rows for Nodes at the same depth at the same time. i.e., a single pass over the data will fill in all leaf Nodes' DHistograms at once. @author Cliff Click */ public class DTree extends Iced { final String[] _names; // Column names final int _ncols; // Active training columns final char _nbins; // Max number of bins to split over final char _nclass; // #classes, or 1 for regression trees final int _min_rows; // Fewest allowed rows in any split final long _seed; // RNG seed; drives sampling seeds if necessary private Node[] _ns; // All the nodes in the tree. Node 0 is the root. int _len; // Resizable array public DTree( String[] names, int ncols, char nbins, char nclass, int min_rows ) { this(names,ncols,nbins,nclass,min_rows,-1); } public DTree( String[] names, int ncols, char nbins, char nclass, int min_rows, long seed ) { _names = names; _ncols = ncols; _nbins=nbins; _nclass=nclass; _min_rows = min_rows; _ns = new Node[1]; _seed = seed; } public final Node root() { return _ns[0]; } // One-time local init after wire transfer void init_tree( ) { for( int j=0; j<_len; j++ ) _ns[j]._tree = this; } // Return Node i public final Node node( int i ) { if( i >= _len ) throw new ArrayIndexOutOfBoundsException(i); return _ns[i]; } public final UndecidedNode undecided( int i ) { return (UndecidedNode)node(i); } public final DecidedNode decided( int i ) { return ( DecidedNode)node(i); } // Get a new node index, growing innards on demand private synchronized int newIdx(Node n) { if( _len == _ns.length ) _ns = Arrays.copyOf(_ns,_len<<1); _ns[_len] = n; return _len++; } // Return a deterministic chunk-local RNG. Can be kinda expensive. // Override this in, e.g. Random Forest algos, to get a per-chunk RNG public Random rngForChunk( int cidx ) { throw H2O.fail(); } public final int len() { return _len; } public final void len(int len) { _len = len; } // Public stats about tree public int leaves; public int depth; // -------------------------------------------------------------------------- // Abstract node flavor public static abstract class Node extends Iced { transient protected DTree _tree; // Make transient, lest we clone the whole tree final protected int _pid; // Parent node id, root has no parent and uses -1 final protected int _nid; // My node-ID, 0 is root Node( DTree tree, int pid, int nid ) { _tree = tree; _pid=pid; tree._ns[_nid=nid] = this; } Node( DTree tree, int pid ) { _tree = tree; _pid=pid; _nid = tree.newIdx(this); } // Recursively print the decision-line from tree root to this child. StringBuilder printLine(StringBuilder sb ) { if( _pid==-1 ) return sb.append("[root]"); DecidedNode parent = _tree.decided(_pid); parent.printLine(sb).append(" to "); return parent.printChild(sb,_nid); } abstract public StringBuilder toString2(StringBuilder sb, int depth); abstract protected AutoBuffer compress(AutoBuffer ab); abstract protected int size(); public final int nid() { return _nid; } public final int pid() { return _pid; } } // -------------------------------------------------------------------------- // Records a column, a bin to split at within the column, and the MSE. public static class Split extends Iced { final int _col, _bin; // Column to split, bin where being split final IcedBitSet _bs; // For binary y and categorical x (with >= 4 levels), split into 2 non-contiguous groups final byte _equal; // Split is 0: <, 1: == with single split point, 2: == with group split (<= 32 levels), 3: == with group split (> 32 levels) final double _se0, _se1; // Squared error of each subsplit final long _n0, _n1; // Rows in each final split final double _p0, _p1; // Predicted value for each split public Split( int col, int bin, IcedBitSet bs, byte equal, double se0, double se1, long n0, long n1, double p0, double p1 ) { _col = col; _bin = bin; _bs = bs; _equal = equal; _n0 = n0; _n1 = n1; _se0 = se0; _se1 = se1; _p0 = p0; _p1 = p1; } public final double se() { return _se0+_se1; } public final int col() { return _col; } public final int bin() { return _bin; } public final long rowsLeft () { return _n0; } public final long rowsRight() { return _n1; } /** Returns empirical improvement in mean-squared error. * * <p>Formula for node splitting space into two subregions R1,R2 with predictions y1, y2:</p> * <code>i2(R1,R2) ~ w1*w2 / (w1+w2) * (y1 - y2)^2</code> * * * <p>For more information see (35), (45) in the paper * <a href="www-stat.stanford.edu/~jhf/ftp/trebst.pdf"><i>J. Friedman - Greedy Function Approximation: A Gradient boosting machine</i></a></p> */ public final float improvement() { double d = (_p0-_p1); return (float) ( d*d*_n0*_n1 / (_n0+_n1) ); } // Split-at dividing point. Don't use the step*bin+bmin, due to roundoff // error we can have that point be slightly higher or lower than the bin // min/max - which would allow values outside the stated bin-range into the // split sub-bins. Always go for a value which splits the nearest two // elements. float splat(DHistogram hs[]) { DHistogram h = hs[_col]; assert _bin > 0 && _bin < h.nbins(); if( _equal == 1 ) { assert h.bins(_bin)!=0; return h.binAt(_bin); } // Find highest non-empty bin below the split int x=_bin-1; while( x >= 0 && h.bins(x)==0 ) x--; // Find lowest non-empty bin above the split int n=_bin; while( n < h.nbins() && h.bins(n)==0 ) n++; // Lo is the high-side of the low non-empty bin, rounded to int for int columns // Hi is the low -side of the hi non-empty bin, rounded to int for int columns // Example: Suppose there are no empty bins, and we are splitting an // integer column at 48.4 (more than nbins, so step != 1.0, perhaps // step==1.8). The next lowest non-empty bin is from 46.6 to 48.4, and // we set lo=48.4. The next highest non-empty bin is from 48.4 to 50.2 // and we set hi=48.4. Since this is an integer column, we round lo to // 48 (largest integer below the split) and hi to 49 (smallest integer // above the split). Finally we average them, and split at 48.5. float lo = h.binAt(x+1); float hi = h.binAt(n ); if( h._isInt > 0 ) lo = h._step==1 ? lo-1 : (float)Math.floor(lo); if( h._isInt > 0 ) hi = h._step==1 ? hi : (float)Math.ceil (hi); return (lo+hi)/2.0f; } // Split a DHistogram. Return null if there is no point in splitting // this bin further (such as there's fewer than min_row elements, or zero // error in the response column). Return an array of DHistograms (one // per column), which are bounded by the split bin-limits. If the column // has constant data, or was not being tracked by a prior DHistogram // (for being constant data from a prior split), then that column will be // null in the returned array. public DHistogram[] split( int way, char nbins, int min_rows, DHistogram hs[], float splat ) { long n = way==0 ? _n0 : _n1; if( n < min_rows || n <= 1 ) return null; // Too few elements double se = way==0 ? _se0 : _se1; if( se <= 1e-30 ) return null; // No point in splitting a perfect prediction // Build a next-gen split point from the splitting bin int cnt=0; // Count of possible splits DHistogram nhists[] = new DHistogram[hs.length]; // A new histogram set for( int j=0; j<hs.length; j++ ) { // For every column in the new split DHistogram h = hs[j]; // old histogram of column if( h == null ) continue; // Column was not being tracked? int adj_nbins = Math.max(h.nbins()>>1,nbins); // min & max come from the original column data, since splitting on an // unrelated column will not change the j'th columns min/max. // Tighten min/max based on actual observed data for tracked columns float min, maxEx; if( h._bins == null ) { // Not tracked this last pass? min = h._min; // Then no improvement over last go maxEx = h._maxEx; } else { // Else pick up tighter observed bounds min = h.find_min(); // Tracked inclusive lower bound if( h.find_maxIn() == min ) continue; // This column will not split again maxEx = h.find_maxEx(); // Exclusive max } // Tighter bounds on the column getting split: exactly each new // DHistogram's bound are the bins' min & max. if( _col==j ) { if( _equal != 0 ) { // Equality split; no change on unequals-side if( way == 1 ) continue; // but know exact bounds on equals-side - and this col will not split again } else { // Less-than split if( h._bins[_bin]==0 ) throw H2O.unimpl(); // Here I should walk up & down same as split() above. float split = splat; if( h._isInt > 0 ) split = (float)Math.ceil(split); if( way == 0 ) maxEx= split; else min = split; } } if( Utils.equalsWithinOneSmallUlp(min, maxEx) ) continue; // This column will not split again if( Float.isInfinite(adj_nbins/(maxEx-min)) ) continue; if( h._isInt > 0 && !(min+1 < maxEx ) ) continue; // This column will not split again if( min > maxEx ) continue; // Happens for all-NA subsplits assert min < maxEx && n > 1 : ""+min+"<"+maxEx+" n="+n; nhists[j] = DHistogram.make(h._name,adj_nbins,h._isInt,min,maxEx,n,min_rows,h._doGrpSplit,h.isBinom()); cnt++; // At least some chance of splitting } return cnt == 0 ? null : nhists; } public static StringBuilder ary2str( StringBuilder sb, int w, long xs[] ) { sb.append('['); for( long x : xs ) UndecidedNode.p(sb,x,w).append(","); return sb.append(']'); } public static StringBuilder ary2str( StringBuilder sb, int w, float xs[] ) { sb.append('['); for( float x : xs ) UndecidedNode.p(sb,x,w).append(","); return sb.append(']'); } public static StringBuilder ary2str( StringBuilder sb, int w, double xs[] ) { sb.append('['); for( double x : xs ) UndecidedNode.p(sb,(float)x,w).append(","); return sb.append(']'); } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append("{"+_col+"/"); UndecidedNode.p(sb,_bin,2); sb.append(", se0=").append(_se0); sb.append(", se1=").append(_se1); sb.append(", n0=" ).append(_n0 ); sb.append(", n1=" ).append(_n1 ); return sb.append("}").toString(); } } // -------------------------------------------------------------------------- // An UndecidedNode: Has a DHistogram which is filled in (in parallel // with other histograms) in a single pass over the data. Does not contain // any split-decision. public static abstract class UndecidedNode extends Node { public transient DHistogram[] _hs; public final int _scoreCols[]; // A list of columns to score; could be null for all public UndecidedNode( DTree tree, int pid, DHistogram[] hs ) { super(tree,pid); assert hs.length==tree._ncols; _scoreCols = scoreCols(_hs=hs); } // Pick a random selection of columns to compute best score. // Can return null for 'all columns'. abstract public int[] scoreCols( DHistogram[] hs ); // Make the parent of this Node use a -1 NID to prevent the split that this // node otherwise induces. Happens if we find out too-late that we have a // perfect prediction here, and we want to turn into a leaf. public void do_not_split( ) { if( _pid == -1 ) return; // skip root DecidedNode dn = _tree.decided(_pid); for( int i=0; i<dn._nids.length; i++ ) if( dn._nids[i]==_nid ) { dn._nids[i] = -1; return; } throw H2O.fail(); } @Override public String toString() { final String colPad=" "; final int cntW=4, mmmW=4, menW=5, varW=5; final int colW=cntW+1+mmmW+1+mmmW+1+menW+1+varW; StringBuilder sb = new StringBuilder(); sb.append("Nid# ").append(_nid).append(", "); printLine(sb).append("\n"); if( _hs == null ) return sb.append("_hs==null").toString(); final int ncols = _hs.length; for( int j=0; j<ncols; j++ ) if( _hs[j] != null ) p(sb,_hs[j]._name+String.format(", %4.1f-%4.1f",_hs[j]._min,_hs[j]._maxEx),colW).append(colPad); sb.append('\n'); for( int j=0; j<ncols; j++ ) { if( _hs[j] == null ) continue; p(sb,"cnt" ,cntW).append('/'); p(sb,"min" ,mmmW).append('/'); p(sb,"max" ,mmmW).append('/'); p(sb,"mean",menW).append('/'); p(sb,"var" ,varW).append(colPad); } sb.append('\n'); // Max bins int nbins=0; for( int j=0; j<ncols; j++ ) if( _hs[j] != null && _hs[j].nbins() > nbins ) nbins = _hs[j].nbins(); for( int i=0; i<nbins; i++ ) { for( int j=0; j<ncols; j++ ) { DHistogram h = _hs[j]; if( h == null ) continue; if( i < h.nbins() && h._bins != null ) { p(sb, h.bins(i),cntW).append('/'); p(sb, h.binAt(i),mmmW).append('/'); p(sb, h.binAt(i+1),mmmW).append('/'); p(sb, h.mean(i),menW).append('/'); p(sb, h.var (i),varW).append(colPad); } else { p(sb,"",colW).append(colPad); } } sb.append('\n'); } sb.append("Nid# ").append(_nid); return sb.toString(); } static private StringBuilder p(StringBuilder sb, String s, int w) { return sb.append(Log.fixedLength(s,w)); } static private StringBuilder p(StringBuilder sb, long l, int w) { return p(sb,Long.toString(l),w); } static private StringBuilder p(StringBuilder sb, double d, int w) { String s = Double.isNaN(d) ? "NaN" : ((d==Float.MAX_VALUE || d==-Float.MAX_VALUE || d==Double.MAX_VALUE || d==-Double.MAX_VALUE) ? " -" : (d==0?" 0":Double.toString(d))); if( s.length() <= w ) return p(sb,s,w); s = String.format("% 4.2f",d); if( s.length() > w ) s = String.format("%4.1f",d); if( s.length() > w ) s = String.format("%4.0f",d); return p(sb,s,w); } @Override public StringBuilder toString2(StringBuilder sb, int depth) { for( int d=0; d<depth; d++ ) sb.append(" "); return sb.append("Undecided\n"); } @Override protected AutoBuffer compress(AutoBuffer ab) { throw H2O.fail(); } @Override protected int size() { throw H2O.fail(); } } // -------------------------------------------------------------------------- // Internal tree nodes which split into several children over a single // column. Includes a split-decision: which child does this Row belong to? // Does not contain a histogram describing how the decision was made. public static abstract class DecidedNode extends Node { public final Split _split; // Split: col, equal/notequal/less/greater, nrows, MSE public final float _splat; // Split At point: lower bin-edge of split // _equals\_nids[] \ 0 1 // ----------------+---------- // F | < >= // T | != == public final int _nids[]; // Children NIDS for the split LEFT, RIGHT transient byte _nodeType; // Complex encoding: see the compressed struct comments transient int _size = 0; // Compressed byte size of this subtree // Make a correctly flavored Undecided public abstract UndecidedNode makeUndecidedNode(DHistogram hs[]); // Pick the best column from the given histograms public abstract Split bestCol( UndecidedNode u, DHistogram hs[] ); public DecidedNode( UndecidedNode n, DHistogram hs[] ) { super(n._tree,n._pid,n._nid); // Replace Undecided with this DecidedNode _nids = new int[2]; // Split into 2 subsets _split = bestCol(n,hs); // Best split-point for this tree if( _split._col == -1 ) { // No good split? // Happens because the predictor columns cannot split the responses - // which might be because all predictor columns are now constant, or // because all responses are now constant. _splat = Float.NaN; Arrays.fill(_nids,-1); return; } _splat = (_split._equal == 0 || _split._equal == 1) ? _split.splat(hs) : -1; // Split-at value (-1 for group-wise splits) final char nbins = _tree._nbins; final int min_rows = _tree._min_rows; for( int b=0; b<2; b++ ) { // For all split-points // Setup for children splits DHistogram nhists[] = _split.split(b,nbins,min_rows,hs,_splat); assert nhists==null || nhists.length==_tree._ncols; _nids[b] = nhists == null ? -1 : makeUndecidedNode(nhists)._nid; } } // Bin #. public int bin( Chunk chks[], int row ) { float d = (float)chks[_split._col].at0(row); // Value to split on for this row if( Float.isNaN(d) ) // Missing data? return 0; // NAs always to bin 0 // Note that during *scoring* (as opposed to training), we can be exposed // to data which is outside the bin limits. if(_split._equal == 0) return d < _splat ? 0 : 1; else if(_split._equal == 1) return d != _splat ? 0 : 1; else return _split._bs.contains((int)d) ? 1 : 0; // return _split._equal ? (d != _splat ? 0 : 1) : (d < _splat ? 0 : 1); } public int ns( Chunk chks[], int row ) { return _nids[bin(chks,row)]; } public double pred( int nid ) { return nid==0 ? _split._p0 : _split._p1; } @Override public String toString() { if( _split._col == -1 ) return "Decided has col = -1"; int col = _split._col; if( _split._equal == 1 ) return _tree._names[col]+" != "+_splat+"\n"+ _tree._names[col]+" == "+_splat+"\n"; else if( _split._equal == 2 || _split._equal == 3 ) return _tree._names[col]+" != "+_split._bs.toString()+"\n"+ _tree._names[col]+" == "+_split._bs.toString()+"\n"; return _tree._names[col]+" < "+_splat+"\n"+ _splat+" <="+_tree._names[col]+"\n"; } StringBuilder printChild( StringBuilder sb, int nid ) { int i = _nids[0]==nid ? 0 : 1; assert _nids[i]==nid : "No child nid "+nid+"? " +Arrays.toString(_nids); sb.append("[").append(_tree._names[_split._col]); sb.append(_split._equal != 0 ? (i==0 ? " != " : " == ") : (i==0 ? " < " : " >= ")); sb.append((_split._equal == 2 || _split._equal == 3) ? _split._bs.toString() : _splat).append("]"); return sb; } @Override public StringBuilder toString2(StringBuilder sb, int depth) { for( int i=0; i<_nids.length; i++ ) { for( int d=0; d<depth; d++ ) sb.append(" "); sb.append(_nid).append(" "); if( _split._col < 0 ) sb.append("init"); else { sb.append(_tree._names[_split._col]); sb.append(_split._equal != 0 ? (i==0 ? " != " : " == ") : (i==0 ? " < " : " >= ")); sb.append((_split._equal == 2 || _split._equal == 3) ? _split._bs.toString() : _splat).append("\n"); } if( _nids[i] >= 0 && _nids[i] < _tree._len ) _tree.node(_nids[i]).toString2(sb,depth+1); } return sb; } // Size of this subtree; sets _nodeType also @Override public final int size(){ if( _size != 0 ) return _size; // Cached size assert _nodeType == 0:"unexpected node type: " + _nodeType; if(_split._equal != 0) _nodeType |= _split._equal == 1 ? 4 : (_split._equal == 2 ? 8 : 12); // int res = 7; // 1B node type + flags, 2B colId, 4B float split val // 1B node type + flags, 2B colId, 4B split val/small group or (2B offset + 2B size) + large group int res = _split._equal == 3 ? 7 + _split._bs.numBytes() : 7; Node left = _tree.node(_nids[0]); int lsz = left.size(); res += lsz; if( left instanceof LeafNode ) _nodeType |= (byte)(48 << 0*2); else { int slen = lsz < 256 ? 0 : (lsz < 65535 ? 1 : (lsz<(1<<24) ? 2 : 3)); _nodeType |= slen; // Set the size-skip bits res += (slen+1); // } Node rite = _tree.node(_nids[1]); if( rite instanceof LeafNode ) _nodeType |= (byte)(48 << 1*2); res += rite.size(); assert (_nodeType&0x33) != 51; assert res != 0; return (_size = res); } // Compress this tree into the AutoBuffer @Override public AutoBuffer compress(AutoBuffer ab) { int pos = ab.position(); if( _nodeType == 0 ) size(); // Sets _nodeType & _size both ab.put1(_nodeType); // Includes left-child skip-size bits assert _split._col != -1; // Not a broken root non-decision? ab.put2((short)_split._col); // Save split-at-value or group if(_split._equal == 0 || _split._equal == 1) ab.put4f(_splat); else if(_split._equal == 2) { /* byte[] ary = MemoryManager.malloc1(4); for(int i = 0; i < 4; i++) ary[i] = _split._bs._val[i]; ab.putA1(ary, 4); */ ab.putA1(_split._bs._val, 4); } else { assert _split._equal == 3; ab.put2((char)_split._bs._offset); ab.put2((char)_split._bs.numBytes()); ab.putA1(_split._bs._val, _split._bs.numBytes()); } Node left = _tree.node(_nids[0]); if( (_nodeType&48) == 0 ) { // Size bits are optional for left leaves ! int sz = left.size(); if(sz < 256) ab.put1( sz); else if (sz < 65535) ab.put2((short)sz); else if (sz < (1<<24)) ab.put3( sz); else ab.put4( sz); // 1<<31-1 } // now write the subtree in left.compress(ab); Node rite = _tree.node(_nids[1]); rite.compress(ab); assert _size == ab.position()-pos:"reported size = " + _size + " , real size = " + (ab.position()-pos); return ab; } } public static abstract class LeafNode extends Node { public double _pred; public LeafNode( DTree tree, int pid ) { super(tree,pid); } public LeafNode( DTree tree, int pid, int nid ) { super(tree,pid,nid); } @Override public String toString() { return "Leaf#"+_nid+" = "+_pred; } @Override public final StringBuilder toString2(StringBuilder sb, int depth) { for( int d=0; d<depth; d++ ) sb.append(" "); sb.append(_nid).append(" "); return sb.append("pred=").append(_pred).append("\n"); } public final double pred() { return _pred; } public final void pred(double pred) { _pred = pred; } } static public final boolean isRootNode(Node n) { return n._pid == -1; } // -------------------------------------------------------------------------- public static abstract class TreeModel extends water.Model { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. @API(help="Expected max trees") public final int N; @API(help="MSE rate as trees are added") public final double [] errs; @API(help="Keys of actual trees built") public final Key [/*N*/][/*nclass*/] treeKeys; // Always filled, but 2-binary classifiers can contain null for 2nd class @API(help="Maximum tree depth") public final int max_depth; @API(help="Fewest allowed observations in a leaf") public final int min_rows; @API(help="Bins in the histograms") public final int nbins; // For classification models, we'll do a Confusion Matrix right in the // model (for now - really should be separate). @API(help="Testing key for cm and errs") public final Key testKey; // Confusion matrix per each generated tree or null @API(help="Confusion Matrix computed on training dataset, cm[actual][predicted]") public final ConfusionMatrix cms[/*CM-per-tree*/]; @API(help="Confusion matrix domain.") public final String[] cmDomain; @API(help="Variable importance for individual input variables.") public final VarImp varimp; // NOTE: in future we can have an array of different variable importance measures (per method) @API(help="Tree statistics") public final TreeStats treeStats; @API(help="AUC for validation dataset") public final AUCData validAUC; @API(help="Whether this is transformed from speedrf") public boolean isFromSpeeDRF=false; private final int num_folds; private transient volatile CompressedTree[/*N*/][/*nclasses OR 1 for regression*/] _treeBitsCache; public TreeModel( Key key, Key dataKey, Key testKey, String names[], String domains[][], String[] cmDomain, int ntrees, int max_depth, int min_rows, int nbins, int num_folds, float[] priorClassDist, float[] classDist) { this(key, dataKey, testKey, names, domains, cmDomain, ntrees, max_depth, min_rows, nbins, num_folds, priorClassDist, classDist, new Key[0][], new ConfusionMatrix[0], new double[0], null, null, null); } private TreeModel( Key key, Key dataKey, Key testKey, String names[], String domains[][], String[] cmDomain, int ntrees, int max_depth, int min_rows, int nbins, int num_folds, float[] priorClassDist, float[] classDist, Key[][] treeKeys, ConfusionMatrix[] cms, double[] errs, TreeStats treeStats, VarImp varimp, AUCData validAUC) { super(key,dataKey,names,domains,priorClassDist, classDist); this.N = ntrees; this.max_depth = max_depth; this.min_rows = min_rows; this.nbins = nbins; this.num_folds = num_folds; this.treeKeys = treeKeys; this.treeStats = treeStats; this.cmDomain = cmDomain!=null ? cmDomain : new String[0];; this.testKey = testKey; this.cms = cms; this.errs = errs; this.varimp = varimp; this.validAUC = validAUC; } // Simple copy ctor, null value of parameter means copy from prior-model protected TreeModel(TreeModel prior, Key[][] treeKeys, double[] errs, ConfusionMatrix[] cms, TreeStats tstats, VarImp varimp, AUCData validAUC) { super(prior._key,prior._dataKey,prior._names,prior._domains, prior._priorClassDist,prior._modelClassDist,prior.training_start_time,prior.training_duration_in_ms); this.N = prior.N; this.testKey = prior.testKey; this.max_depth = prior.max_depth; this.min_rows = prior.min_rows; this.nbins = prior.nbins; this.cmDomain = prior.cmDomain; this.num_folds = prior.num_folds; if (treeKeys != null) this.treeKeys = treeKeys; else this.treeKeys = prior.treeKeys; if (errs != null) this.errs = errs; else this.errs = prior.errs; if (cms != null) this.cms = cms; else this.cms = prior.cms; if (tstats != null) this.treeStats = tstats; else this.treeStats = prior.treeStats; if (varimp != null) this.varimp = varimp; else this.varimp = prior.varimp; if (validAUC != null) this.validAUC = validAUC; else this.validAUC = prior.validAUC; } // Additional copy ctors to update specific fields public TreeModel(TreeModel prior, DTree[] tree, double err, ConfusionMatrix cm, TreeStats tstats) { this(prior, append(prior.treeKeys, tree), Utils.append(prior.errs, err), Utils.append(prior.cms, cm), tstats, null, null); } public TreeModel(TreeModel prior, DTree[] tree, TreeStats tstats) { this(prior, append(prior.treeKeys, tree), null, null, tstats, null, null); } public TreeModel(TreeModel prior, double err, ConfusionMatrix cm, VarImp varimp, AUCData validAUC) { this(prior, null, Utils.append(prior.errs, err), Utils.append(prior.cms, cm), null, varimp, validAUC); } public enum TreeModelType { UNKNOWN, GBM, DRF, } protected TreeModelType getTreeModelType() { return TreeModelType.UNKNOWN; } /** Returns Producer if the model is under construction else null. * <p>The implementation looks for writer lock. If it is present, then returns true.</p> * * <p>WARNING: the method is strictly for UI used, does not provide any atomicity!!!</p>*/ private final Key getProducer() { return FetchProducer.fetch(_key); } private final boolean isProduced() { return getProducer()!=null; } private static final class FetchProducer extends DTask<FetchProducer> { final private Key _key; private Key _producer; public static Key fetch(Key key) { FetchProducer fp = new FetchProducer(key); if (key.home()) fp.compute2(); else fp = RPC.call(key.home_node(), fp).get(); return fp._producer; } private FetchProducer(Key k) { _key = k; } @Override public void compute2() { Lockable l = UKV.get(_key); _producer = l!=null && l._lockers!=null && l._lockers.length > 0 ? l._lockers[0] : null; tryComplete(); } @Override public byte priority() { return H2O.ATOMIC_PRIORITY; } } private static final Key[][] append(Key[][] prior, DTree[] tree ) { if (tree==null) return prior; prior = Arrays.copyOf(prior, prior.length+1); Key ts[] = prior[prior.length-1] = new Key[tree.length]; for( int c=0; c<tree.length; c++ ) if( tree[c] != null ) { ts[c] = tree[c].save(); } return prior; } /** Number of trees in current model. */ public int ntrees() { return treeKeys.length; } // Most recent ConfusionMatrix @Override public ConfusionMatrix cm() { ConfusionMatrix[] cms = this.cms; // Avoid race update; read it once if(cms != null && cms.length > 0){ int n = cms.length-1; while(n > 0 && cms[n] == null)--n; return cms[n] == null?null:cms[n]; } else return null; } @Override public VarImp varimp() { return varimp; } @Override public double mse() { if(errs != null && errs.length > 0){ int n = errs.length-1; while(n > 0 && Double.isNaN(errs[n]))--n; return errs[n]; } else return Double.NaN; } @Override protected float[] score0(double data[], float preds[]) { // Prefetch trees into the local cache if it is necessary // Invoke scoring Arrays.fill(preds,0); for( int tidx=0; tidx<treeKeys.length; tidx++ ) score0(data, preds, tidx); return preds; } /** Returns i-th tree represented by an array of k-trees. */ public final CompressedTree[] ctree(int tidx) { if (_treeBitsCache==null) { synchronized(this) { if (_treeBitsCache==null) _treeBitsCache = new CompressedTree[ntrees()][]; } } if (_treeBitsCache[tidx]==null) { synchronized(this) { if (_treeBitsCache[tidx]==null) { Key[] k = treeKeys[tidx]; CompressedTree[] ctree = new CompressedTree[nclasses()]; for (int i = 0; i < nclasses(); i++) // binary classifiers can contains null for second tree if (k[i]!=null) ctree[i] = UKV.get(k[i]); _treeBitsCache[tidx] = ctree; } } } return _treeBitsCache[tidx]; } // Score per line per tree public void score0(double data[], float preds[], int treeIdx) { CompressedTree ts[] = ctree(treeIdx); DTreeUtils.scoreTree(data, preds, ts); } /** Delete model trees */ public void delete_trees() { Futures fs = new Futures(); delete_trees(fs); fs.blockForPending(); } public Futures delete_trees(Futures fs) { for (int tid = 0; tid < treeKeys.length; tid++) /* over all trees */ for (int cid = 0; cid < treeKeys[tid].length; cid++) /* over all classes */ // 2-binary classifiers can contain null for the second if (treeKeys[tid][cid]!=null) DKV.remove(treeKeys[tid][cid], fs); return fs; } // If model is deleted then all trees has to be delete as well @Override public Futures delete_impl(Futures fs) { delete_trees(fs); super.delete_impl(fs); return fs; } @Override public ModelAutobufferSerializer getModelSerializer() { // Return a serializer which knows how to serialize keys return new ModelAutobufferSerializer() { @Override protected AutoBuffer postSave(Model m, AutoBuffer ab) { int ntrees = treeKeys.length; ab.put4(ntrees); for (int i=0; i<ntrees; i++) { CompressedTree[] ts = ctree(i); ab.putA(ts); } return ab; } @Override protected AutoBuffer postLoad(Model m, AutoBuffer ab) { int ntrees = ab.get4(); Futures fs = new Futures(); for (int i=0; i<ntrees; i++) { CompressedTree[] ts = ab.getA(CompressedTree.class); for (int j=0; j<ts.length; j++) { Key k = ((TreeModel) m).treeKeys[i][j]; assert k == null && ts[j] == null || k != null && ts[j] != null : "Incosistency in model serialization: key is null but model is not null, OR vice versa!"; if (k!=null) { UKV.put(k, ts[j], fs); } } } fs.blockForPending(); return ab; } }; } public void generateHTML(String title, StringBuilder sb) { DocGen.HTML.title(sb,title); sb.append("<div class=\"alert\">").append("Actions: "); if (_dataKey != null) sb.append(Inspect2.link("Inspect training data ("+_dataKey.toString()+")", _dataKey)).append(", "); sb.append(Predict.link(_key,"Score on dataset")).append(", "); if (_dataKey != null) sb.append(UIUtils.builderModelLink(this.getClass(), _dataKey, responseName(), "Compute new model")).append(", "); sb.append(UIUtils.qlink(SaveModel.class, "model", _key, "Save model")).append(", "); if (isProduced()) { // looks at locker field and check W-locker guy sb.append("<i class=\"icon-stop\"></i> ").append(Cancel.link(getProducer(), "Stop training this model")); } else { sb.append("<i class=\"icon-play\"></i> ").append(UIUtils.builderLink(this.getClass(), _dataKey, responseName(), this._key, "Continue training this model")); } sb.append("</div>"); DocGen.HTML.paragraph(sb,"Model Key: "+_key); DocGen.HTML.paragraph(sb,"Max depth: "+max_depth+", Min rows: "+min_rows+", Nbins:"+nbins+", Trees: " + ntrees()); generateModelDescription(sb); sb.append("</pre>"); String[] domain = cmDomain; // Domain of response col // Generate a display using the last scored Model. Not all models are // scored immediately (since scoring can be a big part of model building). ConfusionMatrix cm = null; int last = cms.length-1; while( last > 0 && cms[last]==null ) last--; cm = 0 <= last && last < cms.length ? cms[last] : null; // Display the CM if( cm != null && domain != null ) { // Top row of CM assert cm._arr.length==domain.length; DocGen.HTML.title(sb,"Scoring"); if( testKey == null ) { if (_have_cv_results) sb.append("<div class=\"alert\">Reported on ").append(num_folds).append("-fold cross-validated training data</div>"); else { sb.append("<div class=\"alert\">Reported on ").append(title.contains("DRF") ? "out-of-bag" : "training").append(" data"); if (num_folds > 0) sb.append(" (cross-validation results are being computed - please reload this page later)"); sb.append("."); if (_priorClassDist!=null && _modelClassDist!=null) sb.append("<br />Data were resampled to balance class distribution."); sb.append("</div>"); } } else { RString rs = new RString("<div class=\"alert\">Reported on <a href='Inspect2.html?src_key=%$key'>%key</a></div>"); rs.replace("key", testKey); DocGen.HTML.paragraph(sb,rs.toString()); } if (validAUC == null) { //AUC shows the CM already // generate HTML for CM DocGen.HTML.section(sb, "Confusion Matrix"); cm.toHTML(sb, domain); } } if( errs != null ) { if (!isClassifier() && num_folds > 0) { if (_have_cv_results) DocGen.HTML.section(sb, num_folds + "-fold cross-validated Mean Squared Error: " + String.format("%5.3f", errs[errs.length-1])); else DocGen.HTML.section(sb, num_folds + "-fold cross-validated Mean Squared Error is being computed - please reload this page later."); } DocGen.HTML.section(sb,"Mean Squared Error by Tree"); DocGen.HTML.arrayHead(sb); sb.append("<tr style='min-width:60px'><th>Trees</th>"); last = errs.length-1-(_have_cv_results?1:0); // for regressor reports all errors (except for cross-validated result) for( int i=last; i>=0; i-- ) sb.append("<td style='min-width:60px'>").append(i).append("</td>"); sb.append("</tr>"); sb.append("<tr><th class='warning'>MSE</th>"); for( int i=last; i>=0; i-- ) sb.append(!Double.isNaN(errs[i]) ? String.format("<td style='min-width:60px'>%5.3f</td>",errs[i]) : "<td style='min-width:60px'>---</td>"); sb.append("</tr>"); DocGen.HTML.arrayTail(sb); } // Show AUC for binary classifiers if (validAUC != null) generateHTMLAUC(sb); // Show tree stats if (treeStats != null) generateHTMLTreeStats(sb); // Show variable importance if (varimp != null) { generateHTMLVarImp(sb); } printCrossValidationModelsHTML(sb); } static final String NA = "---"; protected void generateHTMLTreeStats(StringBuilder sb) { DocGen.HTML.section(sb,"Tree stats"); DocGen.HTML.arrayHead(sb); sb.append("<tr><th> </th>").append("<th>Min</th><th>Mean</th><th>Max</th></tr>"); boolean valid = treeStats.isValid(); sb.append("<tr><th>Depth</th>") .append("<td>").append(valid ? treeStats.minDepth : NA).append("</td>") .append("<td>").append(valid ? treeStats.meanDepth : NA).append("</td>") .append("<td>").append(valid ? treeStats.maxDepth : NA).append("</td></tr>"); sb.append("<th>Leaves</th>") .append("<td>").append(valid ? treeStats.minLeaves : NA).append("</td>") .append("<td>").append(valid ? treeStats.meanLeaves : NA).append("</td>") .append("<td>").append(valid ? treeStats.maxLeaves : NA).append("</td></tr>"); DocGen.HTML.arrayTail(sb); } protected void generateHTMLVarImp(StringBuilder sb) { if (varimp!=null) { // Set up variable names for importance varimp.setVariables(Arrays.copyOf(_names, _names.length-1)); varimp.toHTML(this, sb); } } protected void generateHTMLAUC(StringBuilder sb) { validAUC.toHTML(sb); } public static class TreeStats extends Iced { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. @API(help="Minimal tree depth.") public int minDepth = Integer.MAX_VALUE; @API(help="Maximum tree depth.") public int maxDepth = Integer.MIN_VALUE; @API(help="Average tree depth.") public float meanDepth; @API(help="Minimal num. of leaves.") public int minLeaves = Integer.MAX_VALUE; @API(help="Maximum num. of leaves.") public int maxLeaves = Integer.MIN_VALUE; @API(help="Average num. of leaves.") public float meanLeaves; transient long sumDepth = 0; transient long sumLeaves = 0; transient int numTrees = 0; public boolean isValid() { return minDepth <= maxDepth; } public void updateBy(DTree[] ktrees) { if (ktrees==null) return; for (int i=0; i<ktrees.length; i++) { DTree tree = ktrees[i]; if( tree == null ) continue; if (minDepth > tree.depth) minDepth = tree.depth; if (maxDepth < tree.depth) maxDepth = tree.depth; if (minLeaves > tree.leaves) minLeaves = tree.leaves; if (maxLeaves < tree.leaves) maxLeaves = tree.leaves; sumDepth += tree.depth; sumLeaves += tree.leaves; numTrees++; meanDepth = ((float)sumDepth / numTrees); meanLeaves = ((float)sumLeaves / numTrees); } } public void setNumTrees(int i) { numTrees = i; } } // -------------------------------------------------------------------------- // Highly compressed tree encoding: // tree: 1B nodeType, 2B colId, 4B splitVal, left-tree-size, left, right // nodeType: (from lsb): // 2 bits (1,2) skip-tree-size-size, // 2 bits (4,8) operator flag (0 --> <, 1 --> ==, 2 --> small (4B) group, 3 --> big (var size) group), // 1 bit ( 16) left leaf flag, // 1 bit ( 32) left leaf type flag (0: subtree, 1: small cat, 2: big cat, 3: float) // 1 bit ( 64) right leaf flag, // 1 bit (128) right leaf type flag (0: subtree, 1: small cat, 2: big cat, 3: float) // left, right: tree | prediction // prediction: 4 bytes of float public static class CompressedTree extends Iced { final byte [] _bits; final int _nclass; final long _seed; public CompressedTree( byte [] bits, int nclass, long seed ) { _bits = bits; _nclass = nclass; _seed = seed; } public float score( final double row[] ) { AutoBuffer ab = new AutoBuffer(_bits); while(true) { int nodeType = ab.get1(); int colId = ab.get2(); if( colId == 65535 ) return scoreLeaf(ab); // boolean equal = ((nodeType&4)==4); int equal = (nodeType&12) >> 2; assert (equal >= 0 && equal <= 3): "illegal equal value " + equal+" at "+ab.position()+" in bitpile "+Arrays.toString(_bits); // Extract value or group to split on float splitVal = -1; boolean grpContains = false; if(equal == 0 || equal == 1) { splitVal = ab.get4f(); } else { int off = (equal == 3) ? ab.get2() : 0; // number of zero-bits skipped during serialization int sz = (equal == 3) ? ab.get2() : 4; // size of serialized bitset (part containing some non-zeros) in bytes int idx = (int)row[colId]; // the input value driving decision if(Double.isNaN(row[colId]) || idx < off ) { grpContains = false; ab.skip(sz); } else { idx = idx - off; int bbskip = idx >> 3; if (sz-bbskip>0) { ab.skip(bbskip); grpContains = (ab.get1() & ((byte)1 << (idx % 8))) != 0; ab.skip(sz-bbskip-1); } else { // value is not in bit set at all (it is even out of value) grpContains = false; ab.skip(sz); } } } // Compute the amount to skip. int lmask = nodeType & 0x33; int rmask = (nodeType & 0xC0) >> 2; int skip = 0; switch(lmask) { case 0: skip = ab.get1(); break; case 1: skip = ab.get2(); break; case 2: skip = ab.get3(); break; case 3: skip = ab.get4(); break; case 16: skip = _nclass < 256?1:2; break; // Small leaf case 48: skip = 4; break; // skip the prediction default: assert false:"illegal lmask value " + lmask+" at "+ab.position()+" in bitpile "+Arrays.toString(_bits); } // WARNING: Generated code has to be consistent with this code: // - Double.NaN < 3.7f => return false => BUT left branch has to be selected (i.e., ab.position()) // - Double.NaN != 3.7f => return true => left branch has to be select selected (i.e., ab.position()) if( !Double.isNaN(row[colId]) ) { // NaNs always go to bin 0 if( ( equal==0 && ((float)row[colId]) >= splitVal) || ( equal==1 && ((float)row[colId]) == splitVal) || ( (equal==2 || equal==3) && grpContains )) { ab.position(ab.position()+skip); // Skip to the right subtree lmask = rmask; // And set the leaf bits into common place } } /* else Double.isNaN() is true => use left branch */ if( (lmask&16)==16 ) return scoreLeaf(ab); } } private float scoreLeaf( AutoBuffer ab ) { return ab.get4f(); } public Random rngForChunk( int cidx ) { Random rand = createRNG(_seed); // Argh - needs polishment for( int i=0; i<cidx; i++ ) rand.nextLong(); long seed = rand.nextLong(); return createRNG(seed); } } /** Abstract visitor class for serialized trees.*/ public static abstract class TreeVisitor<T extends Exception> { // Override these methods to get walker behavior. protected void pre ( int col, float fcmp, IcedBitSet gcmp, int equal ) throws T { } protected void mid ( int col, float fcmp, int equal ) throws T { } protected void post( int col, float fcmp, int equal ) throws T { } protected void leaf( float pred ) throws T { } long result( ) { return 0; } // Override to return simple results protected final TreeModel _tm; protected final CompressedTree _ct; private final AutoBuffer _ts; protected int _depth; // actual depth protected int _nodes; // number of visited nodes public TreeVisitor( TreeModel tm, CompressedTree ct ) { _tm = tm; _ts = new AutoBuffer((_ct=ct)._bits); } // Call either the single-class leaf or the full-prediction leaf private final void leaf2( int mask ) throws T { assert (mask==0 || ( (mask&16)==16 && (mask&32)==32) ) : "Unknown mask: " + mask; // Is a leaf or a special leaf on the top of tree leaf(_ts.get4f()); } public final void visit() throws T { int nodeType = _ts.get1(); int col = _ts.get2(); if( col==65535 ) { leaf2(nodeType); return; } // float fcmp = _ts.get4f(); // boolean equal = ((nodeType&4)==4); int equal = (nodeType&12) >> 2; // Extract value or group to split on float fcmp = -1; IcedBitSet gcmp = null; if(equal == 0 || equal == 1) fcmp = _ts.get4f(); else { int off = (equal == 3) ? _ts.get2() : 0; int sz = (equal == 3) ? _ts.get2() : 4; byte[] buf = MemoryManager.malloc1(sz); _ts.read(buf, 0, sz); gcmp = new IcedBitSet(buf, sz << 3, off); } // Compute the amount to skip. int lmask = nodeType & 0x33; int rmask = (nodeType & 0xC0) >> 2; int skip = 0; switch(lmask) { case 0: skip = _ts.get1(); break; case 1: skip = _ts.get2(); break; case 2: skip = _ts.get3(); break; case 3: skip = _ts.get4(); break; case 16: skip = _ct._nclass < 256?1:2; break; // Small leaf case 48: skip = 4; break; // skip is always 4 for direct leaves (see DecidedNode.size() and LeafNode.size() methods) default: assert false:"illegal lmask value " + lmask; } pre(col,fcmp,gcmp,equal); // Pre-walk _depth++; if( (lmask & 0x10)==16 ) leaf2(lmask); else visit(); mid(col,fcmp,equal); // Mid-walk if( (rmask & 0x10)==16 ) leaf2(rmask); else visit(); _depth--; post(col,fcmp,equal); _nodes++; } } StringBuilder toString(final String res, CompressedTree ct, final StringBuilder sb ) { new TreeVisitor<RuntimeException>(this,ct) { @Override protected void pre( int col, float fcmp, IcedBitSet gcmp, int equal ) { for( int i=0; i<_depth; i++ ) sb.append(" "); if(equal == 2 || equal == 3) sb.append(_names[col]).append("==").append(gcmp.toString()).append('\n'); else sb.append(_names[col]).append(equal==1?"==":"< ").append(fcmp).append('\n'); } @Override protected void leaf( float pred ) { for( int i=0; i<_depth; i++ ) sb.append(" "); sb.append(res).append("=").append(pred).append(";\n"); } }.visit(); return sb; } // For GBM: learn_rate. For DRF: mtries, sample_rate, seed. abstract protected void generateModelDescription(StringBuilder sb); // Determine whether feature is licensed. private boolean isFeatureAllowed() { boolean featureAllowed = false; try { if (treeStats.numTrees <= 10) { featureAllowed = true; } else { if (getTreeModelType() == TreeModelType.GBM) { featureAllowed = H2O.licenseManager.isFeatureAllowed(LicenseManager.FEATURE_GBM_SCORING); } else if (getTreeModelType() == TreeModelType.DRF) { featureAllowed = H2O.licenseManager.isFeatureAllowed(LicenseManager.FEATURE_RF_SCORING); } } } catch (Exception xe) {} return featureAllowed; } public void toJavaHtml( StringBuilder sb ) { if( treeStats == null ) return; // No trees yet sb.append("<br /><br /><div class=\"pull-right\"><a href=\"#\" onclick=\'$(\"#javaModel\").toggleClass(\"hide\");\'" + "class=\'btn btn-inverse btn-mini\'>Java Model</a></div><br /><div class=\"hide\" id=\"javaModel\">"); boolean featureAllowed = isFeatureAllowed(); if (! featureAllowed) { sb.append("<br/><div id=\'javaModelWarningBlock\' class=\"alert\" style=\"background:#eedd20;color:#636363;text-shadow:none;\">"); sb.append("<b>You have requested a premium feature (> 10 trees) and your H<sub>2</sub>O software is unlicensed.</b><br/><br/>"); sb.append("Please enter your email address below, and we will send you a trial license shortly.<br/>"); sb.append("This will also temporarily enable downloading Java models.<br/>"); sb.append("<form class=\'form-inline\'><input id=\"emailForJavaModel\" class=\"span5\" type=\"text\" placeholder=\"Email\"/> "); sb.append("<a href=\"#\" onclick=\'processJavaModelLicense();\' class=\'btn btn-inverse\'>Send</a></form></div>"); sb.append("<div id=\"javaModelSource\" class=\"hide\">"); } if( ntrees() * treeStats.meanLeaves > 5000 ) { String modelName = JCodeGen.toJavaId(_key.toString()); sb.append("<pre style=\"overflow-y:scroll;\"><code class=\"language-java\">"); sb.append("/* Java code is too large to display, download it directly.\n"); sb.append(" To obtain the code please invoke in your terminal:\n"); sb.append(" curl http:/").append(H2O.SELF.toString()).append("/h2o-model.jar > h2o-model.jar\n"); sb.append(" curl http:/").append(H2O.SELF.toString()).append("/2/").append(this.getClass().getSimpleName()).append("View.java?_modelKey=").append(_key).append(" > ").append(modelName).append(".java\n"); sb.append(" javac -cp h2o-model.jar -J-Xmx2g -J-XX:MaxPermSize=128m ").append(modelName).append(".java\n"); if (GEN_BENCHMARK_CODE) sb.append(" java -cp h2o-model.jar:. -Xmx2g -XX:MaxPermSize=256m -XX:ReservedCodeCacheSize=256m ").append(modelName).append('\n'); sb.append("*/"); sb.append("</code></pre>"); } else { sb.append("<pre style=\"overflow-y:scroll;\"><code class=\"language-java\">"); DocGen.HTML.escape(sb, toJava()); sb.append("</code></pre>"); } if (!featureAllowed) sb.append("</div>"); // close license blog sb.append("</div>"); sb.append("<script type=\"text/javascript\">$(document).ready(showOrHideJavaModel);</script>"); } @Override protected SB toJavaInit(SB sb, SB fileContextSB) { sb = super.toJavaInit(sb, fileContextSB); String modelName = JCodeGen.toJavaId(_key.toString()); // Generate main method with benchmark if (GEN_BENCHMARK_CODE) { sb.i().p("/**").nl(); sb.i().p(" * Sample program harness providing an example of how to call predict().").nl(); sb.i().p(" */").nl(); sb.i().p("public static void main(String[] args) throws Exception {").nl(); sb.i(1).p("int iters = args.length > 0 ? Integer.valueOf(args[0]) : DEFAULT_ITERATIONS;").nl(); sb.i(1).p(modelName).p(" model = new ").p(modelName).p("();").nl(); sb.i(1).p("model.bench(iters, DataSample.DATA, new float[NCLASSES+1], NTREES);").nl(); sb.i().p("}").nl(); sb.di(1); sb.p(TO_JAVA_BENCH_FUNC); } JCodeGen.toStaticVar(sb, "NTREES", ntrees(), "Number of trees in this model."); JCodeGen.toStaticVar(sb, "NTREES_INTERNAL", ntrees()*nclasses(), "Number of internal trees in this model (= NTREES*NCLASSES)."); if (GEN_BENCHMARK_CODE) JCodeGen.toStaticVar(sb, "DEFAULT_ITERATIONS", 10000, "Default number of iterations."); // Generate a data in separated class since we do not want to influence size of constant pool of model class if (GEN_BENCHMARK_CODE) { if( _dataKey != null ) { Value dataval = DKV.get(_dataKey); if (dataval != null) { water.fvec.Frame frdata = dataval.get(); water.fvec.Frame frsub = frdata.subframe(_names); JCodeGen.toClass(fileContextSB, "// Sample of data used by benchmark\nclass DataSample", "DATA", frsub, 10, "Sample test data."); } } } return sb; } // Convert Tree model to Java @Override protected void toJavaPredictBody( final SB bodySb, final SB classCtxSb, final SB fileCtxSb) { // AD-HOC maximal number of trees in forest - in fact constant pool size for Forest class (all UTF String + references to static classes). // TODO: in future this parameter can be a parameter for generator, as well as maxIters final int maxfsize = 4000; int fidx = 0; // forest index int treesInForest = 0; SB forest = new SB(); // divide trees into small forests per 100 trees /* DEBUG line */ bodySb.i().p("// System.err.println(\"Row (gencode.predict): \" + java.util.Arrays.toString(data));").nl(); bodySb.i().p("java.util.Arrays.fill(preds,0f);").nl(); if (isFromSpeeDRF) { bodySb.i().p("// Call forest predicting class ").p(0).nl(); bodySb.i().p("preds").p(" =").p(" Forest_").p(fidx).p("_class_").p(0).p(".predict(data, maxIters - " + fidx * maxfsize + ");").nl(); } for( int c=0; c<nclasses(); c++ ) { toJavaForestBegin(bodySb, forest, c, fidx++, maxfsize); for( int i=0; i < treeKeys.length; i++ ) { CompressedTree cts[] = ctree(i); if( cts[c] == null ) continue; if (!isFromSpeeDRF) { forest.i().p("if (iters-- > 0) pred").p(" +=").p(" Tree_").p(i).p("_class_").p(c).p(".predict(data);").nl(); } else { forest.i().p("pred[(int)").p(" Tree_").p(i).p("_class_").p(c).p(".predict(data) + 1] += 1;").nl(); } // append representation of tree predictor toJavaTreePredictFct(fileCtxSb, cts[c], i, c); if (++treesInForest == maxfsize) { toJavaForestEnd(bodySb, forest, c, fidx); toJavaForestBegin(bodySb, forest, c, fidx++, maxfsize); treesInForest = 0; } } toJavaForestEnd(bodySb, forest, c, fidx); treesInForest = 0; fidx = 0; } fileCtxSb.p(forest); toJavaUnifyPreds(bodySb); toJavaFillPreds0(bodySb); } /* Numeric type used in generated code to hold predicted value between the calls. */ static final String PRED_TYPE = "float"; private void toJavaForestBegin(SB predictBody, SB forest, int c, int fidx, int maxTreesInForest) { // ugly hack here if (!isFromSpeeDRF) { predictBody.i().p("// Call forest predicting class ").p(c).nl(); predictBody.i().p("preds[").p(c + 1).p("] +=").p(" Forest_").p(fidx).p("_class_").p(c).p(".predict(data, maxIters - " + fidx * maxTreesInForest + ");").nl(); forest.i().p("// Forest representing a subset of trees scoring class ").p(c).nl(); forest.i().p("class Forest_").p(fidx).p("_class_").p(c).p(" {").nl().ii(1); forest.i().p("public static ").p(PRED_TYPE).p(" predict(double[] data, int maxIters) {").nl().ii(1); forest.i().p(PRED_TYPE).p(" pred = 0;").nl(); forest.i().p("int iters = maxIters;").nl(); } else { forest.i().p("// Forest representing a subset of trees scoring class ").p(c).nl(); forest.i().p("class Forest_").p(fidx).p("_class_").p(c).p(" {").nl().ii(1); forest.i().p("public static ").p(PRED_TYPE).p("[] predict(double[] data, int maxIters) {").nl().ii(1); forest.i().p(PRED_TYPE).p("[] pred = new float["+(nclasses()+1)+"];").nl(); forest.i().p("java.util.Arrays.fill(pred,0f);").nl(); forest.i().p("int iters = maxIters;").nl(); } } private void toJavaForestEnd(SB predictBody, SB forest, int c, int fidx) { if (!isFromSpeeDRF) { forest.i().p("return pred;").nl(); forest.i().p("}").di(1).nl(); // end of function forest.i().p("}").di(1).nl(); // end of forest classs } else { if (c ==0) { forest.i().p("float sum = 0;").nl(); forest.i().p("for (int i=1; i <= " + nclasses() + "; i++) {").p("sum += pred[i];").p("}").nl(); forest.i().p("for (int i=1; i <= " + nclasses() + "; i++) {").p("pred[i] /= sum;").p("}").nl(); } forest.i().p("return pred;").nl(); forest.i().p("}").di(1).nl(); // end of function forest.i().p("}").di(1).nl(); // end of forest classs } } // Produce prediction code for one tree protected void toJavaTreePredictFct(final SB sb, final CompressedTree cts, int treeIdx, int classIdx) { // generate top-level class definition sb.nl(); sb.i().p("// Tree predictor for ").p(treeIdx).p("-tree and ").p(classIdx).p("-class").nl(); sb.i().p("class Tree_").p(treeIdx).p("_class_").p(classIdx).p(" {").nl().ii(1); new TreeJCodeGen(this,cts, sb).generate(); sb.i().p("}").nl(); // close the class } @Override protected String toJavaDefaultMaxIters() { return String.valueOf(this.N); } } // Build a compressed-tree struct public TreeModel.CompressedTree compress() { int sz = root().size(); if( root() instanceof LeafNode ) sz += 3; // Oops - tree-stump AutoBuffer ab = new AutoBuffer(sz); if( root() instanceof LeafNode ) // Oops - tree-stump ab.put1(0).put2((char)65535); // Flag it special so the decompress doesn't look for top-level decision root().compress(ab); // Compress whole tree assert ab.position() == sz; return new TreeModel.CompressedTree(ab.buf(),_nclass,_seed); } /** Save this tree into DKV store under default random Key. */ public Key save() { return save(defaultTreeKey()); } /** Save this tree into DKV store under the given Key. */ public Key save(Key k) { CompressedTree ts = compress(); UKV.put(k, ts); return k; } private Key defaultTreeKey() { return Key.makeSystem("__Tree_"+Key.rand()); } private static final SB TO_JAVA_BENCH_FUNC = new SB(). nl(). p(" /**").nl(). p(" * Run a predict() benchmark with the generated model and some synthetic test data.").nl(). p(" *").nl(). p(" * @param iters number of iterations to run; each iteration predicts on every sample (i.e. row) in the test data").nl(). p(" * @param data test data to predict on").nl(). p(" * @param preds output predictions").nl(). p(" * @param ntrees number of trees").nl(). p(" */").nl(). p(" public void bench(int iters, double[][] data, float[] preds, int ntrees) {").nl(). p(" System.out.println(\"Iterations: \" + iters);").nl(). p(" System.out.println(\"Data rows : \" + data.length);").nl(). p(" System.out.println(\"Trees : \" + ntrees + \"x\" + (preds.length-1));").nl(). nl(). p(" long startMillis;").nl(). p(" long endMillis;").nl(). p(" long deltaMillis;").nl(). p(" double deltaSeconds;").nl(). p(" double samplesPredicted;").nl(). p(" double samplesPredictedPerSecond;").nl(). p(" System.out.println(\"Starting timing phase of \"+iters+\" iterations...\");").nl(). nl(). p(" startMillis = System.currentTimeMillis();").nl(). p(" for (int i=0; i<iters; i++) {").nl(). p(" // Uncomment the nanoTime logic for per-iteration prediction times.").nl(). p(" // long startTime = System.nanoTime();").nl(). nl(). p(" for (double[] row : data) {").nl(). p(" predict(row, preds);").nl(). p(" // System.out.println(java.util.Arrays.toString(preds) + \" : \" + (DOMAINS[DOMAINS.length-1]!=null?(DOMAINS[DOMAINS.length-1][(int)preds[0]]+\"~\"+DOMAINS[DOMAINS.length-1][(int)row[row.length-1]]):(preds[0] + \" ~ \" + row[row.length-1])) );").nl(). p(" }").nl(). nl(). p(" // long ttime = System.nanoTime()-startTime;").nl(). p(" // System.out.println(i+\". iteration took \" + (ttime) + \"ns: scoring time per row: \" + ttime/data.length +\"ns, scoring time per row and tree: \" + ttime/data.length/ntrees + \"ns\");").nl(). nl(). p(" if ((i % 1000) == 0) {").nl(). p(" System.out.println(\"finished \"+i+\" iterations (of \"+iters+\")...\");").nl(). p(" }").nl(). p(" }").nl(). p(" endMillis = System.currentTimeMillis();").nl(). nl(). p(" deltaMillis = endMillis - startMillis;").nl(). p(" deltaSeconds = (double)deltaMillis / 1000.0;").nl(). p(" samplesPredicted = data.length * iters;").nl(). p(" samplesPredictedPerSecond = samplesPredicted / deltaSeconds;").nl(). p(" System.out.println(\"finished in \"+deltaSeconds+\" seconds.\");").nl(). p(" System.out.println(\"samplesPredicted: \" + samplesPredicted);").nl(). p(" System.out.println(\"samplesPredictedPerSecond: \" + samplesPredictedPerSecond);").nl(). p(" }").nl(). nl(); static class TreeJCodeGen extends TreeVisitor<RuntimeException> { public static final int MAX_NODES = (1 << 12) / 4; // limit for a number decision nodes final byte _bits[] = new byte [100]; final float _fs [] = new float[100]; final SB _sbs [] = new SB [100]; final int _nodesCnt[] = new int [100]; SB _sb; SB _csb; SB _grpsplit; int _subtrees = 0; int _grpcnt = 0; public TreeJCodeGen(TreeModel tm, CompressedTree ct, SB sb) { super(tm, ct); _sb = sb; _csb = new SB(); _grpsplit = new SB(); } // code preamble protected void preamble(SB sb, int subtree) throws RuntimeException { String subt = subtree>0?String.valueOf(subtree):""; sb.i().p("static final ").p(TreeModel.PRED_TYPE).p(" predict").p(subt).p("(double[] data) {").nl().ii(1); // predict method for one tree sb.i().p(TreeModel.PRED_TYPE).p(" pred = "); } // close the code protected void closure(SB sb) throws RuntimeException { sb.p(";").nl(); sb.i(1).p("return pred;").nl().di(1); sb.i().p("}").nl(); // sb.p(_grpsplit).di(1); } @Override protected void pre( int col, float fcmp, IcedBitSet gcmp, int equal ) { if(equal == 2 || equal == 3 && gcmp != null) { _grpsplit.i(1).p("// ").p(gcmp.toString()).nl(); _grpsplit.i(1).p("public static final byte[] GRPSPLIT").p(_grpcnt).p(" = new byte[] ").p(gcmp.toStrArray()).p(";").nl(); } if( _depth > 0 ) { int b = _bits[_depth-1]; assert b > 0 : Arrays.toString(_bits)+"\n"+_sb.toString(); if( b==1 ) _bits[_depth-1]=3; if( b==1 || b==2 ) _sb.p('\n').i(_depth).p("?"); if( b==2 ) _sb.p(' ').pj(_fs[_depth-1]); // Dump the leaf containing float value if( b==2 || b==3 ) _sb.p('\n').i(_depth).p(":"); } if (_nodes>MAX_NODES) { _sb.p("predict").p(_subtrees).p("(data)"); _nodesCnt[_depth] = _nodes; _sbs[_depth] = _sb; _sb = new SB(); _nodes = 0; preamble(_sb, _subtrees); _subtrees++; } // All NAs are going always to the left _sb.p(" (Double.isNaN(data[").p(col).p("]) || "); if(equal == 0 || equal == 1) if (!_tm.isFromSpeeDRF) { _sb.p("(float) data[").p(col).p(" /* ").p(_tm._names[col]).p(" */").p("] ").p(equal == 1 ? "!= " : "< ").pj(fcmp); // then left and then right (left is !=) } else { _sb.p("(float) data[").p(col).p(" /* ").p(_tm._names[col]).p(" */").p("] ").p(equal == 1 ? "!= " : "<= ").pj(fcmp); // then left and then right (left is !=) } else { _sb.p("!water.genmodel.GeneratedModel.grpContains(GRPSPLIT").p(_grpcnt).p(", ").p(gcmp._offset).p(", (int) data[").p(col).p(" /* ").p(_tm._names[col]).p(" */").p("])"); _grpcnt++; } assert _bits[_depth]==0; _bits[_depth]=1; } @Override protected void leaf( float pred ) { assert _depth==0 || _bits[_depth-1] > 0 : Arrays.toString(_bits); // it can be degenerated tree if( _depth==0) { // it is de-generated tree _sb.pj(pred); } else if( _bits[_depth-1] == 1 ) { // No prior leaf; just memorize this leaf _bits[_depth-1]=2; _fs[_depth-1]=pred; } else { // Else==2 (prior leaf) or 3 (prior tree) if( _bits[_depth-1] == 2 ) _sb.p(" ? ").pj(_fs[_depth-1]).p(" "); else _sb.p('\n').i(_depth); _sb.p(": ").pj(pred); } } @Override protected void post( int col, float fcmp, int equal ) { _sb.p(')'); _bits[_depth]=0; if (_sbs[_depth]!=null) { closure(_sb); _csb.p(_sb); _sb = _sbs[_depth]; _nodes = _nodesCnt[_depth]; _sbs[_depth] = null; } } public void generate() { preamble(_sb, _subtrees++); // TODO: Need to pass along group split BitSet visit(); closure(_sb); _sb.p(_grpsplit).di(1); _sb.p(_csb); } } }