package hex;
import water.*;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OKeyNotFoundArgumentException;
import water.fvec.Frame;
import water.util.IcedHashMap;
import water.util.Log;
import water.util.PojoUtils;
import water.util.TwoDimTable;
import java.lang.reflect.Method;
import java.util.*;
/** Container to hold the metric for a model as scored on a specific frame.
*
* The MetricBuilder class is used in a hot inner-loop of a Big Data pass, and
* when given a class-distribution, can be used to compute CM's, and AUC's "on
* the fly" during ModelBuilding - or after-the-fact with a Model and a new
* Frame to be scored.
*/
public class ModelMetrics extends Keyed<ModelMetrics> {
public String _description;
final Key _modelKey;
final Key _frameKey;
final ModelCategory _model_category;
final long _model_checksum;
long _frame_checksum; // when constant column is dropped, frame checksum changed. Need re-assign for GLRM.
public final long _scoring_time;
// Cached fields - cached them when needed
private transient Model _model;
private transient Frame _frame;
public final double _MSE; // Mean Squared Error (Every model is assumed to have this, otherwise leave at NaN)
public final long _nobs;
public ModelMetrics(Model model, Frame frame, long nobs, double MSE, String desc) {
super(buildKey(model, frame));
_description = desc;
_modelKey = model == null ? null : model._key;
_frameKey = frame == null ? null : frame._key;
_model_category = model == null ? null : model._output.getModelCategory();
_model_checksum = model == null ? 0 : model.checksum();
try { _frame_checksum = frame.checksum(); } catch (Throwable t) { }
_MSE = MSE;
_nobs = nobs;
_scoring_time = System.currentTimeMillis();
}
private void setModelAndFrameFields(Model model, Frame frame) {
PojoUtils.setField(this, "_modelKey", model == null ? null : model._key);
PojoUtils.setField(this, "_frameKey", frame == null ? null : frame._key);
PojoUtils.setField(this, "_model_category", model == null ? null : model._output.getModelCategory());
PojoUtils.setField(this, "_model_checksum", model == null ? 0 : model.checksum());
try {
PojoUtils.setField(this, "_frame_checksum", frame.checksum());
}
catch (Throwable t) { }
}
/**
* Utility used by code which creates metrics on a different frame and model than
* the ones that we want the metrics object to be accessible for. An example is
* StackedEnsembleModel, which computes the metrics with a metalearner model.
* @param model
* @param frame
* @return
*/
public ModelMetrics deepCloneWithDifferentModelAndFrame(Model model, Frame frame) {
ModelMetrics m = this.clone();
m._key = buildKey(model, frame);
m.setModelAndFrameFields(model, frame);
return m;
}
public long residual_degrees_of_freedom(){throw new UnsupportedOperationException("residual degrees of freedom is not supported for this metric class");}
@Override public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("Model Metrics Type: " + this.getClass().getSimpleName().substring(12) + "\n");
sb.append(" Description: " + (_description == null ? "N/A" : _description) + "\n");
sb.append(" model id: " + _modelKey + "\n");
sb.append(" frame id: " + _frameKey + "\n");
sb.append(" MSE: " + (float)_MSE + "\n");
sb.append(" RMSE: " + (float)rmse() + "\n");
return sb.toString();
}
public Model model() { return _model==null ? (_model=DKV.getGet(_modelKey)) : _model; }
public Frame frame() { return _frame==null ? (_frame=DKV.getGet(_frameKey)) : _frame; }
public double mse() { return _MSE; }
public double rmse() { return Math.sqrt(_MSE);}
public ConfusionMatrix cm() { return null; }
public float[] hr() { return null; }
public AUC2 auc_obj() { return null; }
static public double getMetricFromModel(Key<Model> key, String criterion) {
Model model = DKV.getGet(key);
if (null == model) throw new H2OIllegalArgumentException("Cannot find model " + key);
ModelMetrics mm =
model._output._cross_validation_metrics != null ?
model._output._cross_validation_metrics :
model._output._validation_metrics != null ?
model._output._validation_metrics :
model._output._training_metrics;
return getMetricFromModelMetric(mm, criterion);
}
static public double getMetricFromModelMetric(ModelMetrics mm, String criterion) {
if (null == criterion || criterion.equals("")) throw new H2OIllegalArgumentException("Need a valid criterion, but got '" + criterion + "'.");
Method method = null;
ConfusionMatrix cm = mm.cm();
try {
method = mm.getClass().getMethod(criterion.toLowerCase());
}
catch (Exception e) {
// fall through
}
if (null == method && null != cm) {
try {
method = cm.getClass().getMethod(criterion.toLowerCase());
}
catch (Exception e) {
// fall through
}
}
if (null == method)
throw new H2OIllegalArgumentException("Failed to find ModelMetrics for criterion: " + criterion);
double c;
try {
c = (double) method.invoke(mm);
} catch(Exception fallthru) {
try {
c = (double)method.invoke(cm);
} catch (Exception e) {
throw new H2OIllegalArgumentException(
"Failed to get metric: " + criterion + " from ModelMetrics object: " + mm,
"Failed to get metric: " + criterion + " from ModelMetrics object: " + mm + ", criterion: " + method + ", exception: " + e
);
}
}
return c;
}
private static class MetricsComparator implements Comparator<Key<Model>> {
String _sort_by = null;
boolean decreasing = false;
public MetricsComparator(String sort_by, boolean decreasing) {
this._sort_by = sort_by;
this.decreasing = decreasing;
}
public int compare(Key<Model> key1, Key<Model> key2) {
double c1 = getMetricFromModel(key1, _sort_by);
double c2 = getMetricFromModel(key2, _sort_by);
return decreasing ? Double.compare(c2, c1) : Double.compare(c1, c2);
}
}
private static class MetricsComparatorForFrame implements Comparator<Key<Model>> {
String _sort_by = null;
boolean decreasing = false;
Frame frame = null;
IcedHashMap<Key<Model>, ModelMetrics> cachedMetrics = new IcedHashMap<>();
public MetricsComparatorForFrame(Frame frame, String sort_by, boolean decreasing) {
this._sort_by = sort_by;
this.decreasing = decreasing;
this.frame = frame;
}
private final ModelMetrics findMetricsForModel(Key<Model> modelKey) {
ModelMetrics mm = cachedMetrics.get(modelKey);
if (null != mm) {
return mm;
}
Model m = modelKey.get();
if (null == m) {
Log.warn("Tried to compare metrics for a model which was not found in the DKV: " + modelKey);
throw new H2OKeyNotFoundArgumentException(modelKey.toString());
}
Model model = modelKey.get();
mm = ModelMetrics.getFromDKV(model, this.frame);
if (null == mm) {
// call score()
Frame preds = model.score(this.frame);
mm = ModelMetrics.getFromDKV(model, this.frame);
if (null == mm) {
Log.warn("Tried to compare metrics for a model/frame combination which was not found in the DKV: (" + modelKey + ", " + frame._key.toString() + ")");
throw new H2OKeyNotFoundArgumentException(modelKey.toString());
}
}
cachedMetrics.put(modelKey, mm);
return mm;
}
public int compare(Key<Model> key1, Key<Model> key2) {
ModelMetrics mm1 = findMetricsForModel(key1);
ModelMetrics mm2 = findMetricsForModel(key2);
double c1 = getMetricFromModelMetric(mm1, _sort_by);
double c2 = getMetricFromModelMetric(mm2, _sort_by);
return decreasing ? Double.compare(c2, c1) : Double.compare(c1, c2);
}
}
//
public static Set<String> getAllowedMetrics(Key<Model> key) {
Set<String> res = new HashSet<>();
Model model = DKV.getGet(key);
if (null == model) throw new H2OIllegalArgumentException("Cannot find model " + key);
ModelMetrics m =
model._output._cross_validation_metrics != null ?
model._output._cross_validation_metrics :
model._output._validation_metrics != null ?
model._output._validation_metrics :
model._output._training_metrics;
ConfusionMatrix cm = m.cm();
Set<String> excluded = new HashSet<>();
excluded.add("makeSchema");
excluded.add("hr");
excluded.add("cm");
excluded.add("auc_obj");
excluded.add("remove");
excluded.add("nobs");
if (m!=null) {
for (Method meth : m.getClass().getMethods()) {
if (excluded.contains(meth.getName())) continue;
try {
double c = (double) meth.invoke(m);
res.add(meth.getName().toLowerCase());
} catch (Exception e) {
// fall through
}
}
}
if (cm!=null) {
for (Method meth : cm.getClass().getMethods()) {
if (excluded.contains(meth.getName())) continue;
try {
double c = (double) meth.invoke(cm);
res.add(meth.getName().toLowerCase());
} catch (Exception e) {
// fall through
}
}
}
return res;
}
/**
* Return a new list of models sorted on their xval, validation or training metrics, by the named criterion.
* The criterion (metric) can be such things as as "auc", mse", "hr", "err", "err_count",
* "accuracy", "specificity", "recall", "precision", "mcc", "max_per_class_error", "f1", "f2", "f0point5". . .
* @param sort_by criterion by which we should sort
* @param decreasing sort by decreasing metrics or not
* @param modelKeys keys of models to sortm
* @return keys of the models, sorted by the criterion
*/
public static List<Key<Model>> sortModelsByMetric(String sort_by, boolean decreasing, List<Key<Model>>modelKeys) {
List<Key<Model>> sorted = new ArrayList<>();
sorted.addAll(modelKeys);
Comparator<Key<Model>> c = new MetricsComparator(sort_by, decreasing);
Collections.sort(sorted, c);
return sorted;
}
/**
* Return a new list of models sorted on metrics computed on the given frame, by the named criterion.
* The criterion (metric) can be such things as as "auc", mse", "hr", "err", "err_count",
* "accuracy", "specificity", "recall", "precision", "mcc", "max_per_class_error", "f1", "f2", "f0point5". . .
* @param frame frame on which to compute the metrics; looked up in the DKV first to see if it was previously computed
* @param sort_by criterion by which we should sort
* @param decreasing sort by decreasing metrics or not
* @param modelKeys keys of models to sortm
* @return keys of the models, sorted by the criterion
*/
public static List<Key<Model>> sortModelsByMetric(Frame frame, String sort_by, boolean decreasing, List<Key<Model>>modelKeys) {
List<Key<Model>> sorted = new ArrayList<>();
sorted.addAll(modelKeys);
Comparator<Key<Model>> c = new MetricsComparatorForFrame(frame, sort_by, decreasing);
Collections.sort(sorted, c);
return sorted;
}
public static TwoDimTable calcVarImp(VarImp vi) {
if (vi == null) return null;
double[] dbl_rel_imp = new double[vi._varimp.length];
for (int i=0; i<dbl_rel_imp.length; ++i) {
dbl_rel_imp[i] = vi._varimp[i];
}
return calcVarImp(dbl_rel_imp, vi._names);
}
public static TwoDimTable calcVarImp(final float[] rel_imp, String[] coef_names) {
double[] dbl_rel_imp = new double[rel_imp.length];
for (int i=0; i<dbl_rel_imp.length; ++i) {
dbl_rel_imp[i] = rel_imp[i];
}
return calcVarImp(dbl_rel_imp, coef_names);
}
public static TwoDimTable calcVarImp(final double[] rel_imp, String[] coef_names) {
return calcVarImp(rel_imp, coef_names, "Variable Importances", new String[]{"Relative Importance", "Scaled Importance", "Percentage"});
}
public static TwoDimTable calcVarImp(final double[] rel_imp, String[] coef_names, String table_header, String[] col_headers) {
if(rel_imp == null) return null;
if(coef_names == null) {
coef_names = new String[rel_imp.length];
for(int i = 0; i < coef_names.length; i++)
coef_names[i] = "C" + String.valueOf(i+1);
}
// Sort in descending order by relative importance
Integer[] sorted_idx = new Integer[rel_imp.length];
for(int i = 0; i < sorted_idx.length; i++) sorted_idx[i] = i;
Arrays.sort(sorted_idx, new Comparator<Integer>() {
public int compare(Integer idx1, Integer idx2) {
return Double.compare(-rel_imp[idx1], -rel_imp[idx2]);
}
});
double total = 0;
double max = rel_imp[sorted_idx[0]];
String[] sorted_names = new String[rel_imp.length];
double[][] sorted_imp = new double[rel_imp.length][3];
// First pass to sum up relative importance measures
int j = 0;
for(int i : sorted_idx) {
total += rel_imp[i];
sorted_names[j] = coef_names[i];
sorted_imp[j][0] = rel_imp[i]; // Relative importance
sorted_imp[j++][1] = rel_imp[i] / max; // Scaled importance
}
// Second pass to calculate percentages
j = 0;
for(int i : sorted_idx)
sorted_imp[j++][2] = rel_imp[i] / total; // Percentage
String [] col_types = new String[3];
String [] col_formats = new String[3];
Arrays.fill(col_types, "double");
Arrays.fill(col_formats, "%5f");
return new TwoDimTable(table_header, null, sorted_names, col_headers, col_types, col_formats, "Variable",
new String[rel_imp.length][], sorted_imp);
}
private static Key<ModelMetrics> buildKey(Key model_key, long model_checksum, Key frame_key, long frame_checksum) {
return Key.make("modelmetrics_" + model_key + "@" + model_checksum + "_on_" + frame_key + "@" + frame_checksum);
}
public static Key<ModelMetrics> buildKey(Model model, Frame frame) {
return frame==null || model == null ? null : buildKey(model._key, model.checksum(), frame._key, frame.checksum());
}
public boolean isForModel(Model m) { return _model_checksum == m.checksum(); }
public boolean isForFrame(Frame f) { return _frame_checksum == f.checksum(); }
public static ModelMetrics getFromDKV(Model model, Frame frame) {
return DKV.getGet(buildKey(model, frame));
}
@Override protected long checksum_impl() { return _frame_checksum * 13 + _model_checksum * 17; }
/** Class used to compute AUCs, CMs & HRs "on the fly" during other passes
* over Big Data. This class is intended to be embedded in other MRTask
* objects. The {@code perRow} method is called once-per-scored-row, and
* the {@code reduce} method called once per MRTask.reduce, and the {@code
* <init>} called once per MRTask.map.
*/
public static abstract class MetricBuilder<T extends MetricBuilder<T>> extends Iced {
transient public double[] _work;
public double _sumsqe; // Sum-squared-error
public long _count;
public double _wcount;
public double _wY; // (Weighted) sum of the response
public double _wYY; // (Weighted) sum of the squared response
public double weightedSigma() {
// double sampleCorrection = _count/(_count-1); //sample variance -> depends on the number of ACTUAL ROWS (not the weighted count)
double sampleCorrection = 1; //this will make the result (and R^2) invariant to globally scaling the weights
return _count <= 1 ? 0 : Math.sqrt(sampleCorrection*(_wYY/_wcount - (_wY*_wY)/(_wcount*_wcount)));
}
abstract public double[] perRow(double ds[], float yact[], Model m);
public double[] perRow(double ds[], float yact[],double weight, double offset, Model m) {
assert(weight==1 && offset == 0);
return perRow(ds, yact, m);
}
public void reduce( T mb ) {
_sumsqe += mb._sumsqe;
_count += mb._count;
_wcount += mb._wcount;
_wY += mb._wY;
_wYY += mb._wYY;
}
public void postGlobal() {}
/**
* Having computed a MetricBuilder, this method fills in a ModelMetrics
* @param m Model
* @param f Scored Frame
* @param adaptedFrame Adapted Frame
*@param preds Predictions of m on f (optional) @return Filled Model Metrics object
*/
public abstract ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds);
}
}