package water.rapids.ast.prims.mungers; import water.H2O; import water.fvec.*; import water.MRTask; import water.rapids.*; import water.rapids.ast.*; import water.rapids.ast.AstRoot; import water.rapids.vals.ValFrame; /** * Apply a Function to a frame * Typically, column-by-column, produces a 1-row frame as a result */ public class AstApply extends AstPrimitive { @Override public String[] args() { return new String[]{"ary", "margin", "fun"}; } @Override public int nargs() { return 1 + 3; } // (apply frame 1/2 fun) @Override public String str() { return "apply"; } @Override public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) { Frame fr = stk.track(asts[1].exec(env)).getFrame(); double margin = stk.track(asts[2].exec(env)).getNum(); AstPrimitive fun = stk.track(asts[3].exec(env)).getFun(); int nargs = fun.nargs(); if (nargs != -1 && nargs != 2) throw new IllegalArgumentException("Incorrect number of arguments; '" + fun + "' expects " + nargs + " but was passed " + 2); switch ((int) margin) { case 1: return rowwise(env, fr, fun); case 2: return colwise(env, stk, fr, fun); default: throw new IllegalArgumentException("Only row-wise (margin 1) or col-wise (margin 2) allowed"); } } // -------------------------------------------------------------------------- private ValFrame colwise(Env env, Env.StackHelp stk, Frame fr, AstPrimitive fun) { // Break each column into it's own Frame, then execute the function passing // the 1 argument. All columns are independent, and this loop should be // parallized over each column. Vec vecs[] = fr.vecs(); Val vals[] = new Val[vecs.length]; AstRoot[] asts = new AstRoot[]{fun, null}; for (int i = 0; i < vecs.length; i++) { asts[1] = new AstFrame(new Frame(new String[]{fr._names[i]}, new Vec[]{vecs[i]})); try (Env.StackHelp stk_inner = env.stk()) { vals[i] = fun.apply(env, stk_inner, asts); } } // All the resulting Vals must be the same scalar type (and if ValFrames, // the columns must be the same count and type). Build a Frame result with // 1 row column per applied function result (per column), and as many rows // as there are columns in the returned Frames. Val v0 = vals[0]; Vec ovecs[] = new Vec[vecs.length]; switch (v0.type()) { case Val.NUM: for (int i = 0; i < vecs.length; i++) ovecs[i] = Vec.makeCon(vals[i].getNum(), 1L); // Since the zero column is a number, all must be numbers break; case Val.FRM: long nrows = v0.getFrame().numRows(); for (int i = 0; i < vecs.length; i++) { Frame res = vals[i].getFrame(); // Since the zero column is a frame, all must be frames if (res.numCols() != 1) throw new IllegalArgumentException("apply result Frames must have one column, found " + res.numCols() + " cols"); if (res.numRows() != nrows) throw new IllegalArgumentException("apply result Frames must have all the same rows, found " + nrows + " rows and " + res.numRows()); ovecs[i] = res.vec(0); } break; case Val.NUMS: for (int i = 0; i < vecs.length; i++) ovecs[i] = Vec.makeCon(vals[i].getNums()[0], 1L); break; case Val.STRS: throw H2O.unimpl(); case Val.FUN: throw water.H2O.unimpl(); case Val.STR: throw water.H2O.unimpl(); default: throw water.H2O.unimpl(); } return new ValFrame(new Frame(fr._names, ovecs)); } // -------------------------------------------------------------------------- // Break each row into it's own Row, then execute the function passing the // 1 argument. All rows are independent, and run in parallel private ValFrame rowwise(Env env, Frame fr, final AstPrimitive fun) { final String[] names = fr._names; final AstFunction scope = env._scope; // Current execution scope; needed to lookup variables // do a single row of the frame to determine the size of the output. double[] ds = new double[fr.numCols()]; for (int col = 0; col < fr.numCols(); ++col) ds[col] = fr.vec(col).at(0); int noutputs = fun.apply(env, env.stk(), new AstRoot[]{fun, new AstRow(ds, fr.names())}).getRow().length; Frame res = new MRTask() { @Override public void map(Chunk chks[], NewChunk[] nc) { double ds[] = new double[chks.length]; // Working row AstRoot[] asts = new AstRoot[]{fun, new AstRow(ds, names)}; // Arguments to be called; they are reused endlessly Session ses = new Session(); // Session, again reused endlessly Env env = new Env(ses); env._scope = scope; // For proper namespace lookup for (int row = 0; row < chks[0]._len; row++) { for (int col = 0; col < chks.length; col++) // Fill the row ds[col] = chks[col].atd(row); try (Env.StackHelp stk_inner = env.stk()) { double[] valRow = fun.apply(env, stk_inner, asts).getRow(); // Make the call per-row for (int newCol = 0; newCol < nc.length; ++newCol) nc[newCol].addNum(valRow[newCol]); } } ses.end(null); // Mostly for the sanity checks } }.doAll(noutputs, Vec.T_NUM, fr).outputFrame(); return new ValFrame(res); } }