package water.rapids.ast.prims.advmath;
import water.AutoBuffer;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.nbhm.NonBlockingHashMapLong;
import water.rapids.Env;
import water.rapids.Val;
import water.rapids.vals.ValFrame;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.util.ArrayUtils;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicLong;
/**
* Variance between columns of a frame
* TODO: Define "table" in terms of "groupby"
* TODO: keep dense format for two-column comparison (like in previous version of Rapids)
* (table X Y) ==>
* (groupby (cbind X Y) [X Y] nrow TRUE)
*/
public class AstTable extends AstPrimitive {
@Override
public String[] args() {
return new String[]{"X", "Y", "dense"};
}
@Override
public int nargs() {
return -1;
} // (table X dense) or (table X Y dense)
@Override
public String str() {
return "table";
}
@Override
public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) {
Frame fr1 = stk.track(asts[1].exec(env)).getFrame();
final boolean dense = asts[asts.length - 1].exec(env).getNum() == 1;
Frame fr2 = asts.length == 4 ? stk.track(asts[2].exec(env)).getFrame() : null;
int ncols = fr1.numCols() + (fr2 == null ? 0 : fr2.numCols());
Vec vec1 = fr1.vec(0);
ValFrame res = fast_table(vec1, ncols, fr1._names[0]);
if (res != null) return res;
if (!(asts.length == 3 || asts.length == 4) || ncols > 2)
throw new IllegalArgumentException("table expects one or two columns");
Vec vec2 = fr1.numCols() == 2 ? fr1.vec(1) : fr2 != null ? fr2.vec(0) : null;
int sz = fr1._names.length + (fr2 != null ? fr2._names.length : 0);
String[] colnames = new String[sz];
int i = 0;
for (String name : fr1._names) colnames[i++] = name;
if (fr2 != null) for (String name : fr2._names) colnames[i++] = name;
return slow_table(vec1, vec2, colnames, dense);
}
// -------------------------------------------------------------------------
// Fast-path for 1 integer column
private ValFrame fast_table(Vec v1, int ncols, String colname) {
if (ncols != 1 || !v1.isInt()) return null;
long spanl = (long) v1.max() - (long) v1.min() + 1;
if (spanl > 1000000) return null; // Cap at decent array size, for performance
// First fast-pass counting
AstTable.FastCnt fastCnt = new AstTable.FastCnt((long) v1.min(), (int) spanl).doAll(v1);
final long cnts[] = fastCnt._cnts;
final long minVal = fastCnt._min;
// Second pass to build the result frame, skipping zeros
Vec dataLayoutVec = Vec.makeCon(0, cnts.length);
Frame fr = new MRTask() {
@Override
public void map(Chunk cs[], NewChunk nc0, NewChunk nc1) {
final Chunk c = cs[0];
for (int i = 0; i < c._len; ++i) {
int idx = (int) (i + c.start());
if (cnts[idx] > 0) {
nc0.addNum(idx + minVal);
nc1.addNum(cnts[idx]);
}
}
}
}.doAll(new byte[]{Vec.T_NUM, Vec.T_NUM}, dataLayoutVec).outputFrame(new String[]{colname, "Count"},
new String[][]{v1.domain(), null});
dataLayoutVec.remove();
return new ValFrame(fr);
}
// Fast-pass for counting unique integers in a span
private static class FastCnt extends MRTask<AstTable.FastCnt> {
final long _min;
final int _span;
long _cnts[];
FastCnt(long min, int span) {
_min = min;
_span = span;
}
@Override
public void map(Chunk c) {
_cnts = new long[_span];
for (int i = 0; i < c._len; i++)
if (!c.isNA(i))
_cnts[(int) (c.at8(i) - _min)]++;
}
@Override
public void reduce(AstTable.FastCnt fc) {
ArrayUtils.add(_cnts, fc._cnts);
}
}
// -------------------------------------------------------------------------
// Count unique combos in 1 or 2 columns, where the values are not integers,
// or cover a very large span.
private ValFrame slow_table(Vec v1, Vec v2, String[] colnames, boolean dense) {
// For simplicity, repeat v1 if v2 is missing; this will end up filling in
// only the diagonal of a 2-D array (in what is otherwise a 1-D array).
// This should be nearly the same cost as a 1-D array, since everything is
// sparsely filled in.
// If this is the 1-column case (all counts on the diagonals), just build a
// 1-d result.
if (v2 == null) {
// Slow-pass group counting, very sparse hashtables. Note that Vec v2 is
// used as the left-most arg, or OUTER dimension - which will be columns in
// the final result.
AstTable.SlowCnt sc = new AstTable.SlowCnt().doAll(v1, v1);
// Get the column headers as sorted doubles
double dcols[] = collectDomain(sc._col0s);
Frame res = new Frame();
Vec rowlabel = Vec.makeVec(dcols, Vec.VectorGroup.VG_LEN1.addVec());
rowlabel.setDomain(v1.domain());
res.add(colnames[0], rowlabel);
long cnts[] = new long[dcols.length];
for (int col = 0; col < dcols.length; col++) {
long lkey = Double.doubleToRawLongBits(dcols[col]);
NonBlockingHashMapLong<AtomicLong> colx = sc._col0s.get(lkey);
AtomicLong al = colx.get(lkey);
cnts[col] = al.get();
}
Vec vec = Vec.makeVec(cnts, null, Vec.VectorGroup.VG_LEN1.addVec());
res.add("Counts", vec);
return new ValFrame(res);
}
// 2-d table result.
Frame res = new Frame();
if (!dense) {
// Slow-pass group counting, very sparse hashtables. Note that Vec v2 is
// used as the left-most arg, or OUTER dimension - which will be columns in
// the final result.
AstTable.SlowCnt sc = new AstTable.SlowCnt().doAll(v2, v1);
// Get the column headers as sorted doubles
double dcols[] = collectDomain(sc._col0s);
// Need the row headers as sorted doubles also, but these are scattered
// throughout the nested tables. Fold 'em into 1 table.
NonBlockingHashMapLong<AtomicLong> rows = new NonBlockingHashMapLong<>();
for (NonBlockingHashMapLong.IteratorLong i = iter(sc._col0s); i.hasNext(); )
rows.putAll(sc._col0s.get(i.nextLong()));
double drows[] = collectDomain(rows);
// Now walk the columns one by one, building a Vec per column, building a
// Frame result. Rowlabel for first column.
Vec rowlabel = Vec.makeVec(drows, Vec.VectorGroup.VG_LEN1.addVec());
rowlabel.setDomain(v1.domain());
res.add(colnames[0], rowlabel);
long cnts[] = new long[drows.length];
for (int col = 0; col < dcols.length; col++) {
NonBlockingHashMapLong<AtomicLong> colx = sc._col0s.get(Double.doubleToRawLongBits(dcols[col]));
for (int row = 0; row < drows.length; row++) {
AtomicLong al = colx.get(Double.doubleToRawLongBits(drows[row]));
cnts[row] = al == null ? 0 : al.get();
}
Vec vec = Vec.makeVec(cnts, null, Vec.VectorGroup.VG_LEN1.addVec());
res.add(v2.isCategorical() ? v2.domain()[col] : Double.toString(dcols[col]), vec);
}
} else {
AstTable.SlowCnt sc = new AstTable.SlowCnt().doAll(v1, v2);
double dcols[] = collectDomain(sc._col0s);
NonBlockingHashMapLong<AtomicLong> rows = new NonBlockingHashMapLong<>();
for (NonBlockingHashMapLong.IteratorLong i = iter(sc._col0s); i.hasNext(); )
rows.putAll(sc._col0s.get(i.nextLong()));
double drows[] = collectDomain(rows);
int x = 0;
int sz = 0;
for (NonBlockingHashMapLong.IteratorLong i = iter(sc._col0s); i.hasNext(); ) {
sz += sc._col0s.get(i.nextLong()).size();
}
long cnts[] = new long[sz];
double[] left_categ = new double[sz];
double[] right_categ = new double[sz];
for (double dcol : dcols) {
NonBlockingHashMapLong<AtomicLong> colx = sc._col0s.get(Double.doubleToRawLongBits(dcol));
for (double drow : drows) {
AtomicLong al = colx.get(Double.doubleToRawLongBits(drow));
if (al != null) {
left_categ[x] = dcol;
right_categ[x] = drow;
cnts[x] = al.get();
x++;
}
}
}
Vec vec = Vec.makeVec(left_categ, Vec.VectorGroup.VG_LEN1.addVec());
if (v1.isCategorical()) vec.setDomain(v1.domain());
res.add(colnames[0], vec);
vec = Vec.makeVec(right_categ, Vec.VectorGroup.VG_LEN1.addVec());
if (v2.isCategorical()) vec.setDomain(v2.domain());
res.add(colnames[1], vec);
vec = Vec.makeVec(cnts, null, Vec.VectorGroup.VG_LEN1.addVec());
res.add("Counts", vec);
}
return new ValFrame(res);
}
// Collect the unique longs from this NBHML, convert to doubles and return
// them as a sorted double[].
private static double[] collectDomain(NonBlockingHashMapLong ls) {
int sz = ls.size(); // Uniques
double ds[] = new double[sz];
int x = 0;
for (NonBlockingHashMapLong.IteratorLong i = iter(ls); i.hasNext(); )
ds[x++] = Double.longBitsToDouble(i.nextLong());
Arrays.sort(ds);
return ds;
}
private static NonBlockingHashMapLong.IteratorLong iter(NonBlockingHashMapLong nbhml) {
return (NonBlockingHashMapLong.IteratorLong) nbhml.keySet().iterator();
}
// Implementation is a double-dimension NBHML. Each dimension key is the raw
// long bits of the double column. Bottoms out in an AtomicLong.
private static class SlowCnt extends MRTask<AstTable.SlowCnt> {
transient NonBlockingHashMapLong<NonBlockingHashMapLong<AtomicLong>> _col0s;
@Override
public void setupLocal() {
_col0s = new NonBlockingHashMapLong<>();
}
@Override
public void map(Chunk c0, Chunk c1) {
for (int i = 0; i < c0._len; i++) {
double d0 = c0.atd(i);
if (Double.isNaN(d0)) continue;
long l0 = Double.doubleToRawLongBits(d0);
double d1 = c1.atd(i);
if (Double.isNaN(d1)) continue;
long l1 = Double.doubleToRawLongBits(d1);
// Atomically fetch/create nested NBHM
NonBlockingHashMapLong<AtomicLong> col1s = _col0s.get(l0);
if (col1s == null) { // Speed filter pre-filled entries
col1s = new NonBlockingHashMapLong<>();
NonBlockingHashMapLong<AtomicLong> old = _col0s.putIfAbsent(l0, col1s);
if (old != null) col1s = old; // Lost race, use old value
}
// Atomically fetch/create nested AtomicLong
AtomicLong cnt = col1s.get(l1);
if (cnt == null) { // Speed filter pre-filled entries
cnt = new AtomicLong();
AtomicLong old = col1s.putIfAbsent(l1, cnt);
if (old != null) cnt = old; // Lost race, use old value
}
// Atomically bump counter
cnt.incrementAndGet();
}
}
@Override
public void reduce(AstTable.SlowCnt sc) {
if (_col0s == sc._col0s) return;
throw water.H2O.unimpl();
}
public final AutoBuffer write_impl(AutoBuffer ab) {
if (_col0s == null) return ab.put8(0);
ab.put8(_col0s.size());
for (long col0 : _col0s.keySetLong()) {
ab.put8(col0);
NonBlockingHashMapLong<AtomicLong> col1s = _col0s.get(col0);
ab.put8(col1s.size());
for (long col1 : col1s.keySetLong()) {
ab.put8(col1);
ab.put8(col1s.get(col1).get());
}
}
return ab;
}
public final AstTable.SlowCnt read_impl(AutoBuffer ab) {
long len0 = ab.get8();
if (len0 == 0) return this;
_col0s = new NonBlockingHashMapLong<>();
for (long i = 0; i < len0; i++) {
NonBlockingHashMapLong<AtomicLong> col1s = new NonBlockingHashMapLong<>();
_col0s.put(ab.get8(), col1s);
long len1 = ab.get8();
for (long j = 0; j < len1; j++)
col1s.put(ab.get8(), new AtomicLong(ab.get8()));
}
return this;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (NonBlockingHashMapLong.IteratorLong i = iter(_col0s); i.hasNext(); ) {
long l = i.nextLong();
double d = Double.longBitsToDouble(l);
sb.append(d).append(": {");
NonBlockingHashMapLong<AtomicLong> col1s = _col0s.get(l);
for (NonBlockingHashMapLong.IteratorLong j = iter(col1s); j.hasNext(); ) {
long l2 = j.nextLong();
double d2 = Double.longBitsToDouble(l2);
AtomicLong al = col1s.get(l2);
sb.append(d2).append(": ").append(al.get()).append(", ");
}
sb.append("}\n");
}
return sb.toString();
}
}
}