package water.rapids.ast.prims.operators; import water.H2O; import water.MRTask; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.NewChunk; import water.fvec.Vec; import water.parser.BufferedString; import water.rapids.*; import water.rapids.ast.AstPrimitive; import water.rapids.ast.AstRoot; import water.rapids.vals.ValFrame; import water.rapids.vals.ValNum; import water.rapids.vals.ValRow; import water.util.ArrayUtils; import java.util.Arrays; /** * Binary operator. * Subclasses auto-widen between scalars and Frames, and have exactly two arguments */ abstract public class AstBinOp extends AstPrimitive { @Override public String[] args() { return new String[]{"leftArg", "rightArg"}; } @Override public int nargs() { return 1 + 2; } @Override public Val apply(Env env, Env.StackHelp stk, AstRoot asts[]) { Val left = stk.track(asts[1].exec(env)); Val rite = stk.track(asts[2].exec(env)); return prim_apply(left, rite); } public Val prim_apply(Val left, Val rite) { switch (left.type()) { case Val.NUM: final double dlf = left.getNum(); switch (rite.type()) { case Val.NUM: return new ValNum(op(dlf, rite.getNum())); case Val.NUMS: return new ValNum(op(dlf, rite.getNums()[0])); case Val.FRM: return scalar_op_frame(dlf, rite.getFrame()); case Val.ROW: double[] lft = new double[rite.getRow().length]; Arrays.fill(lft, dlf); return row_op_row(lft, rite.getRow(), ((ValRow) rite).getNames()); case Val.STR: throw H2O.unimpl(); case Val.STRS: throw H2O.unimpl(); default: throw H2O.unimpl(); } case Val.NUMS: final double ddlf = left.getNums()[0]; switch (rite.type()) { case Val.NUM: return new ValNum(op(ddlf, rite.getNum())); case Val.NUMS: return new ValNum(op(ddlf, rite.getNums()[0])); case Val.FRM: return scalar_op_frame(ddlf, rite.getFrame()); case Val.ROW: double[] lft = new double[rite.getRow().length]; Arrays.fill(lft, ddlf); return row_op_row(lft, rite.getRow(), ((ValRow) rite).getNames()); case Val.STR: throw H2O.unimpl(); case Val.STRS: throw H2O.unimpl(); default: throw H2O.unimpl(); } case Val.FRM: Frame flf = left.getFrame(); switch (rite.type()) { case Val.NUM: return frame_op_scalar(flf, rite.getNum()); case Val.NUMS: return frame_op_scalar(flf, rite.getNums()[0]); case Val.STR: return frame_op_scalar(flf, rite.getStr()); case Val.STRS: return frame_op_scalar(flf, rite.getStrs()[0]); case Val.FRM: return frame_op_frame(flf, rite.getFrame()); default: throw H2O.unimpl(); } case Val.STR: String slf = left.getStr(); switch (rite.type()) { case Val.NUM: throw H2O.unimpl(); case Val.NUMS: throw H2O.unimpl(); case Val.STR: throw H2O.unimpl(); case Val.STRS: throw H2O.unimpl(); case Val.FRM: return scalar_op_frame(slf, rite.getFrame()); default: throw H2O.unimpl(); } case Val.STRS: String sslf = left.getStrs()[0]; switch (rite.type()) { case Val.NUM: throw H2O.unimpl(); case Val.NUMS: throw H2O.unimpl(); case Val.STR: throw H2O.unimpl(); case Val.STRS: throw H2O.unimpl(); case Val.FRM: return scalar_op_frame(sslf, rite.getFrame()); default: throw H2O.unimpl(); } case Val.ROW: double dslf[] = left.getRow(); switch (rite.type()) { case Val.NUM: double[] right = new double[dslf.length]; Arrays.fill(right, rite.getNum()); return row_op_row(dslf, right, ((ValRow) left).getNames()); case Val.ROW: return row_op_row(dslf, rite.getRow(), ((ValRow) rite).getNames()); case Val.FRM: return row_op_row(dslf, rite.getRow(), rite.getFrame().names()); default: throw H2O.unimpl(); } default: throw H2O.unimpl(); } } /** * Override to express a basic math primitive */ public abstract double op(double l, double r); public double str_op(BufferedString l, BufferedString r) { throw H2O.fail(); } /** * Auto-widen the scalar to every element of the frame */ private ValFrame scalar_op_frame(final double d, Frame fr) { Frame res = new MRTask() { @Override public void map(Chunk[] chks, NewChunk[] cress) { for (int c = 0; c < chks.length; c++) { Chunk chk = chks[c]; NewChunk cres = cress[c]; for (int i = 0; i < chk._len; i++) cres.addNum(op(d, chk.atd(i))); } } }.doAll(fr.numCols(), Vec.T_NUM, fr).outputFrame(fr._names, null); return cleanCategorical(fr, res); // Cleanup categorical misuse } /** * Auto-widen the scalar to every element of the frame */ public ValFrame frame_op_scalar(Frame fr, final double d) { Frame res = new MRTask() { @Override public void map(Chunk[] chks, NewChunk[] cress) { for (int c = 0; c < chks.length; c++) { Chunk chk = chks[c]; NewChunk cres = cress[c]; for (int i = 0; i < chk._len; i++) cres.addNum(op(chk.atd(i), d)); } } }.doAll(fr.numCols(), Vec.T_NUM, fr).outputFrame(fr._names, null); return cleanCategorical(fr, res); // Cleanup categorical misuse } // Ops do not make sense on categoricals, except EQ/NE; flip such ops to NAs private ValFrame cleanCategorical(Frame oldfr, Frame newfr) { final boolean categoricalOK = categoricalOK(); final Vec oldvecs[] = oldfr.vecs(); final Vec newvecs[] = newfr.vecs(); for (int i = 0; i < oldvecs.length; i++) if ((oldvecs[i].isCategorical() && !categoricalOK)) // categorical are OK (op is EQ/NE) newvecs[i] = newvecs[i].makeCon(Double.NaN); return new ValFrame(newfr); } /** * Auto-widen the scalar to every element of the frame */ private ValFrame frame_op_scalar(Frame fr, final String str) { Frame res = new MRTask() { @Override public void map(Chunk[] chks, NewChunk[] cress) { BufferedString vstr = new BufferedString(); for (int c = 0; c < chks.length; c++) { Chunk chk = chks[c]; NewChunk cres = cress[c]; Vec vec = chk.vec(); // String Vectors: apply str_op as BufferedStrings to all elements if (vec.isString()) { final BufferedString conStr = new BufferedString(str); for (int i = 0; i < chk._len; i++) cres.addNum(str_op(chk.atStr(vstr, i), conStr)); } else if (vec.isCategorical()) { // categorical Vectors: convert string to domain value; apply op (not // str_op). Not sure what the "right" behavior here is, can // easily argue that should instead apply str_op to the categorical // string domain value - except that this whole operation only // makes sense for EQ/NE, and is much faster when just comparing // doubles vs comparing strings. Note that if the string is not // part of the categorical domain, the find op returns -1 which is never // equal to any categorical dense integer (which are always 0+). final double d = (double) ArrayUtils.find(vec.domain(), str); for (int i = 0; i < chk._len; i++) cres.addNum(op(chk.atd(i), d)); } else { // mixing string and numeric final double d = op(1, 2); // false or true only for (int i = 0; i < chk._len; i++) cres.addNum(d); } } } }.doAll(fr.numCols(), Vec.T_NUM, fr).outputFrame(fr._names, null); return new ValFrame(res); } /** * Auto-widen the scalar to every element of the frame */ private ValFrame scalar_op_frame(final String str, Frame fr) { Frame res = new MRTask() { @Override public void map(Chunk[] chks, NewChunk[] cress) { BufferedString vstr = new BufferedString(); for (int c = 0; c < chks.length; c++) { Chunk chk = chks[c]; NewChunk cres = cress[c]; Vec vec = chk.vec(); // String Vectors: apply str_op as BufferedStrings to all elements if (vec.isString()) { final BufferedString conStr = new BufferedString(str); for (int i = 0; i < chk._len; i++) cres.addNum(str_op(conStr, chk.atStr(vstr, i))); } else if (vec.isCategorical()) { // categorical Vectors: convert string to domain value; apply op (not // str_op). Not sure what the "right" behavior here is, can // easily argue that should instead apply str_op to the categorical // string domain value - except that this whole operation only // makes sense for EQ/NE, and is much faster when just comparing // doubles vs comparing strings. final double d = (double) ArrayUtils.find(vec.domain(), str); for (int i = 0; i < chk._len; i++) cres.addNum(op(d, chk.atd(i))); } else { // mixing string and numeric final double d = op(1, 2); // false or true only for (int i = 0; i < chk._len; i++) cres.addNum(d); } } } }.doAll(fr.numCols(), Vec.T_NUM, fr).outputFrame(fr._names, null); return new ValFrame(res); } /** * Auto-widen: If one frame has only 1 column, auto-widen that 1 column to * the rest. Otherwise the frames must have the same column count, and * auto-widen element-by-element. Short-cut if one frame has zero * columns. */ private ValFrame frame_op_frame(Frame lf, Frame rt) { if (lf.numRows() != rt.numRows()) { // special case for broadcasting a single row of data across a frame if (lf.numRows() == 1 || rt.numRows() == 1) { if (lf.numCols() != rt.numCols()) throw new IllegalArgumentException("Frames must have same columns, found " + lf.numCols() + " columns and " + rt.numCols() + " columns."); return frame_op_row(lf, rt); } else throw new IllegalArgumentException("Frames must have same rows, found " + lf.numRows() + " rows and " + rt.numRows() + " rows."); } if (lf.numCols() == 0) return new ValFrame(lf); if (rt.numCols() == 0) return new ValFrame(rt); if (lf.numCols() == 1 && rt.numCols() > 1) return vec_op_frame(lf.vecs()[0], rt); if (rt.numCols() == 1 && lf.numCols() > 1) return frame_op_vec(lf, rt.vecs()[0]); if (lf.numCols() != rt.numCols()) throw new IllegalArgumentException("Frames must have same columns, found " + lf.numCols() + " columns and " + rt.numCols() + " columns."); Frame res = new MRTask() { @Override public void map(Chunk[] chks, NewChunk[] cress) { BufferedString lfstr = new BufferedString(); BufferedString rtstr = new BufferedString(); assert (cress.length << 1) == chks.length; for (int c = 0; c < cress.length; c++) { Chunk clf = chks[c]; Chunk crt = chks[c + cress.length]; NewChunk cres = cress[c]; if (clf.vec().isString()) for (int i = 0; i < clf._len; i++) cres.addNum(str_op(clf.atStr(lfstr, i), crt.atStr(rtstr, i))); else for (int i = 0; i < clf._len; i++) cres.addNum(op(clf.atd(i), crt.atd(i))); } } }.doAll(lf.numCols(), Vec.T_NUM, new Frame(lf).add(rt)).outputFrame(lf._names, null); return cleanCategorical(lf, res); // Cleanup categorical misuse } private ValFrame frame_op_row(Frame lf, Frame row) { final double[] rawRow = new double[row.numCols()]; for (int i = 0; i < rawRow.length; ++i) rawRow[i] = row.vec(i).isNumeric() || row.vec(i).isTime() ? row.vec(i).at(0) : Double.NaN; // is numberlike, if not then NaN Frame res = new MRTask() { @Override public void map(Chunk[] chks, NewChunk[] cress) { for (int c = 0; c < cress.length; c++) { Chunk clf = chks[c]; NewChunk cres = cress[c]; for (int r = 0; r < clf._len; ++r) { if (clf.vec().isString()) cres.addNum(Double.NaN); // TODO: improve else cres.addNum(op(clf.atd(r), rawRow[c])); } } } }.doAll(lf.numCols(), Vec.T_NUM, lf).outputFrame(lf._names, null); return cleanCategorical(lf, res); } private ValRow row_op_row(double[] lf, double[] rt, String[] names) { double[] res = new double[lf.length]; for (int i = 0; i < lf.length; i++) res[i] = op(lf[i], rt[i]); return new ValRow(res, names); } private ValFrame vec_op_frame(Vec vec, Frame fr) { // Already checked for same rows, non-zero frame Frame rt = new Frame(fr); rt.add("", vec); Frame res = new MRTask() { @Override public void map(Chunk[] chks, NewChunk[] cress) { assert cress.length == chks.length - 1; Chunk clf = chks[cress.length]; for (int c = 0; c < cress.length; c++) { Chunk crt = chks[c]; NewChunk cres = cress[c]; for (int i = 0; i < clf._len; i++) cres.addNum(op(clf.atd(i), crt.atd(i))); } } }.doAll(fr.numCols(), Vec.T_NUM, rt).outputFrame(fr._names, null); return cleanCategorical(fr, res); // Cleanup categorical misuse } private ValFrame frame_op_vec(Frame fr, Vec vec) { // Already checked for same rows, non-zero frame Frame lf = new Frame(fr); lf.add("", vec); Frame res = new MRTask() { @Override public void map(Chunk[] chks, NewChunk[] cress) { assert cress.length == chks.length - 1; Chunk crt = chks[cress.length]; for (int c = 0; c < cress.length; c++) { Chunk clf = chks[c]; NewChunk cres = cress[c]; for (int i = 0; i < clf._len; i++) cres.addNum(op(clf.atd(i), crt.atd(i))); } } }.doAll(fr.numCols(), Vec.T_NUM, lf).outputFrame(fr._names, null); return cleanCategorical(fr, res); // Cleanup categorical misuse } // Make sense to run this OP on an enm? public boolean categoricalOK() { return false; } }