package hex.genmodel.algos.drf; import hex.genmodel.GenModel; import hex.genmodel.algos.tree.SharedTreeMojoModel; /** * "Distributed Random Forest" MojoModel */ public final class DrfMojoModel extends SharedTreeMojoModel { protected boolean _binomial_double_trees; public DrfMojoModel(String[] columns, String[][] domains) { super(columns, domains); } /** * Corresponds to `hex.tree.drf.DrfMojoModel.score0()` */ @Override public final double[] score0(double[] row, double offset, double[] preds) { super.scoreAllTrees(row, preds); // Correct the predictions -- see `DRFModel.toJavaUnifyPreds` if (_nclasses == 1) { // Regression preds[0] /= _ntree_groups; } else { // Classification if (_nclasses == 2 && !_binomial_double_trees) { // Binomial model preds[1] /= _ntree_groups; preds[2] = 1.0 - preds[1]; } else { // Multinomial double sum = 0; for (int i = 1; i <= _nclasses; i++) { sum += preds[i]; } if (sum > 0) for (int i = 1; i <= _nclasses; i++) { preds[i] /= sum; } } if (_balanceClasses) GenModel.correctProbabilities(preds, _priorClassDistrib, _modelClassDistrib); preds[0] = GenModel.getPrediction(preds, _priorClassDistrib, row, _defaultThreshold); } return preds; } @Override public double[] score0(double[] row, double[] preds) { return score0(row, 0.0, preds); } }