package hex.tree.drf;
import hex.tree.SharedTreeModel;
import water.Key;
import water.util.MathUtils;
import water.util.SBPrintStream;
public class DRFModel extends SharedTreeModel<DRFModel, DRFModel.DRFParameters, DRFModel.DRFOutput> {
public static class DRFParameters extends SharedTreeModel.SharedTreeParameters {
public String algoName() { return "DRF"; }
public String fullName() { return "Distributed Random Forest"; }
public String javaName() { return DRFModel.class.getName(); }
public boolean _binomial_double_trees = false;
public int _mtries = -1; //number of columns to use per split. default depeonds on the algorithm and problem (classification/regression)
public DRFParameters() {
super();
// Set DRF-specific defaults (can differ from SharedTreeModel's defaults)
_mtries = -1;
_sample_rate = 0.632f;
_max_depth = 20;
_min_rows = 1;
}
}
public static class DRFOutput extends SharedTreeModel.SharedTreeOutput {
public DRFOutput( DRF b) { super(b); }
}
public DRFModel(Key<DRFModel> selfKey, DRFParameters parms, DRFOutput output ) { super(selfKey, parms, output); }
@Override protected boolean binomialOpt() { return !_parms._binomial_double_trees; }
/** Bulk scoring API for one row. Chunks are all compatible with the model,
* and expect the last Chunks are for the final distribution and prediction.
* Default method is to just load the data into the tmp array, then call
* subclass scoring logic. */
@Override protected double[] score0(double[] data, double[] preds, double weight, double offset, int ntrees) {
super.score0(data, preds, weight, offset, ntrees);
int N = _output._ntrees;
if (_output.nclasses() == 1) { // regression - compute avg over all trees
if (N>=1) preds[0] /= N;
} else { // classification
if (_output.nclasses() == 2 && binomialOpt()) {
if (N>=1) {
preds[1] /= N; //average probability
}
preds[2] = 1. - preds[1];
} else {
double sum = MathUtils.sum(preds);
if (sum > 0) MathUtils.div(preds, sum);
}
}
return preds;
}
@Override protected void toJavaUnifyPreds(SBPrintStream body) {
if (_output.nclasses() == 1) { // Regression
body.ip("preds[0] /= " + _output._ntrees + ";").nl();
} else { // Classification
if( _output.nclasses()==2 && binomialOpt()) { // Kept the initial prediction for binomial
body.ip("preds[1] /= " + _output._ntrees + ";").nl();
body.ip("preds[2] = 1.0 - preds[1];").nl();
} else {
body.ip("double sum = 0;").nl();
body.ip("for(int i=1; i<preds.length; i++) { sum += preds[i]; }").nl();
body.ip("if (sum>0) for(int i=1; i<preds.length; i++) { preds[i] /= sum; }").nl();
}
if (_parms._balance_classes)
body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold() + ");").nl();
}
}
@Override
public DrfMojoWriter getMojo() {
return new DrfMojoWriter(this);
}
}