package hex.drf;
import hex.ShuffleTask;
import hex.gbm.DTreeUtils;
import hex.gbm.DTree.TreeModel.CompressedTree;
import java.util.Arrays;
import java.util.Random;
import water.Iced;
import water.MRTask2;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.ModelUtils;
import water.util.Utils;
/** Score given tree model and preserve errors per tree in form of votes (for classification)
* or MSE (for regression).
*
* This is different from Model.score() function since the MR task
* uses inverse loop: first over all trees and over all rows in chunk.
*/
public class TreeMeasuresCollector extends MRTask2<TreeMeasuresCollector> {
/* @IN */ final private float _rate;
/* @IN */ private CompressedTree[/*N*/][/*nclasses*/] _trees; // FIXME: Pass only tree-keys since serialized trees are passed over wire !!!
/* @IN */ final private int _var;
/* @IN */ final private boolean _oob;
/* @IN */ final private int _ncols;
/* @IN */ final private int _nclasses;
/* @IN */ final private boolean _classification;
/* @INOUT */ private final int _ntrees;
/* @OUT */ private long [/*ntrees*/] _votes; // Number of correct votes per tree (for classification only)
/* @OUT */ private long [/*ntrees*/] _nrows; // Number of scored row per tree (for classification/regression)
/* @OUT */ private float[/*ntrees*/] _sse; // Sum of squared errors per tree (for regression only)
private TreeMeasuresCollector(CompressedTree[/*N*/][/*nclasses*/] trees, int nclasses, int ncols, float rate, int variable) {
assert trees.length > 0;
assert nclasses == trees[0].length;
_trees = trees; _ncols = ncols;
_rate = rate; _var = variable;
_oob = true; _ntrees = trees.length;
_nclasses = nclasses;
_classification = (nclasses>1);
}
@Override public void map(Chunk[] chks) {
double[] data = new double[_ncols];
float [] preds = new float[_nclasses+1];
Chunk cresp = chk_resp(chks);
int nrows = cresp._len;
int [] oob = new int[2+Math.round((1f-_rate)*nrows*1.2f+0.5f)]; // preallocate
int [] soob = null;
// Prepare output data
_nrows = new long[_ntrees];
_votes = _classification ? new long[_ntrees] : null;
_sse = _classification ? null : new float[_ntrees];
long seedForOob = ShuffleTask.seed(cresp.cidx()); // seed for shuffling oob samples
// Start iteration
for( int tidx=0; tidx<_ntrees; tidx++) { // tree
// OOB RNG for this tree
Random rng = rngForTree(_trees[tidx], cresp.cidx());
// Collect oob rows and permutate them
oob = ModelUtils.sampleOOBRows(nrows, _rate, rng, oob); // reuse use the same array for sampling
int oobcnt = oob[0]; // Get number of sample rows
if (_var>=0) {
if (soob==null || soob.length < oobcnt) soob = new int[oobcnt];
Utils.shuffleArray(oob, oobcnt, soob, seedForOob, 1); // Shuffle array and copy results into <code>soob</code>
}
for(int j = 1; j < 1+oobcnt; j++) {
int row = oob[j];
if (cresp.isNA0(row)) continue; // we cannot deal with this row anyhow
// Do scoring:
// - prepare a row data
for (int i=0;i<_ncols;i++) data[i] = chks[i].at0(row); // 1+i - one free is expected by prediction
// - permute variable
if (_var>=0) data[_var] = chks[_var].at0(soob[j-1]);
else assert soob==null;
// - score data
Arrays.fill(preds, 0);
// - score only the tree
score0(data, preds, _trees[tidx]);
// - derive a prediction
if (_classification) {
int pred = ModelUtils.getPrediction(preds, data);
int actu = (int) cresp.at80(row);
// assert preds[pred] > 0 : "There should be a vote for at least one class.";
// - collect only correct votes
if (pred == actu) _votes[tidx]++;
} else { /* regression */
float pred = preds[0]; // Important!
float actu = (float) cresp.at0(row);
_sse[tidx] += (actu-pred)*(actu-pred);
}
// - collect rows which were used for voting
_nrows[tidx]++;
//if (_var<0) System.err.println("VARIMP OOB row: " + (cresp._start+row) + " : " + Arrays.toString(data) + " tree/actu: " + pred + "/" + actu);
}
}
// Clean-up
_trees = null;
}
@Override public void reduce( TreeMeasuresCollector t ) { Utils.add(_votes,t._votes); Utils.add(_nrows, t._nrows); Utils.add(_sse, t._sse); }
public TreeVotes resultVotes() { return new TreeVotes(_votes, _nrows, _ntrees); }
public TreeSSE resultSSE () { return new TreeSSE (_sse, _nrows, _ntrees); }
/* This is a copy of score0 method from DTree:615 */
private void score0(double data[], float preds[], CompressedTree[] ts) {
DTreeUtils.scoreTree(data, preds, ts);
}
private Chunk chk_resp( Chunk chks[] ) { return chks[_ncols]; }
private Random rngForTree(CompressedTree[] ts, int cidx) {
return _oob ? ts[0].rngForChunk(cidx) : new DummyRandom(); // k-class set of trees shares the same random number
}
/* For bulk scoring
public static TreeVotes collect(TreeModel tmodel, Frame f, int ncols, float rate, int variable) {
CompressedTree[][] trees = new CompressedTree[tmodel.ntrees()][];
for (int tidx = 0; tidx < tmodel.ntrees(); tidx++) trees[tidx] = tmodel.ctree(tidx);
return new TreeVotesCollector(trees, tmodel.nclasses(), ncols, rate, variable).doAll(f).result();
}*/
public static TreeVotes collectVotes(CompressedTree[/*nclass || 1 for regression*/] tree, int nclasses, Frame f, int ncols, float rate, int variable) {
return new TreeMeasuresCollector(new CompressedTree[][] {tree}, nclasses, ncols, rate, variable).doAll(f).resultVotes();
}
public static TreeSSE collectSSE(CompressedTree[/*nclass || 1 for regression*/] tree, int nclasses, Frame f, int ncols, float rate, int variable) {
return new TreeMeasuresCollector(new CompressedTree[][] {tree}, nclasses, ncols, rate, variable).doAll(f).resultSSE();
}
private static final class DummyRandom extends Random {
@Override public final float nextFloat() { return 1.0f; }
}
/** A simple holder for set of different tree measurements. */
public static abstract class TreeMeasures<T extends TreeMeasures> extends Iced {
/** Actual number of trees which votes are stored in this object */
protected int _ntrees;
/** Number of processed row per tree. */
protected long[/*ntrees*/] _nrows;
public TreeMeasures(int initialCapacity) { _nrows = new long[initialCapacity]; }
public TreeMeasures(long[] nrows, int ntrees) { _nrows = nrows; _ntrees = ntrees;}
/** Returns number of rows which were used during voting per individual tree. */
public final long[] nrows() { return _nrows; }
/** Returns number of voting predictors */
public final int npredictors() { return _ntrees; }
/** Returns a list of accuracies per tree. */
public abstract double accuracy(int tidx);
public final double[] accuracy() {
double[] r = new double[_ntrees];
// Average of all trees
for (int tidx=0; tidx<_ntrees; tidx++) r[tidx] = accuracy(tidx);
return r;
}
/** Compute variable importance with respect to given votes.
* The given {@link T} object represents correct votes.
* This object represents votes over shuffled data.
*
* @param right individual tree measurements performed over not shuffled data.
* @return computed importance and standard deviation
*/
public abstract double[/*2*/] imp(T right);
public abstract T append(T t);
}
/** A class holding tree votes. */
public static class TreeVotes extends TreeMeasures<TreeVotes> {
/** Number of correct votes per tree */
private long[/*ntrees*/] _votes;
public TreeVotes(int initialCapacity) {
super(initialCapacity);
_votes = new long[initialCapacity];
}
public TreeVotes(long[] votes, long[] nrows, int ntrees) {
super(nrows, ntrees);
_votes = votes;
}
/** Returns number of positive votes per tree. */
public final long[] votes() { return _votes; }
/** Returns accuracy per individual trees. */
@Override public final double accuracy(int tidx) {
assert tidx < _nrows.length && tidx < _votes.length;
return ((double) _votes[tidx]) / _nrows[tidx];
}
/** Compute variable importance with respect to given votes.
* The given {@link TreeVotes} object represents correct votes.
* This object represents votes over shuffled data.
*
* @param right individual tree voters performed over not shuffled data.
* @return computed importance and standard deviation
*/
@Override public final double[/*2*/] imp(TreeVotes right) {
assert npredictors() == right.npredictors();
int ntrees = npredictors();
double imp = 0;
double sd = 0;
// Over all trees
for (int tidx = 0; tidx < ntrees; tidx++) {
assert right.nrows()[tidx] == nrows()[tidx];
double delta = ((double) (right.votes()[tidx] - votes()[tidx])) / nrows()[tidx];
imp += delta;
sd += delta * delta;
}
double av = imp / ntrees;
double csd = Math.sqrt( (sd/ntrees - av*av) / ntrees );
return new double[] { av, csd};
}
/** Append a tree votes to a list of trees. */
public TreeVotes append(long rightVotes, long allRows) {
assert _votes.length > _ntrees && _votes.length == _nrows.length : "TreeVotes inconsistency!";
_votes[_ntrees] = rightVotes;
_nrows[_ntrees] = allRows;
_ntrees++;
return this;
}
@Override public TreeVotes append(final TreeVotes tv) {
for (int i=0; i<tv.npredictors(); i++)
append(tv._votes[i], tv._nrows[i]);
return this;
}
}
/** A simple holder serving SSE per tree. */
public static class TreeSSE extends TreeMeasures<TreeSSE> {
/** SSE per tree */
private float[/*ntrees*/] _sse;
public TreeSSE(int initialCapacity) {
super(initialCapacity);
_sse = new float[initialCapacity];
}
public TreeSSE(float[] sse, long[] nrows, int ntrees) {
super(nrows, ntrees);
_sse = sse;
}
@Override public double accuracy(int tidx) {
return _sse[tidx] / _nrows[tidx];
}
@Override public double[] imp(TreeSSE right) {
assert npredictors() == right.npredictors();
int ntrees = npredictors();
double imp = 0;
double sd = 0;
// Over all trees
for (int tidx = 0; tidx < ntrees; tidx++) {
assert right.nrows()[tidx] == nrows()[tidx]; // check that we iterate over same OOB rows
double delta = ((double) (_sse[tidx] - right._sse[tidx])) / nrows()[tidx];
imp += delta;
sd += delta * delta;
}
double av = imp / ntrees;
double csd = Math.sqrt( (sd/ntrees - av*av) / ntrees );
return new double[] { av, csd };
}
@Override public TreeSSE append(TreeSSE t) {
for (int i=0; i<t.npredictors(); i++)
append(t._sse[i], t._nrows[i]);
return this;
}
/** Append a tree sse to a list of trees. */
public TreeSSE append(float sse, long allRows) {
assert _sse.length > _ntrees && _sse.length == _nrows.length : "TreeVotes inconsistency!";
_sse [_ntrees] = sse;
_nrows[_ntrees] = allRows;
_ntrees++;
return this;
}
}
public static TreeVotes asVotes(TreeMeasures tm) { return (TreeVotes) tm; }
public static TreeSSE asSSE (TreeMeasures tm) { return (TreeSSE) tm; }
}