package hex;
import hex.ensemble.StackedEnsemble;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel;
import hex.tree.drf.DRFModel;
import water.*;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.nbhm.NonBlockingHashSet;
import water.util.Log;
import water.util.ReflectionUtils;
import java.lang.reflect.Field;
import java.util.Arrays;
import static hex.Model.Parameters.FoldAssignmentScheme.Modulo;
/**
* An ensemble of other models, created by <i>stacking</i> with the SuperLearner algorithm or a variation.
*/
public class StackedEnsembleModel extends Model<StackedEnsembleModel,StackedEnsembleModel.StackedEnsembleParameters,StackedEnsembleModel.StackedEnsembleOutput> {
public ModelCategory modelCategory;
public long trainingFrameChecksum = -1;
public String responseColumn = null;
private NonBlockingHashSet<String> names = null; // keep columns as a set for easier comparison
private NonBlockingHashSet<String> ignoredColumns = null; // keep ignored_columns as a set for easier comparison
public int nfolds = -1;
// TODO: add a separate holdout dataset for the ensemble
// TODO: add a separate overall cross-validation for the ensemble, including _fold_column and FoldAssignmentScheme / _fold_assignment
public StackedEnsembleModel(Key selfKey, StackedEnsembleParameters parms, StackedEnsembleOutput output) {
super(selfKey, parms, output);
}
public static class StackedEnsembleParameters extends Model.Parameters {
public String algoName() { return "StackedEnsemble"; }
public String fullName() { return "Stacked Ensemble"; }
public String javaName() { return StackedEnsembleModel.class.getName(); }
@Override public long progressUnits() { return 1; } // TODO
public static enum SelectionStrategy { choose_all }
// TODO: make _selection_strategy an object:
/** How do we choose which models to stack? */
public SelectionStrategy _selection_strategy;
/** Which models can we choose from? */
public Key<Model> _base_models[] = new Key[0];
}
public static class StackedEnsembleOutput extends Model.Output {
public StackedEnsembleOutput() { super(); }
public StackedEnsembleOutput(StackedEnsemble b) { super(b); }
public StackedEnsembleOutput(Job job) { _job = job; }
// The metalearner model (e.g., a GLM that has a coefficient for each of the base_learners).
public Model _metalearner;
}
/**
* For StackedEnsemble we call score on all the base_models and then combine the results
* with the metalearner to create the final predictions frame.
*
* @see Model#predictScoreImpl(Frame, Frame, String, Job, boolean)
* @param adaptFrm Already adapted frame
* @param computeMetrics
* @return A Frame containing the prediction column, and class distribution
*/
protected Frame predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j, boolean computeMetrics) {
// Build up the names & domains.
String[] names = makeScoringNames();
String[][] domains = new String[names.length][];
domains[0] = names.length == 1 ? null : !computeMetrics ? _output._domains[_output._domains.length-1] : adaptFrm.lastVec().domain();
// TODO: optimize these DKV lookups:
Frame levelOneFrame = new Frame(Key.<Frame>make("preds_levelone_" + this._key.toString() + fr._key));
int baseIdx = 0;
Frame[] base_prediction_frames = new Frame[this._parms._base_models.length];
// TODO: don't score models that have 0 coefficients / aren't used by the metalearner.
for (Key<Model> baseKey : this._parms._base_models) {
Model base = baseKey.get(); // TODO: cacheme!
// adapt fr for each base_model
// TODO: cache: don't need to call base.adaptTestForTrain() if the
// base_model has the same names and domains as one we've seen before.
// Such base_models can share their adapted frame.
Frame adaptedFrame = new Frame(fr);
base.adaptTestForTrain(adaptedFrame, true, computeMetrics);
// TODO: parallel scoring for the base_models
BigScore baseBs = (BigScore) base.makeBigScoreTask(domains, names, adaptedFrame, computeMetrics, true, j).doAll(names.length, Vec.T_NUM, adaptedFrame);
Frame basePreds = baseBs.outputFrame(Key.<Frame>make("preds_base_" + this._key.toString() + fr._key), names, domains);
base_prediction_frames[baseIdx] = basePreds;
StackedEnsemble.addModelPredictionsToLevelOneFrame(base, basePreds, levelOneFrame);
Model.cleanup_adapt(adaptedFrame, fr);
baseIdx++;
}
levelOneFrame.add(this.responseColumn, adaptFrm.vec(this.responseColumn));
// TODO: what if we're running multiple in parallel and have a name collision?
DKV.put(levelOneFrame);
Log.info("Finished creating \"level one\" frame for scoring: " + levelOneFrame.toString());
// Score the dataset, building the class distribution & predictions
Model metalearner = this._output._metalearner;
Frame levelOneAdapted = new Frame(levelOneFrame);
metalearner.adaptTestForTrain(levelOneAdapted, true, computeMetrics);
DKV.put(levelOneAdapted);
String[] metaNames = metalearner.makeScoringNames();
String[][] metaDomains = new String[metaNames.length][];
metaDomains[0] = metaNames.length == 1 ? null : !computeMetrics ? metalearner._output._domains[metalearner._output._domains.length-1] : levelOneAdapted.lastVec().domain();
BigScore metaBs = (BigScore)metalearner.makeBigScoreTask(metaDomains, metaNames, levelOneAdapted, computeMetrics, true, j).
doAll(metaNames.length, Vec.T_NUM, levelOneAdapted);
if (computeMetrics) {
ModelMetrics mmMetalearner = metaBs._mb.makeModelMetrics(metalearner, levelOneFrame, levelOneAdapted, metaBs.outputFrame());
// This has just stored a ModelMetrics object for the (metalearner, preds_levelone) Model/Frame pair.
// We need to be able to look it up by the (this, fr) pair.
// The ModelMetrics object for the metalearner will be removed when the metalearner is removed.
ModelMetrics mmStackedEnsemble = mmMetalearner.deepCloneWithDifferentModelAndFrame(this, fr);
this.addModelMetrics(mmStackedEnsemble);
}
Model.cleanup_adapt(levelOneAdapted, levelOneFrame);
return metaBs.outputFrame(Key.<Frame>make(destination_key), metaNames, metaDomains);
}
/**
* Should never be called: the code paths that normally go here should call predictScoreImpl().
* @see Model#score0(double[], double[])
*/
@Override
protected double[] score0(double data[/*ncols*/], double preds[/*nclasses+1*/]) {
throw new UnsupportedOperationException("StackedEnsembleModel.score0() should never be called: the code paths that normally go here should call predictScoreImpl().");
}
@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
switch (_output.getModelCategory()) {
case Binomial:
return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
// case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain);
case Regression:
return new ModelMetricsRegression.MetricBuilderRegression();
default:
throw H2O.unimpl();
}
}
public ModelMetrics doScoreMetricsOneFrame(Frame frame, Job job) {
this.predictScoreImpl(frame, new Frame(frame), null, job, true);
return ModelMetrics.getFromDKV(this, frame);
}
public void doScoreMetrics(Job job) {
this._output._training_metrics = doScoreMetricsOneFrame(this._parms.train(), job);
if (null != this._parms.valid()) {
this._output._validation_metrics = doScoreMetricsOneFrame(this._parms.valid(), job);
}
}
private DistributionFamily distributionFamily(Model aModel) {
// TODO: hack alert: In DRF, _parms._distribution is always set to multinomial. Yay.
if (aModel instanceof DRFModel)
if (aModel._output.isBinomialClassifier())
return DistributionFamily.bernoulli;
else if (aModel._output.isClassifier())
throw new H2OIllegalArgumentException("Don't know how to set the distribution for a multinomial Random Forest classifier.");
else
return DistributionFamily.gaussian;
try {
Field familyField = ReflectionUtils.findNamedField(aModel._parms, "_family");
Field distributionField = (familyField != null ? null : ReflectionUtils.findNamedField(aModel, "_dist"));
if (null != familyField) {
// GLM only, for now
GLMModel.GLMParameters.Family thisFamily = (GLMModel.GLMParameters.Family) familyField.get(aModel._parms);
if (thisFamily == GLMModel.GLMParameters.Family.binomial) {
return DistributionFamily.bernoulli;
}
try {
return Enum.valueOf(DistributionFamily.class, thisFamily.toString());
}
catch (IllegalArgumentException e) {
throw new H2OIllegalArgumentException("Don't know how to find the right DistributionFamily for Family: " + thisFamily);
}
}
if (null != distributionField) {
Distribution distribution = ((Distribution)distributionField.get(aModel));
DistributionFamily distributionFamily;
if (null != distribution)
distributionFamily = distribution.distribution;
else
distributionFamily = aModel._parms._distribution;
// NOTE: If the algo does smart guessing of the distribution family we need to duplicate the logic here.
if (distributionFamily == DistributionFamily.AUTO) {
if (aModel._output.isBinomialClassifier())
distributionFamily = DistributionFamily.bernoulli;
else if (aModel._output.isClassifier())
throw new H2OIllegalArgumentException("Don't know how to determine the distribution for a multinomial classifier.");
else
distributionFamily = DistributionFamily.gaussian;
} // DistributionFamily.AUTO
return distributionFamily;
}
throw new H2OIllegalArgumentException("Don't know how to stack models that have neither a distribution hyperparameter nor a family hyperparameter.");
}
catch (Exception e) {
throw new H2OIllegalArgumentException(e.toString(), e.toString());
}
}
public void checkAndInheritModelProperties() {
if (null == _parms._base_models || 0 == _parms._base_models.length)
throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; found 0.");
Model aModel = null;
boolean beenHere = false;
trainingFrameChecksum = _parms.train().checksum();
for (Key<Model> k : _parms._base_models) {
aModel = DKV.getGet(k);
if (null == aModel) {
Log.warn("Failed to find base model; skipping: " + k);
continue;
}
if (beenHere) {
// check that the base models are all consistent
if (_output._isSupervised ^ aModel.isSupervised())
throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of supervised and unsupervised models: " + Arrays.toString(_parms._base_models));
if (modelCategory != aModel._output.getModelCategory())
throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different categories of models: " + Arrays.toString(_parms._base_models));
Frame aTrainingFrame = aModel._parms.train();
if (trainingFrameChecksum != aTrainingFrame.checksum())
throw new H2OIllegalArgumentException("Base models are inconsistent: they use different training frames. Found checksums: " + trainingFrameChecksum + " and: " + aTrainingFrame.checksum() + ".");
NonBlockingHashSet<String> aNames = new NonBlockingHashSet<>();
aNames.addAll(Arrays.asList(aModel._output._names));
if (! aNames.equals(this.names))
throw new H2OIllegalArgumentException("Base models are inconsistent: they use different column lists. Found: " + this.names + " and: " + aNames + ".");
NonBlockingHashSet<String> anIgnoredColumns = new NonBlockingHashSet<>();
if (null != aModel._parms._ignored_columns)
anIgnoredColumns.addAll(Arrays.asList(aModel._parms._ignored_columns));
if (! anIgnoredColumns.equals(this.ignoredColumns))
throw new H2OIllegalArgumentException("Base models are inconsistent: they use different ignored_column lists. Found: " + this.ignoredColumns + " and: " + aModel._parms._ignored_columns + ".");
if (! responseColumn.equals(aModel._parms._response_column))
throw new H2OIllegalArgumentException("Base models are inconsistent: they use different response columns. Found: " + responseColumn + " and: " + aModel._parms._response_column + ".");
if (_output._domains.length != aModel._output._domains.length)
throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different numbers of domains (categorical levels): " + Arrays.toString(_parms._base_models));
if (nfolds != aModel._parms._nfolds)
throw new H2OIllegalArgumentException("Base models are inconsistent: they use different values for nfolds.");
// TODO: loosen this iff _parms._valid or if we add a separate holdout dataset for the ensemble
if (aModel._parms._nfolds < 2)
throw new H2OIllegalArgumentException("Base model does not use cross-validation: " + aModel._parms._nfolds);
// TODO: loosen this iff it's consistent, like if we have a _fold_column
if (aModel._parms._fold_assignment != Modulo)
throw new H2OIllegalArgumentException("Base model does not use Modulo for cross-validation: " + aModel._parms._nfolds);
if (! aModel._parms._keep_cross_validation_predictions)
throw new H2OIllegalArgumentException("Base model does not keep cross-validation predictions: " + aModel._parms._nfolds);
// In GLM, we get _family instead of _distribution.
// Further, we have Family.binomial instead of DistributionFamily.bernoulli.
// We also handle DistributionFamily.AUTO in distributionFamily()
//
// Hack alert: DRF only does Bernoulli and Gaussian, so only compare _domains.length above.
if (! (aModel instanceof DRFModel) && distributionFamily(aModel) != distributionFamily(this))
Log.warn("Base models are inconsistent; they use different distributions: " + distributionFamily(this) + " and: " + distributionFamily(aModel) + ". Is this intentional?");
// TODO: If we're set to DistributionFamily.AUTO then GLM might auto-conform the response column
// giving us inconsistencies.
} else {
// !beenHere: this is the first base_model
_output._isSupervised = aModel.isSupervised();
this.modelCategory = aModel._output.getModelCategory();
this._dist = new Distribution(distributionFamily(aModel));
_output._domains = Arrays.copyOf(aModel._output._domains, aModel._output._domains.length);
// TODO: set _parms._train to aModel._parms.train()
_output.setNames(aModel._output._names);
this.names = new NonBlockingHashSet<>();
this.names.addAll(Arrays.asList(aModel._output._names));
this.ignoredColumns = new NonBlockingHashSet<>();
if (null != aModel._parms._ignored_columns)
this.ignoredColumns.addAll(Arrays.asList(aModel._parms._ignored_columns));
// If the client has set _ignored_columns for the StackedEnsemble make sure it's
// consistent with the base_models:
if (null != this._parms._ignored_columns) {
NonBlockingHashSet<String> ensembleIgnoredColumns = new NonBlockingHashSet<>();
ensembleIgnoredColumns.addAll(Arrays.asList(this._parms._ignored_columns));
if (! ensembleIgnoredColumns.equals(this.ignoredColumns))
throw new H2OIllegalArgumentException("A StackedEnsemble takes its ignored_columns list from the base models. An inconsistent list of ignored_columns was specified for the ensemble model.");
}
responseColumn = aModel._parms._response_column;
if (! responseColumn.equals(_parms._response_column))
throw new H2OIllegalArgumentException("StackedModel response_column must match the response_column of each base model. Found: " + responseColumn + " and: " + _parms._response_column);
nfolds = aModel._parms._nfolds;
_parms._distribution = aModel._parms._distribution;
beenHere = true;
}
} // for all base_models
if (null == aModel)
throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; " + _parms._base_models.length + " were specified but none of those were found: " + Arrays.toString(_parms._base_models));
}
// TODO: Are we leaking anything?
@Override protected Futures remove_impl(Futures fs ) {
if (_output._metalearner != null)
DKV.remove(_output._metalearner._key, fs);
return super.remove_impl(fs);
}
}