package hex.drf;
import java.util.Arrays;
import java.util.Random;
import hex.gbm.DTree.TreeModel.CompressedTree;
import hex.gbm.*;
import water.*;
import water.fvec.Chunk;
/**
* Computing oob scores over all trees and rows
* and reconstructing <code>ntree_id, oobt</code> fields in given frame.
*
* <p>It prepares voter per tree and also marks
* rows which were consider out-of-bag.</p>
*/
/* package */ class OOBScorer extends DTreeScorer<OOBScorer> {
/* @IN */ final protected float _rate;
public OOBScorer(int ncols, int nclass, float rate, Key[][] treeKeys) {
super(ncols,nclass,treeKeys);
_rate = rate;
}
@Override public void map(Chunk[] chks) {
double[] data = new double[_ncols];
float [] preds = new float[_nclass+1];
int ntrees = _trees.length;
Chunk coobt = chk_oobt(chks);
Chunk cys = chk_resp(chks);
for( int tidx=0; tidx<ntrees; tidx++) { // tree
// OOB RNG for this tree
Random rng = rngForTree(_trees[tidx], coobt.cidx());
for (int row=0; row<coobt._len; row++) {
if( rng.nextFloat() >= _rate || Double.isNaN(cys.at0(row)) ) {
// Mark oob row and store number of trees voting for this row (only for regression)
coobt.set0(row, _nclass>1?1:coobt.at0(row)+1);
// Make a prediction
for (int i=0;i<_ncols;i++) data[i] = chks[i].at0(row);
Arrays.fill(preds, 0);
score0(data, preds, _trees[tidx]);
if (_nclass==1) preds[1]=preds[0]; // Only for regression, keep consistency
// Write tree predictions
for (int c=0;c<_nclass;c++) { // over all class
if (preds[1+c] != 0) {
Chunk ctree = chk_tree(chks, c);
ctree.set0(row, (float)(ctree.at0(row) + preds[1+c]));
}
}
}
}
}
}
private Random rngForTree(CompressedTree[] ts, int cidx) {
return ts[0].rngForChunk(cidx); // k-class set of trees shares the same random number
}
}