package water.rapids.ast.prims.mungers;
import water.H2O;
import water.Iced;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Val;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.rapids.vals.ValFun;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.params.AstNum;
import water.rapids.ast.params.AstNumList;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.Log;
import java.util.Arrays;
/**
* GroupBy
* Group the rows of 'data' by unique combinations of '[group-by-cols]'.
* Apply function 'fcn' to a Frame for each group, with a single column
* argument, and a NA-handling flag. Sets of tuples {fun,col,na} are allowed.
* <p/>
* 'fcn' must be a one of a small set of functions, all reductions, and 'GB'
* returns a row per unique group, with the first columns being the grouping
* columns, and the last column(s) the reduction result(s).
* <p/>
* The returned column(s).
*/
public class AstGroup extends AstPrimitive {
public enum NAHandling {ALL, RM, IGNORE}
// Functions handled by GroupBy
public enum FCN {
nrow() {
@Override
public void op(double[] d0s, double d1) {
d0s[0]++;
}
@Override
public void atomic_op(double[] d0s, double[] d1s) {
d0s[0] += d1s[0];
}
@Override
public double postPass(double ds[], long n) {
return ds[0];
}
},
mean() {
@Override
public void op(double[] d0s, double d1) {
d0s[0] += d1;
}
@Override
public void atomic_op(double[] d0s, double[] d1s) {
d0s[0] += d1s[0];
}
@Override
public double postPass(double ds[], long n) {
return ds[0] / n;
}
},
sum() {
@Override
public void op(double[] d0s, double d1) {
d0s[0] += d1;
}
@Override
public void atomic_op(double[] d0s, double[] d1s) {
d0s[0] += d1s[0];
}
@Override
public double postPass(double ds[], long n) {
return ds[0];
}
},
sumSquares() {
@Override
public void op(double[] d0s, double d1) {
d0s[0] += d1 * d1;
}
@Override
public void atomic_op(double[] d0s, double[] d1s) {
d0s[0] += d1s[0];
}
@Override
public double postPass(double ds[], long n) {
return ds[0];
}
},
var() {
@Override
public void op(double[] d0s, double d1) {
d0s[0] += d1 * d1;
d0s[1] += d1;
}
@Override
public void atomic_op(double[] d0s, double[] d1s) {
ArrayUtils.add(d0s, d1s);
}
@Override
public double postPass(double ds[], long n) {
double numerator = ds[0] - ds[1] * ds[1] / n;
if (Math.abs(numerator) < 1e-5) numerator = 0;
return numerator / (n - 1);
}
@Override
public double[] initVal(int ignored) {
return new double[2]; /* 0 -> sum_squares; 1 -> sum*/
}
},
sdev() {
@Override
public void op(double[] d0s, double d1) {
d0s[0] += d1 * d1;
d0s[1] += d1;
}
@Override
public void atomic_op(double[] d0s, double[] d1s) {
ArrayUtils.add(d0s, d1s);
}
@Override
public double postPass(double ds[], long n) {
double numerator = ds[0] - ds[1] * ds[1] / n;
if (Math.abs(numerator) < 1e-5) numerator = 0;
return Math.sqrt(numerator / (n - 1));
}
@Override
public double[] initVal(int ignored) {
return new double[2]; /* 0 -> sum_squares; 1 -> sum*/
}
},
min() {
@Override
public void op(double[] d0s, double d1) {
d0s[0] = Math.min(d0s[0], d1);
}
@Override
public void atomic_op(double[] d0s, double[] d1s) {
op(d0s, d1s[0]);
}
@Override
public double postPass(double ds[], long n) {
return ds[0];
}
@Override
public double[] initVal(int maxx) {
return new double[]{Double.MAX_VALUE};
}
},
max() {
@Override
public void op(double[] d0s, double d1) {
d0s[0] = Math.max(d0s[0], d1);
}
@Override
public void atomic_op(double[] d0s, double[] d1s) {
op(d0s, d1s[0]);
}
@Override
public double postPass(double ds[], long n) {
return ds[0];
}
@Override
public double[] initVal(int maxx) {
return new double[]{-Double.MAX_VALUE};
}
},
mode() {
@Override
public void op(double[] d0s, double d1) {
d0s[(int) d1]++;
}
@Override
public void atomic_op(double[] d0s, double[] d1s) {
ArrayUtils.add(d0s, d1s);
}
@Override
public double postPass(double ds[], long n) {
return ArrayUtils.maxIndex(ds);
}
@Override
public double[] initVal(int maxx) {
return new double[maxx];
}
},;
public abstract void op(double[] d0, double d1);
public abstract void atomic_op(double[] d0, double[] d1);
public abstract double postPass(double ds[], long n);
public double[] initVal(int maxx) {
return new double[]{0};
}
}
@Override
public int nargs() {
return -1;
} // (GB data [group-by-cols] {fcn col "na"}...)
@Override
public String[] args() {
return new String[]{"..."};
}
@Override
public String str() {
return "GB";
}
@Override
public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) {
Frame fr = stk.track(asts[1].exec(env)).getFrame();
int ncols = fr.numCols();
AstNumList groupby = check(ncols, asts[2]);
final int[] gbCols = groupby.expand4();
// Count of aggregates; knock off the first 4 ASTs (GB data [group-by] [order-by]...),
// then count by triples.
int naggs = (asts.length - 3) / 3;
final AGG[] aggs = new AGG[naggs];
for (int idx = 3; idx < asts.length; idx += 3) {
Val v = asts[idx].exec(env);
String fn = v instanceof ValFun ? v.getFun().str() : v.getStr();
FCN fcn = FCN.valueOf(fn);
AstNumList col = check(ncols, asts[idx + 1]);
if (col.cnt() != 1) throw new IllegalArgumentException("Group-By functions take only a single column");
int agg_col = (int) col.min(); // Aggregate column
if (fcn == FCN.mode && !fr.vec(agg_col).isCategorical())
throw new IllegalArgumentException("Mode only allowed on categorical columns");
NAHandling na = NAHandling.valueOf(asts[idx + 2].exec(env).getStr().toUpperCase());
aggs[(idx - 3) / 3] = new AGG(fcn, agg_col, na, (int) fr.vec(agg_col).max() + 1);
}
// do the group by work now
IcedHashMap<G, String> gss = doGroups(fr, gbCols, aggs);
final G[] grps = gss.keySet().toArray(new G[gss.size()]);
// apply an ORDER by here...
if (gbCols.length > 0)
Arrays.sort(grps, new java.util.Comparator<G>() {
// Compare 2 groups. Iterate down _gs, stop when _gs[i] > that._gs[i],
// or _gs[i] < that._gs[i]. Order by various columns specified by
// gbCols. NaN is treated as least
@Override
public int compare(G g1, G g2) {
for (int i = 0; i < gbCols.length; i++) {
if (Double.isNaN(g1._gs[i]) && !Double.isNaN(g2._gs[i])) return -1;
if (!Double.isNaN(g1._gs[i]) && Double.isNaN(g2._gs[i])) return 1;
if (g1._gs[i] != g2._gs[i]) return g1._gs[i] < g2._gs[i] ? -1 : 1;
}
return 0;
}
// I do not believe sort() calls equals() at this time, so no need to implement
@Override
public boolean equals(Object o) {
throw H2O.unimpl();
}
});
// Build the output!
String[] fcnames = new String[aggs.length];
for (int i = 0; i < aggs.length; i++) {
if(aggs[i]._fcn.toString() != "nrow") {
fcnames[i] = aggs[i]._fcn.toString() + "_" + fr.name(aggs[i]._col);
}else{
fcnames[i] = aggs[i]._fcn.toString();
}
}
MRTask mrfill = new MRTask() {
@Override
public void map(Chunk[] c, NewChunk[] ncs) {
int start = (int) c[0].start();
for (int i = 0; i < c[0]._len; ++i) {
G g = grps[i + start]; // One Group per row
int j;
for (j = 0; j < g._gs.length; j++) // The Group Key, as a row
ncs[j].addNum(g._gs[j]);
for (int a = 0; a < aggs.length; a++)
ncs[j++].addNum(aggs[a]._fcn.postPass(g._dss[a], g._ns[a]));
}
}
};
Frame f = buildOutput(gbCols, naggs, fr, fcnames, grps.length, mrfill);
return new ValFrame(f);
}
// Argument check helper
public static AstNumList check(long dstX, AstRoot ast) {
// Sanity check vs dst. To simplify logic, jam the 1 col/row case in as a AstNumList
AstNumList dim;
if (ast instanceof AstNumList) dim = (AstNumList) ast;
else if (ast instanceof AstNum) dim = new AstNumList(((AstNum) ast).getNum());
else throw new IllegalArgumentException("Requires a number-list, but found a " + ast.getClass());
if (dim.isEmpty()) return dim; // Allow empty
for (int col : dim.expand4())
if (!(0 <= col && col < dstX))
throw new IllegalArgumentException("Selection must be an integer from 0 to " + dstX);
return dim;
}
// Do all the grouping work. Find groups in frame 'fr', grouped according to
// the selected 'gbCols' columns, and for each group compute aggregrate
// results using 'aggs'. Return an array of groups, with the aggregate results.
public static IcedHashMap<G, String> doGroups(Frame fr, int[] gbCols, AGG[] aggs) {
// do the group by work now
long start = System.currentTimeMillis();
GBTask p1 = new GBTask(gbCols, aggs).doAll(fr);
Log.info("Group By Task done in " + (System.currentTimeMillis() - start) / 1000. + " (s)");
return p1._gss;
}
// Utility for AstDdply; return a single aggregate for counting rows-per-group
public static AGG[] aggNRows() {
return new AGG[]{new AGG(FCN.nrow, 0, NAHandling.IGNORE, 0)};
}
// Build output frame from the multi-column results
public static Frame buildOutput(int[] gbCols, int noutCols, Frame fr, String[] fcnames, int ngrps, MRTask mrfill) {
// Build the output!
// the names of columns
final int nCols = gbCols.length + noutCols;
String[] names = new String[nCols];
String[][] domains = new String[nCols][];
for (int i = 0; i < gbCols.length; i++) {
names[i] = fr.name(gbCols[i]);
domains[i] = fr.domains()[gbCols[i]];
}
for (int i = 0; i < fcnames.length; i++)
names[i + gbCols.length] = fcnames[i];
Vec v = Vec.makeZero(ngrps); // dummy layout vec
// Convert the output arrays into a Frame, also doing the post-pass work
Frame f = mrfill.doAll(nCols, Vec.T_NUM, new Frame(v)).outputFrame(names, domains);
v.remove();
return f;
}
// Description of a single aggregate, including the reduction function, the
// column and specified NA handling
public static class AGG extends Iced {
final FCN _fcn;
public final int _col;
final NAHandling _na;
final int _maxx; // Largest integer this column
public AGG(FCN fcn, int col, NAHandling na, int maxx) {
_fcn = fcn;
_col = col;
_na = na;
_maxx = maxx;
}
// Update the array pair {ds[i],ns[i]} with d1.
// ds is the reduction array
// ns is the element count
public void op(double[][] d0ss, long[] n0s, int i, double d1) {
// Normal number or ALL : call op()
if (!Double.isNaN(d1) || _na == NAHandling.ALL) _fcn.op(d0ss[i], d1);
// Normal number or IGNORE: bump count; RM: do not bump count
if (!Double.isNaN(d1) || _na == NAHandling.IGNORE) n0s[i]++;
}
// Atomically update the array pair {dss[i],ns[i]} with the pair {d1,n1}.
// Same as op() above, but called racily and updates atomically.
public void atomic_op(double[][] d0ss, long[] n0s, int i, double[] d1s, long n1) {
synchronized (d0ss[i]) {
_fcn.atomic_op(d0ss[i], d1s);
n0s[i] += n1;
}
}
public double[] initVal() {
return _fcn.initVal(_maxx);
}
}
// --------------------------------------------------------------------------
// Main worker MRTask. Makes 1 pass over the data, and accumulates both all
// groups and all aggregates
public static class GBTask extends MRTask<GBTask> {
final IcedHashMap<G, String> _gss; // Shared per-node, common, racy
private final int[] _gbCols; // Columns used to define group
private final AGG[] _aggs; // Aggregate descriptions
GBTask(int[] gbCols, AGG[] aggs) {
_gbCols = gbCols;
_aggs = aggs;
_gss = new IcedHashMap<>();
}
@Override
public void map(Chunk[] cs) {
// Groups found in this Chunk
IcedHashMap<G, String> gs = new IcedHashMap<>();
G gWork = new G(_gbCols.length, _aggs); // Working Group
G gOld; // Existing Group to be filled in
for (int row = 0; row < cs[0]._len; row++) {
// Find the Group being worked on
gWork.fill(row, cs, _gbCols); // Fill the worker Group for the hashtable lookup
if (gs.putIfAbsent(gWork, "") == null) { // Insert if not absent (note: no race, no need for atomic)
gOld = gWork; // Inserted 'gWork' into table
gWork = new G(_gbCols.length, _aggs); // need entirely new G
} else gOld = gs.getk(gWork); // Else get existing group
for (int i = 0; i < _aggs.length; i++) // Accumulate aggregate reductions
_aggs[i].op(gOld._dss, gOld._ns, i, cs[_aggs[i]._col].atd(row));
}
// This is a racy update into the node-local shared table of groups
reduce(gs); // Atomically merge Group stats
}
// Racy update on a subtle path: reduction is always single-threaded, but
// the shared global hashtable being reduced into is ALSO being written by
// parallel map calls.
@Override
public void reduce(GBTask t) {
if (_gss != t._gss) reduce(t._gss);
}
// Non-blocking race-safe update of the shared per-node groups hashtable
private void reduce(IcedHashMap<G, String> r) {
for (G rg : r.keySet())
if (_gss.putIfAbsent(rg, "") != null) {
G lg = _gss.getk(rg);
for (int i = 0; i < _aggs.length; i++)
_aggs[i].atomic_op(lg._dss, lg._ns, i, rg._dss[i], rg._ns[i]); // Need to atomically merge groups here
}
}
}
// Groups! Contains a Group Key - an array of doubles (often just 1 entry
// long) that defines the Group. Also contains an array of doubles for the
// aggregate results, one per aggregate.
public static class G extends Iced {
public final double[] _gs; // Group Key: Array is final; contents change with the "fill"
int _hash; // Hash is not final; changes with the "fill"
public final double _dss[][]; // Aggregates: usually sum or sum*2
public final long _ns[]; // row counts per aggregate, varies by NA handling and column
public G(int ncols, AGG[] aggs) {
_gs = new double[ncols];
int len = aggs == null ? 0 : aggs.length;
_dss = new double[len][];
_ns = new long[len];
for (int i = 0; i < len; i++) _dss[i] = aggs[i].initVal();
}
public G fill(int row, Chunk chks[], int cols[]) {
for (int c = 0; c < cols.length; c++) // For all selection cols
_gs[c] = chks[cols[c]].atd(row); // Load into working array
_hash = hash();
return this;
}
protected int hash() {
long h = 0; // hash is sum of field bits
for (double d : _gs) h += Double.doubleToRawLongBits(d);
// Doubles are lousy hashes; mix up the bits some
h ^= (h >>> 20) ^ (h >>> 12);
h ^= (h >>> 7) ^ (h >>> 4);
return (int) ((h ^ (h >> 32)) & 0x7FFFFFFF);
}
@Override
public boolean equals(Object o) {
return o instanceof G && Arrays.equals(_gs, ((G) o)._gs);
}
@Override
public int hashCode() {
return _hash;
}
@Override
public String toString() {
return Arrays.toString(_gs);
}
}
}