package hex.genmodel.algos.tree; import hex.genmodel.ModelMojoReader; import java.io.IOException; /** */ public abstract class SharedTreeMojoReader<M extends SharedTreeMojoModel> extends ModelMojoReader<M> { @Override protected void readModelData() throws IOException { // In mojos v=1.0 this info wasn't saved. Integer tpc = readkv("n_trees_per_class"); if (tpc == null) { Boolean bdt = readkv("binomial_double_trees"); // This flag exists only for DRF models tpc = _model.nclasses() == 2 && (bdt == null || !bdt)? 1 : _model.nclasses(); } _model._ntree_groups = readkv("n_trees"); _model._ntrees_per_group = tpc; _model._compressed_trees = new byte[_model._ntree_groups * tpc][]; _model._mojo_version = readkv("mojo_version"); // In mojos v=1.0 this info wasn't saved. if (!_model._mojo_version.equals(1.0)) { _model._compressed_trees_aux = new byte[_model._ntree_groups * tpc][]; } for (int j = 0; j < _model._ntree_groups; j++) for (int i = 0; i < tpc; i++) { String blobName = String.format("trees/t%02d_%03d.bin", i, j); if (!exists(blobName)) continue; _model._compressed_trees[_model.treeIndex(j, i)] = readblob(blobName); if (_model._compressed_trees_aux!=null) { _model._compressed_trees_aux[_model.treeIndex(j, i)] = readblob(String.format("trees/t%02d_%03d_aux.bin", i, j)); } } // Calibration String calibMethod = readkv("calib_method"); if (calibMethod != null) { if (! "platt".equals(calibMethod)) throw new IllegalStateException("Unknown calibration method: " + calibMethod); _model._calib_glm_beta = readkv("calib_glm_beta", new double[0]); } } }