package hex.singlenoderf;
import dontweave.gson.JsonObject;
import hex.ConfusionMatrix;
import hex.FrameTask;
import hex.VarImp;
import hex.drf.DRF;
import water.*;
import water.Timer;
import water.api.AUCData;
import water.api.Constants;
import water.api.DocGen;
import water.api.ParamImportance;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.*;
import java.util.*;
import static water.util.MRUtils.sampleFrameStratified;
public class SpeeDRF extends Job.ValidatedJob {
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
public static DocGen.FieldDoc[] DOC_FIELDS;
public static final String DOC_GET = "SpeeDRF";
@API(help = "Number of trees", filter = Default.class, json = true, lmin = 1, lmax = Integer.MAX_VALUE, importance = ParamImportance.CRITICAL)
public int ntrees = 50;
@API(help = "Number of features to randomly select at each split.", filter = Default.class, json = true, lmin = -1, lmax = Integer.MAX_VALUE, importance = ParamImportance.SECONDARY)
public int mtries = -1;
@API(help = "Max Depth", filter = Default.class, json = true, lmin = 0, lmax = Integer.MAX_VALUE, importance = ParamImportance.CRITICAL)
public int max_depth = 20;
@API(help = "Split Criterion Type", filter = Default.class, json=true, importance = ParamImportance.SECONDARY)
public Tree.SelectStatType select_stat_type = Tree.SelectStatType.ENTROPY;
// @API(help = "Use local data. Auto-enabled if data does not fit in a single node.") /*, filter = Default.class, json = true, importance = ParamImportance.EXPERT) */
// public boolean local_mode = false;
/* Legacy parameter: */
public double[] class_weights = null;
@API(help = "Sampling Strategy", filter = Default.class, json = true, importance = ParamImportance.SECONDARY)
public Sampling.Strategy sampling_strategy = Sampling.Strategy.RANDOM;
@API(help = "Sampling Rate at each split.", filter = Default.class, json = true, dmin = 0, dmax = 1, importance = ParamImportance.EXPERT)
public double sample_rate = 0.67;
// @API(help ="Score each iteration", filter = Default.class, json = true, importance = ParamImportance.SECONDARY)
public boolean score_each_iteration = false;
@API(help = "Create the Score POJO", filter = Default.class, json = true, importance = ParamImportance.EXPERT)
public boolean score_pojo = true;
/*Imbalanced Classes*/
/**
* For imbalanced data, balance training data class counts via
* over/under-sampling. This can result in improved predictive accuracy.
*/
@API(help = "Balance training data class counts via over/under-sampling (for imbalanced data)", filter = Default.class, json = true, importance = ParamImportance.EXPERT)
public boolean balance_classes = false;
/**
* When classes are balanced, limit the resulting dataset size to the
* specified multiple of the original dataset size.
*/
@API(help = "Maximum relative size of the training data after balancing class counts (can be less than 1.0)", filter = Default.class, json = true, dmin=1e-3, importance = ParamImportance.EXPERT)
public float max_after_balance_size = Float.POSITIVE_INFINITY;
@API(help = "Out of bag error estimate", filter = Default.class, json = true, importance = ParamImportance.SECONDARY)
public boolean oobee = true;
@API(help = "Variable Importance", filter = Default.class, json = true)
public boolean importance = false;
public Key _modelKey = dest();
/* Advanced settings */
@API(help = "bin limit", filter = Default.class, json = true, lmin = 0, lmax = 65534, importance = ParamImportance.EXPERT)
public int nbins = 1024;
@API(help = "seed", filter = Default.class, json = true, importance = ParamImportance.EXPERT)
public long seed = -1;
@API(help = "Tree splits and extra statistics printed to stdout.", filter = Default.class, json = true, importance = ParamImportance.EXPERT)
public boolean verbose = false;
@API(help = "split limit", importance = ParamImportance.EXPERT)
public int _exclusiveSplitLimit = 0;
private static Random _seedGenerator = Utils.getDeterRNG( new Random().nextLong() );//0xd280524ad7fe0602L);
private boolean regression;
public DRFParams drfParams;
private long use_seed;
Tree.StatType stat_type;
/** Return the query link to this page */
public static String link(Key k, String content) {
RString rs = new RString("<a href='/2/SpeeDRF.query?source=%$key'>%content</a>");
rs.replace("key", k.toString());
rs.replace("content", content);
return rs.toString();
}
protected SpeeDRFModel makeModel( SpeeDRFModel model, double err, ConfusionMatrix cm, VarImp varimp, AUCData validAUC) {
return new SpeeDRFModel(model, err, cm, varimp, validAUC);
}
@Override protected void queryArgumentValueSet(Argument arg, java.util.Properties inputArgs) {
super.queryArgumentValueSet(arg, inputArgs);
if (arg._name.equals("classification")) {
arg._hideInQuery = true;
}
if (arg._name.equals("balance_classes")) {
arg.setRefreshOnChange();
if(regression) {
arg.disable("Class balancing is only for classification.");
}
}
// Regression is selected if classification is false and vice-versa.
if (arg._name.equals("classification")) {
regression = !this.classification;
}
// Regression only accepts the MSE stat type.
if (arg._name.equals("select_stat_type")) {
if(regression) {
arg.disable("Minimize MSE for regression.");
}
}
// Class weights depend on the source data set an response value to be specified and are invalid for regression
if (arg._name.equals("class_weights")) {
if (source == null || response == null) {
arg.disable("Requires source and response to be specified.");
}
if (regression) {
arg.disable("No class weights for regression.");
}
}
// Prevent Stratified Local when building regression tress.
if (arg._name.equals("sampling_strategy")) {
arg.setRefreshOnChange();
if (regression) {
arg.disable("Random Sampling for regression trees.");
}
}
// Variable Importance disabled in SpeeDRF regression currently
if (arg._name.equals("importance")) {
if (regression) {
arg.disable("Variable Importance not supported in SpeeDRF regression.");
}
}
// max balance size depends on balance_classes to be enabled
if(classification) {
if(arg._name.equals("max_after_balance_size") && !balance_classes) {
arg.disable("Requires balance classes flag to be set.", inputArgs);
}
}
}
// Put here all precondition verification
@Override protected void init() {
super.init();
assert 0 <= ntrees && ntrees < 1000000; // Sanity check
// Not enough rows to run
if (source.numRows() - response.naCnt() <=0)
throw new IllegalArgumentException("Dataset contains too many NAs!");
if( !classification && (!(response.isEnum() || response.isInt())))
throw new IllegalArgumentException("Classification cannot be performed on a float column!");
if(classification) {
if (0.0f > sample_rate || sample_rate > 1.0f)
throw new IllegalArgumentException("Sampling rate must be in [0,1] but found " + sample_rate);
}
if(regression) throw new IllegalArgumentException("SpeeDRF does not currently support regression.");
}
@Override protected void execImpl() {
SpeeDRFModel rf_model;
try {
source.read_lock(self());
if (validation != null && validation != source) validation.read_lock(self());
buildForest();
if (n_folds > 0) CrossValUtils.crossValidate(this);
} catch (JobCancelledException ex){
rf_model = UKV.get(dest());
state = JobState.CANCELLED; //for JSON REST response
rf_model.get_params().state = state; //for parameter JSON on the HTML page
Log.info("Random Forest was cancelled.");
} catch(Exception ex) {
ex.printStackTrace();
throw new RuntimeException(ex);
} finally {
source.unlock(self());
if (validation != null && validation != source) validation.unlock(self());
remove();
state = UKV.<Job>get(self()).state;
// Argh, this is horrible
new TAtomic<SpeeDRFModel>() {
@Override
public SpeeDRFModel atomic(SpeeDRFModel m) {
if (m != null) m.get_params().state = state;
return m;
}
}.invoke(dest());
emptyLTrash();
cleanup();
}
}
@Override protected Response redirect() { return SpeeDRFProgressPage.redirect(this, self(), dest()); }
private void buildForest() {
logStart();
SpeeDRFModel model = null;
try {
Frame train = setTrain();
Frame test = setTest();
Vec resp = regression ? null : train.lastVec().toEnum();
if (resp != null) gtrash(resp);
float[] priorDist = setPriorDist(train);
train = setStrat(train, test, resp);
model = initModel(train, test, priorDist);
model.start_training(null);
model.write_lock(self());
drfParams = DRFParams.create(train.find(resp), model.N, model.max_depth, (int) train.numRows(), model.nbins,
model.statType, use_seed, model.weights, mtries, model.sampling_strategy, (float) sample_rate, model.strata_samples, model.verbose ? 100 : 1, _exclusiveSplitLimit, true, regression);
DRFTask tsk = new DRFTask(self(), train, drfParams, model._key, model.src_key);
tsk.validateInputData(train);
tsk.invokeOnAllNodes();
Log.info("Tree building complete. Scoring...");
model = UKV.get(dest());
model.scoreAllTrees(test == null ? train : test, resp);
// Launch a Variable Importance Task
if (importance && !regression) {
Log.info("Scoring complete. Performing Variable Importance Calculations.");
model.current_status = "Performing Variable Importance Calculation.";
Timer VITimer = new Timer();
model.variableImportanceCalc(train, resp);
Log.info("Variable Importance on "+(train.numCols()-1)+" variables and "+ ntrees +" trees done in " + VITimer);
}
Log.info("Generating Tree Stats");
JsonObject trees = new JsonObject();
trees.addProperty(Constants.TREE_COUNT, model.size());
if( model.size() > 0 ) {
trees.add(Constants.TREE_DEPTH, model.depth().toJson());
trees.add(Constants.TREE_LEAVES, model.leaves().toJson());
}
model.generateHTMLTreeStats(new StringBuilder(), trees);
model.current_status = "Model Complete";
} finally {
if (model != null) {
model.unlock(self());
model.stop_training();
}
}
}
public SpeeDRFModel initModel(Frame train, Frame test, float[] priorDist) {
setStatType();
setSeed(seed);
if (mtries == -1) setMtry(regression, train.numCols() - 1);
Key src_key = source._key;
int src_ncols = source.numCols();
SpeeDRFModel model = new SpeeDRFModel(dest(), src_key, train, regression ? null : train.lastVec().domain(), this, priorDist);
// Model INPUTS
model.src_key = src_key.toString();
model.verbose = verbose; model.verbose_output = new String[]{""};
model.validation = test != null;
model.confusion = null;
model.zeed = use_seed;
model.cmDomain = getCMDomain();
model.nbins = nbins;
model.max_depth = max_depth;
model.oobee = validation == null && oobee;
model.statType = regression ? Tree.StatType.MSE : stat_type;
model.testKey = validation == null ? null : validation._key;
model.importance = importance;
model.regression = regression;
model.features = src_ncols;
model.sampling_strategy = regression ? Sampling.Strategy.RANDOM : sampling_strategy;
model.sample = (float) sample_rate;
model.weights = regression ? null : class_weights;
model.time = 0;
model.N = ntrees;
model.useNonLocal = true;
if (!regression) model.setModelClassDistribution(new MRUtils.ClassDist(train.lastVec()).doAll(train.lastVec()).rel_dist());
model.resp_min = (int) train.lastVec().min();
model.mtry = mtries;
int csize = H2O.CLOUD.size();
model.local_forests = new Key[csize][]; for(int i=0;i<csize;i++) model.local_forests[i] = new Key[0];
model.node_split_features = new int[csize];
model.t_keys = new Key[0];
model.dtreeKeys = new Key[ntrees][regression ? 1 : model.classes()];
model.time = 0;
for( Key tkey : model.t_keys ) assert DKV.get(tkey)!=null;
model.jobKey = self();
model.score_pojo = score_pojo;
model.current_status = "Initializing Model";
// Model OUTPUTS
model.varimp = null;
model.validAUC = null;
model.cms = new ConfusionMatrix[1];
model.errs = new double[]{-1.0};
return model;
}
private void setStatType() {
if (regression) stat_type = Tree.StatType.MSE;
stat_type = select_stat_type == Tree.SelectStatType.ENTROPY ? Tree.StatType.ENTROPY : Tree.StatType.GINI;
if (select_stat_type == Tree.SelectStatType.TWOING) stat_type = Tree.StatType.TWOING;
}
private void setSeed(long s) {
if (s == -1) { seed = _seedGenerator.nextLong(); use_seed = seed; }
else {
_seedGenerator = Utils.getDeterRNG(s);
use_seed = _seedGenerator.nextLong();
}
}
private void setMtry(boolean reg, int numCols) { mtries = reg ? (int) Math.floor((float) (numCols) / 3.0f) : (int) Math.floor(Math.sqrt(numCols)); }
private Frame setTrain() { Frame train = FrameTask.DataInfo.prepareFrame(source, response, ignored_cols, !regression /*toEnum is TRUE if regression is FALSE*/, false, false); if (train.lastVec().masterVec() != null && train.lastVec() != response) gtrash(train.lastVec()); return train; }
private Frame setTest() {
if (validation == null) return null;
Frame test = null;
ArrayList<Integer> v_ignored_cols = new ArrayList<Integer>();
for (int ignored_col : ignored_cols) if (validation.find(source.names()[ignored_col]) != -1) v_ignored_cols.add(ignored_col);
int[] v_ignored = new int[v_ignored_cols.size()];
for (int i = 0; i < v_ignored.length; ++i) v_ignored[i] = v_ignored_cols.get(i);
if (validation != null) test = FrameTask.DataInfo.prepareFrame(validation, validation.vecs()[validation.find(source.names()[source.find(response)])], v_ignored, !regression, false, false);
if (test != null && test.lastVec().masterVec() != null) gtrash(test.lastVec());
return test;
}
private Frame setStrat(Frame train, Frame test, Vec resp) {
Frame fr = train;
float[] trainSamplingFactors;
if (classification && balance_classes) {
assert resp != null : "Regression called and stratified sampling was invoked to balance classes!";
// Handle imbalanced classes by stratified over/under-sampling
// initWorkFrame sets the modeled class distribution, and model.score() corrects the probabilities back using the distribution ratios
int response_idx = fr.find(_responseName);
fr.replace(response_idx, resp);
trainSamplingFactors = new float[resp.domain().length]; //leave initialized to 0 -> will be filled up below
Frame stratified = sampleFrameStratified(fr, resp, trainSamplingFactors, (long) (max_after_balance_size * fr.numRows()), use_seed, true, false);
if (stratified != fr) {
fr = stratified;
gtrash(stratified);
}
}
// Check that that test/train data are consistent, throw warning if not
if(classification && validation != null) {
assert resp != null : "Regression called and stratified sampling was invoked to balance classes!";
Vec testresp = test.lastVec().toEnum();
gtrash(testresp);
if (!isSubset(testresp.domain(), resp.domain())) {
Log.warn("Test set domain: " + Arrays.toString(testresp.domain()) + " \nTrain set domain: " + Arrays.toString(resp.domain()));
Log.warn("Train and Validation data have inconsistent response columns! Test data has a response not found in the Train data!");
}
}
return fr;
}
private float[] setPriorDist(Frame train) { return classification ? new MRUtils.ClassDist(train.lastVec()).doAll(train.lastVec()).rel_dist() : null; }
public Frame score( Frame fr ) { return ((SpeeDRFModel)UKV.get(dest())).score(fr); }
private boolean isSubset(String[] sub, String[] container) {
HashSet<String> hs = new HashSet<String>();
Collections.addAll(hs, container);
for (String s: sub) {
if (!hs.contains(s)) return false;
}
return true;
}
public final static class DRFTask extends DRemoteTask {
/** The RF Model. Contains the dataset being worked on, the classification
* column, and the training columns. */
// private final SpeeDRFModel _rfmodel;
private final Key _rfmodel;
/** Job representing this DRF execution. */
private final Key _jobKey;
/** RF parameters. */
private final DRFParams _params;
private final Frame _fr;
private final String _key;
DRFTask(Key jobKey, Frame frameKey, DRFParams params, Key rfmodel, String src_key) {
_jobKey = jobKey; _fr = frameKey; _params = params; _rfmodel = rfmodel; _key = src_key;
}
/**Inhale the data, build a DataAdapter and kick-off the computation.
* */
@Override public final void lcompute() {
final DataAdapter dapt = DABuilder.create(_params, _rfmodel).build(_fr, _params._useNonLocalData);
if (dapt == null) {
tryComplete();
return;
}
Data localData = Data.make(dapt);
int numSplitFeatures = howManySplitFeatures();
int ntrees = howManyTrees();
int[] rowsPerChunks = howManyRPC(_fr);
updateRFModel(_rfmodel, numSplitFeatures);
updateRFModelStatus(_rfmodel, "Building Forest");
updateRFModelLocalForests(_rfmodel, ntrees);
Log.info("Dispalying local forest stats:");
SpeeDRF.build(_jobKey, _rfmodel, _params, localData, ntrees, numSplitFeatures, rowsPerChunks);
tryComplete();
}
static void updateRFModel(Key modelKey, final int numSplitFeatures) {
final int idx = H2O.SELF.index();
new TAtomic<SpeeDRFModel>() {
@Override public SpeeDRFModel atomic(SpeeDRFModel old) {
if(old == null) return null;
SpeeDRFModel newModel = (SpeeDRFModel)old.clone();
newModel.node_split_features[idx] = numSplitFeatures;
return newModel;
}
}.invoke(modelKey);
}
static void updateRFModelLocalForests(Key modelKey, final int num_trees) {
final int selfIdx = H2O.SELF.index();
new TAtomic<SpeeDRFModel>() {
@Override public SpeeDRFModel atomic(SpeeDRFModel old) {
if (old == null) return null;
SpeeDRFModel newModel = (SpeeDRFModel)old.clone();
newModel.local_forests[selfIdx] = new Key[num_trees];
return newModel;
}
}.invoke(modelKey);
}
static void updateRFModelStatus(Key modelKey, final String status) {
new TAtomic<SpeeDRFModel>() {
@Override public SpeeDRFModel atomic(SpeeDRFModel old) {
if(old == null) return null;
SpeeDRFModel newModel = (SpeeDRFModel)old.clone();
newModel.current_status = status;
return newModel;
}
}.invoke(modelKey);
}
/** Unless otherwise specified each split looks at sqrt(#features). */
private int howManySplitFeatures() {
return _params.num_split_features;
}
/** Figure the number of trees to make locally, so the total hits ntrees.
* Divide equally amongst all the nodes that actually have data. First:
* compute how many nodes have data. Give each Node ntrees/#nodes worth of
* trees. Round down for later nodes, and round up for earlier nodes.
*/
private int howManyTrees() {
Frame fr = _fr;
final long num_chunks = fr.anyVec().nChunks();
final int num_nodes = H2O.CLOUD.size();
HashSet<H2ONode> nodes = new HashSet<H2ONode>();
for( int i=0; i<num_chunks; i++ ) {
nodes.add(fr.anyVec().chunkKey(i).home_node());
if( nodes.size() == num_nodes ) // All of nodes covered?
break; // That means we are done.
}
H2ONode[] array = nodes.toArray(new H2ONode[nodes.size()]);
Arrays.sort(array);
// Give each H2ONode ntrees/#nodes worth of trees. Round down for later nodes,
// and round up for earlier nodes
int ntrees = _params.num_trees / nodes.size();
if( Arrays.binarySearch(array, H2O.SELF) < _params.num_trees - ntrees*nodes.size() )
++ntrees;
return ntrees;
}
private int[] howManyRPC(Frame fr) {
int[] result = new int[fr.anyVec().nChunks()];
for(int i = 0; i < result.length; ++i) {
result[i] = fr.anyVec().chunkLen(i);
}
return result;
}
private void validateInputData(Frame fr) {
Vec[] vecs = fr.vecs();
Vec c = vecs[vecs.length-1];
if (!_params.regression) {
final int classes = c.cardinality();
if (!(2 <= classes && classes <= 254))
throw new IllegalArgumentException("Response contains " + classes + " classes, but algorithm supports only 254 levels");
}
if (_params.num_split_features!=-1 && (_params.num_split_features< 1 || _params.num_split_features>vecs.length-1))
throw new IllegalArgumentException("Number of split features exceeds available data. Should be in [1,"+(vecs.length-1)+"]");
ChunkAllocInfo cai = new ChunkAllocInfo();
boolean can_load_all = canLoadAll(fr, cai);
if (_params._useNonLocalData && !can_load_all) {
String heap_warning = "This algorithm requires loading of all data from remote nodes." +
"\nThe node " + cai.node + " requires " + PrettyPrint.bytes(cai.requiredMemory) + " more memory to load all data and perform computation but there is only " + PrettyPrint.bytes(cai.availableMemory) + " of available memory." +
"\n\nPlease provide more memory for JVMs \n\n-OR-\n\n Try Big Data Random Forest: ";
Log.warn(heap_warning);
throw new IllegalArgumentException(heap_warning + DRF.link(Key.make(_key), "Big Data Random Forest") );
}
if (can_load_all) {
_params._useNonLocalData = true;
Log.info("Enough available free memory to compute on all data. Pulling all data locally and then launching RF.");
}
}
private boolean canLoadAll(final Frame fr, ChunkAllocInfo cai) {
int nchks = fr.anyVec().nChunks();
long localBytes = 0l;
for (int i = 0; i < nchks; ++i) {
Key k = fr.anyVec().chunkKey(i);
if (k.home()) {
localBytes += fr.anyVec().chunkForChunkIdx(i).byteSize();
}
}
long memForNonLocal = fr.byteSize() - localBytes;
// Also must add in the RF internal data structure overhead
memForNonLocal += fr.numRows() * fr.numCols();
for(int i = 0; i < H2O.CLOUD._memary.length; i++) {
HeartBeat hb = H2O.CLOUD._memary[i]._heartbeat;
long nodeFreeMemory = (long)(hb.get_max_mem() * 0.8); // * OVERHEAD_MAGIC;
Log.debug(Log.Tag.Sys.RANDF, i + ": computed available mem: " + PrettyPrint.bytes(nodeFreeMemory));
Log.debug(Log.Tag.Sys.RANDF, i + ": remote chunks require: " + PrettyPrint.bytes(memForNonLocal));
if (nodeFreeMemory - memForNonLocal <= 0 || (nodeFreeMemory <= TWO_HUNDRED_MB && memForNonLocal >= ONE_FIFTY_MB)) {
Log.info("Node free memory raw: "+nodeFreeMemory);
cai.node = H2O.CLOUD._memary[i];
cai.availableMemory = nodeFreeMemory;
cai.requiredMemory = memForNonLocal;
return false;
}
}
return true;
}
/** Helper POJO to store required chunk allocation. */
private static class ChunkAllocInfo {
H2ONode node;
long availableMemory;
long requiredMemory;
}
static final float OVERHEAD_MAGIC = 3/8.f; // memory overhead magic
static final long TWO_HUNDRED_MB = 200 * 1024 * 1024;
static final long ONE_FIFTY_MB = 150 * 1024 * 1024;
@Override
public void reduce(DRemoteTask drt) { }
}
private static final long ROOT_SEED_ADD = 0x026244fd935c5111L;
private static final long TREE_SEED_INIT = 0x1321e74a0192470cL;
/** Build random forest for data stored on this node. */
public static void build(
final Key jobKey,
final Key modelKey,
final DRFParams drfParams,
final Data localData,
int ntrees,
int numSplitFeatures,
int[] rowsPerChunks) {
Timer t_alltrees = new Timer();
Tree[] trees = new Tree[ntrees];
Log.info(Log.Tag.Sys.RANDF,"Building "+ntrees+" trees");
Log.info(Log.Tag.Sys.RANDF,"Number of split features: "+ numSplitFeatures);
Log.info(Log.Tag.Sys.RANDF,"Starting RF computation with "+ localData.rows()+" rows ");
Random rnd = Utils.getRNG(localData.seed() + ROOT_SEED_ADD);
Sampling sampler = createSampler(drfParams, rowsPerChunks);
byte producerId = (byte) H2O.SELF.index();
for (int i = 0; i < ntrees; ++i) {
long treeSeed = rnd.nextLong() + TREE_SEED_INIT; // make sure that enough bits is initialized
trees[i] = new Tree(jobKey, modelKey, localData, producerId, drfParams.max_depth, drfParams.stat_type, numSplitFeatures, treeSeed,
i, drfParams._exclusiveSplitLimit, sampler, drfParams._verbose, drfParams.regression, !drfParams._useNonLocalData, ((SpeeDRFModel)UKV.get(modelKey)).score_pojo);
}
Log.info("Invoking the tree build tasks on all nodes.");
DRemoteTask.invokeAll(trees);
Log.info(Log.Tag.Sys.RANDF,"All trees ("+ntrees+") done in "+ t_alltrees);
}
static Sampling createSampler(final DRFParams params, int[] rowsPerChunks) {
switch(params.sampling_strategy) {
case RANDOM : return new Sampling.Random(params.sample, rowsPerChunks);
default:
assert false : "Unsupported sampling strategy";
return null;
}
}
/** RF execution parameters. */
public final static class DRFParams extends Iced {
/** Total number of trees */
int num_trees;
/** If true, build trees in parallel (default: true) */
boolean parallel;
/** Maximum depth for trees (default MaxInt) */
int max_depth;
/** Split statistic */
Tree.StatType stat_type;
/** Feature holding the classifier (default: #features-1) */
int classcol;
/** Utilized sampling method */
Sampling.Strategy sampling_strategy;
/** Proportion of observations to use for building each individual tree (default: .67)*/
float sample;
/** Limit of the cardinality of a feature before we bin. */
int bin_limit;
/** Weights of the different features (default: 1/features) */
double[] class_weights;
/** Arity under which we may use exclusive splits */
public int _exclusiveSplitLimit;
/** Output warnings and info*/
public int _verbose;
/** Number of features which are tried at each split
* If it is equal to -1 then it is computed as sqrt(num of usable columns) */
int num_split_features;
/** Defined stratas samples for each class */
float[] strata_samples;
/** Utilize not only local data but try to use data from other nodes. */
boolean _useNonLocalData;
/** Number of rows per chunk - used to replay sampling */
int _numrows;
/** Pseudo random seed initializing RF algorithm */
long seed;
/** Build regression trees if true */
boolean regression;
public static DRFParams create(int col, int ntrees, int depth, int numrows, int binLimit,
Tree.StatType statType, long seed, double[] classWt,
int numSplitFeatures, Sampling.Strategy samplingStrategy, float sample,
float[] strataSamples, int verbose, int exclusiveSplitLimit,
boolean useNonLocalData, boolean regression) {
DRFParams drfp = new DRFParams();
drfp.num_trees = ntrees;
drfp.max_depth = depth;
drfp.sample = sample;
drfp.bin_limit = binLimit;
drfp.stat_type = statType;
drfp.seed = seed;
drfp.class_weights = classWt;
drfp.num_split_features = numSplitFeatures;
drfp.sampling_strategy = samplingStrategy;
drfp.strata_samples = strataSamples;
drfp._numrows = numrows;
drfp._useNonLocalData = useNonLocalData;
drfp._exclusiveSplitLimit = exclusiveSplitLimit;
drfp._verbose = verbose;
drfp.classcol = col;
drfp.regression = regression;
drfp.parallel = true;
return drfp;
}
}
/**
* Cross-Validate a SpeeDRF model by building new models on N train/test holdout splits
* @param splits Frames containing train/test splits
* @param cv_preds Array of Frames to store the predictions for each cross-validation run
* @param offsets Array to store the offsets of starting row indices for each cross-validation run
* @param i Which fold of cross-validation to perform
*/
@Override public void crossValidate(Frame[] splits, Frame[] cv_preds, long[] offsets, int i) {
// Train a clone with slightly modified parameters (to account for cross-validation)
final SpeeDRF cv = (SpeeDRF) this.clone();
cv.genericCrossValidation(splits, offsets, i);
cv_preds[i] = ((SpeeDRFModel) UKV.get(cv.dest())).score(cv.validation);
new TAtomic<SpeeDRFModel>() {
@Override public SpeeDRFModel atomic(SpeeDRFModel m) {
if (!keep_cross_validation_splits && /*paranoid*/ cv.dest().toString().contains("xval")) {
m.get_params().source = null;
m.get_params().validation=null;
m.get_params().response=null;
}
return m;
}
}.invoke(cv.dest());
}
}