package hex.tree; import water.*; public abstract class DTreeScorer<T extends DTreeScorer<T>> extends MRTask<T> { protected final int _ncols; protected final int _nclass; protected final int _skip; protected final Key[][] _treeKeys; protected transient CompressedTree[][] _trees; protected SharedTree _st; public DTreeScorer(int ncols, int nclass, SharedTree st, Key[][] treeKeys) { _ncols = ncols; _nclass = nclass; _treeKeys = treeKeys; _st = st; _skip = _st.numSpecialCols(); } protected int ntrees() { return _trees.length; } @Override protected final void setupLocal() { int ntrees = _treeKeys.length; _trees = new CompressedTree[ntrees][]; for (int t=0; t<ntrees; t++) { Key[] treek = _treeKeys[t]; _trees[t] = new CompressedTree[treek.length]; // FIXME remove get by introducing fetch class for all trees for (int i=0; i<treek.length; i++) if (treek[i]!=null) _trees[t][i] = DKV.get(treek[i]).get(); } } protected void score0(double data[], double preds[], CompressedTree[] ts) { scoreTree(data, preds, ts); } /** Score given tree on the row of data. * @param data row of data * @param preds array to hold resulting prediction * @param ts a tree representation (single regression tree, or multi tree) */ public static void scoreTree(double data[], double preds[], CompressedTree[] ts) { for( int c=0; c<ts.length; c++ ) if( ts[c] != null ) preds[ts.length==1?0:c+1] += ts[c].score(data); } }