package hex.singlenoderf; import hex.ShuffleTask; //import hex.gbm.DTree.TreeModel.CompressedTree; import java.util.ArrayList; //import java.util.Arrays; import java.util.Random; import water.AutoBuffer; import water.Iced; //import water.Key; import water.MRTask2; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.Vec; //import water.util.ModelUtils; import water.util.Utils; /** Score given tree model and preserve errors per tree in form of votes (for classification) * or MSE (for regression). * * This is different from Model.score() function since the MR task * uses inverse loop: first over all trees and over all rows in chunk. */ public class VariableImportance extends MRTask2<VariableImportance> { /* @IN */ final private float _rate; // /* @IN */ private int _trees; // FIXME: Pass only tree-keys since serialized trees are passed over wire !!! /* @IN */ final private int _var; /* @IN */ final private boolean _oob; /* @IN */ final private int _ncols; /* @IN */ final private int _nclasses; /* @IN */ final private boolean _classification; /* @IN */ final private SpeeDRFModel _model; /* @IN */ final private int[] _modelDataMap; /* @IN */ private Frame _data; /* @IN */ private int _classcol; /** Computed mapping of model prediction classes to confusion matrix classes */ /* @IN */ private int[] _model_classes_mapping; /** Computed mapping of data prediction classes to confusion matrix classes */ /* @IN */ private int[] _data_classes_mapping; /** Difference between model cmin and CM cmin */ /* @IN */ private int _cmin_model_mapping; /** Difference between data cmin and CM cmin */ /* @IN */ private int _cmin_data_mapping; /* @IN */ private int _cmin; /* @INOUT */ private final int _ntrees; // /* @OUT */ private long [/*ntrees*/] _votes; // Number of correct votes per tree (for classification only) /* @OUT */ private long [/*ntrees*/] _nrows; // Number of scored row per tree (for classification/regression) // /* @OUT */ private float[/*ntrees*/] _sse; // Sum of squared errors per tree (for regression only) /* @OUT */ private long [/*ntrees*/] _votesSOOB; /* @OUT */ private long [/*ntrees*/] _votesOOB; /* @OUT */ private long [/*ntrees*/] _voteDiffs; /* @OUT */ private float _varimp; /* @OUT */ private float _varimpSD; /* @OUT */ private int[] _oobs; private VariableImportance(int trees, int nclasses, int ncols, float rate, int variable, SpeeDRFModel model, Frame fr, Vec resp) { _ncols = ncols; _rate = rate; _var = variable; _oob = true; _ntrees = trees; _nclasses = nclasses; _classification = (nclasses>1); _classcol = fr.numCols() - 1; _data = fr; _cmin = (int) resp.min(); _model = model; _modelDataMap = _model.colMap(_data); init(resp); } private void init(Vec resp) { Vec respData = _data.vecs()[_classcol]; int model_min = (int) resp.min(); int data_min = (int)respData.min(); if (resp._domain!=null) { assert respData._domain != null; _model_classes_mapping = new int[resp._domain.length]; _data_classes_mapping = new int[respData._domain.length]; // compute mapping alignEnumDomains(resp._domain, respData._domain, _model_classes_mapping, _data_classes_mapping); } else { assert respData._domain == null; _model_classes_mapping = null; _data_classes_mapping = null; // compute mapping _cmin_model_mapping = model_min - Math.min(model_min, data_min); _cmin_data_mapping = data_min - Math.min(model_min, data_min); } } @Override public void map(Chunk[] chks) { _votesOOB = new long[_ntrees]; _votesSOOB = new long[_ntrees]; _voteDiffs = new long[_ntrees]; _varimp = 0.f; _varimpSD = 0.f; _nrows = new long[_ntrees]; double[] data = new double[_ncols]; float [] preds = new float[_nclasses+1]; final int rows = chks[0]._len; int _N = _nclasses; int[] soob = null; // shuffled oob rows boolean collectOOB = true; final int cmin = _cmin; //Need the chunk of code to score over every tree... //Doesn't do anything with the first tree, we score time last *manually* (after looping over all da trees) long seedForOob = ShuffleTask.seed(chks[0].cidx()); for( int ntree = 0; ntree < _ntrees; ntree++ ) { int oobcnt = 0; ArrayList<Integer> oob = new ArrayList<Integer>(); // oob rows long treeSeed = _model.seed(ntree); byte producerId = _model.producerId(ntree); int init_row = (int)chks[0]._start; long seed = Sampling.chunkSampleSeed(treeSeed, init_row); Random rand = Utils.getDeterRNG(seed); // Now for all rows, classify & vote! for (int row = 0; row < rows; row++) { // int row = r + (int)chks[0]._start; // ------ THIS CODE is crucial and serve to replay the same sequence // of random numbers as in the method Data.sampleFair() // Skip row used during training if OOB is computed float sampledItem = rand.nextFloat(); // Bail out of broken rows with NA in class column. // Do not skip yet the rows with NAs in the rest of columns if (chks[_ncols - 1].isNA0(row)) continue; if (sampledItem < _model.sample) continue; oob.add(row); oobcnt++; // Predict with this tree - produce 0-based class index int prediction = (int) _model.classify0(ntree, chks, row, _modelDataMap, (short) _N, false); if (prediction >= _nclasses) continue; // Junk row cannot be predicted // Check tree miss int alignedPrediction = alignModelIdx(prediction); int alignedData = alignDataIdx((int) chks[_classcol].at80(row) - cmin); if (alignedPrediction == alignedData) _votesOOB[ntree]++; } _oobs = new int[oob.size()]; for (int i = 0; i < oob.size(); ++i) _oobs[i] = oob.get(i); //score on shuffled data... if (soob==null || soob.length < oobcnt) soob = new int[oobcnt]; Utils.shuffleArray(_oobs, oobcnt, soob, seedForOob, 0); // Shuffle array and copy results into <code>soob</code> for(int j = 0; j < oobcnt; j++) { int row = _oobs[j]; // Do scoring: // - prepare a row data for (int i=0;i<chks.length - 1;i++) { data[i] = chks[i].at0(row); // 1+i - one free is expected by prediction } // - permute variable if (_var>=0) data[_var] = chks[_var].at0(soob[j]); else assert false; // - score data // - score only the tree int prediction = (int) Tree.classify(new AutoBuffer(_model.tree(ntree)), data, (double)_N, false); //.classify0(ntree, _data, chks, row, _modelDataMap, numClasses ); if( prediction >= _nclasses ) continue; int pred = alignModelIdx(prediction); int actu = alignDataIdx((int) chks[_classcol].at80(_oobs[j]) - cmin); if (pred == actu) _votesSOOB[ntree]++; _nrows[ntree]++; } } } @Override public void reduce( VariableImportance t ) { Utils.add(_votesOOB, t._votesOOB); Utils.add(_votesSOOB, t._votesSOOB); Utils.add(_nrows, t._nrows); } /** Transforms 0-based class produced by model to CF zero-based */ private int alignModelIdx(int modelClazz) { if (_model_classes_mapping!=null) return _model_classes_mapping[modelClazz]; else return modelClazz + _cmin_model_mapping; } /** Transforms 0-based class from input data to CF zero-based */ private int alignDataIdx(int dataClazz) { if (_data_classes_mapping!=null) return _data_classes_mapping[dataClazz]; else return dataClazz + _cmin_data_mapping; } public static int alignEnumDomains(final String[] modelDomain, final String[] dataDomain, int[] modelMapping, int[] dataMapping) { assert modelMapping!=null && modelMapping.length == modelDomain.length; assert dataMapping!=null && dataMapping.length == dataDomain.length; int idx = 0, idxM = 0, idxD = 0; while(idxM!=modelDomain.length || idxD!=dataDomain.length) { if (idxM==modelDomain.length) { dataMapping[idxD++] = idx++; continue; } if (idxD==dataDomain.length) { modelMapping[idxM++] = idx++; continue; } int c = modelDomain[idxM].compareTo(dataDomain[idxD]); if (c < 0) { modelMapping[idxM] = idx; idxM++; } else if (c > 0) { dataMapping[idxD] = idx; idxD++; } else { // strings are identical modelMapping[idxM] = idx; dataMapping[idxD] = idx; idxM++; idxD++; } idx++; } return idx; } public TreeVotes[] resultVotes() { return new TreeVotes[]{new TreeVotes(_votesOOB, _nrows, _ntrees), new TreeVotes(_votesSOOB, _nrows, _ntrees)}; } // public TreeSSE resultSSE () { return new TreeSSE (_sse, _nrows, _ntrees); } /* This is a copy of score0 method from DTree:615 */ // private void score0(double data[], float preds[], CompressedTree[] ts) { // for( int c=0; c<ts.length; c++ ) // if( ts[c] != null ) // preds[ts.length==1?0:c+1] += ts[c].score(data); // } // private Chunk chk_resp( Chunk chks[] ) { return chks[_ncols]; } // // private Random rngForTree(CompressedTree[] ts, int cidx) { // return _oob ? ts[0].rngForChunk(cidx) : new DummyRandom(); // k-class set of trees shares the same random number // } /* For bulk scoring public static TreeVotes collect(TreeModel tmodel, Frame f, int ncols, float rate, int variable) { CompressedTree[][] trees = new CompressedTree[tmodel.ntrees()][]; for (int tidx = 0; tidx < tmodel.ntrees(); tidx++) trees[tidx] = tmodel.ctree(tidx); return new TreeVotesCollector(trees, tmodel.nclasses(), ncols, rate, variable).doAll(f).result(); }*/ // VariableImportance(int trees, int nclasses, int ncols, float rate, int variable, SpeeDRFModel model) public static TreeVotes[] collectVotes(int trees, int nclasses, Frame f, int ncols, float rate, int variable, SpeeDRFModel model, Vec resp) { return new VariableImportance(trees, nclasses, ncols, rate, variable, model, f, resp).doAll(f).resultVotes(); } // public static TreeSSE collectSSE(CompressedTree[/*nclass || 1 for regression*/] tree, int nclasses, Frame f, int ncols, float rate, int variable) { // return new TreeMeasuresCollector(new CompressedTree[][] {tree}, nclasses, ncols, rate, variable).doAll(f).resultSSE(); // } // private static final class DummyRandom extends Random { // @Override public final float nextFloat() { return 1.0f; } // } /** A simple holder for set of different tree measurements. */ public static abstract class TreeMeasures<T extends TreeMeasures> extends Iced { /** Actual number of trees which votes are stored in this object */ protected int _ntrees; /** Number of processed row per tree. */ protected long[/*ntrees*/] _nrows; public TreeMeasures(int initialCapacity) { _nrows = new long[initialCapacity]; } public TreeMeasures(long[] nrows, int ntrees) { _nrows = nrows; _ntrees = ntrees;} /** Returns number of rows which were used during voting per individual tree. */ public final long[] nrows() { return _nrows; } /** Returns number of voting predictors */ public final int npredictors() { return _ntrees; } /** Returns a list of accuracies per tree. */ public abstract double accuracy(int tidx); public final double[] accuracy() { double[] r = new double[_ntrees]; // Average of all trees for (int tidx=0; tidx<_ntrees; tidx++) r[tidx] = accuracy(tidx); return r; } /** Compute variable importance with respect to given votes. * The given {@link T} object represents correct votes. * This object represents votes over shuffled data. * * @param right individual tree measurements performed over not shuffled data. * @return computed importance and standard deviation */ public abstract double[/*2*/] imp(T right); public abstract T append(T t); } /** A class holding tree votes. */ public static class TreeVotes extends TreeMeasures<TreeVotes> { /** Number of correct votes per tree */ private long[/*ntrees*/] _votes; public TreeVotes(int initialCapacity) { super(initialCapacity); _votes = new long[initialCapacity]; } public TreeVotes(long[] votes, long[] nrows, int ntrees) { super(nrows, ntrees); _votes = votes; } /** Returns number of positive votes per tree. */ public final long[] votes() { return _votes; } /** Returns accuracy per individual trees. */ @Override public final double accuracy(int tidx) { assert tidx < _nrows.length && tidx < _votes.length; return ((double) _votes[tidx]) / _nrows[tidx]; } /** Compute variable importance with respect to given votes. * The given {@link TreeVotes} object represents correct votes. * This object represents votes over shuffled data. * * @param right individual tree voters performed over not shuffled data. * @return computed importance and standard deviation */ @Override public final double[/*2*/] imp(TreeVotes right) { assert npredictors() == right.npredictors(); int ntrees = npredictors(); double imp = 0; double sd = 0; // Over all trees for (int tidx = 0; tidx < ntrees; tidx++) { assert right.nrows()[tidx] == nrows()[tidx]; double delta = ((double) (right.votes()[tidx] - votes()[tidx])) / nrows()[tidx]; imp += delta; sd += delta * delta; } double av = imp / ntrees; double csd = Math.sqrt( (sd/ntrees - av*av) / ntrees ); return new double[] { av, csd}; } /** Append a tree votes to a list of trees. */ public TreeVotes append(long rightVotes, long allRows) { assert _votes.length > _ntrees && _votes.length == _nrows.length : "TreeVotes inconsistency!"; _votes[_ntrees] = rightVotes; _nrows[_ntrees] = allRows; _ntrees++; return this; } @Override public TreeVotes append(final TreeVotes tv) { for (int i=0; i<tv.npredictors(); i++) append(tv._votes[i], tv._nrows[i]); return this; } } /** A simple holder serving SSE per tree. */ // public static class TreeSSE extends TreeMeasures<TreeSSE> { // /** SSE per tree */ // private float[/*ntrees*/] _sse; // // public TreeSSE(int initialCapacity) { // super(initialCapacity); // _sse = new float[initialCapacity]; // } // public TreeSSE(float[] sse, long[] nrows, int ntrees) { // super(nrows, ntrees); // _sse = sse; // } // @Override public double accuracy(int tidx) { // return _sse[tidx] / _nrows[tidx]; // } // @Override public double[] imp(TreeSSE right) { // assert npredictors() == right.npredictors(); // int ntrees = npredictors(); // double imp = 0; // double sd = 0; // // Over all trees // for (int tidx = 0; tidx < ntrees; tidx++) { // assert right.nrows()[tidx] == nrows()[tidx]; // check that we iterate over same OOB rows // double delta = ((double) (_sse[tidx] - right._sse[tidx])) / nrows()[tidx]; // imp += delta; // sd += delta * delta; // } // double av = imp / ntrees; // double csd = Math.sqrt( (sd/ntrees - av*av) / ntrees ); // return new double[] { av, csd }; // } // @Override public TreeSSE append(TreeSSE t) { // for (int i=0; i<t.npredictors(); i++) // append(t._sse[i], t._nrows[i]); // return this; // } // /** Append a tree sse to a list of trees. */ // public TreeSSE append(float sse, long allRows) { // assert _sse.length > _ntrees && _sse.length == _nrows.length : "TreeVotes inconsistency!"; // _sse [_ntrees] = sse; // _nrows[_ntrees] = allRows; // _ntrees++; // return this; // } // } public static TreeVotes asVotes(TreeMeasures tm) { return (TreeVotes) tm; } // public static TreeSSE asSSE (TreeMeasures tm) { return (TreeSSE) tm; } }