package hex;
import hex.genmodel.GenModel;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.MathUtils;
import water.util.TwoDimTable;
import java.util.Arrays;
public class ModelMetricsMultinomial extends ModelMetricsSupervised {
public final float[] _hit_ratios; // Hit ratios
public final ConfusionMatrix _cm;
public final double _logloss;
public final double _mean_per_class_error;
public ModelMetricsMultinomial(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma, ConfusionMatrix cm, float[] hr, double logloss) {
super(model, frame, nobs, mse, domain, sigma);
_cm = cm;
_hit_ratios = hr;
_logloss = logloss;
_mean_per_class_error = cm==null || cm.tooLarge() ? Double.NaN : cm.mean_per_class_error();
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(super.toString());
sb.append(" logloss: " + (float)_logloss + "\n");
sb.append(" mean_per_class_error: " + (float)_mean_per_class_error + "\n");
sb.append(" hit ratios: " + Arrays.toString(_hit_ratios) + "\n");
if (cm() != null) {
if (cm().nclasses() <= 20)
sb.append(" CM: " + cm().toASCII());
else
sb.append(" CM: too large to print.\n");
}
return sb.toString();
}
public double logloss() { return _logloss; }
public double mean_per_class_error() { return _mean_per_class_error; }
@Override public ConfusionMatrix cm() { return _cm; }
@Override public float[] hr() { return _hit_ratios; }
public static ModelMetricsMultinomial getFromDKV(Model model, Frame frame) {
ModelMetrics mm = ModelMetrics.getFromDKV(model, frame);
if (! (mm instanceof ModelMetricsMultinomial))
throw new H2OIllegalArgumentException("Expected to find a Multinomial ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(),
"Expected to find a ModelMetricsMultinomial for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + mm.getClass());
return (ModelMetricsMultinomial) mm;
}
public static void updateHits(double w, int iact, double[] ds, double[] hits) {
updateHits(w, iact,ds,hits,null);
}
public static void updateHits(double w, int iact, double[] ds, double[] hits, double[] priorClassDistribution) {
if (iact == ds[0]) { hits[0]++; return; }
double before = ArrayUtils.sum(hits);
// Use getPrediction logic to see which top K labels we would have predicted
// Pick largest prob, assign label, then set prob to 0, find next-best label, etc.
double[] ds_copy = Arrays.copyOf(ds, ds.length); //don't modify original ds!
ds_copy[1+(int)ds[0]] = 0;
for (int k=1; k<hits.length; ++k) {
final int pred_labels = GenModel.getPrediction(ds_copy, priorClassDistribution, ds, 0.5 /*ignored*/); //use tie-breaking of getPrediction
ds_copy[1+pred_labels] = 0; //next iteration, we'll find the next-best label
if (pred_labels==iact) {
hits[k]+=w;
break;
}
}
// must find at least one hit if K == n_classes
if (hits.length == ds.length-1) {
double after = ArrayUtils.sum(hits);
if (after == before) hits[hits.length-1]+=w; //assume worst case
}
}
public static TwoDimTable getHitRatioTable(float[] hits) {
String tableHeader = "Top-" + hits.length + " Hit Ratios";
String[] rowHeaders = new String[hits.length];
for (int k=0; k<hits.length; ++k)
rowHeaders[k] = Integer.toString(k+1);
String[] colHeaders = new String[]{"Hit Ratio"};
String[] colTypes = new String[]{"float"};
String[] colFormats = new String[]{"%f"};
String colHeaderForRowHeaders = "K";
TwoDimTable table = new TwoDimTable(tableHeader, null/*tableDescription*/, rowHeaders, colHeaders, colTypes, colFormats, colHeaderForRowHeaders);
for (int k=0; k<hits.length; ++k)
table.set(k, 0, hits[k]);
return table;
}
/**
* Build a Multinomial ModelMetrics object from per-class probabilities (in Frame preds - no labels!), from actual labels, and a given domain for all possible labels (maybe more than what's in labels)
* @param perClassProbs Frame containing predicted per-class probabilities (and no predicted labels)
* @param actualLabels A Vec containing the actual labels (can be for fewer labels than what's in domain, since the predictions can be for a small subset of the data)
* @return ModelMetrics object
*/
static public ModelMetricsMultinomial make(Frame perClassProbs, Vec actualLabels) {
String[] names = perClassProbs.names();
String[] label = actualLabels.domain();
String[] union = ArrayUtils.union(names, label, true);
if (union.length == names.length + label.length)
throw new IllegalArgumentException("Column names of per-class-probabilities and categorical domain of actual labels have no common values!");
return make(perClassProbs, actualLabels, perClassProbs.names());
}
/**
* Build a Multinomial ModelMetrics object from per-class probabilities (in Frame preds - no labels!), from actual labels, and a given domain for all possible labels (maybe more than what's in labels)
* @param perClassProbs Frame containing predicted per-class probabilities (and no predicted labels)
* @param actualLabels A Vec containing the actual labels (can be for fewer labels than what's in domain, since the predictions can be for a small subset of the data)
* @param domain Ordered list of factor levels for which the probabilities are given (perClassProbs[i] are the per-observation probabilities for belonging to class domain[i])
* @return ModelMetrics object
*/
static public ModelMetricsMultinomial make(Frame perClassProbs, Vec actualLabels, String[] domain) {
Scope.enter();
Vec _labels = actualLabels.toCategoricalVec();
if (_labels == null || perClassProbs == null)
throw new IllegalArgumentException("Missing actualLabels or predictedProbs for multinomial metrics!");
if (_labels.length() != perClassProbs.numRows())
throw new IllegalArgumentException("Both arguments must have the same length for multinomial metrics (" + _labels.length() + "!=" + perClassProbs.numRows() + ")!");
for (Vec p : perClassProbs.vecs()) {
if (!p.isNumeric())
throw new IllegalArgumentException("Predicted probabilities must be numeric per-class probabilities for multinomial metrics.");
if (p.min() < 0 || p.max() > 1)
throw new IllegalArgumentException("Predicted probabilities must be between 0 and 1 for multinomial metrics.");
}
int nclasses = perClassProbs.numCols();
if (domain.length!=nclasses)
throw new IllegalArgumentException("Given domain has " + domain.length + " classes, but predictions have " + nclasses + " columns (per-class probabilities) for multinomial metrics.");
_labels = _labels.adaptTo(domain);
Frame predsLabel = new Frame(perClassProbs);
predsLabel.add("labels", _labels);
MetricBuilderMultinomial mb = new MultinomialMetrics((_labels.domain())).doAll(predsLabel)._mb;
_labels.remove();
ModelMetricsMultinomial mm = (ModelMetricsMultinomial)mb.makeModelMetrics(null, predsLabel, null, null);
mm._description = "Computed on user-given predictions and labels.";
Scope.exit();
return mm;
}
// helper to build a ModelMetricsMultinomial for a N-class problem from a Frame that contains N per-class probability columns, and the actual label as the (N+1)-th column
private static class MultinomialMetrics extends MRTask<MultinomialMetrics> {
public MultinomialMetrics(String[] domain) { this.domain = domain; }
String[] domain;
private MetricBuilderMultinomial _mb;
@Override public void map(Chunk[] chks) {
_mb = new MetricBuilderMultinomial(domain.length, domain);
Chunk actuals = chks[chks.length-1];
double [] ds = new double[chks.length];
for (int i=0;i<chks[0]._len;++i) {
for (int c=1;c<chks.length;++c)
ds[c] = chks[c-1].atd(i); //per-class probs - user-given
ds[0] = GenModel.getPrediction(ds, null, ds, 0.5 /*ignored*/);
_mb.perRow(ds, new float[]{actuals.at8(i)}, null);
}
}
@Override public void reduce(MultinomialMetrics mrt) { _mb.reduce(mrt._mb); }
}
public static class MetricBuilderMultinomial<T extends MetricBuilderMultinomial<T>> extends MetricBuilderSupervised<T> {
double[/*nclasses*/][/*nclasses*/] _cm;
double[/*K*/] _hits; // the number of hits for hitratio, length: K
int _K; // TODO: Let user set K
double _logloss;
public MetricBuilderMultinomial( int nclasses, String[] domain ) {
super(nclasses,domain);
_cm = domain.length > ConfusionMatrix.MAX_CM_CLASSES ? null : new double[domain.length][domain.length];
_K = Math.min(10,_nclasses);
_hits = new double[_K];
}
public transient double [] _priorDistribution;
// Passed a float[] sized nclasses+1; ds[0] must be a prediction. ds[1...nclasses-1] must be a class
// distribution;
@Override public double[] perRow(double ds[], float[] yact, Model m) { return perRow(ds, yact, 1, 0, m); }
@Override public double[] perRow(double ds[], float[] yact, double w, double o, Model m) {
if (_cm == null) return ds;
if( Float .isNaN(yact[0]) ) return ds; // No errors if actual is missing
if(ArrayUtils.hasNaNs(ds)) return ds;
if(w == 0 || Double.isNaN(w)) return ds;
final int iact = (int)yact[0];
_count++;
_wcount += w;
_wY += w*iact;
_wYY += w*iact*iact;
// Compute error
double err = iact+1 < ds.length ? 1-ds[iact+1] : 1; // Error: distance from predicting ycls as 1.0
_sumsqe += w*err*err; // Squared error
assert !Double.isNaN(_sumsqe);
// Plain Olde Confusion Matrix
_cm[iact][(int)ds[0]]++; // actual v. predicted
// Compute hit ratio
if( _K > 0 && iact < ds.length-1)
updateHits(w,iact,ds,_hits,m != null?m._output._priorClassDist:_priorDistribution);
// Compute log loss
_logloss += w*MathUtils.logloss(err);
return ds; // Flow coding
}
@Override public void reduce( T mb ) {
if (_cm == null) return;
super.reduce(mb);
assert mb._K == _K;
ArrayUtils.add(_cm, mb._cm);
_hits = ArrayUtils.add(_hits, mb._hits);
_logloss += mb._logloss;
}
@Override public ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) {
double mse = Double.NaN;
double logloss = Double.NaN;
float[] hr = new float[_K];
ConfusionMatrix cm = new ConfusionMatrix(_cm, _domain);
double sigma = weightedSigma();
if (_wcount > 0) {
if (_hits != null) {
for (int i = 0; i < hr.length; i++) hr[i] = (float) (_hits[i] / _wcount);
for (int i = 1; i < hr.length; i++) hr[i] += hr[i - 1];
}
mse = _sumsqe / _wcount;
logloss = _logloss / _wcount;
}
ModelMetricsMultinomial mm = new ModelMetricsMultinomial(m, f, _count, mse, _domain, sigma, cm, hr, logloss);
if (m!=null) m.addModelMetrics(mm);
return mm;
}
}
}