package water.rapids.ast.prims.reducers; import water.H2O; import water.Key; import water.MRTask; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.NewChunk; import water.fvec.Vec; import water.rapids.Env; import water.rapids.vals.ValFrame; import water.rapids.ast.AstPrimitive; import water.rapids.ast.AstRoot; import water.util.ArrayUtils; import java.util.Arrays; /** */ public abstract class AstCumu extends AstPrimitive { @Override public String[] args() { return new String[]{"ary","axis"}; } @Override public int nargs() { return 1 + 1; } // (cumu x) @Override public String str() { throw H2O.unimpl(); } public abstract double op(double l, double r); public abstract double init(); @Override public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) { Frame f = stk.track(asts[1].exec(env)).getFrame(); AstRoot axisAR = asts[2]; for (Vec v:f.vecs()) { if(v.isCategorical() || v.isString() || v.isUUID()) throw new IllegalArgumentException( "Cumulative functions not applicable to enum, string, or UUID values"); } double axis = axisAR.exec(env).getNum(); if (axis != 1.0 && axis != 0.0) throw new IllegalArgumentException("Axis must be 0 or 1"); if (f.numCols() == 1) { if (axis == 0.0) { AstCumu.CumuTask t = new AstCumu.CumuTask(f.anyVec().nChunks(), init()); t.doAll(new byte[]{Vec.T_NUM}, f.anyVec()); final double[] chkCumu = t._chkCumu; Vec cumuVec = t.outputFrame().anyVec(); new MRTask() { @Override public void map(Chunk c) { if (c.cidx() != 0) { double d = chkCumu[c.cidx() - 1]; for (int i = 0; i < c._len; ++i) c.set(i, op(c.atd(i), d)); } } }.doAll(cumuVec); Key<Frame> k = Key.make(); return new ValFrame(new Frame(k, null, new Vec[]{cumuVec})); } else { return new ValFrame(new Frame(f)); } } else { if (axis == 0.0) { // down the column implementation AstCumu.CumuTaskWholeFrame t = new AstCumu.CumuTaskWholeFrame(f.anyVec().nChunks(), init(), f.numCols()); Frame fr2 = t.doAll(f.numCols(), Vec.T_NUM, f).outputFrame(null, f.names(), null); final double[][] chkCumu = t._chkCumu; new MRTask() { @Override public void map(Chunk cs[]) { if (cs[0].cidx() != 0) { for (int i = 0; i < cs.length; i++) { double d = chkCumu[i][cs[i].cidx() - 1]; for (int j = 0; j < cs[i]._len; ++j) cs[i].set(j, op(cs[i].atd(j), d)); } } } }.doAll(fr2); return new ValFrame(new Frame(fr2)); } else { AstCumu.CumuTaskAxis1 t = new AstCumu.CumuTaskAxis1(init()); Frame fr2 = t.doAll(f.numCols(), Vec.T_NUM, f).outputFrame(null, f.names(), null); return new ValFrame(new Frame(fr2)); } } } protected class CumuTaskAxis1 extends MRTask<AstCumu.CumuTaskAxis1> { // apply function along the rows final double _init; CumuTaskAxis1(double init) { _init = init; } @Override public void map(Chunk cs[], NewChunk nc[]) { for (int i = 0; i < cs[0].len(); i++) { for (int j = 0; j < cs.length; j++) { double preVal = j == 0 ? _init : nc[j-1].atd(i); nc[j].addNum(op(preVal,cs[j].atd(i))); } } } } protected class CumuTaskWholeFrame extends MRTask<AstCumu.CumuTaskWholeFrame> { final int _nchks; // IN final double _init; // IN final int _ncols; // IN double[][] _chkCumu; // OUT, accumulation over each chunk CumuTaskWholeFrame(int nchks, double init, int ncols) { _nchks = nchks; _init = init; _ncols = ncols; } @Override public void setupLocal() { _chkCumu = new double[_ncols][_nchks]; } @Override public void map(Chunk cs[], NewChunk nc[]) { double acc[] = new double[cs.length]; Arrays.fill(acc,_init); for (int i = 0; i < cs.length; i++) { for (int j = 0; j < cs[i]._len; ++j) nc[i].addNum(acc[i] = op(acc[i], cs[i].atd(j))); _chkCumu[i][cs[i].cidx()] = acc[i]; } } @Override public void reduce(AstCumu.CumuTaskWholeFrame t) { if (_chkCumu != t._chkCumu) ArrayUtils.add(_chkCumu, t._chkCumu); } @Override public void postGlobal() { for (int i = 1; i < _chkCumu.length; i++) { for (int j = 1; j < _chkCumu[i].length; ++j) { _chkCumu[i][j] = op(_chkCumu[i][j], _chkCumu[i][j - 1]); } } } } protected class CumuTask extends MRTask<AstCumu.CumuTask> { final int _nchks; // IN final double _init; // IN double[] _chkCumu; // OUT, accumulation over each chunk CumuTask(int nchks, double init) { _nchks = nchks; _init = init; } @Override public void setupLocal() { _chkCumu = new double[_nchks]; } @Override public void map(Chunk c, NewChunk nc) { double acc = _init; for (int i = 0; i < c._len; ++i) nc.addNum(acc = op(acc, c.atd(i))); _chkCumu[c.cidx()] = acc; } @Override public void reduce(AstCumu.CumuTask t) { if (_chkCumu != t._chkCumu) ArrayUtils.add(_chkCumu, t._chkCumu); } @Override public void postGlobal() { for (int i = 1; i < _chkCumu.length; ++i) _chkCumu[i] = op(_chkCumu[i], _chkCumu[i - 1]); } } }