package hex.tree;
import hex.*;
import static hex.ModelCategory.Binomial;
import static hex.genmodel.GenModel.createAuxKey;
import static hex.glm.GLMModel.GLMParameters.Family.binomial;
import hex.glm.GLMModel;
import hex.util.LinearAlgebraUtils;
import water.*;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.JCodeSB;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.*;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
public abstract class SharedTreeModel<
M extends SharedTreeModel<M, P, O>,
P extends SharedTreeModel.SharedTreeParameters,
O extends SharedTreeModel.SharedTreeOutput
> extends Model<M, P, O> implements Model.LeafNodeAssignment, Model.GetMostImportantFeatures {
@Override
public String[] getMostImportantFeatures(int n) {
if (_output == null) return null;
TwoDimTable vi = _output._variable_importances;
if (vi==null) return null;
n = Math.min(n, vi.getRowHeaders().length);
String[] res = new String[n];
System.arraycopy(vi.getRowHeaders(), 0, res, 0, n);
return res;
}
@Override public ToEigenVec getToEigenVec() { return LinearAlgebraUtils.toEigen; }
public abstract static class SharedTreeParameters extends Model.Parameters {
public int _ntrees=50; // Number of trees in the final model. Grid Search, comma sep values:50,100,150,200
public int _max_depth = 5; // Maximum tree depth. Grid Search, comma sep values:5,7
public double _min_rows = 10; // Fewest allowed observations in a leaf (in R called 'nodesize'). Grid Search, comma sep values
public int _nbins = 20; // Numerical (real/int) cols: Build a histogram of this many bins, then split at the best point
public int _nbins_cats = 1024; // Categorical (factor) cols: Build a histogram of this many bins, then split at the best point
public double _min_split_improvement = 1e-5; // Minimum relative improvement in squared error reduction for a split to happen
public enum HistogramType { AUTO, UniformAdaptive, Random, QuantilesGlobal, RoundRobin }
public HistogramType _histogram_type = HistogramType.AUTO; // What type of histogram to use for finding optimal split points
public double _r2_stopping = Double.MAX_VALUE; // Stop when the r^2 metric equals or exceeds this value
public int _nbins_top_level = 1<<10; //hardcoded maximum top-level number of bins for real-valued columns
public boolean _build_tree_one_node = false;
public int _score_tree_interval = 0; // score every so many trees (no matter what)
public int _initial_score_interval = 4000; //Adding this parameter to take away the hard coded value of 4000 for scoring the first 4 secs
public int _score_interval = 4000; //Adding this parameter to take away the hard coded value of 4000 for scoring each iteration every 4 secs
public double _sample_rate = 0.632; //fraction of rows to sample for each tree
public double[] _sample_rate_per_class; //fraction of rows to sample for each tree, per class
public boolean _calibrate_model = false; // Use Platt Scaling
public Key<Frame> _calibration_frame;
public Frame calib() { return _calibration_frame == null ? null : _calibration_frame.get(); }
@Override public long progressUnits() { return _ntrees + (_histogram_type==HistogramType.QuantilesGlobal || _histogram_type==HistogramType.RoundRobin ? 1 : 0); }
public double _col_sample_rate_change_per_level = 1.0f; //relative change of the column sampling rate for every level
public double _col_sample_rate_per_tree = 1.0f; //fraction of columns to sample for each tree
/** Fields which can NOT be modified if checkpoint is specified.
* FIXME: should be defined in Schema API annotation
*/
private static String[] CHECKPOINT_NON_MODIFIABLE_FIELDS = { "_build_tree_one_node", "_sample_rate", "_max_depth", "_min_rows", "_nbins", "_nbins_cats", "_nbins_top_level"};
protected String[] getCheckpointNonModifiableFields() {
return CHECKPOINT_NON_MODIFIABLE_FIELDS;
}
/** This method will take actual parameters and validate them with parameters of
* requested checkpoint. In case of problem, it throws an API exception.
*
* @param checkpointParameters checkpoint parameters
*/
public void validateWithCheckpoint(SharedTreeParameters checkpointParameters) {
for (Field fAfter : this.getClass().getFields()) {
// only look at non-modifiable fields
if (ArrayUtils.contains(getCheckpointNonModifiableFields(),fAfter.getName())) {
for (Field fBefore : checkpointParameters.getClass().getFields()) {
if (fBefore.equals(fAfter)) {
try {
if (!PojoUtils.equals(this, fAfter, checkpointParameters, checkpointParameters.getClass().getField(fAfter.getName()))) {
throw new H2OIllegalArgumentException(fAfter.getName(), "TreeBuilder", "Field " + fAfter.getName() + " cannot be modified if checkpoint is specified!");
}
} catch (NoSuchFieldException e) {
throw new H2OIllegalArgumentException(fAfter.getName(), "TreeBuilder", "Field " + fAfter.getName() + " is not supported by checkpoint!");
}
}
}
}
}
}
}
@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
switch(_output.getModelCategory()) {
case Binomial: return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain);
case Regression: return new ModelMetricsRegression.MetricBuilderRegression();
default: throw H2O.unimpl();
}
}
public abstract static class SharedTreeOutput extends Model.Output {
/** InitF value (for zero trees)
* f0 = mean(yi) for gaussian
* f0 = log(yi/1-yi) for bernoulli
*
* For GBM bernoulli, the initial prediction for 0 trees is
* p = 1/(1+exp(-f0))
*
* From this, the mse for 0 trees (null model) can be computed as follows:
* mean((yi-p)^2)
* */
public double _init_f;
/** Number of trees actually in the model (as opposed to requested) */
public int _ntrees;
/** More indepth tree stats */
public final TreeStats _treeStats;
/** Trees get big, so store each one separately in the DKV. */
public Key<CompressedTree>[/*_ntrees*/][/*_nclass*/] _treeKeys;
public Key<CompressedTree>[/*_ntrees*/][/*_nclass*/] _treeKeysAux;
public ScoreKeeper[/*ntrees+1*/] _scored_train;
public ScoreKeeper[/*ntrees+1*/] _scored_valid;
public ScoreKeeper[] scoreKeepers() {
ArrayList<ScoreKeeper> skl = new ArrayList<>();
ScoreKeeper[] ska = _validation_metrics != null ? _scored_valid : _scored_train;
for( ScoreKeeper sk : ska )
if (!sk.isEmpty())
skl.add(sk);
return skl.toArray(new ScoreKeeper[skl.size()]);
}
/** Training time */
public long[/*ntrees+1*/] _training_time_ms = {System.currentTimeMillis()};
/**
* Variable importances computed during training
*/
public TwoDimTable _variable_importances;
public VarImp _varimp;
public GLMModel _calib_model;
public SharedTreeOutput( SharedTree b) {
super(b);
_ntrees = 0; // No trees yet
_treeKeys = new Key[_ntrees][]; // No tree keys yet
_treeKeysAux = new Key[_ntrees][]; // No tree keys yet
_treeStats = new TreeStats();
_scored_train = new ScoreKeeper[]{new ScoreKeeper(Double.NaN)};
_scored_valid = new ScoreKeeper[]{new ScoreKeeper(Double.NaN)};
_modelClassDist = _priorClassDist;
}
// Append next set of K trees
public void addKTrees( DTree[] trees) {
// DEBUG: Print the generated K trees
//SharedTree.printGenerateTrees(trees);
assert nclasses()==trees.length;
// Compress trees and record tree-keys
_treeKeys = Arrays.copyOf(_treeKeys ,_ntrees+1);
_treeKeysAux = Arrays.copyOf(_treeKeysAux ,_ntrees+1);
Key[] keys = _treeKeys[_ntrees] = new Key[trees.length];
Key[] keysAux = _treeKeysAux[_ntrees] = new Key[trees.length];
Futures fs = new Futures();
for( int i=0; i<nclasses(); i++ ) if( trees[i] != null ) {
CompressedTree ct = trees[i].compress(_ntrees,i,_domains);
DKV.put(keys[i]=ct._key,ct,fs);
_treeStats.updateBy(trees[i]); // Update tree shape stats
CompressedTree ctAux = new CompressedTree(trees[i]._abAux.buf(),-1,-1,-1,-1,_domains);
keysAux[i] = ctAux._key = Key.make(createAuxKey(ct._key.toString()));
DKV.put(ctAux);
}
_ntrees++;
// 1-based for errors; _scored_train[0] is for zero trees, not 1 tree
_scored_train = ArrayUtils.copyAndFillOf(_scored_train, _ntrees+1, new ScoreKeeper());
_scored_valid = _scored_valid != null ? ArrayUtils.copyAndFillOf(_scored_valid, _ntrees+1, new ScoreKeeper()) : null;
_training_time_ms = ArrayUtils.copyAndFillOf(_training_time_ms, _ntrees+1, System.currentTimeMillis());
fs.blockForPending();
}
public CompressedTree ctree( int tnum, int knum ) { return _treeKeys[tnum][knum].get(); }
public String toStringTree ( int tnum, int knum ) { return ctree(tnum,knum).toString(this); }
}
public SharedTreeModel(Key<M> selfKey, P parms, O output) {
super(selfKey, parms, output);
}
public Frame scoreLeafNodeAssignment(Frame frame, Key<Frame> destination_key) {
Frame adaptFrm = new Frame(frame);
adaptTestForTrain(adaptFrm, true, false);
int classTrees = 0;
for (int i = 0; i < _output._treeKeys[0].length; ++i) {
if (_output._treeKeys[0][i] != null) classTrees++;
}
final int outputcols = _output._treeKeys.length * classTrees;
final String[] names = new String[outputcols];
int col = 0;
for (int tidx = 0; tidx < _output._treeKeys.length; tidx++) {
Key[] keys = _output._treeKeys[tidx];
for (int c = 0; c < keys.length; c++) {
if (keys[c] != null) {
names[col++] = "T" + (tidx + 1) + (keys.length == 1 ? "" : (".C" + (c + 1)));
}
}
}
Frame res = new MRTask() {
@Override public void map(Chunk chks[], NewChunk[] idx ) {
double[] input = new double[chks.length];
String[] output = new String[outputcols];
for( int row=0; row<chks[0]._len; row++ ) {
for( int i=0; i<chks.length; i++ )
input[i] = chks[i].atd(row);
int col=0;
for( int tidx=0; tidx<_output._treeKeys.length; tidx++ ) {
Key[] keys = _output._treeKeys[tidx];
for (Key key : keys) {
if (key != null) {
String pred = DKV.get(key).<CompressedTree>get().getDecisionPath(input);
output[col++] = pred;
}
}
}
assert(col==outputcols);
for (int i=0; i<outputcols; ++i)
idx[i].addStr(output[i]);
}
}
}.doAll(outputcols, Vec.T_STR, adaptFrm).outputFrame(destination_key, names, null);
Vec vv;
Vec[] nvecs = new Vec[res.vecs().length];
for(int c=0;c<res.vecs().length;++c) {
vv = res.vec(c);
try {
nvecs[c] = vv.toCategoricalVec();
} catch (Exception e) {
VecUtils.deleteVecs(nvecs, c);
throw e;
}
}
res.delete();
res = new Frame(destination_key, names, nvecs);
DKV.put(res);
return res;
}
@Override
protected Frame postProcessPredictions(Frame predictFr) {
if (_output._calib_model == null)
return predictFr;
if (_output.getModelCategory() == Binomial) {
Key<Frame> calibInputKey = Key.make();
Frame calibOutput = null;
try {
Frame calibInput = new Frame(calibInputKey, new String[]{"p"}, new Vec[]{predictFr.vec(1)});
calibOutput = _output._calib_model.score(calibInput);
assert calibOutput._names.length == 3;
Vec[] calPredictions = calibOutput.remove(new int[]{1, 2});
// append calibrated probabilities to the prediction frame
predictFr.write_lock();
for (int i = 0; i < calPredictions.length; i++)
predictFr.add("cal_" + predictFr.name(1 + i), calPredictions[i]);
return predictFr.update();
} finally {
DKV.remove(calibInputKey);
if (calibOutput != null)
calibOutput.remove();
}
} else
throw H2O.unimpl("Calibration is only supported for binomial models");
}
@Override protected double[] score0(double[] data, double[] preds, double weight, double offset) {
return score0(data, preds, weight, offset, _output._treeKeys.length);
}
@Override protected double[] score0(double[/*ncols*/] data, double[/*nclasses+1*/] preds) {
return score0(data, preds, 1.0, 0.0);
}
protected double[] score0(double[] data, double[] preds, double weight, double offset, int ntrees) {
// Prefetch trees into the local cache if it is necessary
// Invoke scoring
Arrays.fill(preds,0);
for( int tidx=0; tidx<ntrees; tidx++ )
score0(data, preds, tidx);
return preds;
}
// Score per line per tree
private void score0(double[] data, double[] preds, int treeIdx) {
Key[] keys = _output._treeKeys[treeIdx];
for( int c=0; c<keys.length; c++ ) {
if (keys[c] != null) {
double pred = DKV.get(keys[c]).<CompressedTree>get().score(data);
assert (!Double.isInfinite(pred));
preds[keys.length == 1 ? 0 : c + 1] += pred;
}
}
}
/** Performs deep clone of given model. */
protected M deepClone(Key<M> result) {
M newModel = IcedUtils.deepCopy(self());
newModel._key = result;
// Do not clone model metrics
newModel._output.clearModelMetrics();
newModel._output._training_metrics = null;
newModel._output._validation_metrics = null;
// Clone trees
Key[][] treeKeys = newModel._output._treeKeys;
for (int i = 0; i < treeKeys.length; i++) {
for (int j = 0; j < treeKeys[i].length; j++) {
if (treeKeys[i][j] == null) continue;
CompressedTree ct = DKV.get(treeKeys[i][j]).get();
CompressedTree newCt = IcedUtils.deepCopy(ct);
newCt._key = CompressedTree.makeTreeKey(i, j);
DKV.put(treeKeys[i][j] = newCt._key,newCt);
}
}
// Clone Aux info
Key[][] treeKeysAux = newModel._output._treeKeysAux;
if (treeKeysAux!=null) {
for (int i = 0; i < treeKeysAux.length; i++) {
for (int j = 0; j < treeKeysAux[i].length; j++) {
if (treeKeysAux[i][j] == null) continue;
CompressedTree ct = DKV.get(treeKeysAux[i][j]).get();
CompressedTree newCt = IcedUtils.deepCopy(ct);
newCt._key = Key.make(createAuxKey(treeKeys[i][j].toString()));
DKV.put(treeKeysAux[i][j] = newCt._key,newCt);
}
}
}
return newModel;
}
@Override protected Futures remove_impl( Futures fs ) {
for (Key[] ks : _output._treeKeys)
for (Key k : ks)
if( k != null ) k.remove(fs);
for (Key[] ks : _output._treeKeysAux)
for (Key k : ks)
if( k != null ) k.remove(fs);
if (_output._calib_model != null)
_output._calib_model.remove(fs);
return super.remove_impl(fs);
}
/** Write out K/V pairs */
@Override protected AutoBuffer writeAll_impl(AutoBuffer ab) {
for (Key<CompressedTree>[] ks : _output._treeKeys)
for (Key<CompressedTree> k : ks)
ab.putKey(k);
for (Key<CompressedTree>[] ks : _output._treeKeysAux)
for (Key<CompressedTree> k : ks)
ab.putKey(k);
return super.writeAll_impl(ab);
}
@Override protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
for (Key<CompressedTree>[] ks : _output._treeKeys)
for (Key<CompressedTree> k : ks)
ab.getKey(k,fs);
for (Key<CompressedTree>[] ks : _output._treeKeysAux)
for (Key<CompressedTree> k : ks)
ab.getKey(k,fs);
return super.readAll_impl(ab,fs);
}
@SuppressWarnings("unchecked") // `M` is really the type of `this`
private M self() { return (M)this; }
//--------------------------------------------------------------------------------------------------------------------
// Serialization into a POJO
//--------------------------------------------------------------------------------------------------------------------
// Override in subclasses to provide some top-level model-specific goodness
@Override protected boolean toJavaCheckTooBig() {
// If the number of leaves in a forest is more than N, don't try to render it in the browser as POJO code.
return _output==null || _output._treeStats._num_trees * _output._treeStats._mean_leaves > 1000000;
}
protected boolean binomialOpt() { return true; }
@Override protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
sb.nl();
sb.ip("public boolean isSupervised() { return true; }").nl();
sb.ip("public int nfeatures() { return " + _output.nfeatures() + "; }").nl();
sb.ip("public int nclasses() { return " + _output.nclasses() + "; }").nl();
return sb;
}
@Override protected void toJavaPredictBody(SBPrintStream body,
CodeGeneratorPipeline classCtx,
CodeGeneratorPipeline fileCtx,
final boolean verboseCode) {
final int nclass = _output.nclasses();
body.ip("java.util.Arrays.fill(preds,0);").nl();
final String mname = JCodeGen.toJavaId(_key.toString());
// One forest-per-GBM-tree, with a real-tree-per-class
for (int t=0; t < _output._treeKeys.length; t++) {
// Generate score method for given tree
toJavaForestName(body.i(),mname,t).p(".score0(data,preds);").nl();
final int treeIdx = t;
fileCtx.add(new CodeGenerator() {
@Override
public void generate(JCodeSB out) {
try {
// Generate a class implementing a tree
out.nl();
toJavaForestName(out.ip("class "), mname, treeIdx).p(" {").nl().ii(1);
out.ip("public static void score0(double[] fdata, double[] preds) {").nl().ii(1);
for (int c = 0; c < nclass; c++) {
if (_output._treeKeys[treeIdx][c] == null) continue;
if (!(binomialOpt() && c == 1 && nclass == 2)) // Binomial optimization
toJavaTreeName(out.ip("preds[").p(nclass == 1 ? 0 : c + 1).p("] += "), mname, treeIdx, c).p(".score0(fdata);").nl();
}
out.di(1).ip("}").nl(); // end of function
out.di(1).ip("}").nl(); // end of forest class
// Generate the pre-tree classes afterwards
for (int c = 0; c < nclass; c++) {
if (_output._treeKeys[treeIdx][c] == null) continue;
if (!(binomialOpt() && c == 1 && nclass == 2)) { // Binomial optimization
String javaClassName = toJavaTreeName(new SB(), mname, treeIdx, c).toString();
CompressedTree ct = _output.ctree(treeIdx, c);
SB sb = new SB();
new TreeJCodeGen(SharedTreeModel.this, ct, sb, javaClassName, verboseCode).generate();
out.p(sb);
}
}
} catch (Throwable t) {
t.printStackTrace();
throw new IllegalArgumentException("Internal error creating the POJO.", t);
}
}
});
}
toJavaUnifyPreds(body);
}
protected abstract void toJavaUnifyPreds(SBPrintStream body);
protected <T extends JCodeSB> T toJavaTreeName(T sb, String mname, int t, int c ) {
return (T) sb.p(mname).p("_Tree_").p(t).p("_class_").p(c);
}
protected <T extends JCodeSB> T toJavaForestName(T sb, String mname, int t ) {
return (T) sb.p(mname).p("_Forest_").p(t);
}
}