package water.util; import java.util.*; import water.H2O; /** * Shared static code to support modeling, prediction, and scoring. * * <p>Used by interpreted models as well as by generated model code.</p> * * <p><strong>WARNING:</strong> The class should have no other H2O dependencies * since it is provided for generated code as h2o-model.jar which contains * only a few files.</p> * */ public class ModelUtils { /** List of default thresholds */ public static float[] DEFAULT_THRESHOLDS = new float [] { 0.00f, 0.01f, 0.02f, 0.03f, 0.04f, 0.05f, 0.06f, 0.07f, 0.08f, 0.09f, 0.10f, 0.11f, 0.12f, 0.13f, 0.14f, 0.15f, 0.16f, 0.17f, 0.18f, 0.19f, 0.20f, 0.21f, 0.22f, 0.23f, 0.24f, 0.25f, 0.26f, 0.27f, 0.28f, 0.29f, 0.30f, 0.31f, 0.32f, 0.33f, 0.34f, 0.35f, 0.36f, 0.37f, 0.38f, 0.39f, 0.40f, 0.41f, 0.42f, 0.43f, 0.44f, 0.45f, 0.46f, 0.47f, 0.48f, 0.49f, 0.50f, 0.51f, 0.52f, 0.53f, 0.54f, 0.55f, 0.56f, 0.57f, 0.58f, 0.59f, 0.60f, 0.61f, 0.62f, 0.63f, 0.64f, 0.65f, 0.66f, 0.67f, 0.68f, 0.69f, 0.70f, 0.71f, 0.72f, 0.73f, 0.74f, 0.75f, 0.76f, 0.77f, 0.78f, 0.79f, 0.80f, 0.81f, 0.82f, 0.83f, 0.84f, 0.85f, 0.86f, 0.87f, 0.88f, 0.89f, 0.90f, 0.91f, 0.92f, 0.93f, 0.94f, 0.95f, 0.96f, 0.97f, 0.98f, 0.99f, 1.00f }; /** * Utility function to get a best prediction from an array of class * prediction distribution. It returns index of max value if predicted * values are unique. In the case of tie, the implementation solve it in * pseudo-random way. * @param preds an array of prediction distribution. Length of arrays is equal to a number of classes+1. * @return the best prediction (index of class, zero-based) */ public static int getPrediction( float[] preds, double data[] ) { int best=1, tieCnt=0; // Best class; count of ties for( int c=2; c<preds.length; c++) { if( preds[best] < preds[c] ) { best = c; // take the max index tieCnt=0; // No ties } else if (preds[best] == preds[c]) { tieCnt++; // Ties } } if( tieCnt==0 ) return best-1; // Return zero-based best class // Tie-breaking logic float res = preds[best]; // One of the tied best results long hash = 0; // hash for tie-breaking if( data != null ) for( double d : data ) hash ^= Double.doubleToRawLongBits(d) >> 6; // drop 6 least significants bits of mantisa (layout of long is: 1b sign, 11b exp, 52b mantisa) int idx = (int)hash%(tieCnt+1); // Which of the ties we'd like to keep for( best=1; best<preds.length; best++) if( res == preds[best] && --idx < 0 ) return best-1; // Return best throw H2O.fail(); // Should Not Reach Here } /** * Create labels from per-class probabilities with pseudo-random tie-breaking, if needed. * @param numK Number of top probabilities to make labels for * @param preds Predictions (first element is ignored here: placeholder for a label) * @param data Data to break ties (typically, the test set data for this row) * @return Array of predicted labels */ public static int[] getPredictions( int numK, float[] preds, double data[] ) { assert(numK <= preds.length-1); int[] labels = new int[numK]; // create a sorted mapping from probability to label(s) TreeMap<Float, List<Integer> > prob_idx = new TreeMap<Float, List<Integer> >(new Comparator<Float>() { @Override public int compare(Float o1, Float o2) { if (o1 > o2) return -1; if (o2 > o1) return 1; return 0; } }); for (int i = 1; i < preds.length; ++i) { final Float prob = preds[i]; final int label = i-1; assert(prob >= 0 && prob <= 1) : "prob is not inside [0,1]: " + prob; if (prob_idx.containsKey(prob)) { prob_idx.get(prob).add(label); //add all ties } else { // add prob to top K probs only if either: // 1) don't have K probs yet // 2) prob is greater than the smallest prob in the store -> evict the smallest if (prob_idx.size() < numK || prob > prob_idx.lastKey()) { List<Integer> li = new LinkedList<Integer>(); li.add(label); prob_idx.put(prob, li); } // keep size small, only need the best numK probabilities (max-heap) if (prob_idx.size()>numK) { prob_idx.remove(prob_idx.lastKey()); } } } assert(!prob_idx.isEmpty()); assert(prob_idx.size() <= numK); //have at most numK probabilities, maybe less if there are ties int i = 0; //which label we are filling in while (i < numK && !prob_idx.isEmpty()) { final Map.Entry p_id = prob_idx.firstEntry(); final Float prob = (Float)p_id.getKey(); //max prob. final List<Integer> indices = (List<Integer>)p_id.getValue(); //potential candidate labels if there are ties if (i + indices.size() <= numK) for (Integer id : indices) labels[i++] = id; else { // Tie-breaking logic: pick numK-i classes (indices) from the list of indices. // if data == null, then pick the first numK-i indices, otherwise break ties pseudo-randomly. while (i<numK) { assert(!indices.isEmpty()); long hash = 0; if( data != null ) for( double d : data ) hash ^= Double.doubleToRawLongBits(d+i) >> 6; // drop 6 least significant bits of mantissa (layout of long is: 1b sign, 11b exp, 52b mantissa) labels[i++] = indices.remove((int)(Math.abs(hash)%indices.size())); } assert(i==numK); } prob_idx.remove(prob); } assert(i==numK); return labels; } public static int getPrediction(float[] preds, int row) { int best=1, tieCnt=0; // Best class; count of ties for( int c=2; c<preds.length; c++) { if( preds[best] < preds[c] ) { best = c; // take the max index tieCnt=0; // No ties } else if (preds[best] == preds[c]) { tieCnt++; // Ties } } if( tieCnt==0 ) return best-1; // Return zero-based best class // Tie-breaking logic float res = preds[best]; // One of the tied best results int idx = row%(tieCnt+1); // Which of the ties we'd like to keep for( best=1; best<preds.length; best++) if( res == preds[best] && --idx < 0 ) return best-1; // Return best throw H2O.fail(); // Should Not Reach Here } /** * Correct a given list of class probabilities produced as a prediction by a model back to prior class distribution * * <p>The implementation is based on Eq. (27) in <a href="http://gking.harvard.edu/files/0s.pdf">the paper</a>. * * @param scored list of class probabilities beginning at index 1 * @param priorClassDist original class distribution * @param modelClassDist class distribution used for model building (e.g., data was oversampled) * @return corrected list of probabilities */ public static float[] correctProbabilities(float[] scored, float[] priorClassDist, float[] modelClassDist) { double probsum=0; for( int c=1; c<scored.length; c++ ) { final double original_fraction = priorClassDist[c-1]; final double oversampled_fraction = modelClassDist[c-1]; assert(!Double.isNaN(scored[c])); if (original_fraction != 0 && oversampled_fraction != 0) scored[c] *= original_fraction / oversampled_fraction; probsum += scored[c]; } if (probsum>0) for (int i=1;i<scored.length;++i) scored[i] /= probsum; return scored; } /** * Sample out-of-bag rows with given rate with help of given sampler. * It returns array of sampled rows. The first element of array contains a number * of sampled rows. The returned array can be larger than number of returned sampled * elements. * * @param nrows number of rows to sample from. * @param rate sampling rate * @param sampler random "dice" * @return an array contains numbers of sampled rows. The first element holds a number of sampled rows. The array length * can be greater than number of sampled rows. */ public static int[] sampleOOBRows(int nrows, float rate, Random sampler) { return sampleOOBRows(nrows, rate, sampler, new int[2+Math.round((1f-rate)*nrows*1.2f+0.5f)]); } /** * In-situ version of {@link #sampleOOBRows(int, float, Random)}. * * @param oob an initial array to hold sampled rows. Can be internally reallocated. * @return an array containing sampled rows. * * @see #sampleOOBRows(int, float, Random) */ public static int[] sampleOOBRows(int nrows, float rate, Random sampler, int[] oob) { int oobcnt = 0; // Number of oob rows Arrays.fill(oob, 0); for(int row = 0; row < nrows; row++) { if (sampler.nextFloat() >= rate) { // it is out-of-bag row oob[1+oobcnt++] = row; if (1+oobcnt>=oob.length) oob = Arrays.copyOf(oob, Math.round(1.2f*nrows+0.5f)+2); } } oob[0] = oobcnt; return oob; } }