package water.api; import water.util.Log; import static water.util.ModelUtils.getPredictions; import water.Func; import water.MRTask2; import water.UKV; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.Vec; import water.util.Utils; import java.util.Arrays; import java.util.Random; public class HitRatio extends Func { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. public static final String DOC_GET = "Hit Ratio"; @API(help = "", required = true, filter = Default.class, json=true) public Frame actual; @API(help="Column of the actual results (will display vertically)", required=true, filter=actualVecSelect.class, json=true) public Vec vactual; class actualVecSelect extends VecClassSelect { actualVecSelect() { super("actual"); } } @API(help = "", required = true, filter = Default.class, json=true) public Frame predict; @API(help = "Max. number of labels (K) to use for hit ratio computation", required = false, filter = Default.class, json = true) private int max_k = 10; public void set_max_k(int k) { max_k = k; } @API(help = "Random number seed for breaking ties between equal probabilities", required = false, filter = Default.class, json = true) private long seed = new Random().nextLong(); @API(help="domain of the actual response") private String [] actual_domain; @API(help="Hit ratios for k=1...K") private float[] hit_ratios; // public float[] hit_ratios() { return hit_ratios; } public HitRatio() {} @Override protected void init() throws IllegalArgumentException { // Input handling if( actual==null || predict==null ) throw new IllegalArgumentException("Missing actual or predict!"); if( vactual==null ) throw new IllegalArgumentException("Missing vactual!"); if (vactual.length() != predict.anyVec().length()) throw new IllegalArgumentException("Both arguments must have the same length!"); if (!vactual.isInt()) throw new IllegalArgumentException("Actual column must be integer class labels!"); } @Override protected void execImpl() { Vec va = null; try { va = vactual.toEnum(); // always returns TransfVec actual_domain = va._domain; if (max_k > predict.numCols()-1) { Log.warn("Reducing Hitratio Top-K value to maximum value allowed: " + String.format("%,d", predict.numCols() - 1)); max_k = predict.numCols() - 1; } final Frame actual_predict = new Frame(predict.names().clone(), predict.vecs().clone()); actual_predict.replace(0, va); // place actual labels in first column hit_ratios = new HitRatioTask(max_k, seed).doAll(actual_predict).hit_ratios(); } finally { // Delete adaptation vectors if (va!=null) UKV.remove(va._key); } } @Override public boolean toHTML( StringBuilder sb ) { if (hit_ratios==null) return false; sb.append("<div>"); DocGen.HTML.section(sb, "Hit Ratio for Multi-Class Classification"); DocGen.HTML.paragraph(sb, "(Frequency of actual class label to be among the top-K predicted class labels)"); DocGen.HTML.arrayHead(sb); sb.append("<th>K</th>"); sb.append("<th>Hit Ratio</th>"); for (int k = 1; k<=max_k; ++k) sb.append("<tr><td>" + k + "</td><td>" + String.format("%.3f", hit_ratios[k-1]*100.) + "%</td></tr>"); DocGen.HTML.arrayTail(sb); return true; } public void toASCII( StringBuilder sb ) { if (hit_ratios==null) return; sb.append("K Hit-ratio\n"); for (int k = 1; k<=max_k; ++k) sb.append(k + " " + String.format("%.3f", hit_ratios[k-1]*100.) + "%\n"); } /** * Update hit counts for given set of actual label and predicted labels * This is to be called for each predicted row * @param hits Array of length K, counting the number of hits (entries will be incremented) * @param actual_label 1 actual label * @param pred_labels K predicted labels */ static void updateHits(long[] hits, int actual_label, int[] pred_labels) { assert(hits != null); for (long h : hits) assert(h >= 0); assert(pred_labels != null); assert(actual_label >= 0); assert(hits.length == pred_labels.length); //find the first occurrence of the actual label and increment all counters from there on //do nothing if no hit for (int k=0;k<pred_labels.length;++k) { if (pred_labels[k] == actual_label) { while (k<pred_labels.length) hits[k++]++; } } } // Compute CMs for different thresholds via MRTask2 private static class HitRatioTask extends MRTask2<HitRatioTask> { /* @OUT CMs */ public final float[] hit_ratios() { float[] hit_ratio = new float[_K]; if (_count == 0) return new float[_K]; for (int i=0;i<_K;++i) { hit_ratio[i] = ((float)_hits[i])/_count; } return hit_ratio; } /* IN K */ final private int _K; /* IN Seed */ private long _seed; /* Helper */ private long[] _hits; //the number of hits, length: K private long _count; //the number of scored rows HitRatioTask(int K, long seed) { _K = K; _seed = seed; } @Override public void map( Chunk[] cs ) { _hits = new long[_K]; Arrays.fill(_hits, 0); // pseudo-random tie breaking needs some bits to work with final double[] tieBreaker = new double [] { new Random(_seed).nextDouble(), new Random(_seed+1).nextDouble(), new Random(_seed+2).nextDouble(), new Random(_seed+3).nextDouble() }; float [] preds = new float[cs.length]; // rows for( int r=0; r < cs[0]._len; r++ ) { if (cs[0].isNA0(r)) { _count--; continue; } final int actual_label = (int)cs[0].at80(r); //predict K labels for(int p=1; p < cs.length; p++) preds[p] = (float)cs[p].at0(r); final int[] pred_labels = getPredictions(_K, preds, tieBreaker); if (actual_label < cs.length-1) updateHits(_hits, actual_label, pred_labels); } _count += cs[0]._len; } @Override public void reduce( HitRatioTask other ) { assert(other._K == _K); _hits = Utils.add(_hits, other._hits); _count += other._count; } } }