/**
*
*/
package hex;
import java.util.Arrays;
import water.H2O.H2OCountedCompleter;
import water.*;
import water.fvec.*;
import water.util.Utils;
/**
*
*/
public class NFoldFrameExtractor extends FrameExtractor {
/** Number of folds */
final int nfolds;
/** Active fold which will be extracted. */
final int afold;
public NFoldFrameExtractor(Frame dataset, int nfolds, int afold, Key[] destKeys, Key jobKey) {
super(dataset, destKeys, jobKey);
assert afold >= 0 && afold < nfolds : "afold parameter is out of bound <0,nfolds)";
this.nfolds = nfolds;
this.afold = afold;
}
@Override protected MRTask2 createNewWorker(H2OCountedCompleter completer, Vec[] inputVecs, int split) {
assert split == 0 || split == 1;
return new FoldExtractTask(completer, inputVecs, nfolds, afold, split==1);
}
@Override protected long[][] computeEspcPerSplit(long[] espc, long nrows) {
assert espc[espc.length-1] == nrows : "Total number of rows does not match!";
long[] ith = Utils.nfold(nrows, nfolds, afold); // Compute desired fold position
long startRow = ith[0], endRow = startRow + ith[1];
long[][] r = new long[2][espc.length+1]; // In the worst case we will introduce a new chunk
int c1 = 0, c2 = 0; // Number of chunks in each partition
long p1rows = 0, p2rows = 0;
int c = 0; // Chunk idx
// Extract the first section of the remaining part
for (; c<espc.length-1 && espc[c+1] <= startRow; c++) p1rows = r[0][++c1] = espc[c+1]; // Find the chunk with the split
// c is chunk which needs a split between remaining part and selected fold, but it can be split into 3 pieces as well!
if (r[0][c1] < (p1rows += (startRow-espc[c]))) r[0][++c1] = p1rows; // Start for new chunk of part1
// Now extract i-th fold
for (; c<espc.length-1 && espc[c+1] <= endRow; c++ ) p2rows = r[1][++c2] = espc[c+1]-startRow;
if (r[1][c2] < (p2rows += (endRow-Math.max(espc[c],startRow)))) r[1][++c2] = p2rows;
assert p2rows == ith[1];
// Extract rest
for (; c<espc.length-1; c++) p1rows = r[0][++c1] = espc[c+1]-ith[1];
r[0] = Arrays.copyOf(r[0], c1+1);
r[1] = Arrays.copyOf(r[1], c2+1);
// Post-conditions
assert r[0][r[0].length-1]+r[1][r[1].length-1] == nrows;
return r;
}
@Override protected int numOfOutputs() {
return 2;
}
private static class FoldExtractTask extends MRTask2<FoldExtractTask> {
private final Vec [] _vecs; // source vectors
private final int _nfolds;
private final int _afold;
private final boolean _inFold;
transient int _precedingChks; // number of preceding chunks
transient int _startFoldChkIdx; // idx of 1st chunk for the fold
transient int _startRestChkIdx; // idx of 1st of remaining part
transient int _startFoldRow; // fold start row inside the chunk _startFoldChkIdx
transient int _startRestRow; // index of the 1st row inside chunk _startRestChkIdx begining remaining part of data
@Override protected void setupLocal() {
Vec anyInVec = _vecs[0];
long[] folds = Utils.nfold(anyInVec.length(), _nfolds, _afold);
long startRow = folds[0];
long endRow = startRow+folds[1];
long espc[] = anyInVec._espc;
int c = 0;
for (; c<espc.length-1 && espc[c+1] <= startRow; c++) ;
_startFoldChkIdx = c;
_startFoldRow = (int) (startRow-espc[c]);
_precedingChks = _startFoldRow > 0 ? c+1 : c;
for (; c<espc.length-1 && espc[c+1] <= endRow; c++) ;
_startRestChkIdx = c;
_startRestRow = (int) (endRow-espc[c]);
}
public FoldExtractTask(H2OCountedCompleter completer, Vec[] srcVecs, int nfold, int afold, boolean inFold) {
super(completer);
_vecs = srcVecs;
_nfolds = nfold;
_afold = afold;
_inFold = inFold;
}
@Override public void map(Chunk[] cs) {
int coutidx = cs[0].cidx(); // output chunk where to extract
int cinidx = getInChunkIdx(coutidx); // input chunk where to extract
int startRow = getStartRow(coutidx); // start row for extraction
int nrows = cs[0]._len; // number of rows to extract from the input chunk
for (int i=0; i<cs.length; i++) {
ChunkSplitter.extractChunkPart(_vecs[i].chunkForChunkIdx(cinidx), cs[i], startRow, nrows, _fs);
}
}
private int getInChunkIdx(int coutidx) {
if (_inFold)
return _startFoldChkIdx==_startRestChkIdx ? _startFoldChkIdx : coutidx + _startFoldChkIdx;
else { // out fold part
if (coutidx < _precedingChks)
return coutidx;
else
return _startRestChkIdx + (coutidx-_precedingChks);
}
}
private int getStartRow(int coutidx) {
if (_inFold)
return coutidx == 0 ? _startFoldRow : 0;
else { //out fold part
return coutidx == _precedingChks ? _startRestRow : 0;
}
}
}
}