package water.rapids.ast.prims.operators; import water.H2O; import water.MRTask; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.NewChunk; import water.fvec.Vec; import water.rapids.*; import water.rapids.ast.AstPrimitive; import water.rapids.ast.AstRoot; import water.rapids.vals.ValFrame; import water.rapids.vals.ValNum; import water.rapids.vals.ValRow; import water.util.ArrayUtils; import water.util.VecUtils; import java.util.Arrays; /** * If-Else -- ternary conditional operator, equivalent of "?:" in C++ and Java. * <p/> * "NaNs poison". If the test is a NaN, evaluate neither side and return a NaN * <p/> * "Frames poison". If the test is a Frame, both sides are evaluated and selected between according to the test. * The result is a Frame. All Frames must be compatible, and scalars and 1-column Frames are widened to match the * widest frame. NaN test values produce NaN results. * <p/> * If the test is a scalar, then only the returned side is evaluated. If both sides are scalars or frames, then the * evaluated result is returned. The unevaluated side is not checked for being a compatible frame. It is an error * if one side is typed as a scalar and the other as a Frame. */ public class AstIfElse extends AstPrimitive { @Override public String[] args() { return new String[]{"test", "true", "false"}; } /* (ifelse test true false) */ @Override public int nargs() { return 1 + 3; } public String str() { return "ifelse"; } @Override public Val apply(Env env, Env.StackHelp stk, AstRoot asts[]) { Val val = stk.track(asts[1].exec(env)); if (val.isNum()) { // Scalar test, scalar result double d = val.getNum(); if (Double.isNaN(d)) return new ValNum(Double.NaN); Val res = stk.track(asts[d == 0 ? 3 : 2].exec(env)); // exec only 1 of false and true return res.isFrame() ? new ValNum(res.getFrame().vec(0).at(0)) : res; } // Frame test. Frame result. if (val.type() == Val.ROW) return row_ifelse((ValRow) val, asts[2].exec(env), asts[3].exec(env)); Frame tst = val.getFrame(); // If all zero's, return false and never execute true. Frame fr = new Frame(tst); Val tval = null; for (Vec vec : tst.vecs()) if (vec.min() != 0 || vec.max() != 0) { tval = exec_check(env, stk, tst, asts[2], fr); break; } final boolean has_tfr = tval != null && tval.isFrame(); final String ts = (tval != null && tval.isStr()) ? tval.getStr() : null; final double td = (tval != null && tval.isNum()) ? tval.getNum() : Double.NaN; final int[] tsIntMap = new int[tst.numCols()]; // If all nonzero's (or NA's), then never execute false. Val fval = null; for (Vec vec : tst.vecs()) if (vec.nzCnt() + vec.naCnt() < vec.length()) { fval = exec_check(env, stk, tst, asts[3], fr); break; } final boolean has_ffr = fval != null && fval.isFrame(); final String fs = (fval != null && fval.isStr()) ? fval.getStr() : null; final double fd = (fval != null && fval.isNum()) ? fval.getNum() : Double.NaN; final int[] fsIntMap = new int[tst.numCols()]; String[][] domains = null; final int[][] maps = new int[tst.numCols()][]; if (fs != null || ts != null) { // time to build domains... domains = new String[tst.numCols()][]; if (fs != null && ts != null) { for (int i = 0; i < tst.numCols(); ++i) { domains[i] = new String[]{fs, ts}; // false => 0; truth => 1 fsIntMap[i] = 0; tsIntMap[i] = 1; } } else if (ts != null) { for (int i = 0; i < tst.numCols(); ++i) { if (has_ffr) { Vec v = fr.vec(i + tst.numCols() + (has_tfr ? tst.numCols() : 0)); if (!v.isCategorical()) throw H2O.unimpl("Column is not categorical."); String[] dom = Arrays.copyOf(v.domain(), v.domain().length + 1); dom[dom.length - 1] = ts; Arrays.sort(dom); maps[i] = computeMap(v.domain(), dom); tsIntMap[i] = ArrayUtils.find(dom, ts); domains[i] = dom; } else throw H2O.unimpl(); } } else { // fs!=null for (int i = 0; i < tst.numCols(); ++i) { if (has_tfr) { Vec v = fr.vec(i + tst.numCols() + (has_ffr ? tst.numCols() : 0)); if (!v.isCategorical()) throw H2O.unimpl("Column is not categorical."); String[] dom = Arrays.copyOf(v.domain(), v.domain().length + 1); dom[dom.length - 1] = fs; Arrays.sort(dom); maps[i] = computeMap(v.domain(), dom); fsIntMap[i] = ArrayUtils.find(dom, fs); domains[i] = dom; } else throw H2O.unimpl(); } } } // Now pick from left-or-right in the new frame Frame res = new MRTask() { @Override public void map(Chunk chks[], NewChunk nchks[]) { assert nchks.length + (has_tfr ? nchks.length : 0) + (has_ffr ? nchks.length : 0) == chks.length; for (int i = 0; i < nchks.length; i++) { Chunk ctst = chks[i]; NewChunk res = nchks[i]; for (int row = 0; row < ctst._len; row++) { double d; if (ctst.isNA(row)) d = Double.NaN; else if (ctst.atd(row) == 0) d = has_ffr ? domainMap(chks[i + nchks.length + (has_tfr ? nchks.length : 0)].atd(row), maps[i]) : fs != null ? fsIntMap[i] : fd; else d = has_tfr ? domainMap(chks[i + nchks.length].atd(row), maps[i]) : ts != null ? tsIntMap[i] : td; res.addNum(d); } } } }.doAll(tst.numCols(), Vec.T_NUM, fr).outputFrame(null, domains); // flatten domains since they may be larger than needed if (domains != null) { for (int i = 0; i < res.numCols(); ++i) { if (res.vec(i).domain() != null) { final long[] dom = new VecUtils.CollectDomainFast((int) res.vec(i).max()).doAll(res.vec(i)).domain(); String[] newDomain = new String[dom.length]; for (int l = 0; l < dom.length; ++l) newDomain[l] = res.vec(i).domain()[(int) dom[l]]; new MRTask() { @Override public void map(Chunk c) { for (int i = 0; i < c._len; ++i) { if (!c.isNA(i)) c.set(i, ArrayUtils.find(dom, c.at8(i))); } } }.doAll(res.vec(i)); res.vec(i).setDomain(newDomain); // needs a DKVput? } } } return new ValFrame(res); } private static double domainMap(double d, int[] maps) { if (maps != null && d == (int) d && (0 <= d && d < maps.length)) return maps[(int) d]; return d; } private static int[] computeMap(String[] from, String[] to) { int[] map = new int[from.length]; for (int i = 0; i < from.length; ++i) map[i] = ArrayUtils.find(to, from[i]); return map; } Val exec_check(Env env, Env.StackHelp stk, Frame tst, AstRoot ast, Frame xfr) { Val val = ast.exec(env); if (val.isFrame()) { Frame fr = stk.track(val).getFrame(); if (tst.numCols() != fr.numCols() || tst.numRows() != fr.numRows()) throw new IllegalArgumentException("ifelse test frame and other frames must match dimensions, found " + tst + " and " + fr); xfr.add(fr); } return val; } ValRow row_ifelse(ValRow tst, Val yes, Val no) { double[] test = tst.getRow(); double[] True; double[] False; if (!(yes.isRow() || no.isRow())) throw H2O.unimpl(); switch (yes.type()) { case Val.NUM: True = new double[]{yes.getNum()}; break; case Val.ROW: True = yes.getRow(); break; default: throw H2O.unimpl("row ifelse unimpl: " + yes.getClass()); } switch (no.type()) { case Val.NUM: False = new double[]{no.getNum()}; break; case Val.ROW: False = no.getRow(); break; default: throw H2O.unimpl("row ifelse unimplL " + no.getClass()); } double[] ds = new double[test.length]; String[] ns = new String[test.length]; for (int i = 0; i < test.length; ++i) { ns[i] = "C" + (i + 1); if (Double.isNaN(test[i])) ds[i] = Double.NaN; else ds[i] = test[i] == 0 ? False[i] : True[i]; } return new ValRow(ds, ns); } }