package hex.tree;
import hex.ModelMojoWriter;
import hex.glm.GLMModel;
import water.DKV;
import water.Key;
import water.Value;
import water.exceptions.H2OKeyNotFoundArgumentException;
import java.io.IOException;
/**
* Shared Mojo definition file for DRF and GBM models.
*/
public abstract class SharedTreeMojoWriter<
M extends SharedTreeModel<M, P, O>,
P extends SharedTreeModel.SharedTreeParameters,
O extends SharedTreeModel.SharedTreeOutput
> extends ModelMojoWriter<M, P, O> {
public SharedTreeMojoWriter() {}
public SharedTreeMojoWriter(M model) {
super(model);
}
@Override
protected void writeModelData() throws IOException {
assert model._output._treeKeys.length == model._output._ntrees;
int nclasses = model._output.nclasses();
int ntreesPerClass = model.binomialOpt() && nclasses == 2 ? 1 : nclasses;
writekv("n_trees", model._output._ntrees);
writekv("n_trees_per_class", ntreesPerClass);
if (model._output._calib_model != null) {
GLMModel calibModel = model._output._calib_model;
double[] beta = calibModel.beta();
assert beta.length == nclasses; // n-1 coefficients + 1 intercept
writekv("calib_method", "platt");
writekv("calib_glm_beta", beta);
}
for (int i = 0; i < model._output._ntrees; i++) {
for (int j = 0; j < ntreesPerClass; j++) {
Key<CompressedTree> key = model._output._treeKeys[i][j];
Value ctVal = key != null ? DKV.get(key) : null;
if (ctVal == null)
continue; //throw new H2OKeyNotFoundArgumentException("CompressedTree " + key + " not found");
CompressedTree ct = ctVal.get();
assert ct._nclass == nclasses;
// assume ct._seed is useless and need not be persisted
writeblob(String.format("trees/t%02d_%03d.bin", j, i), ct._bits);
if (model._output._treeKeysAux!=null) {
key = model._output._treeKeysAux[i][j];
ctVal = key != null ? DKV.get(key) : null;
if (ctVal != null) {
ct = ctVal.get();
assert ct._nclass == -1;
writeblob(String.format("trees/t%02d_%03d_aux.bin", j, i), ct._bits);
}
}
}
}
}
}