package water.rapids.ast.prims.advmath;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.vals.ValFrame;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.util.VecUtils;
import java.util.Random;
import static water.util.RandomUtils.getRNG;
public class AstKFold extends AstPrimitive {
@Override
public String[] args() {
return new String[]{"ary", "nfolds", "seed"};
}
@Override
public int nargs() {
return 1 + 3;
} // (kfold_column x nfolds seed)
@Override
public String str() {
return "kfold_column";
}
public static Vec kfoldColumn(Vec v, final int nfolds, final long seed) {
new MRTask() {
@Override
public void map(Chunk c) {
long start = c.start();
for (int i = 0; i < c._len; ++i) {
int fold = Math.abs(getRNG(start + seed + i).nextInt()) % nfolds;
c.set(i, fold);
}
}
}.doAll(v);
return v;
}
public static Vec moduloKfoldColumn(Vec v, final int nfolds) {
new MRTask() {
@Override
public void map(Chunk c) {
long start = c.start();
for (int i = 0; i < c._len; ++i)
c.set(i, (int) ((start + i) % nfolds));
}
}.doAll(v);
return v;
}
public static Vec stratifiedKFoldColumn(Vec y, final int nfolds, final long seed) {
// for each class, generate a fold column (never materialized)
// therefore, have a seed per class to be used by the map call
if (!(y.isCategorical() || (y.isNumeric() && y.isInt())))
throw new IllegalArgumentException("stratification only applies to integer and categorical columns. Got: " + y.get_type_str());
final long[] classes = new VecUtils.CollectDomain().doAll(y).domain();
final int nClass = y.isNumeric() ? classes.length : y.domain().length;
final long[] seeds = new long[nClass]; // seed for each regular fold column (one per class)
for (int i = 0; i < nClass; ++i)
seeds[i] = getRNG(seed + i).nextLong();
return new MRTask() {
private int getFoldId(long absoluteRow, long seed) {
return Math.abs(getRNG(absoluteRow + seed).nextInt()) % nfolds;
}
// dress up the foldColumn (y[1]) as follows:
// 1. For each testFold and each classLabel loop over the response column (y[0])
// 2. If the classLabel is the current response and the testFold is the foldId
// for the current row and classLabel, then set the foldColumn to testFold
//
// How this balances labels per fold:
// Imagine that a KFold column was generated for each class. Observe that this
// makes the outer loop a way of selecting only the test rows from each fold
// (i.e., the holdout rows). Each fold is balanced sequentially in this way
// since y[1] is only updated if the current row happens to be a holdout row
// for the given classLabel.
//
// Next observe that looping over each classLabel filters down each KFold
// so that it contains labels for just THAT class. This is how the balancing
// can be made so that it is independent of the chunk distribution and the
// per chunk class distribution.
//
// Downside is this performs nfolds*nClass passes over each Chunk. For
// "reasonable" classification problems, this could be 100 passes per Chunk.
@Override
public void map(Chunk[] y) {
long start = y[0].start();
for (int testFold = 0; testFold < nfolds; ++testFold) {
for (int classLabel = 0; classLabel < nClass; ++classLabel) {
for (int row = 0; row < y[0]._len; ++row) {
// missing response gets spread around
if (y[0].isNA(row)) {
if ((start + row) % nfolds == testFold)
y[1].set(row, testFold);
} else {
if (y[0].at8(row) == (classes == null ? classLabel : classes[classLabel])) {
if (testFold == getFoldId(start + row, seeds[classLabel]))
y[1].set(row, testFold);
}
}
}
}
}
}
}.doAll(new Frame(y, y.makeZero()))._fr.vec(1);
}
@Override
public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) {
Vec foldVec = stk.track(asts[1].exec(env)).getFrame().anyVec().makeZero();
int nfolds = (int) asts[2].exec(env).getNum();
long seed = (long) asts[3].exec(env).getNum();
return new ValFrame(new Frame(kfoldColumn(foldVec, nfolds, seed == -1 ? new Random().nextLong() : seed)));
}
}