package hex.genmodel.algos.gbm; import hex.genmodel.GenModel; import hex.genmodel.algos.tree.SharedTreeMojoModel; import hex.genmodel.utils.DistributionFamily; import static hex.genmodel.utils.DistributionFamily.*; /** * "Gradient Boosting Machine" MojoModel */ public final class GbmMojoModel extends SharedTreeMojoModel { public DistributionFamily _family; public double _init_f; public GbmMojoModel(String[] columns, String[][] domains) { super(columns, domains); } /** * Corresponds to `hex.tree.drf.DrfMojoModel.score0()` */ @Override public final double[] score0(double[] row, double offset, double[] preds) { super.scoreAllTrees(row, preds); if (_family == bernoulli || _family == modified_huber) { double f = preds[1] + _init_f + offset; preds[2] = _family.linkInv(f); preds[1] = 1.0 - preds[2]; } else if (_family == multinomial) { if (_nclasses == 2) { // 1-tree optimization for binomial preds[1] += _init_f + offset; //offset is not yet allowed, but added here to be future-proof preds[2] = -preds[1]; } GenModel.GBM_rescale(preds); } else { // Regression double f = preds[0] + _init_f + offset; preds[0] = _family.linkInv(f); return preds; } if (_balanceClasses) GenModel.correctProbabilities(preds, _priorClassDistrib, _modelClassDistrib); preds[0] = GenModel.getPrediction(preds, _priorClassDistrib, row, _defaultThreshold); return preds; } @Override public double[] score0(double[] row, double[] preds) { return score0(row, 0.0, preds); } }