package hex.drf;
import hex.ConfusionMatrix;
import hex.VarImp;
import hex.drf.TreeMeasuresCollector.TreeMeasures;
import hex.drf.TreeMeasuresCollector.TreeSSE;
import hex.drf.TreeMeasuresCollector.TreeVotes;
import hex.gbm.DHistogram;
import hex.gbm.DTree;
import hex.gbm.DTree.DecidedNode;
import hex.gbm.DTree.LeafNode;
import hex.gbm.DTree.TreeModel.CompressedTree;
import hex.gbm.DTree.TreeModel.TreeStats;
import hex.gbm.DTree.UndecidedNode;
import hex.gbm.SharedTreeModelBuilder;
import water.*;
import water.H2O.H2OCountedCompleter;
import water.api.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.*;
import water.util.Log.Tag.Sys;
import java.util.Arrays;
import java.util.Random;
import static hex.drf.TreeMeasuresCollector.asSSE;
import static hex.drf.TreeMeasuresCollector.asVotes;
import static water.util.Utils.div;
import static water.util.Utils.sum;
// Random Forest Trees
public class DRF extends SharedTreeModelBuilder<DRF.DRFModel> {
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.
static final boolean DEBUG_DETERMINISTIC = false; // enable this for deterministic version of DRF. It will use same seed for each execution. I would prefere here to read this property from system properties.
@API(help = "Columns to randomly select at each level, or -1 for sqrt(#cols)", filter = Default.class, lmin=-1, lmax=100000)
int mtries = -1;
@API(help = "Sample rate, from 0. to 1.0", filter = Default.class, dmin=0, dmax=1, importance=ParamImportance.SECONDARY)
float sample_rate = 0.6666667f;
@API(help = "Seed for the random number generator (autogenerated)", filter = Default.class)
long seed = -1; // To follow R-semantics, each call of RF should provide different seed. -1 means seed autogeneration
@API(help = "Check non-contiguous group splits for categorical predictors", filter = Default.class, hide = true)
boolean do_grpsplit = true;
@API(help="Run on one node only; no network overhead but fewer cpus used. Suitable for small datasets.", filter=myClassFilter.class, importance=ParamImportance.SECONDARY)
public boolean build_tree_one_node = false;
class myClassFilter extends DRFCopyDataBoolean { myClassFilter() { super("source"); } }
@API(help = "Computed number of split features", importance=ParamImportance.EXPERT)
protected int _mtry; // FIXME remove and replace by mtries
@API(help = "Autogenerated seed", importance=ParamImportance.EXPERT)
protected long _seed; // FIXME remove and replace by seed
// Fixed seed generator for DRF
private static final Random _seedGenerator = Utils.getDeterRNG(0xd280524ad7fe0602L);
// --- Private data handled only on master node
// Classification or Regression:
// Tree votes/SSE of individual trees on OOB rows
private transient TreeMeasures _treeMeasuresOnOOB;
// Tree votes/SSE per individual features on permutated OOB rows
private transient TreeMeasures[/*features*/] _treeMeasuresOnSOOB;
// Variable importance beased on tree split decisions
private transient float[/*nfeatures*/] _improvPerVar;
/** DRF model holding serialized tree and implementing logic for scoring a row */
public static class DRFModel extends DTree.TreeModel {
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.
@API(help = "Model parameters", json = true)
private final DRF parameters; // This is used purely for printing values out.
@Override public final DRF get_params() { return parameters; }
@Override public final Request2 job() { return get_params(); }
@API(help = "Number of columns picked at each split") final int mtries;
@API(help = "Sample rate") final float sample_rate;
@API(help = "Seed") final long seed;
// Params that do not affect model quality:
//
public DRFModel(DRF params, Key key, Key dataKey, Key testKey, String names[], String domains[][], String[] cmDomain, int ntrees, int max_depth, int min_rows, int nbins, int mtries, float sample_rate, long seed, int num_folds, float[] priorClassDist, float[] classDist) {
super(key,dataKey,testKey,names,domains,cmDomain,ntrees, max_depth, min_rows, nbins, num_folds, priorClassDist, classDist);
this.parameters = Job.hygiene((DRF) params.clone());
this.mtries = mtries;
this.sample_rate = sample_rate;
this.seed = seed;
}
private DRFModel(DRFModel prior, DTree[] trees, TreeStats tstats) {
super(prior, trees, tstats);
this.parameters = prior.parameters;
this.mtries = prior.mtries;
this.sample_rate = prior.sample_rate;
this.seed = prior.seed;
}
private DRFModel(DRFModel prior, double err, ConfusionMatrix cm, VarImp varimp, AUCData validAUC) {
super(prior, err, cm, varimp, validAUC);
this.parameters = prior.parameters;
this.mtries = prior.mtries;
this.sample_rate = prior.sample_rate;
this.seed = prior.seed;
}
private DRFModel(DRFModel prior, Key[][] treeKeys, double[] errs, ConfusionMatrix[] cms, TreeStats tstats, VarImp varimp, AUCData validAUC) {
super(prior, treeKeys, errs, cms, tstats, varimp, validAUC);
this.parameters = prior.parameters;
this.mtries = prior.mtries;
this.sample_rate = prior.sample_rate;
this.seed = prior.seed;
}
@Override protected TreeModelType getTreeModelType() { return TreeModelType.DRF; }
@Override protected float[] score0(double data[], float preds[]) {
float[] p = super.score0(data, preds);
int ntrees = ntrees();
if (p.length==1) { if (ntrees>0) div(p, ntrees); } // regression - compute avg over all trees
else { // classification
float s = sum(p);
if (s>0) div(p, s); // unify over all classes
p[0] = ModelUtils.getPrediction(p, data);
}
return p;
}
@Override protected void generateModelDescription(StringBuilder sb) {
DocGen.HTML.paragraph(sb,"mtries: "+mtries+", Sample rate: "+sample_rate+", Seed: "+seed);
if (testKey==null && sample_rate==1f) {
sb.append("<div class=\"alert alert-danger\">There are no out-of-bag data to compute out-of-bag error estimate, since sampling rate is 1!</div>");
}
}
@Override protected void toJavaUnifyPreds(SB bodySb) {
if (isClassifier()) {
bodySb.i().p("float sum = 0;").nl();
bodySb.i().p("for(int i=1; i<preds.length; i++) sum += preds[i];").nl();
bodySb.i().p("if (sum>0) for(int i=1; i<preds.length; i++) preds[i] /= sum;").nl();
} else bodySb.i().p("preds[1] = preds[1]/NTREES;").nl();
}
@Override protected void setCrossValidationError(ValidatedJob job, double cv_error, water.api.ConfusionMatrix cm, AUCData auc, HitRatio hr) {
DRFModel drfm = ((DRF)job).makeModel(this, cv_error, cm.cm == null ? null : new ConfusionMatrix(cm.cm, cms[0].nclasses()), this.varimp, auc);
drfm._have_cv_results = true;
DKV.put(this._key, drfm); //overwrite this model
}
}
public Frame score( Frame fr ) { return ((DRFModel)UKV.get(dest())).score(fr); }
@Override protected Log.Tag.Sys logTag() { return Sys.DRF__; }
@Override protected DRFModel makeModel(Key outputKey, Key dataKey, Key testKey, int ntrees, String[] names, String[][] domains, String[] cmDomain, float[] priorClassDist, float[] classDist) {
return new DRFModel(this,outputKey,dataKey,validation==null?null:testKey,names,domains,cmDomain,ntrees, max_depth, min_rows, nbins, mtries, sample_rate, _seed, n_folds, priorClassDist, classDist);
}
@Override protected DRFModel makeModel( DRFModel model, double err, ConfusionMatrix cm, VarImp varimp, AUCData validAUC) {
return new DRFModel(model, err, cm, varimp, validAUC);
}
@Override protected DRFModel makeModel( DRFModel model, DTree ktrees[], TreeStats tstats) {
return new DRFModel(model, ktrees, tstats);
}
@Override protected DRFModel updateModel(DRFModel model, DRFModel checkpoint, boolean overwriteCheckpoint) {
// Do not forget to clone trees in case that we are not going to overwrite checkpoint
Key[][] treeKeys = null;
if (!overwriteCheckpoint) throw H2O.unimpl("Cloning of model trees is not implemented yet!");
else treeKeys = checkpoint.treeKeys;
return new DRFModel(model, treeKeys, checkpoint.errs, checkpoint.cms, checkpoint.treeStats, checkpoint.varimp, checkpoint.validAUC);
}
public DRF() { description = "Distributed RF"; ntrees = 50; max_depth = 20; min_rows = 1; }
/** Return the query link to this page */
public static String link(Key k, String content) {
RString rs = new RString("<a href='/2/DRF.query?source=%$key'>%content</a>");
rs.replace("key", k.toString());
rs.replace("content", content);
return rs.toString();
}
// ==========================================================================
/** Compute a DRF tree.
*
* Start by splitting all the data according to some criteria (minimize
* variance at the leaves). Record on each row which split it goes to, and
* assign a split number to it (for next pass). On *this* pass, use the
* split-number to build a per-split histogram, with a per-histogram-bucket
* variance. */
@Override protected void execImpl() {
try {
logStart();
buildModel(seed);
if (n_folds > 0) CrossValUtils.crossValidate(this);
} finally {
remove(); // Remove Job
// Ugly hack updating job state carried as parameters inside a model
state = UKV.<Job>get(self()).state;
new TAtomic<DRFModel>() {
@Override
public DRFModel atomic(DRFModel m) {
if (m != null) m.get_params().state = state;
return m;
}
}.invoke(dest());
}
}
@Override protected Response redirect() {
return DRFProgressPage.redirect(this, self(), dest());
}
@SuppressWarnings("unused")
@Override protected void init() {
super.init();
// Initialize local variables
_mtry = (mtries==-1) ? // classification: mtry=sqrt(_ncols), regression: mtry=_ncols/3
( classification ? Math.max((int)Math.sqrt(_ncols),1) : Math.max(_ncols/3,1)) : mtries;
if (!(1 <= _mtry && _mtry <= _ncols)) throw new IllegalArgumentException("Computed mtry should be in interval <1,#cols> but it is " + _mtry);
if (!(0.0 < sample_rate && sample_rate <= 1.0)) throw new IllegalArgumentException("Sample rate should be interval (0,1> but it is " + sample_rate);
if (DEBUG_DETERMINISTIC && seed == -1) _seed = 0x1321e74a0192470cL; // fixed version of seed
else if (seed == -1) _seed = _seedGenerator.nextLong(); else _seed = seed;
if (sample_rate==1f && validation!=null)
Log.warn(Sys.DRF__, "Sample rate is 100% and no validation dataset is specified. There are no OOB data to compute out-of-bag error estimation!");
if (!classification && do_grpsplit) {
Log.info(Sys.DRF__, "Group splitting not supported for DRF regression. Forcing group splitting to false.");
do_grpsplit = false;
}
}
@Override protected void initAlgo(DRFModel initialModel) {
// Initialize TreeVotes for classification, MSE arrays for regression
if (importance) initTreeMeasurements();
}
@Override protected void initWorkFrame(DRFModel initialModel, Frame fr) {
// Append number of trees participating in on-the-fly scoring
fr.add("OUT_BAG_TREES", response.makeZero());
// Prepare working columns
new SetWrkTask().doAll(fr);
// If there was a check point recompute tree_<_> and oob columns based on predictions from previous trees
// but only if OOB validation is requested.
if (validation==null && checkpoint!=null) {
Timer t = new Timer();
// Compute oob votes for each output level
new OOBScorer(_ncols, _nclass, sample_rate, initialModel.treeKeys).doAll(fr);
Log.info(logTag(), "Reconstructing oob stats from checkpointed model took " + t);
}
}
@Override protected DRFModel buildModel( DRFModel model, final Frame fr, String names[], String domains[][], final Timer t_build ) {
// The RNG used to pick split columns
Random rand = createRNG(_seed);
// To be deterministic get random numbers for previous trees and
// put random generator to the same state
for (int i=0; i<_ntreesFromCheckpoint; i++) rand.nextLong();
int tid;
DTree[] ktrees = null;
// Prepare tree statistics
TreeStats tstats = model.treeStats!=null ? model.treeStats : new TreeStats();
// Build trees until we hit the limit
for( tid=0; tid<ntrees; tid++) { // Building tid-tree
if (tid!=0 || checkpoint==null) { // do not make initial scoring if model already exist
model = doScoring(model, fr, ktrees, tid, tstats, tid==0, !hasValidation(), build_tree_one_node);
}
// At each iteration build K trees (K = nclass = response column domain size)
// TODO: parallelize more? build more than k trees at each time, we need to care about temporary data
// Idea: launch more DRF at once.
Timer kb_timer = new Timer();
ktrees = buildNextKTrees(fr,_mtry,sample_rate,rand,tid);
Log.info(logTag(), (tid+1) + ". tree was built " + kb_timer.toString());
if( !Job.isRunning(self()) ) break; // If canceled during building, do not bulkscore
// Check latest predictions
tstats.updateBy(ktrees);
}
if( Job.isRunning(self()) ) { // do not perform final scoring and finish
model = doScoring(model, fr, ktrees, tid, tstats, true, !hasValidation(), build_tree_one_node);
// Make sure that we did not miss any votes
// assert !importance || _treeMeasuresOnOOB.npredictors() == _treeMeasuresOnSOOB[0/*variable*/].npredictors() : "Missing some tree votes in variable importance voting?!";
}
return model;
}
private void initTreeMeasurements() {
assert importance : "Tree votes should be initialized only if variable importance is requested!";
_improvPerVar = new float[_ncols];
// Preallocate tree votes
if (classification) {
_treeMeasuresOnOOB = new TreeVotes(ntrees);
_treeMeasuresOnSOOB = new TreeVotes[_ncols];
for (int i=0; i<_ncols; i++) _treeMeasuresOnSOOB[i] = new TreeVotes(ntrees);
} else {
_treeMeasuresOnOOB = new TreeSSE(ntrees);
_treeMeasuresOnSOOB = new TreeSSE[_ncols];
for (int i=0; i<_ncols; i++) _treeMeasuresOnSOOB[i] = new TreeSSE(ntrees);
}
}
// /** On-the-fly version for varimp. After generation a new tree, its tree votes are collected on shuffled
// * OOB rows and variable importance is recomputed.
// * <p>
// * The <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#varimp">page</a> says:
// * <cite>
// * "In every tree grown in the forest, put down the oob cases and count the number of votes cast for the correct class.
// * Now randomly permute the values of variable m in the oob cases and put these cases down the tree.
// * Subtract the number of votes for the correct class in the variable-m-permuted oob data from the number of votes
// * for the correct class in the untouched oob data.
// * The average of this number over all trees in the forest is the raw importance score for variable m."
// * </cite>
// * </p>
// * */
// @Override
// protected VarImp doVarImpCalc(final DRFModel model, DTree[] ktrees, final int tid, final Frame fTrain, boolean scale) {
// // Check if we have already serialized 'ktrees'-trees in the model
// assert model.ntrees()-1-_ntreesFromCheckpoint == tid : "Cannot compute DRF varimp since 'ktrees' are not serialized in the model! tid="+tid;
// assert _treeMeasuresOnOOB.npredictors()-1 == tid : "Tree votes over OOB rows for this tree (var ktrees) were not found!";
// // Compute tree votes over shuffled data
// final CompressedTree[/*nclass*/] theTree = model.ctree(tid); // get the last tree FIXME we should pass only keys
// final int nclasses = model.nclasses();
// Futures fs = new Futures();
// for (int var=0; var<_ncols; var++) {
// final int variable = var;
// H2OCountedCompleter task4var = classification ? new H2OCountedCompleter() {
// @Override public void compute2() {
// // Compute this tree votes over all data over given variable
// TreeVotes cd = TreeMeasuresCollector.collectVotes(theTree, nclasses, fTrain, _ncols, sample_rate, variable);
// assert cd.npredictors() == 1;
// asVotes(_treeMeasuresOnSOOB[variable]).append(cd);
// tryComplete();
// }
// } : /* regression */ new H2OCountedCompleter() {
// @Override public void compute2() {
// // Compute this tree votes over all data over given variable
// TreeSSE cd = TreeMeasuresCollector.collectSSE(theTree, nclasses, fTrain, _ncols, sample_rate, variable);
// assert cd.npredictors() == 1;
// asSSE(_treeMeasuresOnSOOB[variable]).append(cd);
// tryComplete();
// }
// };
// fs.add(task4var);
// H2O.submitTask(task4var); // Fork computation
// }
// fs.blockForPending(); // Wait for results
// // Compute varimp for individual features (_ncols)
// final float[] varimp = new float[_ncols]; // output variable importance
// final float[] varimpSD = new float[_ncols]; // output variable importance sd
// for (int var=0; var<_ncols; var++) {
// double[/*2*/] imp = classification ? asVotes(_treeMeasuresOnSOOB[var]).imp(asVotes(_treeMeasuresOnOOB)) : asSSE(_treeMeasuresOnSOOB[var]).imp(asSSE(_treeMeasuresOnOOB));
// varimp [var] = (float) imp[0];
// varimpSD[var] = (float) imp[1];
// }
// return new VarImp.VarImpMDA(varimp, varimpSD, model.ntrees());
// }
/** Compute relative variable importance for RF model.
*
* See (45), (35) formulas in Friedman: Greedy Function Approximation: A Gradient boosting machine.
* Algo used here can be used for computation individual importance of features per output class. */
@Override protected VarImp doVarImpCalc(DRFModel model, DTree[] ktrees, int tid, Frame validationFrame, boolean scale) {
assert model.ntrees()-1-_ntreesFromCheckpoint == tid : "varimp computation expect model with already serialized trees: tid="+tid;
// Iterates over k-tree
for (DTree t : ktrees) { // Iterate over trees
if (t!=null) {
for (int n = 0; n< t.len()-t.leaves; n++)
if (t.node(n) instanceof DecidedNode) { // it is split node
DTree.Split split = t.decided(n)._split;
if (split.col()!=-1) // Skip impossible splits ~ leafs
_improvPerVar[split.col()] += split.improvement(); // least squares improvement
}
}
}
// Compute variable importance for all trees in model
float[] varimp = new float[model.nfeatures()];
int ntreesTotal = model.ntrees() * model.nclasses();
int maxVar = 0;
for (int var=0; var<_improvPerVar.length; var++) {
varimp[var] = _improvPerVar[var] / ntreesTotal;
if (varimp[var] > varimp[maxVar]) maxVar = var;
}
// scale varimp to scale 0..100
if (scale) {
float maxVal = varimp[maxVar];
for (int var=0; var<varimp.length; var++) varimp[var] /= maxVal;
}
return new VarImp.VarImpRI(varimp);
}
@Override public boolean supportsBagging() { return true; }
/** Fill work columns:
* - classification: set 1 in the corresponding wrk col according to row response
* - regression: copy response into work column (there is only 1 work column) */
private class SetWrkTask extends MRTask2<SetWrkTask> {
@Override public void map( Chunk chks[] ) {
Chunk cy = chk_resp(chks);
for( int i=0; i<cy._len; i++ ) {
if( cy.isNA0(i) ) continue;
if (classification) {
int cls = (int)cy.at80(i);
chk_work(chks,cls).set0(i,1L);
} else {
float pred = (float) cy.at0(i);
chk_work(chks,0).set0(i,pred);
}
}
}
}
// --------------------------------------------------------------------------
// Build the next random k-trees representing tid-th tree
private DTree[] buildNextKTrees(Frame fr, int mtrys, float sample_rate, Random rand, int tid) {
// We're going to build K (nclass) trees - each focused on correcting
// errors for a single class.
final DTree[] ktrees = new DTree[_nclass];
// Initial set of histograms. All trees; one leaf per tree (the root
// leaf); all columns
DHistogram hcs[][][] = new DHistogram[_nclass][1/*just root leaf*/][_ncols];
// Adjust nbins for the top-levels
int adj_nbins = Math.max((1<<(10-0)),nbins);
// Use for all k-trees the same seed. NOTE: this is only to make a fair
// view for all k-trees
long rseed = rand.nextLong();
// Initially setup as-if an empty-split had just happened
for( int k=0; k<_nclass; k++ ) {
assert (_distribution!=null && classification) || (_distribution==null && !classification);
if( _distribution == null || _distribution[k] != 0 ) { // Ignore missing classes
// The Boolean Optimization cannot be applied here for RF !
// This optimization assumes the 2nd tree of a 2-class system is the
// inverse of the first. This is false for DRF (and true for GBM) -
// DRF picks a random different set of columns for the 2nd tree.
//if( k==1 && _nclass==2 ) continue;
ktrees[k] = new DRFTree(fr,_ncols,(char)nbins,(char)_nclass,min_rows,mtrys,rseed);
boolean isBinom = classification;
new DRFUndecidedNode(ktrees[k],-1, DHistogram.initialHist(fr,_ncols,adj_nbins,hcs[k][0],min_rows,do_grpsplit,isBinom) ); // The "root" node
}
}
// Sample - mark the lines by putting 'OUT_OF_BAG' into nid(<klass>) vector
Timer t_1 = new Timer();
Sample ss[] = new Sample[_nclass];
for( int k=0; k<_nclass; k++)
if (ktrees[k] != null) ss[k] = new Sample((DRFTree)ktrees[k], sample_rate).dfork(0,new Frame(vec_nids(fr,k),vec_resp(fr,k)), build_tree_one_node);
for( int k=0; k<_nclass; k++)
if( ss[k] != null ) ss[k].getResult();
Log.debug(Sys.DRF__, "Sampling took: + " + t_1);
int[] leafs = new int[_nclass]; // Define a "working set" of leaf splits, from leafs[i] to tree._len for each tree i
// ----
// One Big Loop till the ktrees are of proper depth.
// Adds a layer to the trees each pass.
Timer t_2 = new Timer();
int depth=0;
for( ; depth<max_depth; depth++ ) {
if( !Job.isRunning(self()) ) return null;
hcs = buildLayer(fr, ktrees, leafs, hcs, true, build_tree_one_node);
// If we did not make any new splits, then the tree is split-to-death
if( hcs == null ) break;
}
Log.debug(Sys.DRF__, "Tree build took: " + t_2);
// Each tree bottomed-out in a DecidedNode; go 1 more level and insert
// LeafNodes to hold predictions.
Timer t_3 = new Timer();
for( int k=0; k<_nclass; k++ ) {
DTree tree = ktrees[k];
if( tree == null ) continue;
int leaf = leafs[k] = tree.len();
for( int nid=0; nid<leaf; nid++ ) {
if( tree.node(nid) instanceof DecidedNode ) {
DecidedNode dn = tree.decided(nid);
for( int i=0; i<dn._nids.length; i++ ) {
int cnid = dn._nids[i];
if( cnid == -1 || // Bottomed out (predictors or responses known constant)
tree.node(cnid) instanceof UndecidedNode || // Or chopped off for depth
(tree.node(cnid) instanceof DecidedNode && // Or not possible to split
((DecidedNode)tree.node(cnid))._split.col()==-1) ) {
LeafNode ln = new DRFLeafNode(tree,nid);
ln._pred = dn.pred(i); // Set prediction into the leaf
dn._nids[i] = ln.nid(); // Mark a leaf here
}
}
// Handle the trivial non-splitting tree
if( nid==0 && dn._split.col() == -1 )
new DRFLeafNode(tree,-1,0);
}
}
} // -- k-trees are done
Log.debug(Sys.DRF__, "Nodes propagation: " + t_3);
// ----
// Move rows into the final leaf rows
Timer t_4 = new Timer();
CollectPreds cp = new CollectPreds(ktrees,leafs).doAll(fr,build_tree_one_node);
if (importance) {
if (classification) asVotes(_treeMeasuresOnOOB).append(cp.rightVotes, cp.allRows); // Track right votes over OOB rows for this tree
else /* regression */ asSSE (_treeMeasuresOnOOB).append(cp.sse, cp.allRows);
}
Log.debug(Sys.DRF__, "CollectPreds done: " + t_4);
// Collect leaves stats
for (int i=0; i<ktrees.length; i++)
if( ktrees[i] != null )
ktrees[i].leaves = ktrees[i].len() - leafs[i];
// DEBUG: Print the generated K trees
//printGenerateTrees(ktrees);
return ktrees;
}
// Read the 'tree' columns, do model-specific math and put the results in the
// fs[] array, and return the sum. Dividing any fs[] element by the sum
// turns the results into a probability distribution.
@Override protected float score1( Chunk chks[], float fs[/*nclass*/], int row ) {
float sum=0;
for( int k=0; k<_nclass; k++ ) // Sum across of likelyhoods
sum+=(fs[k+1]=(float)chk_tree(chks,k).at0(row));
if (_nclass == 1) sum /= (float)chk_oobt(chks).at0(row); // for regression average per trees voted for this row (only trees which have row in "out-of-bag"
return sum;
}
@Override protected boolean inBagRow(Chunk[] chks, int row) {
return chk_oobt(chks).at80(row) == 0;
}
// Collect and write predictions into leafs.
private class CollectPreds extends MRTask2<CollectPreds> {
/* @IN */ final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
/* @OUT */ long rightVotes; // number of right votes over OOB rows (performed by this tree) represented by DTree[] _trees
/* @OUT */ long allRows; // number of all OOB rows (sampled by this tree)
/* @OUT */ float sse; // Sum of squares for this tree only
CollectPreds(DTree trees[], int leafs[]) { _trees=trees; }
@Override public void map( Chunk[] chks ) {
final Chunk y = importance ? chk_resp(chks) : null; // Response
final float [] rpred = importance ? new float [1+_nclass] : null; // Row prediction
final double[] rowdata = importance ? new double[_ncols] : null; // Pre-allocated row data
final Chunk oobt = chk_oobt(chks); // Out-of-bag rows counter over all trees
// Iterate over all rows
for( int row=0; row<oobt._len; row++ ) {
boolean wasOOBRow = false;
// For all tree (i.e., k-classes)
for( int k=0; k<_nclass; k++ ) {
final DTree tree = _trees[k];
if( tree == null ) continue; // Empty class is ignored
// If we have all constant responses, then we do not split even the
// root and the residuals should be zero.
if( tree.root() instanceof LeafNode ) continue;
final Chunk nids = chk_nids(chks,k); // Node-ids for this tree/class
final Chunk ct = chk_tree(chks,k); // k-tree working column holding votes for given row
int nid = (int)nids.at80(row); // Get Node to decide from
// Update only out-of-bag rows
// This is out-of-bag row - but we would like to track on-the-fly prediction for the row
if( isOOBRow(nid) ) { // The row should be OOB for all k-trees !!!
assert k==0 || wasOOBRow : "Something is wrong: k-class trees oob row computing is broken! All k-trees should agree on oob row!";
wasOOBRow = true;
nid = oob2Nid(nid);
if( tree.node(nid) instanceof UndecidedNode ) // If we bottomed out the tree
nid = tree.node(nid).pid(); // Then take parent's decision
DecidedNode dn = tree.decided(nid); // Must have a decision point
if( dn._split.col() == -1 ) // Unable to decide?
dn = tree.decided(tree.node(nid).pid()); // Then take parent's decision
int leafnid = dn.ns(chks,row); // Decide down to a leafnode
// Setup Tree(i) - on the fly prediction of i-tree for row-th row
// - for classification: cumulative number of votes for this row
// - for regression: cumulative sum of prediction of each tree - has to be normalized by number of trees
double prediction = ((LeafNode)tree.node(leafnid)).pred(); // Prediction for this k-class and this row
if (importance) rpred[1+k] = (float) prediction; // for both regression and classification
ct.set0(row, (float)(ct.at0(row) + prediction));
// For this tree this row is out-of-bag - i.e., a tree voted for this row
oobt.set0(row, _nclass>1?1:oobt.at0(row)+1); // for regression track number of trees, for classification boolean flag is enough
}
// reset help column for this row and this k-class
nids.set0(row,0);
} /* end of k-trees iteration */
if (importance) {
if (wasOOBRow && !y.isNA0(row)) {
if (classification) {
int treePred = ModelUtils.getPrediction(rpred, data_row(chks,row, rowdata));
int actuPred = (int) y.at80(row);
if (treePred==actuPred) rightVotes++; // No miss !
} else { // regression
float treePred = rpred[1];
float actuPred = (float) y.at0(row);
sse += (actuPred-treePred)*(actuPred-treePred);
}
allRows++;
}
}
}
}
@Override public void reduce(CollectPreds mrt) {
rightVotes += mrt.rightVotes;
allRows += mrt.allRows;
sse += mrt.sse;
}
}
// A standard DTree with a few more bits. Support for sampling during
// training, and replaying the sample later on the identical dataset to
// e.g. compute OOBEE.
static class DRFTree extends DTree {
final int _mtrys; // Number of columns to choose amongst in splits
final long _seeds[]; // One seed for each chunk, for sampling
final transient Random _rand; // RNG for split decisions & sampling
DRFTree( Frame fr, int ncols, char nbins, char nclass, int min_rows, int mtrys, long seed ) {
super(fr._names, ncols, nbins, nclass, min_rows, seed);
_mtrys = mtrys;
_rand = createRNG(seed);
_seeds = new long[fr.vecs()[0].nChunks()];
for( int i=0; i<_seeds.length; i++ )
_seeds[i] = _rand.nextLong();
}
// Return a deterministic chunk-local RNG. Can be kinda expensive.
@Override public Random rngForChunk( int cidx ) {
long seed = _seeds[cidx];
return createRNG(seed);
}
}
@Override protected DecidedNode makeDecided( UndecidedNode udn, DHistogram hs[] ) {
return new DRFDecidedNode(udn,hs);
}
// DRF DTree decision node: same as the normal DecidedNode, but specifies a
// decision algorithm given complete histograms on all columns.
// DRF algo: find the lowest error amongst a random mtry columns.
static class DRFDecidedNode extends DecidedNode {
DRFDecidedNode( UndecidedNode n, DHistogram hs[] ) { super(n,hs); }
@Override public DRFUndecidedNode makeUndecidedNode( DHistogram hs[] ) {
return new DRFUndecidedNode(_tree,_nid, hs);
}
// Find the column with the best split (lowest score).
@Override public DTree.Split bestCol( UndecidedNode u, DHistogram hs[] ) {
DTree.Split best = new DTree.Split(-1,-1,null,(byte)0,Double.MAX_VALUE,Double.MAX_VALUE,0L,0L,0,0);
if( hs == null ) return best;
for( int i=0; i<u._scoreCols.length; i++ ) {
int col = u._scoreCols[i];
DTree.Split s = hs[col].scoreMSE(col);
if( s == null ) continue;
if( s.se() < best.se() ) best = s;
if( s.se() <= 0 ) break; // No point in looking further!
}
return best;
}
}
// DRF DTree undecided node: same as the normal UndecidedNode, but specifies
// a list of columns to score on now, and then decide over later.
// DRF algo: pick a random mtry columns
static class DRFUndecidedNode extends UndecidedNode {
DRFUndecidedNode( DTree tree, int pid, DHistogram[] hs ) { super(tree,pid, hs); }
// Randomly select mtry columns to 'score' in following pass over the data.
@Override public int[] scoreCols( DHistogram[] hs ) {
DRFTree tree = (DRFTree)_tree;
int[] cols = new int[hs.length];
int len=0;
// Gather all active columns to choose from.
for( int i=0; i<hs.length; i++ ) {
if( hs[i]==null ) continue; // Ignore not-tracked cols
assert hs[i]._min < hs[i]._maxEx && hs[i].nbins() > 1 : "broken histo range "+hs[i];
cols[len++] = i; // Gather active column
}
int choices = len; // Number of columns I can choose from
assert choices > 0;
// Draw up to mtry columns at random without replacement.
for( int i=0; i<tree._mtrys; i++ ) {
if( len == 0 ) break; // Out of choices!
int idx2 = tree._rand.nextInt(len);
int col = cols[idx2]; // The chosen column
cols[idx2] = cols[--len]; // Compress out of array; do not choose again
cols[len] = col; // Swap chosen in just after 'len'
}
assert choices - len > 0;
return Arrays.copyOfRange(cols,len,choices);
}
}
static class DRFLeafNode extends LeafNode {
DRFLeafNode( DTree tree, int pid ) { super(tree,pid); }
DRFLeafNode( DTree tree, int pid, int nid ) { super(tree,pid,nid); }
// Insert just the predictions: a single byte/short if we are predicting a
// single class, or else the full distribution.
@Override protected AutoBuffer compress(AutoBuffer ab) { assert !Double.isNaN(pred()); return ab.put4f((float)pred()); }
@Override protected int size() { return 4; }
}
// Deterministic sampling
static class Sample extends MRTask2<Sample> {
final DRFTree _tree;
final float _rate;
Sample( DRFTree tree, float rate ) { _tree = tree; _rate = rate; }
@Override public void map( Chunk nids, Chunk ys ) {
Random rand = _tree.rngForChunk(nids.cidx());
for( int row=0; row<nids._len; row++ )
if( rand.nextFloat() >= _rate || Double.isNaN(ys.at0(row)) ) {
nids.set0(row, OUT_OF_BAG); // Flag row as being ignored by sampling
}
}
}
/**
* Cross-Validate a DRF model by building new models on N train/test holdout splits
* @param splits Frames containing train/test splits
* @param cv_preds Array of Frames to store the predictions for each cross-validation run
* @param offsets Array to store the offsets of starting row indices for each cross-validation run
* @param i Which fold of cross-validation to perform
*/
@Override public void crossValidate(Frame[] splits, Frame[] cv_preds, long[] offsets, int i) {
// Train a clone with slightly modified parameters (to account for cross-validation)
DRF cv = (DRF) this.clone();
// cv.importance = false; //Don't compute variable importance for N CV-folds
cv.genericCrossValidation(splits, offsets, i);
cv_preds[i] = ((DRFModel) UKV.get(cv.dest())).score(cv.validation); // cv_preds is escaping the context of this function and needs to be DELETED by the caller!!!
}
}