package water.rapids.ast.prims.advmath;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.vals.ValFrame;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.util.VecUtils;
import java.util.*;
import static water.util.RandomUtils.getRNG;
public class AstStratifiedSplit extends AstPrimitive {
public static final String OUTPUT_COLUMN_NAME = "test_train_split";
public static final String[] OUTPUT_COLUMN_DOMAIN = new String[]{"train", "test"};
@Override
public String[] args() {
return new String[]{"ary", "test_frac", "seed"};
}
@Override
public int nargs() {
return 1 + 3;
} // (h2o.random_stratified_split y test_frac seed)
@Override
public String str() {
return "h2o.random_stratified_split";
}
@Override
public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) {
Frame frame = stk.track(asts[1].exec(env)).getFrame();
final double testFrac = asts[2].exec(env).getNum();
long seed = (long) asts[3].exec(env).getNum();
// It is just a single column
if (frame.numCols() != 1)
throw new IllegalArgumentException("Must give a single column to stratify against. Got: " + frame.numCols() + " columns.");
Vec stratifyingColumn = frame.anyVec();
Frame result = new Frame(Key.<Frame>make(),
new String[] {OUTPUT_COLUMN_NAME},
new Vec[] { split(stratifyingColumn, testFrac, seed, OUTPUT_COLUMN_DOMAIN)}
);
return new ValFrame(result);
}
public static Vec split(Vec stratifyingColumn, double splittingFraction, long randomizationSeed, String[] splittingDom) {
checkIfCanStratifyBy(stratifyingColumn);
randomizationSeed = randomizationSeed == -1 ? new Random().nextLong() : randomizationSeed;
// Collect input vector domain
final long[] classes = new VecUtils.CollectDomain().doAll(stratifyingColumn).domain();
// Number of output classes
final int numClasses = stratifyingColumn.isNumeric() ? classes.length : stratifyingColumn.domain().length;
// Make a new column based on input column - this needs to follow layout of input vector!
// Save vector into DKV
Vec outputVec = stratifyingColumn.makeCon(0.0, Vec.T_CAT);
outputVec.setDomain(splittingDom);
DKV.put(outputVec);
// Collect index frame
// FIXME: This is in fact collecting inverse index class -> {row indices}
ClassIdxTask finTask = new ClassIdxTask(numClasses,classes).doAll(stratifyingColumn);
// Loop through each class in the input column
HashSet<Long> usedIdxs = new HashSet<>();
for (int classLabel = 0; classLabel < numClasses; classLabel++) {
// extract frame with index locations of the minority class
// calculate target number of this class to go to test
final LongAry indexAry = finTask._indexes[classLabel];
long tnum = Math.max(Math.round(indexAry.size() * splittingFraction), 1);
HashSet<Long> tmpIdxs = new HashSet<>();
// randomly select the target number of indexes
int generated = 0;
int count = 0;
while (generated < tnum) {
int i = (int) (getRNG(count+ randomizationSeed).nextDouble() * indexAry.size());
if (tmpIdxs.contains(indexAry.get(i))) { count+=1;continue; }
tmpIdxs.add(indexAry.get(i));
generated += 1;
count += 1;
}
usedIdxs.addAll(tmpIdxs);
}
// Update class assignments
new ClassAssignMRTask(usedIdxs).doAll(outputVec);
return outputVec;
}
static void checkIfCanStratifyBy(Vec vec) {
if (!(vec.isCategorical() || (vec.isNumeric() && vec.isInt())))
throw new IllegalArgumentException("Stratification only applies to integer and categorical columns. Got: " + vec.get_type_str());
if (vec.length() > Integer.MAX_VALUE) {
throw new IllegalArgumentException("Cannot stratified the frame because it is too long: nrows=" + vec.length());
}
}
public static class ClassAssignMRTask extends MRTask<AstStratifiedSplit.ClassAssignMRTask> {
HashSet<Long> _idx;
ClassAssignMRTask(HashSet<Long> idx) {
_idx = idx;
}
@Override
public void map(Chunk ck) {
for (int i = 0; i<ck.len(); i++) {
if (_idx.contains(ck.start() + i)) {
ck.set(i,1.0);
}
}
_idx = null; // Do not send it back
}
}
public static class ClassIdxTask extends MRTask<AstStratifiedSplit.ClassIdxTask> {
LongAry[] _indexes;
private final int _nclasses;
private long[] _classes;
private transient HashMap<Long, Integer> _classMap;
public ClassIdxTask(int nclasses, long[] classes) {
_nclasses = nclasses;
_classes = classes;
}
@Override
protected void setupLocal() {
_classMap = new HashMap<>(2*_classes.length);
for (int i = 0; i < _classes.length; i++) {
_classMap.put(_classes[i], i);
}
}
@Override
public void map(Chunk[] ck) {
_indexes = new LongAry[_nclasses];
for (int i = 0; i < _nclasses; i++) { _indexes[i] = new LongAry(); }
for (int i = 0; i < ck[0].len(); i++) {
long clas = ck[0].at8(i);
Integer clas_idx = _classMap.get(clas);
if (clas_idx != null) _indexes[clas_idx].add(ck[0].start() + i);
}
_classes = null;
}
@Override
public void reduce(AstStratifiedSplit.ClassIdxTask c) {
for (int i = 0; i < c._indexes.length; i++) {
for (int j = 0; j < c._indexes[i].size(); j++) {
_indexes[i].add(c._indexes[i].get(j));
}
}
}
}
public static class LongAry extends Iced<AstStratifiedSplit.LongAry> {
public LongAry(long ...vals){_ary = vals; _sz = vals.length;}
long [] _ary = new long[4];
int _sz;
public void add(long i){
if (_sz == _ary.length)
_ary = Arrays.copyOf(_ary, Math.max(4, _ary.length * 2));
_ary[_sz++] = i;
}
public long get(int i){
if(i >= _sz) throw new ArrayIndexOutOfBoundsException(i);
return _ary[i];
}
public int size(){return _sz;}
public long[] toArray(){return Arrays.copyOf(_ary,_sz);}
public void clear() {_sz = 0;}
}
}