package water.rapids.ast.prims.mungers;
import water.*;
import water.fvec.*;
import water.rapids.Env;
import water.rapids.Val;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.rapids.ast.AstExec;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.params.AstId;
import water.rapids.ast.params.AstNum;
import water.rapids.ast.params.AstNumList;
import java.util.*;
/**
* Row Slice
*/
public class AstRowSlice extends AstPrimitive {
@Override
public String[] args() {
return new String[]{"ary", "rows"};
}
// (rows src [row_list])
@Override
public int nargs() {
return 1 + 2;
}
@Override
public String str() {
return "rows";
}
@Override
public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) {
Frame fr = stk.track(asts[1].exec(env)).getFrame();
Frame returningFrame;
long nrows = fr.numRows();
if (asts[2] instanceof AstNumList) {
final AstNumList nums = (AstNumList) asts[2];
if (!nums._isSort && !nums.isEmpty() && nums._bases[0] >= 0)
throw new IllegalArgumentException("H2O does not currently reorder rows, please sort your row selection first");
long[] rows = (nums._isList || nums.min() < 0) ? nums.expand8Sort() : null;
if (rows != null) {
if (rows.length == 0) { // Empty inclusion list?
} else if (rows[0] >= 0) { // Positive (inclusion) list
if (rows[rows.length - 1] > nrows)
throw new IllegalArgumentException("Row must be an integer from 0 to " + (nrows - 1));
} else { // Negative (exclusion) list
if (rows[rows.length - 1] >= 0)
throw new IllegalArgumentException("Cannot mix negative and postive row selection");
// Invert the list to make a positive list, ignoring out-of-bounds values
BitSet bs = new BitSet((int) nrows);
for (long row : rows) {
int idx = (int) (-row - 1); // The positive index
if (idx >= 0 && idx < nrows)
bs.set(idx); // Set column to EXCLUDE
}
rows = new long[(int) nrows - bs.cardinality()];
for (int i = bs.nextClearBit(0), j = 0; i < nrows; i = bs.nextClearBit(i + 1))
rows[j++] = i;
}
}
final long[] ls = rows;
returningFrame = new MRTask() {
@Override
public void map(Chunk[] cs, NewChunk[] ncs) {
if (nums.cnt() == 0) return;
if (ls != null && ls.length == 0) return;
long start = cs[0].start();
long end = start + cs[0]._len;
long min = ls == null ? (long) nums.min() : ls[0], max = ls == null ? (long) nums.max() - 1 : ls[ls.length - 1]; // exclusive max to inclusive max when stride == 1
// [ start, ..., end ] the chunk
//1 [] nums out left: nums.max() < start
//2 [] nums out rite: nums.min() > end
//3 [ nums ] nums run left: nums.min() < start && nums.max() <= end
//4 [ nums ] nums run in : start <= nums.min() && nums.max() <= end
//5 [ nums ] nums run rite: start <= nums.min() && end < nums.max()
if (!(max < start || min > end)) { // not situation 1 or 2 above
long startOffset = (min > start ? min : start); // situation 4 and 5 => min > start;
for (int i = (int) (startOffset - start); i < cs[0]._len; ++i) {
if ((ls == null && nums.has(start + i)) || (ls != null && Arrays.binarySearch(ls, start + i) >= 0)) {
for (int c = 0; c < cs.length; ++c) {
if (cs[c] instanceof CStrChunk) ncs[c].addStr(cs[c], i);
else if (cs[c] instanceof C16Chunk) ncs[c].addUUID(cs[c], i);
else if (cs[c].isNA(i)) ncs[c].addNA();
else ncs[c].addNum(cs[c].atd(i));
}
}
}
}
}
}.doAll(fr.types(), fr).outputFrame(fr.names(), fr.domains());
} else if ((asts[2] instanceof AstNum)) {
long[] rows = new long[]{(long) (((AstNum) asts[2]).getNum())};
returningFrame = fr.deepSlice(rows, null);
} else if ((asts[2] instanceof AstExec) || (asts[2] instanceof AstId)) {
Frame predVec = stk.track(asts[2].exec(env)).getFrame();
if (predVec.numCols() != 1)
throw new IllegalArgumentException("Conditional Row Slicing Expression evaluated to " + predVec.numCols() + " columns. Must be a boolean Vec.");
returningFrame = fr.deepSlice(predVec, null);
} else
throw new IllegalArgumentException("Row slicing requires a number-list as the last argument, but found a " + asts[2].getClass());
return new ValFrame(returningFrame);
}
}