package hex.drf;
import water.*;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log.Tag.Sys;
import water.util.Log;
import hex.gbm.SharedTreeModelBuilder;
import hex.gbm.GBM;
// Class for running DRF from the cmd line
// Run as : java -jar h2o.jar water.Boot -mainClass hex.drf.Runner <runner_args>
// Example: java -jar h2o.jar water.Boot -mainClass hex.drf.Runner -trainFile=smalldata/iris/iris_wheader.csv -testFile= -response=class -cols=sepal_len,sepal_wid,petal_len,petal_wid -ntrees=5 -mtries=3
public class Runner {
// Every field in this class is also a command-line argument.
public static class OptArgs extends Arguments.Opt {
String h2oArgs; // Extra args for H2O
int clusterSize=1; // Stall till cluster gets this big
final static String defaultTrainFile = "smalldata/gbm_test/ecology_model.csv";
final static String defaultTestFile = "smalldata/gbm_test/ecology_eval.csv";
String trainFile = defaultTrainFile;
String testFile = defaultTestFile;
String response = "Angaus";
String cols = "SegSumT,SegTSeas,SegLowFlow,DSDist,DSMaxSlope,USAvgT,USRainDays,USSlope,USNative,DSDam,Method,LocSed";
boolean regression=false; // Defalt to classification (vs regression)
int min_rows = 1; // Smallest number of rows per terminal
int ntrees = 10; // Number of trees
int depth = 999; // Max tree depth
int nbins = 20; // Nominal bins per column histogram
boolean gbm = false; // True for GBM, False for DRF
int mtries = 0; // Number of columns to try; zero defaults to Sqrt
float sample = 0.6666667f; // Sampling rate
long seed = 0xae44a87f9edf1cbL;
float learn = 0.1f; //
float splitTestTrain = Float.NaN; // Ratio on test/train split
}
public static void main(String[] args) throws Throwable {
OptArgs ARGS = new Arguments(args).extract(new OptArgs());
// Bring up the cluster
String[] h2oArgs;
String as = ARGS.h2oArgs;
if( as != null ) {
if( as.startsWith("\"") && as.endsWith("\"") ) as = as.substring(1, as.length()-1);
h2oArgs = as.trim().split("[ \t]+");
} else h2oArgs=new String[0];
H2O.main(h2oArgs);
// Make sure we shutdown on all exit paths
try {
main(ARGS);
} catch( Throwable t ) {
t.printStackTrace();
throw t;
} finally {
UDPRebooted.T.shutdown.broadcast();
}
}
// Do the Work
static void main(OptArgs ARGS) {
// Finish building the cluster
TestUtil.stall_till_cloudsize(ARGS.clusterSize);
// Sanity check basic args
if( ARGS.ntrees <= 0 || ARGS.ntrees > 100000 ) throw new RuntimeException("ntrees "+ARGS.ntrees+" out of bounds");
if( ARGS.sample < 0 || ARGS.sample > 1.0f ) throw new RuntimeException("sample "+ARGS.sample+" out of bounds");
if( ARGS.learn < 0 || ARGS.learn > 1.0f ) throw new RuntimeException("learn " +ARGS.learn +" out of bounds");
if( ARGS.nbins < 2 || ARGS.nbins > 100000 ) throw new RuntimeException("nbins " +ARGS.nbins +" out of bounds");
if( ARGS.depth <= 0 ) throw new RuntimeException("depth " +ARGS.depth +" out of bounds");
if( ARGS.splitTestTrain < 0 || ARGS.splitTestTrain > 1.0f ) throw new RuntimeException("splitTestTrain "+ARGS.splitTestTrain+" out of bounds");
// If trainFile is NOT set, you are doing the default file and cannot set testFile.
if( (ARGS.trainFile == OptArgs.defaultTrainFile) && (ARGS.testFile != OptArgs.defaultTestFile) )
throw new RuntimeException("Cannot set test file unless also setting train file");
// If testFile is set, cannot set splitTestTrain
if( (ARGS.testFile != OptArgs.defaultTestFile) && !Float.isNaN(ARGS.splitTestTrain) )
throw new RuntimeException("Cannot have both testFile and splitTestTrain");
Sys sys = ARGS.gbm ? Sys.GBM__ : Sys.DRF__;
String cs[] = (ARGS.cols+","+ARGS.response).split("[,\t]");
// Set mtries
if( ARGS.mtries == 0 ) ARGS.mtries = (int)Math.sqrt(cs.length);
if( ARGS.mtries <= 0 || ARGS.mtries >cs.length)throw new RuntimeException("mtries "+ARGS.mtries+" out of bounds");
// Load data
Timer t_load = new Timer();
Key trainkey = Key.make("train.hex");
Key testkey = Key.make( "test.hex");
Frame train = TestUtil.parseFrame(trainkey,ARGS.trainFile);
Frame test = null;
if( !Float.isNaN(ARGS.splitTestTrain) ) {
water.exec.Exec2.exec("r=runif(train.hex,-1); test.hex=train.hex[r>=0.7,]; train.hex=train.hex[r<0.7,]").remove_and_unlock();
train = UKV.get(trainkey);
test = UKV.get( testkey);
} else if( ARGS.testFile.length() != 0 ) {
test = TestUtil.parseFrame(testkey,ARGS. testFile);
}
Log.info(sys,"Data loaded in "+t_load);
// Pull out the response vector from the train data
Vec response = train.subframe(new String[] {ARGS.response}).vecs()[0];
// Build a Frame with just the requested columns.
train = train.subframe(cs);
if( test != null ) test = test.subframe(cs);
Vec vs[] = train.vecs();
for( Vec v : vs ) v.min(); // Do rollups
for( int i=0; i<train.numCols(); i++ )
Log.info(sys,train._names[i]+", "+vs[i].min()+" - "+vs[i].max()+(vs[i].naCnt()==0?"":(", missing="+vs[i].naCnt())));
Log.info(sys,"Arguments used:\n"+ARGS.toString());
Timer t_model = new Timer();
SharedTreeModelBuilder stmb = ARGS.gbm ? new GBM() : new DRF();
stmb.source = train;
stmb.validation = test;
stmb.classification = !ARGS.regression;
stmb.response = response;
stmb.ntrees = ARGS.ntrees;
stmb.max_depth = ARGS.depth;
stmb.min_rows = ARGS.min_rows;
stmb.destination_key = Key.make("DRF_Model_" + ARGS.trainFile);
if( ARGS.gbm ) {
GBM gbm = (GBM)stmb;
gbm.learn_rate = ARGS.learn;
} else {
DRF drf = (DRF)stmb;
drf.mtries = ARGS.mtries;
drf.sample_rate= ARGS.sample;
drf.seed = ARGS.seed;
}
// Invoke DRF and block till the end
stmb.invoke();
Log.info(sys,"Model trained in "+t_model);
}
}