package water.api;
import hex.*;
import hex.genmodel.utils.DistributionFamily;
import water.*;
import water.api.schemas3.*;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OKeyNotFoundArgumentException;
import water.fvec.Frame;
import water.util.Log;
class ModelMetricsHandler extends Handler {
/** Class which contains the internal representation of the ModelMetrics list and params. */
public static final class ModelMetricsList extends Iced {
public Model _model;
public Frame _frame;
public ModelMetrics[] _model_metrics;
public String _predictions_name;
public String _deviances_name;
public boolean _deviances;
public boolean _reconstruction_error;
public boolean _reconstruction_error_per_feature;
public int _deep_features_hidden_layer = -1;
public String _deep_features_hidden_layer_name = null;
public boolean _reconstruct_train;
public boolean _project_archetypes;
public boolean _reverse_transform;
public boolean _leaf_node_assignment;
public int _exemplar_index = -1;
// Fetch all metrics that match model and/or frame
ModelMetricsList fetch() {
final Key[] modelMetricsKeys = KeySnapshot.globalSnapshot().filter(new KeySnapshot.KVFilter() {
@Override public boolean filter(KeySnapshot.KeyInfo k) {
try {
if( !Value.isSubclassOf(k._type, ModelMetrics.class) ) return false; // Fast-path cutout
ModelMetrics mm = DKV.getGet(k._key);
// If we're filtering by model filter by Model. :-)
if( _model != null && !mm.isForModel((Model)DKV.getGet(_model._key)) ) return false;
// If we're filtering by frame filter by Frame. :-)
if( _frame != null && !mm.isForFrame((Frame)DKV.getGet(_frame._key)) ) return false;
} catch( NullPointerException | ClassCastException ex ) {
return false; // Handle all kinds of broken racey key updates
}
return true;
}
}).keys();
_model_metrics = new ModelMetrics[modelMetricsKeys.length];
for (int i = 0; i < modelMetricsKeys.length; i++)
_model_metrics[i] = DKV.getGet(modelMetricsKeys[i]);
return this; // Flow coding
}
// Delete the metrics that match model and/or frame
ModelMetricsList delete() {
ModelMetricsList matches = fetch();
for (ModelMetrics mm : matches._model_metrics)
DKV.remove(mm._key);
return matches;
}
/** Return all the models matching the model&frame filters */
public Schema list(int version, ModelMetricsList m) {
return this.schema(version).fillFromImpl(m.fetch());
}
protected ModelMetricsListSchemaV3 schema(int version) {
switch (version) {
case 3: return new ModelMetricsListSchemaV3();
default: throw H2O.fail("Bad version for ModelMetrics schema: " + version);
}
}
} // class ModelMetricsList
/** Schema for a list of ModelMetricsBaseV3.
* This should be common across all versions of ModelMetrics schemas, so it lives here.
* TODO: move to water.api.schemas3
* */
public static final class ModelMetricsListSchemaV3 extends RequestSchemaV3<ModelMetricsList, ModelMetricsListSchemaV3> {
// Input fields
@API(help = "Key of Model of interest (optional)")
public KeyV3.ModelKeyV3<Model> model;
@API(help = "Key of Frame of interest (optional)")
public KeyV3.FrameKeyV3 frame;
@API(help = "Key of predictions frame, if predictions are requested (optional)", direction = API.Direction.INOUT)
public KeyV3.FrameKeyV3 predictions_frame;
@API(help = "Key for the frame containing per-observation deviances (optional)", direction = API.Direction.INOUT)
public KeyV3.FrameKeyV3 deviances_frame;
@API(help = "Compute reconstruction error (optional, only for Deep Learning AutoEncoder models)", json = false)
public boolean reconstruction_error;
@API(help = "Compute reconstruction error per feature (optional, only for Deep Learning AutoEncoder models)", json = false)
public boolean reconstruction_error_per_feature;
@API(help = "Extract Deep Features for given hidden layer (optional, only for Deep Learning models)", json = false)
public int deep_features_hidden_layer;
@API(help = "Extract Deep Features for given hidden layer by name (optional, only for Deep Water models)", json = false)
public String deep_features_hidden_layer_name;
@API(help = "Reconstruct original training frame (optional, only for GLRM models)", json = false)
public boolean reconstruct_train;
@API(help = "Project GLRM archetypes back into original feature space (optional, only for GLRM models)", json = false)
public boolean project_archetypes;
@API(help = "Reverse transformation applied during training to model output (optional, only for GLRM models)", json = false)
public boolean reverse_transform;
@API(help = "Return the leaf node assignment (optional, only for DRF/GBM models)", json = false)
public boolean leaf_node_assignment;
@API(help = "Retrieve all members for a given exemplar (optional, only for Aggregator models)", json = false)
public int exemplar_index;
@API(help = "Compute the deviances per row (optional, only for classification or regression models)", json = false)
public boolean deviances;
// Output fields
@API(help = "ModelMetrics", direction = API.Direction.OUTPUT)
public ModelMetricsBaseV3[] model_metrics;
@Override public ModelMetricsHandler.ModelMetricsList fillImpl(ModelMetricsList mml) {
// TODO: check for type!
mml._model = (this.model == null || this.model.key() == null ? null : this.model.key().get());
mml._frame = (this.frame == null || this.frame.key() == null ? null : this.frame.key().get());
mml._predictions_name = (null == this.predictions_frame || null == this.predictions_frame.key() ? null : this.predictions_frame.key().toString());
mml._reconstruction_error = this.reconstruction_error;
mml._reconstruction_error_per_feature = this.reconstruction_error_per_feature;
mml._deep_features_hidden_layer = this.deep_features_hidden_layer;
mml._deep_features_hidden_layer_name = this.deep_features_hidden_layer_name;
mml._reconstruct_train = this.reconstruct_train;
mml._project_archetypes = this.project_archetypes;
mml._reverse_transform = this.reverse_transform;
mml._leaf_node_assignment = this.leaf_node_assignment;
mml._exemplar_index = this.exemplar_index;
mml._deviances = this.deviances;
if (model_metrics != null) {
mml._model_metrics = new ModelMetrics[model_metrics.length];
for( int i=0; i<model_metrics.length; i++ )
mml._model_metrics[i++] = (ModelMetrics)model_metrics[i].createImpl();
}
return mml;
}
@Override public ModelMetricsListSchemaV3 fillFromImpl(ModelMetricsList mml) {
// TODO: this is failing in PojoUtils with an IllegalAccessException. Why? Different class loaders?
// PojoUtils.copyProperties(this, m, PojoUtils.FieldNaming.CONSISTENT);
// Shouldn't need to do this manually. . .
this.model = (mml._model == null ? null : new KeyV3.ModelKeyV3(mml._model._key));
this.frame = (mml._frame == null ? null : new KeyV3.FrameKeyV3(mml._frame._key));
this.predictions_frame = (mml._predictions_name == null ? null : new KeyV3.FrameKeyV3(Key.<Frame>make(mml._predictions_name)));
this.deviances_frame = (mml._deviances_name == null ? null : new KeyV3.FrameKeyV3(Key.<Frame>make(mml._deviances_name)));
this.reconstruction_error = mml._reconstruction_error;
this.reconstruction_error_per_feature = mml._reconstruction_error_per_feature;
this.deep_features_hidden_layer = mml._deep_features_hidden_layer;
this.deep_features_hidden_layer_name = mml._deep_features_hidden_layer_name;
this.reconstruct_train = mml._reconstruct_train;
this.project_archetypes = mml._project_archetypes;
this.reverse_transform = mml._reverse_transform;
this.leaf_node_assignment = mml._leaf_node_assignment;
this.exemplar_index = mml._exemplar_index;
this.deviances = mml._deviances;
if (null != mml._model_metrics) {
this.model_metrics = new ModelMetricsBaseV3[mml._model_metrics.length];
for( int i=0; i<model_metrics.length; i++ ) {
ModelMetrics mm = mml._model_metrics[i];
this.model_metrics[i] = (ModelMetricsBaseV3) SchemaServer.schema(3, mm.getClass()).fillFromImpl(mm);
}
} else {
this.model_metrics = new ModelMetricsBaseV3[0];
}
return this;
}
} // ModelMetricsListSchemaV3
// TODO: almost identical to ModelsHandler; refactor
public static ModelMetrics getFromDKV(Key key) {
if (null == key)
throw new IllegalArgumentException("Got null key.");
Value v = DKV.get(key);
if (null == v)
throw new IllegalArgumentException("Did not find key: " + key.toString());
Iced ice = v.get();
if (! (ice instanceof ModelMetrics))
throw new IllegalArgumentException("Expected a Model for key: " + key.toString() + "; got a: " + ice.getClass());
return (ModelMetrics)ice;
}
/** Return a single ModelMetrics. */
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelMetricsListSchemaV3 fetch(int version, ModelMetricsListSchemaV3 s) {
ModelMetricsList m = s.createAndFillImpl();
s.fillFromImpl(m.fetch());
return s;
}
/** Delete one or more ModelMetrics. */
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelMetricsListSchemaV3 delete(int version, ModelMetricsListSchemaV3 s) {
ModelMetricsList m = s.createAndFillImpl();
s.fillFromImpl(m.delete());
return s;
}
/**
* Score a frame with the given model and return just the metrics.
* <p>
* NOTE: ModelMetrics are now always being created by model.score. . .
*/
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelMetricsListSchemaV3 score(int version, ModelMetricsListSchemaV3 s) {
// parameters checking:
if (null == s.model) throw new H2OIllegalArgumentException("model", "predict", s.model);
if (null == DKV.get(s.model.name)) throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name);
if (null == s.frame) throw new H2OIllegalArgumentException("frame", "predict", s.frame);
if (null == DKV.get(s.frame.name)) throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name);
ModelMetricsList parms = s.createAndFillImpl();
parms._model.score(parms._frame, parms._predictions_name).remove(); // throw away predictions, keep metrics as a side-effect
ModelMetricsListSchemaV3 mm = this.fetch(version, s);
// TODO: for now only binary predictors write an MM object.
// For the others cons one up here to return the predictions frame.
if (null == mm)
mm = new ModelMetricsListSchemaV3();
if (null == mm.model_metrics || 0 == mm.model_metrics.length) {
Log.warn("Score() did not return a ModelMetrics for model: " + s.model + " on frame: " + s.frame);
}
return mm;
}
public static final class ModelMetricsMaker extends Iced {
public String _predictions_frame;
public String _actuals_frame;
public String[] _domain;
public DistributionFamily _distribution;
public ModelMetrics _model_metrics;
}
public static final class ModelMetricsMakerSchemaV3 extends SchemaV3<ModelMetricsMaker, ModelMetricsMakerSchemaV3> {
@API(help="Predictions Frame.", direction=API.Direction.INOUT)
public String predictions_frame;
@API(help="Actuals Frame.", direction=API.Direction.INOUT)
public String actuals_frame;
@API(help="Domain (for classification).", direction=API.Direction.INOUT)
public String[] domain;
@API(help="Distribution (for regression).", direction=API.Direction.INOUT, values = { "gaussian", "poisson", "gamma", "laplace" })
public DistributionFamily distribution;
@API(help="Model Metrics.", direction=API.Direction.OUTPUT)
public ModelMetricsBaseV3 model_metrics;
}
/**
* Make a model metrics object from actual and predicted values
*/
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelMetricsMakerSchemaV3 make(int version, ModelMetricsMakerSchemaV3 s) {
// parameters checking:
if (null == s.predictions_frame) throw new H2OIllegalArgumentException("predictions_frame", "make", s.predictions_frame);
Frame pred = DKV.getGet(s.predictions_frame);
if (null == pred) throw new H2OKeyNotFoundArgumentException("predictions_frame", "make", s.predictions_frame);
if (null == s.actuals_frame) throw new H2OIllegalArgumentException("actuals_frame", "make", s.actuals_frame);
Frame act = DKV.getGet(s.actuals_frame);
if (null == act) throw new H2OKeyNotFoundArgumentException("actuals_frame", "make", s.actuals_frame);
if (s.domain ==null) {
if (pred.numCols()!=1) {
throw new H2OIllegalArgumentException("predictions_frame", "make", "For regression problems (domain=null), the predictions_frame must have exactly 1 column.");
}
ModelMetricsRegression mm = ModelMetricsRegression.make(pred.anyVec(), act.anyVec(), s.distribution);
s.model_metrics = new ModelMetricsRegressionV3().fillFromImpl(mm);
} else if (s.domain.length==2) {
if (pred.numCols()!=1) {
throw new H2OIllegalArgumentException("predictions_frame", "make", "For domains with 2 class labels, the predictions_frame must have exactly one column containing the class-1 probabilities.");
}
ModelMetricsBinomial mm = ModelMetricsBinomial.make(pred.anyVec(), act.anyVec(), s.domain);
s.model_metrics = new ModelMetricsBinomialV3().fillFromImpl(mm);
} else if (s.domain.length>2){
if (pred.numCols()!=s.domain.length) {
throw new H2OIllegalArgumentException("predictions_frame", "make", "For domains with " + s.domain.length + " class labels, the predictions_frame must have exactly " + s.domain.length + " columns containing the class-probabilities.");
}
ModelMetricsMultinomial mm = ModelMetricsMultinomial.make(pred, act.anyVec(), s.domain);
s.model_metrics = new ModelMetricsMultinomialV3().fillFromImpl(mm);
} else {
throw H2O.unimpl();
}
return s;
}
/**
* Score a frame with the given model and return the metrics AND the prediction frame.
*/
@SuppressWarnings("unused") // called through reflection by RequestServer
public JobV3 predictAsync(int version, final ModelMetricsListSchemaV3 s) {
// parameters checking:
if (null == s.model) throw new H2OIllegalArgumentException("model", "predict", s.model);
if (null == DKV.get(s.model.name)) throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name);
if (null == s.frame) throw new H2OIllegalArgumentException("frame", "predict", s.frame);
if (null == DKV.get(s.frame.name)) throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name);
if (s.deviances || null != s.deviances_frame) throw new H2OIllegalArgumentException("deviances", "not supported for async", s.deviances_frame);
final ModelMetricsList parms = s.createAndFillImpl();
//predict2 does not return modelmetrics, so cannot handle deeplearning: reconstruction_error (anomaly) or GLRM: reconstruct and archetypes
//predict2 can handle deeplearning: deepfeatures and predict
if (s.deep_features_hidden_layer > 0 || s.deep_features_hidden_layer_name != null) {
if (null == parms._predictions_name)
parms._predictions_name = "deep_features" + Key.make().toString().substring(0, 5) + "_" +
parms._model._key.toString() + "_on_" + parms._frame._key.toString();
} else if (null == parms._predictions_name) {
if (parms._exemplar_index >= 0) {
parms._predictions_name = "members_" + parms._model._key.toString() + "_for_exemplar_" + parms._exemplar_index;
} else {
parms._predictions_name = "predictions" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
}
}
final Job<Frame> j = new Job(Key.make(parms._predictions_name), Frame.class.getName(), "prediction");
H2O.H2OCountedCompleter work = new H2O.H2OCountedCompleter() {
@Override
public void compute2() {
if (s.deep_features_hidden_layer < 0 && s.deep_features_hidden_layer_name == null) {
parms._model.score(parms._frame, parms._predictions_name, j, true);
}
else if (s.deep_features_hidden_layer_name != null){
Frame predictions = null;
try {
predictions = ((Model.DeepFeatures) parms._model).scoreDeepFeatures(parms._frame, s.deep_features_hidden_layer_name, j);
} catch(IllegalArgumentException e) {
Log.warn(e.getMessage());
throw e;
}
if (predictions!=null) {
predictions = new Frame(Key.<Frame>make(parms._predictions_name), predictions.names(), predictions.vecs());
DKV.put(predictions._key, predictions);
}
}
else {
Frame predictions = ((Model.DeepFeatures) parms._model).scoreDeepFeatures(parms._frame, s.deep_features_hidden_layer, j);
predictions = new Frame(Key.<Frame>make(parms._predictions_name), predictions.names(), predictions.vecs());
DKV.put(predictions._key, predictions);
}
tryComplete();
}
};
j.start(work, parms._frame.anyVec().nChunks());
return new JobV3().fillFromImpl(j);
}
/**
* Score a frame with the given model and return the metrics AND the prediction frame.
*/
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelMetricsListSchemaV3 predict(int version, ModelMetricsListSchemaV3 s) {
// parameters checking:
if (s.model == null) throw new H2OIllegalArgumentException("model", "predict", null);
if (DKV.get(s.model.name) == null) throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name);
// Aggregator doesn't need a Frame to 'predict'
if (s.exemplar_index < 0) {
if (s.frame == null) throw new H2OIllegalArgumentException("frame", "predict", null);
if (DKV.get(s.frame.name) == null) throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name);
}
ModelMetricsList parms = s.createAndFillImpl();
Frame predictions;
Frame deviances = null;
if (!s.reconstruction_error && !s.reconstruction_error_per_feature && s.deep_features_hidden_layer < 0 &&
!s.project_archetypes && !s.reconstruct_train && !s.leaf_node_assignment && s.exemplar_index < 0) {
if (null == parms._predictions_name)
parms._predictions_name = "predictions" + Key.make().toString().substring(0,5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
predictions = parms._model.score(parms._frame, parms._predictions_name);
if (s.deviances) {
if (!parms._model.isSupervised())
throw new H2OIllegalArgumentException("Deviances can only be computed for supervised models.");
if (null == parms._deviances_name)
parms._deviances_name = "deviances" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
deviances = parms._model.computeDeviances(parms._frame, predictions, parms._deviances_name);
}
} else {
if (s.deviances)
throw new H2OIllegalArgumentException("Cannot compute deviances in combination with other special predictions.");
if (Model.DeepFeatures.class.isAssignableFrom(parms._model.getClass())) {
if (s.reconstruction_error || s.reconstruction_error_per_feature) {
if (s.deep_features_hidden_layer >= 0)
throw new H2OIllegalArgumentException("Can only compute either reconstruction error OR deep features.", "");
if (null == parms._predictions_name)
parms._predictions_name = "reconstruction_error" + Key.make().toString().substring(0,5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
predictions = ((Model.DeepFeatures) parms._model).scoreAutoEncoder(parms._frame, Key.make(parms._predictions_name), parms._reconstruction_error_per_feature);
} else {
if (s.deep_features_hidden_layer < 0)
throw new H2OIllegalArgumentException("Deep features hidden layer index must be >= 0.", "");
if (null == parms._predictions_name)
parms._predictions_name = "deep_features" + Key.make().toString().substring(0,5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
predictions = ((Model.DeepFeatures) parms._model).scoreDeepFeatures(parms._frame, s.deep_features_hidden_layer);
}
predictions = new Frame(Key.<Frame>make(parms._predictions_name), predictions.names(), predictions.vecs());
DKV.put(predictions._key, predictions);
} else if(Model.GLRMArchetypes.class.isAssignableFrom(parms._model.getClass())) {
if(s.project_archetypes) {
if (parms._predictions_name == null)
parms._predictions_name = "reconstructed_archetypes_" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_of_" + parms._frame._key.toString();
predictions = ((Model.GLRMArchetypes) parms._model).scoreArchetypes(parms._frame, Key.<Frame>make(parms._predictions_name), s.reverse_transform);
} else {
assert s.reconstruct_train;
if (parms._predictions_name == null)
parms._predictions_name = "reconstruction_" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_of_" + parms._frame._key.toString();
predictions = ((Model.GLRMArchetypes) parms._model).scoreReconstruction(parms._frame, Key.<Frame>make(parms._predictions_name), s.reverse_transform);
}
} else if(s.leaf_node_assignment) {
assert(Model.LeafNodeAssignment.class.isAssignableFrom(parms._model.getClass()));
if (null == parms._predictions_name)
parms._predictions_name = "leaf_node_assignment" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
predictions = ((Model.LeafNodeAssignment) parms._model).scoreLeafNodeAssignment(parms._frame, Key.<Frame>make(parms._predictions_name));
} else if(s.exemplar_index >= 0) {
assert(Model.ExemplarMembers.class.isAssignableFrom(parms._model.getClass()));
if (null == parms._predictions_name)
parms._predictions_name = "members_" + parms._model._key.toString() + "_for_exemplar_" + parms._exemplar_index;
predictions = ((Model.ExemplarMembers) parms._model).scoreExemplarMembers(Key.<Frame>make(parms._predictions_name), parms._exemplar_index);
}
else throw new H2OIllegalArgumentException("Requires a Deep Learning, GLRM, DRF or GBM model.", "Model must implement specific methods.");
}
ModelMetricsListSchemaV3 mm = this.fetch(version, s);
// TODO: for now only binary predictors write an MM object.
// For the others cons one up here to return the predictions frame.
if (null == mm)
mm = new ModelMetricsListSchemaV3();
mm.predictions_frame = new KeyV3.FrameKeyV3(predictions._key);
if (parms._leaf_node_assignment) //don't show metrics in leaf node assignments are made
mm.model_metrics = null;
if (deviances !=null)
mm.deviances_frame = new KeyV3.FrameKeyV3(deviances._key);
if (null == mm.model_metrics || 0 == mm.model_metrics.length) {
// There was no response in the test set -> cannot make a model_metrics object
} else {
mm.model_metrics[0].predictions = new FrameV3(predictions, 0, 100); // TODO: Should call schema(version)
}
return mm;
}
}