package hex.singlenoderf; import water.*; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.Vec; import water.util.Log; import water.util.Log.Tag.Sys; import water.util.ModelUtils; import water.util.Utils; import java.util.Arrays; import java.util.Random; import hex.VarImp; /** * Confusion Matrix. Incrementally computes a Confusion Matrix for a forest * of Trees, vs a given input dataset. The set of Trees can grow over time. Each * request from the Confusion compute on any new trees (if any), and report a * matrix. Cheap if all trees already computed. */ public class CMTask extends MRTask2<CMTask> { public double[] _classWt; public boolean _computeOOB; public int _treesUsed; public Key _modelKey; public Key _datakey; public int _classcol; public CM _matrix; public float _sum; //sum of squares Sum_ti((f_ti - delta(oti,i))^2) AKA brier score ~ classification mse public CM[] _localMatrices; public long[] _errorsPerTree; public SpeeDRFModel _model; public int[] _modelDataMap; public Frame _data; public int _N; public long _cms[][][]; public VarImp _varimp; public int[] _oobs; public Key[][] _remoteChunksKeys; public float _ss; // Sum of squares public int _rowcnt; // Rows used in scoring for regression public boolean _score_new_tree_only; /** Data to replay the sampling algorithm */ private long[] _chunk_row_mapping; /** Number of rows at each node */ private int[] _rowsPerNode; /** Computed mapping of model prediction classes to confusion matrix classes */ private int[] _model_classes_mapping; /** Computed mapping of data prediction classes to confusion matrix classes */ private int[] _data_classes_mapping; /** Difference between model cmin and CM cmin */ private int _cmin_model_mapping; /** Difference between data cmin and CM cmin */ private int _cmin_data_mapping; transient private Random _rand; /** Confusion matrix * @param model the ensemble used to classify */ private CMTask(SpeeDRFModel model, int treesToUse, boolean computeOOB, Frame fr, Vec resp) { _modelKey = model._key; _datakey = model._dataKey; _classcol = fr.numCols() - 1; //model.test_frame == null ? (model.fr.numCols() - 1) : (model.test_frame.numCols() - 1); _treesUsed = treesToUse; _computeOOB = computeOOB; _model = model; _varimp = null; _ss = 0.f; _data = fr; shared_init(resp); } public static CMTask scoreTask(SpeeDRFModel model, int treesToUse, boolean computeOOB, Frame fr, Vec resp) { CMTask tsk = new CMTask(model, treesToUse, computeOOB, fr, resp); tsk.doAll(fr); return tsk; } /** Shared init: pre-compute local data for new Confusions, for remote Confusions*/ private void shared_init(Vec resp) { /* For reproducibility we can control the randomness in the computation of the confusion matrix. The default seed when deserializing is 42. */ // _data = _model.test_frame == null ? _model.fr : _model.test_frame; if (_model.validation) _computeOOB = false; _modelDataMap = _model.colMap(_data); assert !_computeOOB || _model._dataKey.equals(_datakey) : !_computeOOB + " || " + _model._dataKey + " equals " + _datakey; Vec respModel = resp; Vec respData = _data.vecs()[_classcol]; int model_max = (int)respModel.max(); int model_min = (int)respModel.min(); int data_max = (int)respData.max(); int data_min = (int)respData.min(); if (respModel._domain!=null) { assert respData._domain != null; _model_classes_mapping = new int[respModel._domain.length]; _data_classes_mapping = new int[respData._domain.length]; // compute mapping _N = alignEnumDomains(respModel._domain, respData._domain, _model_classes_mapping, _data_classes_mapping); } else { assert respData._domain == null; _model_classes_mapping = null; _data_classes_mapping = null; // compute mapping _cmin_model_mapping = model_min - Math.min(model_min, data_min); _cmin_data_mapping = data_min - Math.min(model_min, data_min); _N = Math.max(model_max, data_max) - Math.min(model_min, data_min) + 1; } assert _N > 0; // You know...it is good to be sure init(); } public void init() { // Make a mapping from chunk# to row# just for chunks on this node // First compute the number of chunks homed to this node int total_home = 0; for (int i = 0; i < _data.anyVec().nChunks(); ++i) { if (_data.anyVec().chunkKey(i).home()) { total_home++; } } // Now generate the mapping _chunk_row_mapping = new long[total_home]; int off=0; int cidx=0; for (int i = 0; i < _data.anyVec().nChunks(); ++i) { if (_data.anyVec().chunkKey(i).home()) { _chunk_row_mapping[cidx++] = _data.anyVec().chunk2StartElem(i); } } // Initialize number of rows per node _rowsPerNode = new int[H2O.CLOUD.size()]; long chunksCount = _data.anyVec().nChunks(); for(int ci=0; ci<chunksCount; ci++) { Key cKey = _data.anyVec().chunkKey(ci); _rowsPerNode[cKey.home_node().index()] += _data.anyVec().chunkLen(ci); } _remoteChunksKeys = new Key[H2O.CLOUD.size()][]; int[] _remoteChunksCounter = new int[H2O.CLOUD.size()]; for (int i = 0; i < _data.anyVec().nChunks(); ++i) { _remoteChunksCounter[_data.anyVec().chunkKey(i).home(H2O.CLOUD)]++; } for (int i = 0; i < H2O.CLOUD.size(); ++i) _remoteChunksKeys[i] = new Key[_remoteChunksCounter[i]]; int[] cnter = new int[H2O.CLOUD.size()]; for (int i = 0; i < _data.anyVec().nChunks(); ++i) { int node_idx = _data.anyVec().chunkKey(i).home(H2O.CLOUD); _remoteChunksKeys[node_idx][cnter[node_idx]++] = _data.anyVec().chunkKey(i); } } private int producerRemoteRows(byte treeProducerID, Key chunkKey) { Key[] remoteCKeys = _remoteChunksKeys[treeProducerID]; int off = 0; for (int i=0; i<remoteCKeys.length; i++) { if (chunkKey.equals(remoteCKeys[i])) return off; off += _data.anyVec().chunkLen(i); } return off; } @Override public void map(Chunk[] chks) { final int rows = chks[0]._len; final int cmin = _model.resp_min; short numClasses = (short)_model.classes(); _cms = new long[ModelUtils.DEFAULT_THRESHOLDS.length][2][2]; // Votes: we vote each tree on each row, holding on to the votes until the end int[][] votes = new int[rows][_N]; int[][] localVotes = _computeOOB ? new int[rows][_N] : null; // Errors per tree _errorsPerTree = new long[_model.treeCount()]; // Replay the Data.java's "sample_fair" sampling algorithm to exclude data // we trained on during voting. for( int ntree = 0; ntree < _model.treeCount(); ntree++ ) { if (_score_new_tree_only) ntree = _model.treeCount() - 1; long treeSeed = _model.seed(ntree); byte producerId = _model.producerId(ntree); int init_row = (int)chks[0]._start; boolean isLocalTree = _computeOOB && isLocalTree(producerId); // tree is local boolean isRemote = true; for (long a_chunk_row_mapping : _chunk_row_mapping) { if (chks[0]._start == a_chunk_row_mapping) { isRemote = false; break; } } boolean isRemoteTreeChunk = _computeOOB && isRemote; // this is chunk which was used for construction the tree by another node if (isRemoteTreeChunk) init_row = _rowsPerNode[producerId] + (int)chks[0]._start + producerRemoteRows(producerId, chks[0]._vec.chunkKey(chks[0].cidx())); /* NOTE: Before changing used generator think about which kind of random generator you need: * if always deterministic or non-deterministic version - see hex.rf.Utils.get{Deter}RNG */ // DEBUG: if( _computeOOB && (isLocalTree || isRemoteTreeChunk)) System.err.println(treeSeed + " : " + init_row + " (CM) " + isRemoteTreeChunk); long seed = Sampling.chunkSampleSeed(treeSeed, init_row); Random rand = Utils.getDeterRNG(seed); // Now for all rows, classify & vote! ROWS: for( int row = 0; row < rows; row++ ) { // ------ THIS CODE is crucial and serve to replay the same sequence // of random numbers as in the method Data.sampleFair() // Skip row used during training if OOB is computed float sampledItem = rand.nextFloat(); // Bail out of broken rows with NA in class column. // Do not skip yet the rows with NAs in the rest of columns if( chks[_classcol].isNA0(row)) continue; if( _computeOOB && (isLocalTree || isRemoteTreeChunk) ) { // if OOBEE is computed then we need to take into account utilized sampling strategy if (sampledItem < _model.sample) continue; } // --- END OF CRUCIAL CODE --- // Predict with this tree - produce 0-based class index if (!_model.regression) { int prediction = (int)_model.classify0(ntree, chks, row, _modelDataMap, numClasses, false /*Not regression*/); if( prediction >= numClasses ) continue; // Junk row cannot be predicted // Check tree miss int alignedPrediction = alignModelIdx(prediction); int alignedData = alignDataIdx((int) chks[_classcol].at80(row) - cmin); if (alignedPrediction != alignedData) { _errorsPerTree[ntree]++; } votes[row][alignedPrediction]++; // Vote the row // if (isLocalTree) localVotes[row][alignedPrediction]++; // Vote } else { float pred = _model.classify0(ntree, chks, row, _modelDataMap, (short) 0, true /*regression*/); float actual = chks[_classcol].at80(row); float delta = actual - pred; _ss += delta * delta; _rowcnt++; } } } if(!_model.regression) { // Assemble the votes-per-class into predictions & score each row _matrix = computeCM(votes, chks, false /*Do the _cms once*/, _model.get_params().balance_classes); // Make a confusion matrix for this chunk if (localVotes!=null) { _localMatrices = new CM[H2O.CLOUD.size()]; _localMatrices[H2O.SELF.index()] = computeCM(localVotes, chks, true /*Don't compute the _cms again!*/, _model.get_params().balance_classes); } } } public static float[] computeVarImpSD(long[][] vote_diffs) { float[] res = new float[vote_diffs.length]; for (int var = 0; var < vote_diffs.length; ++var) { float mean_diffs = 0.f; float r = 0.f; for (long d: vote_diffs[var]) mean_diffs += (float) d / (float) vote_diffs.length; for (long d: vote_diffs[var]) { r += (d - mean_diffs) * (d - mean_diffs); } r *= 1.f / (float)vote_diffs[var].length; res[var] = (float) Math.sqrt(r); } return res; } /** Returns true if tree was produced by this node. * Note: chunkKey is key stored at this local node */ private boolean isLocalTree(byte treeProducerId) { assert _computeOOB : "Calling this method makes sense only for oobee"; int idx = H2O.SELF.index(); return idx == treeProducerId; } /** Reduction combines the confusion matrices. */ @Override public void reduce(CMTask drt) { if (!_model.regression) { if (_matrix == null) { _matrix = drt._matrix; } else { _matrix = _matrix.add(drt._matrix); } _sum += drt._sum; // Reduce tree errors long[] ept1 = _errorsPerTree; long[] ept2 = drt._errorsPerTree; if (ept1 == null) _errorsPerTree = ept2; else if (ept2 != null) { if (ept1.length < ept2.length) ept1 = Arrays.copyOf(ept1, ept2.length); for (int i = 0; i < ept2.length; i++) ept1[i] += ept2[i]; } if (_cms!=null) for (int i = 0; i < _cms.length; i++) Utils.add(_cms[i], drt._cms[i]); if (_oobs != null) for (int i = 0; i < _oobs.length; ++i) _oobs[i] += drt._oobs[i]; } else { _ss += drt._ss; _rowcnt += drt._rowcnt; } } /** Transforms 0-based class produced by model to CF zero-based */ private int alignModelIdx(int modelClazz) { if (_model_classes_mapping!=null) return _model_classes_mapping[modelClazz]; else return modelClazz + _cmin_model_mapping; } /** Transforms 0-based class from input data to CF zero-based */ private int alignDataIdx(int dataClazz) { if (_data_classes_mapping!=null) return _data_classes_mapping[dataClazz]; else return dataClazz + _cmin_data_mapping; } /** Merge model and data predictor domain to produce domain for CM. * The domain is expected to be ordered and containing unique values. */ public static int alignEnumDomains(final String[] modelDomain, final String[] dataDomain, int[] modelMapping, int[] dataMapping) { assert modelMapping!=null && modelMapping.length == modelDomain.length; assert dataMapping!=null && dataMapping.length == dataDomain.length; int idx = 0, idxM = 0, idxD = 0; while(idxM!=modelDomain.length || idxD!=dataDomain.length) { if (idxM==modelDomain.length) { dataMapping[idxD++] = idx++; continue; } if (idxD==dataDomain.length) { modelMapping[idxM++] = idx++; continue; } int c = modelDomain[idxM].compareTo(dataDomain[idxD]); if (c < 0) { modelMapping[idxM] = idx; idxM++; } else if (c > 0) { dataMapping[idxD] = idx; idxD++; } else { // strings are identical modelMapping[idxM] = idx; dataMapping[idxD] = idx; idxM++; idxD++; } idx++; } return idx; } public static String[] domain(final Vec modelCol, final Vec dataCol) { int[] modelEnumMapping = null; int[] dataEnumMapping = null; int N; if (modelCol._domain!=null) { assert dataCol._domain != null; modelEnumMapping = new int[modelCol._domain.length]; dataEnumMapping = new int[dataCol._domain.length]; N = alignEnumDomains(modelCol._domain, dataCol._domain, modelEnumMapping, dataEnumMapping); } else { assert dataCol._domain == null; N = (int) (Math.max(modelCol.max(), dataCol.max()) - Math.min(modelCol.min(), dataCol.min()) + 1); } return domain(N, modelCol, dataCol, modelEnumMapping, dataEnumMapping); } public static String[] domain(int N, final Vec modelCol, final Vec dataCol, int[] modelEnumMapping, int[] dataEnumMapping) { String[] result = new String[N]; String[] modelDomain = modelCol._domain; String[] dataDomain = dataCol._domain; if (modelDomain!=null) { assert dataDomain!=null; assert modelEnumMapping!=null && modelEnumMapping.length == modelDomain.length; assert dataEnumMapping!=null && dataEnumMapping.length == dataDomain.length; for (int i = 0; i < modelDomain.length; i++) result[modelEnumMapping[i]] = modelDomain[i]; for (int i = 0; i < dataDomain.length; i++) result[dataEnumMapping [i]] = dataDomain[i]; } else { assert dataDomain==null; int dmin = (int) Math.min(modelCol.min(), dataCol.min()); int dmax = (int) Math.max(modelCol.max(), dataCol.max()); for (int i = dmin; i <= dmax; i++) result[i-dmin] = String.valueOf(i); } return result; } /** Compute confusion matrix domain based on model and data key. */ public String[] domain(Vec modelResp) { return domain(_N, modelResp, _data.vecs()[_classcol], _model_classes_mapping, _data_classes_mapping); } /** Return number of classes - in fact dimension of CM. */ public final int dimension() { return _N; } /** Confusion matrix representation. */ static class CM extends Iced { /** The Confusion Matrix - a NxN matrix of [actual] -vs- [predicted] classes, referenced as _matrix[actual][predicted]. Each row in the dataset is voted on by all trees, and the majority vote is the predicted class for the row. Each row thus gets 1 entry in the matrix.*/ protected long _matrix[][]; /** Number of mistaken assignments. */ protected long _errors; /** Number of rows used for building the matrix.*/ protected long _rows; /** Number of skipped rows. Rows can contain bad data, or can be skipped by selecting only out-of-back rows */ protected long _skippedRows; /** Domain - names of columns and rows */ public float classError() { return _errors / (float) _rows; } /** Return number of rows used for CM computation */ public long rows() { return _rows; } /** Return number of skipped rows during CM computation * The number includes in-bag rows if oobee is used. */ public long skippedRows(){ return _skippedRows; } /** Add a confusion matrix. */ public CM add(final CM cm) { if (cm!=null) { if( _matrix == null ) _matrix = cm._matrix; // Take other work straight-up else Utils.add(_matrix,cm._matrix); _rows += cm._rows; _errors += cm._errors; _skippedRows += cm._skippedRows; } return this; } /** Text form of the confusion matrix */ @Override public String toString() { if( _matrix == null ) return "no trees"; int N = _matrix.length; final int K = N + 1; double[] e2c = new double[N]; for( int i = 0; i < N; i++ ) { long err = -_matrix[i][i]; for( int j = 0; j < N; j++ ) err += _matrix[i][j]; e2c[i] = Math.round((err / (double) (err + _matrix[i][i])) * 100) / (double) 100; } String[][] cms = new String[K][K + 1]; cms[0][0] = ""; for( int i = 1; i < K; i++ ) cms[0][i] = "" + (i - 1); cms[0][K] = "err/class"; for( int j = 1; j < K; j++ ) cms[j][0] = "" + (j - 1); for( int j = 1; j < K; j++ ) cms[j][K] = "" + e2c[j - 1]; for( int i = 1; i < K; i++ ) for( int j = 1; j < K; j++ ) cms[j][i] = "" + _matrix[j - 1][i - 1]; int maxlen = 0; for( int i = 0; i < K; i++ ) for( int j = 0; j < K + 1; j++ ) maxlen = Math.max(maxlen, cms[i][j].length()); for( int i = 0; i < K; i++ ) for( int j = 0; j < K + 1; j++ ) cms[i][j] = pad(cms[i][j], maxlen); String s = ""; for( int i = 0; i < K; i++ ) { for( int j = 0; j < K + 1; j++ ) s += cms[i][j]; s += "\n"; } return s; } /** Pad a string with spaces. */ private String pad(String s, int l){ String p=""; for(int i=0; i<l-s.length();i++)p+=" "; return " "+p+s; } } public static class CMFinal extends CM { final protected Key _SpeeDRFModelKey; final protected String[] _domain; final protected long [] _errorsPerTree; final protected boolean _computedOOB; final protected long[][][] _cms; protected boolean _valid; final protected float _sum; private CMFinal() { _valid = false; _SpeeDRFModelKey = null; _domain = null; _errorsPerTree = null; _computedOOB = false; _sum = 0.f; _cms = null; } private CMFinal(CM cm, Key SpeeDRFModelKey, String[] domain, long[] errorsPerTree, boolean computedOOB, boolean valid, float sum, long[][][] cms) { _matrix = cm._matrix; _errors = cm._errors; _rows = cm._rows; _skippedRows = cm._skippedRows; _SpeeDRFModelKey = SpeeDRFModelKey; _domain = domain; _errorsPerTree = errorsPerTree; _computedOOB = computedOOB; _valid = valid; _sum = sum; _cms = cms; } /** Make non-valid confusion matrix */ public static CMFinal make() { return new CMFinal(); } /** Create a new confusion matrix. */ public static CMFinal make(CM cm, SpeeDRFModel model, String[] domain, long[] errorsPerTree, boolean computedOOB, float sum, long[][][] cms) { return new CMFinal(cm, model._key, domain, errorsPerTree, computedOOB, true, sum, cms); } public String[] domain() { return _domain; } public int dimension() { return _matrix.length; } public long matrix(int i, int j) { return _matrix[i][j]; } public boolean valid() { return _valid; } public float mse() { return _sum / (float) _rows; } /** Output information about this RF. */ public final void report() { double err = classError(); assert _valid : "Trying to report status of invalid CM!"; SpeeDRFModel model = UKV.get(_SpeeDRFModelKey); String s = " Type of random forest: classification\n" + " Number of trees: " + model.size() + "\n" + "No of variables tried at each split: " + model.mtry + "\n" + " Estimate of err. rate: " + Math.round(err * 10000) / 100 + "% (" + err + ")\n" + " OOBEE: " + (_computedOOB ? "YES (sampling rate: "+model.sample*100+"%)" : "NO")+ "\n" + " Confusion matrix:\n" + toString() + "\n" + " CM domain: " + Arrays.toString(_domain) + "\n" + " Avg tree depth (min, max): " + model.depth() + "\n" + " Avg tree leaves (min, max): " + model.leaves() + "\n" + " Validated on (rows): " + rows() + "\n" + " Rows skipped during validation: " + skippedRows() + "\n" + " Mispredictions per tree (in rows): " + Arrays.toString(_errorsPerTree)+"\n"; Log.info(Sys.RANDF,s); } /** * Reports size of dataset and computed classification error. */ public final void report(StringBuilder sb) { double err = _errors / (double) _rows; sb.append(_rows).append(','); sb.append(err).append(','); } } /** Compute the sum of squared errors */ static float doSSECalc(int[] votes, float[] preds, int cclass) { float err; // Get the total number of votes for the row float sum = doSum(votes); // No votes for the row if (sum == 0) { err = 1f - (1f / (votes.length - 0f)); return err * err; } err = Float.isInfinite(sum) ? (Float.isInfinite(preds[cclass + 1]) ? 0f : 1f) : 1f - preds[cclass + 1] / sum; return err * err; } static float doSum(int[] votes) { float sum = 0f; for (int v : votes) sum += v; return sum; } static float[] toProbs(float[] preds, float s ) { for (int i = 1; i < preds.length; ++i) { preds[i] /= s; } return preds; } /** Produce confusion matrix from given votes. */ final CM computeCM(int[/**/][/**/] votes, Chunk[] chks, boolean local, boolean balance) { CM cm = new CM(); int rows = votes.length; int validation_rows = 0; int cmin = (int) _data.vecs()[_classcol].min(); // Assemble the votes-per-class into predictions & score each row // Make an empty confusion matrix for this chunk cm._matrix = new long[_N][_N]; float preds[] = new float[_N+1]; float num_trees = _errorsPerTree.length; // Loop over the rows for( int row = 0; row < rows; row++ ) { // Skip rows with missing response values if (chks[_classcol].isNA0(row)) continue; // The class votes for the i-th row int[] vi = votes[row]; // Fill the predictions with the vote counts, keeping the 0th index unchanged for( int v=0; v<_N; v++ ) preds[v+1] = vi[v]; float s = doSum(vi); if (s == 0) { cm._skippedRows++; continue; } int result; if (balance) { float[] scored = toProbs(preds.clone(), doSum(vi)); double probsum=0; for( int c=1; c<scored.length; c++ ) { final double original_fraction = _model.priordist()[c-1]; assert(original_fraction > 0) : "original fraction should be > 0, but is " + original_fraction + ": not using enough training data?"; final double oversampled_fraction = _model.modeldist()[c-1]; assert(oversampled_fraction > 0) : "oversampled fraction should be > 0, but is " + oversampled_fraction + ": not using enough training data?"; assert(!Double.isNaN(scored[c])); scored[c] *= original_fraction / oversampled_fraction; probsum += scored[c]; } for (int i=1;i<scored.length;++i) scored[i] /= probsum; result = ModelUtils.getPrediction(scored, row); } else { // `result` is the class with the most votes, accounting for ties in the shared logic in ModelUtils result = ModelUtils.getPrediction(preds, row); } // Get the class value from the response column for the current row int cclass = alignDataIdx((int) chks[_classcol].at80(row) - cmin); assert 0 <= cclass && cclass < _N : ("cclass " + cclass + " < " + _N); // Ignore rows with zero votes, but still update the sum of squared errors if( vi[result]==0 ) { cm._skippedRows++; if (!local) _sum += doSSECalc(vi, preds, cclass); continue; } // Update the confusion matrix cm._matrix[cclass][result]++; if( result != cclass ) cm._errors++; validation_rows++; // Update the sum of squared errors if (!local) _sum += doSSECalc(vi, preds, cclass); float sum = doSum(vi); // Binomial classification -> compute AUC, draw ROC if(_N == 2 && !local) { float snd = preds[2] / sum; for(int i = 0; i < ModelUtils.DEFAULT_THRESHOLDS.length; i++) { int p = snd >= ModelUtils.DEFAULT_THRESHOLDS[i] ? 1 : 0; _cms[i][cclass][p]++; // Increase matrix } } } // End of loop over rows, return confusion matrix cm._rows=validation_rows; return cm; } public static class MSETask extends MRTask2<MSETask> { //M double _ss; public static double doTask(Frame fr) { MSETask tsk = new MSETask(); tsk.doAll(fr); return tsk._ss / (double) fr.numRows(); } @Override public void map(Chunk[] cks) { for (int i = 0; i < cks[0]._len; ++i) { int cls = (int)cks[cks.length - 1].at0(i); double err = ( 1 - cks[cls+1].at0(i)); _ss += err * err; } } @Override public void reduce(MSETask tsk) { _ss += tsk._ss; } } }