package water.rapids.ast.prims.search; import water.H2O; import water.Key; import water.MRTask; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.NewChunk; import water.rapids.Env; import water.rapids.Val; import water.rapids.ast.AstBuiltin; import water.fvec.Vec; import water.rapids.ast.AstRoot; import water.rapids.vals.ValFrame; import water.rapids.vals.ValRow; public abstract class AstWhichFunc extends AstBuiltin<AstWhichFunc> { @Override public String[] args() { return new String[]{"frame", "na_rm", "axis"}; } @Override public int nargs() { return 1 + 1; } @Override public String str() { throw H2O.unimpl(); } public abstract double op(Vec l); //Operation to perform in colWiseWhichVal() -> Vec.max() or Vec.min(). public abstract String searchVal(); //String indicating what we are searching for across rows in rowWiseWhichVal() -> max or min. public abstract double init(); @Override public Val apply(Env env, Env.StackHelp stk, AstRoot[] asts) { Val val1 = asts[1].exec(env); if (val1 instanceof ValFrame) { Frame fr = stk.track(val1).getFrame(); boolean na_rm = asts[2].exec(env).getNum() == 1; boolean axis = asts.length == 4 && (asts[3].exec(env).getNum() == 1); return axis ? rowwiseWhichVal(fr, na_rm) : colwiseWhichVal(fr, na_rm); } else if (val1 instanceof ValRow) { // This may be called from AstApply when doing per-row computations. double[] row = val1.getRow(); boolean na_rm = asts[2].exec(env).getNum() == 1; double val = Double.NEGATIVE_INFINITY; double valIndex = 0; if(searchVal() == "max") { //Looking for the max? for (int i = 0; i < row.length; i++) { if (Double.isNaN(row[i])) { if (!na_rm) return new ValRow(new double[]{Double.NaN}, null); } else { if (row[i] > val) { val = row[i]; valIndex = i; } } } }else if(searchVal() == "min"){ //Looking for the min? for (int i = 0; i < row.length; i++) { if (Double.isNaN(row[i])) { if (!na_rm) return new ValRow(new double[]{Double.NaN}, null); } else { if (row[i] < val) { val = row[i]; valIndex = i; } } } } else{ throw new IllegalArgumentException("Incorrect argument: expected to search for max() or min(), received " + searchVal()); } return new ValRow(new double[]{valIndex}, null); } else throw new IllegalArgumentException("Incorrect argument: expected a frame or a row, received " + val1.getClass()); } /** * Compute row-wise, and return a frame consisting of a single Vec of value indexes in each row. */ private ValFrame rowwiseWhichVal(Frame fr, final boolean na_rm) { String[] newnames = {"which." + searchVal()}; Key<Frame> newkey = Key.make(); // Determine how many columns of different types we have int n_numeric = 0, n_time = 0; for (Vec vec : fr.vecs()) { if (vec.isNumeric()) n_numeric++; if (vec.isTime()) n_time++; } // Compute the type of the resulting column: if all columns are TIME then the result is also time; otherwise // if at least one column is numeric then the result is also numeric. byte resType = n_numeric > 0 ? Vec.T_NUM : Vec.T_TIME; // Construct the frame over which the val index should be computed Frame compFrame = new Frame(); for (int i = 0; i < fr.numCols(); i++) { Vec vec = fr.vec(i); if (n_numeric > 0? vec.isNumeric() : vec.isTime()) compFrame.add(fr.name(i), vec); } Vec anyvec = compFrame.anyVec(); // Take into account certain corner cases if (anyvec == null) { Frame res = new Frame(newkey); anyvec = fr.anyVec(); if (anyvec != null) { // All columns in the original frame are non-numeric -> return a vec of NAs res.add("which." + searchVal(), anyvec.makeCon(Double.NaN)); } // else the original frame is empty, in which case we return an empty frame too return new ValFrame(res); } if (!na_rm && n_numeric < fr.numCols() && n_time < fr.numCols()) { // If some of the columns are non-numeric and na_rm==false, then the result is a vec of NAs Frame res = new Frame(newkey, newnames, new Vec[]{anyvec.makeCon(Double.NaN)}); return new ValFrame(res); } // Compute over all rows final int numCols = compFrame.numCols(); Frame res = new MRTask() { @Override public void map(Chunk[] cs, NewChunk nc) { for (int i = 0; i < cs[0]._len; i++) { int numNaColumns = 0; double value = Double.NEGATIVE_INFINITY; int valueIndex = 0; if (searchVal() == "max") { //Looking for the max? for (int j = 0; j < numCols; j++) { double val = cs[j].atd(i); if (Double.isNaN(val)) { numNaColumns++; } else if (val > value) { //Return the first occurrence of the val value = val; valueIndex = j; } } }else if(searchVal()=="min"){ //Looking for the min? for (int j = 0; j < numCols; j++) { double val = cs[j].atd(i); if (Double.isNaN(val)) { numNaColumns++; } else if(val < value) { //Return the first occurrence of the min index value = val; valueIndex = j; } } }else{ throw new IllegalArgumentException("Incorrect argument: expected to search for max() or min(), received " + searchVal()); } if (na_rm ? numNaColumns < numCols : numNaColumns == 0) nc.addNum(valueIndex); else nc.addNum(Double.NaN); } } }.doAll(1, resType, compFrame) .outputFrame(newkey, newnames, null); // Return the result return new ValFrame(res); } /** * Compute column-wise (i.e.value index of each column), and return a frame having a single row. */ private ValFrame colwiseWhichVal(Frame fr, final boolean na_rm) { Frame res = new Frame(); Vec vec1 = Vec.makeCon(null, 0); assert vec1.length() == 1; for (int i = 0; i < fr.numCols(); i++) { Vec v = fr.vec(i); double searchValue = op(v); boolean valid = (v.isNumeric() || v.isTime() || v.isBinary()) && v.length() > 0 && (na_rm || v.naCnt() == 0); FindIndexCol findIndexCol = new FindIndexCol(searchValue).doAll(new byte[]{Vec.T_NUM}, v); Vec newvec = vec1.makeCon(valid ? findIndexCol._valIndex : Double.NaN, v.isTime()? Vec.T_TIME : Vec.T_NUM); res.add(fr.name(i), newvec); } vec1.remove(); return new ValFrame(res); } private static class FindIndexCol extends MRTask<AstWhichFunc.FindIndexCol>{ double _val; double _valIndex; FindIndexCol(double val) { _val = val; _valIndex = Double.POSITIVE_INFINITY; } @Override public void map(Chunk c, NewChunk nc) { long start = c.start(); for (int i = 0; i < c._len; ++i) { if (c.atd(i) == _val) { _valIndex = start + i; break; } } } @Override public void reduce(AstWhichFunc.FindIndexCol mic) { _valIndex = Math.min(_valIndex, mic._valIndex); //Return the first occurrence of the val index } } }