package hex.ensemble; import hex.Model; import hex.ModelBuilder; import hex.ModelCategory; import hex.StackedEnsembleModel; import hex.glm.GLM; import hex.glm.GLMModel; import water.*; import water.exceptions.H2OIllegalArgumentException; import water.exceptions.H2OModelBuilderIllegalArgumentException; import water.fvec.Frame; import water.fvec.Vec; import water.util.Log; /** * An ensemble of other models, created by <i>stacking</i> with the SuperLearner algorithm or a variation. */ public class StackedEnsemble extends ModelBuilder<StackedEnsembleModel,StackedEnsembleModel.StackedEnsembleParameters,StackedEnsembleModel.StackedEnsembleOutput> { StackedEnsembleDriver _driver; // The in-progress model being built protected StackedEnsembleModel _model; public StackedEnsemble(boolean startup_once) { super(new StackedEnsembleModel.StackedEnsembleParameters(),startup_once); } /* public StackedEnsemble(Key selfKey, StackedEnsembleModel.StackedEnsembleParameters parms, StackedEnsemble job) { super(selfKey, parms, job == null ? new StackedEnsembleModel.StackedEnsembleOutput(): new StackedEnsembleModel.StackedEnsembleOutput(job)); } */ @Override public ModelCategory[] can_build() { return new ModelCategory[]{ ModelCategory.Regression, ModelCategory.Binomial, // ModelCategory.Multinomial, // TODO }; } @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; } @Override public boolean isSupervised() { return true; } @Override protected StackedEnsembleDriver trainModelImpl() { return _driver = new StackedEnsembleDriver(); } public static void addModelPredictionsToLevelOneFrame(Model aModel, Frame aModelsPredictions, Frame levelOneFrame) { if (aModel._output.isBinomialClassifier()) { // GLM uses a different column name than the other algos, yay! Vec preds = aModelsPredictions.vec(2); // Predictions column names have been changed. . . levelOneFrame.add(aModel._key.toString(), preds); } else if (aModel._output.isClassifier()) { throw new H2OIllegalArgumentException("Don't yet know how to stack multinomial classifiers: " + aModel._key); } else if (aModel._output.isAutoencoder()) { throw new H2OIllegalArgumentException("Don't yet know how to stack autoencoders: " + aModel._key); } else if (!aModel._output.isSupervised()) { throw new H2OIllegalArgumentException("Don't yet know how to stack unsupervised models: " + aModel._key); } else { levelOneFrame.add(aModel._key.toString(), aModelsPredictions.vec("predict")); } } private class StackedEnsembleDriver extends Driver { private Frame prepareLevelOneFrame(StackedEnsembleModel.StackedEnsembleParameters parms) { // TODO: allow the user to name the level one frame Frame levelOneFrame = new Frame(Key.<Frame>make("levelone_" + _model._key.toString())); for (Key<Model> k : _parms._base_models) { Model aModel = DKV.getGet(k); if (null == aModel) { Log.warn("Failed to find base model; skipping: " + k); continue; } if (null == aModel._output._cross_validation_holdout_predictions_frame_id) throw new H2OIllegalArgumentException("Failed to find the xval predictions frame id. . . Looks like keep_cross_validation_predictions wasn't set when building the models."); // add the predictions for aModel to levelOneFrame // TODO: multinomial classification: Frame aModelsPredictions = aModel._output._cross_validation_holdout_predictions_frame_id.get(); StackedEnsemble.addModelPredictionsToLevelOneFrame(aModel, aModelsPredictions, levelOneFrame); } // for all base_models levelOneFrame.add(_model.responseColumn, _model._parms.train().vec(_model.responseColumn)); // TODO: what if we're running multiple in parallel and have a name collision? Frame old = DKV.getGet(levelOneFrame._key); if (old != null && old instanceof Frame) { Frame oldFrame = (Frame)old; // Remove ALL the columns so we don't delete them in remove_impl. Their // lifetime is controlled by their model. oldFrame.removeAll(); oldFrame.write_lock(_job); oldFrame.update(_job); oldFrame.unlock(_job); } levelOneFrame.delete_and_lock(_job); levelOneFrame.unlock(_job); Log.info("Finished creating \"level one\" frame for stacking: " + levelOneFrame.toString()); return levelOneFrame; } public void computeImpl() { init(true); if (error_count() > 0) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(StackedEnsemble.this); _model = new StackedEnsembleModel(dest(), _parms, new StackedEnsembleModel.StackedEnsembleOutput(StackedEnsemble.this)); _model.delete_and_lock(_job); // and clear & write-lock it (smashing any prior) _model.checkAndInheritModelProperties(); Frame levelOneFrame = prepareLevelOneFrame(_parms); // train the metalearner model // TODO: allow types other than GLM // Default Job for just this training Key<Model> metalearnerKey = Key.<Model>make("metalearner_" + _model._key); Job job = new Job<>(metalearnerKey, ModelBuilder.javaName("glm"), "StackingEnsemble metalearner (GLM)"); GLM metaBuilder = ModelBuilder.make("GLM", job, metalearnerKey); metaBuilder._parms._non_negative = true; metaBuilder._parms._train = levelOneFrame._key; metaBuilder._parms._response_column = _model.responseColumn; // TODO: multinomial // TODO: support other families for regression metaBuilder._parms._family = _model.modelCategory == ModelCategory.Regression ? GLMModel.GLMParameters.Family.gaussian : GLMModel.GLMParameters.Family.binomial; metaBuilder.init(false); Job<GLMModel> j = metaBuilder.trainModel(); while (j.isRunning()) { try { _job.update(j._work, "training metalearner"); Thread.sleep(100); } catch (InterruptedException e) {} } Log.info("Finished training metalearner model."); _model._output._metalearner = metaBuilder.get(); _model.doScoreMetrics(_job); // _model._output._model_summary = createModelSummaryTable(model._output); _model.update(_job); _model.unlock(_job); } // computeImpl } }