package water.rapids.ast.prims.mungers; import water.H2O; import water.MRTask; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.NewChunk; import water.fvec.Vec; import water.rapids.Env; import water.rapids.Val; import water.rapids.ast.AstRoot; import water.rapids.vals.ValFrame; import water.rapids.ast.AstPrimitive; import water.rapids.ast.params.AstNumList; import water.util.IcedHashMap; import water.util.Log; import java.util.HashMap; public class AstGroupedPermute extends AstPrimitive { // .newExpr("grouped_permute", fr, permCol, permByCol, groupByCols, keepCol) @Override public String[] args() { return new String[]{"ary", "permCol", "groupBy", "permuteBy", "keepCol"}; } // currently only allow 2 items in permuteBy @Override public int nargs() { return 1 + 5; } // (trim x col groupBy permuteBy keepCol) @Override public String str() { return "grouped_permute"; } @Override public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) { Frame fr = stk.track(asts[1].exec(env)).getFrame(); final int permCol = (int) asts[2].exec(env).getNum(); AstNumList groupby = AstGroup.check(fr.numCols(), asts[3]); final int[] gbCols = groupby.expand4(); final int permuteBy = (int) asts[4].exec(env).getNum(); final int keepCol = (int) asts[5].exec(env).getNum(); String[] names = new String[gbCols.length + 4]; int i = 0; for (; i < gbCols.length; ++i) names[i] = fr.name(gbCols[i]); names[i++] = "In"; names[i++] = "Out"; names[i++] = "InAmnt"; names[i] = "OutAmnt"; String[][] domains = new String[names.length][]; int d = 0; for (; d < gbCols.length; d++) domains[d] = fr.domains()[gbCols[d]]; domains[d++] = fr.domains()[permCol]; domains[d++] = fr.domains()[permCol]; domains[d++] = fr.domains()[keepCol]; domains[d] = fr.domains()[keepCol]; long s = System.currentTimeMillis(); BuildGroups t = new BuildGroups(gbCols, permuteBy, permCol, keepCol).doAll(fr); Log.info("Elapsed time: " + (System.currentTimeMillis() - s) / 1000. + "s"); s = System.currentTimeMillis(); SmashGroups sg; H2O.submitTask(sg = new SmashGroups(t._grps)).join(); Log.info("Elapsed time: " + (System.currentTimeMillis() - s) / 1000. + "s"); return new ValFrame(buildOutput(sg._res.values().toArray(new double[0][][]), names, domains)); } private static Frame buildOutput(final double[][][] a, String[] names, String[][] domains) { Frame dVec = new Frame(Vec.makeSeq(0, a.length)); long s = System.currentTimeMillis(); Frame res = new MRTask() { @Override public void map(Chunk[] cs, NewChunk[] ncs) { for (int i = 0; i < cs[0]._len; ++i) for (double[] anAa : a[(int) cs[0].at8(i)]) for (int k = 0; k < anAa.length; ++k) ncs[k].addNum(anAa[k]); } }.doAll(5, Vec.T_NUM, dVec).outputFrame(null, names, domains); Log.info("Elapsed time: " + (System.currentTimeMillis() - s) / 1000. + "s"); dVec.delete(); return res; } private static class BuildGroups extends MRTask<BuildGroups> { IcedHashMap<Long, IcedHashMap<Long, double[]>[]> _grps; // shared per node (all grps with permutations atomically inserted) private final int _gbCols[]; private final int _permuteBy; private final int _permuteCol; private final int _amntCol; BuildGroups(int[] gbCols, int permuteBy, int permuteCol, int amntCol) { _gbCols = gbCols; _permuteBy = permuteBy; _permuteCol = permuteCol; _amntCol = amntCol; } @Override public void setupLocal() { _grps = new IcedHashMap<>(); } @Override public void map(Chunk[] chks) { String[] dom = chks[_permuteBy].vec().domain(); IcedHashMap<Long, IcedHashMap<Long, double[]>[]> grps = new IcedHashMap<>(); for (int row = 0; row < chks[0]._len; ++row) { long jid = chks[_gbCols[0]].at8(row); long rid = chks[_permuteCol].at8(row); double[] aci = new double[]{rid, chks[_amntCol].atd(row)}; int type = dom[(int) chks[_permuteBy].at8(row)].equals("D") ? 0 : 1; if (grps.containsKey(jid)) { IcedHashMap<Long, double[]>[] dcWork = grps.get(jid); if (dcWork[type].putIfAbsent(rid, aci) != null) dcWork[type].get(rid)[1] += aci[1]; } else { IcedHashMap<Long, double[]>[] dcAcnts = new IcedHashMap[2]; dcAcnts[0] = new IcedHashMap<>(); dcAcnts[1] = new IcedHashMap<>(); dcAcnts[type].put(rid, aci); grps.put(jid, dcAcnts); } } reduce(grps); } @Override public void reduce(BuildGroups t) { if (_grps != t._grps) reduce(t._grps); } private void reduce(IcedHashMap<Long, IcedHashMap<Long, double[]>[]> r) { for (Long l : r.keySet()) { if (_grps.putIfAbsent(l, r.get(l)) != null) { IcedHashMap<Long, double[]>[] rdbls = r.get(l); IcedHashMap<Long, double[]>[] ldbls = _grps.get(l); for (Long rr : rdbls[0].keySet()) if (ldbls[0].putIfAbsent(rr, rdbls[0].get(rr)) != null) ldbls[0].get(rr)[1] += rdbls[0].get(rr)[1]; for (Long rr : rdbls[1].keySet()) if (ldbls[1].putIfAbsent(rr, rdbls[1].get(rr)) != null) ldbls[1].get(rr)[1] += rdbls[1].get(rr)[1]; } } } } private static class SmashGroups extends H2O.H2OCountedCompleter<SmashGroups> { private final IcedHashMap<Long, IcedHashMap<Long, double[]>[]> _grps; private final HashMap<Integer, Long> _map; private int _hi; private int _lo; SmashGroups _left; SmashGroups _rite; private IcedHashMap<Long, double[][]> _res; SmashGroups(IcedHashMap<Long, IcedHashMap<Long, double[]>[]> grps) { _grps = grps; _lo = 0; _hi = _grps.size(); _res = new IcedHashMap<>(); _map = new HashMap<>(); int i = 0; for (Long l : _grps.keySet()) _map.put(i++, l); } @Override public void compute2() { assert _left == null && _rite == null; if ((_hi - _lo) >= 2) { // divide/conquer down to 1 IHM final int mid = (_lo + _hi) >>> 1; // Mid-point _left = copyAndInit(); _rite = copyAndInit(); _left._hi = mid; // Reset mid-point _rite._lo = mid; // Also set self mid-point addToPendingCount(1); // One fork awaiting completion _left.fork(); // Runs in another thread/FJ instance _rite.compute2(); // Runs in THIS F/J thread return; } if (_hi > _lo) { smash(); } tryComplete(); } private void smash() { long key = _map.get(_lo); IcedHashMap<Long, double[]>[] pair = _grps.get(key); double[][] res = new double[pair[0].size() * pair[1].size()][]; // all combos int d0 = 0; for (double[] ds0 : pair[0].values()) { for (double[] ds1 : pair[1].values()) res[d0++] = new double[]{key, ds0[0], ds1[0], ds0[1], ds1[1]}; } _res.put(key, res); } private SmashGroups copyAndInit() { SmashGroups x = SmashGroups.this.clone(); x.setCompleter(this); x._left = x._rite = null; x.setPendingCount(0); return x; } } }