package water.rapids.ast.prims.mungers; import water.MRTask; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.NewChunk; import water.fvec.Vec; import water.rapids.*; import water.rapids.ast.AstPrimitive; import water.rapids.ast.AstRoot; import water.rapids.ast.params.AstNum; import water.rapids.ast.params.AstNumList; import water.rapids.ast.params.AstStr; import water.rapids.ast.params.AstStrList; import water.rapids.vals.ValFrame; import water.util.MathUtils; import java.util.Arrays; public class AstCut extends AstPrimitive { @Override public String[] args() { return new String[]{"ary", "breaks", "labels", "include_lowest", "right", "digits"}; } @Override public int nargs() { return 1 + 6; } // (cut x breaks labels include_lowest right digits) @Override public String str() { return "cut"; } @Override public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) { Frame fr = stk.track(asts[1].exec(env)).getFrame(); double[] cuts = check(asts[2]); Arrays.sort(cuts); String[] labels = check2(asts[3]); final boolean lowest = asts[4].exec(env).getNum() == 1; final boolean rite = asts[5].exec(env).getNum() == 1; final int digits = Math.min((int) asts[6].exec(env).getNum(), 12); // cap at 12 if (fr.vecs().length != 1 || fr.vecs()[0].isCategorical()) throw new IllegalArgumentException("First argument must be a numeric column vector"); double fmin = fr.anyVec().min(); double fmax = fr.anyVec().max(); int nbins = cuts.length - 1; // c(0,10,100) -> 2 bins (0,10] U (10, 100] double width; if (nbins == 0) { if (cuts[0] < 2) throw new IllegalArgumentException("The number of cuts must be >= 2. Got: " + cuts[0]); // in this case, cut the vec into _cuts[0] many pieces of equal length nbins = (int) Math.floor(cuts[0]); width = (fmax - fmin) / nbins; cuts = new double[nbins]; cuts[0] = fmin - 0.001 * (fmax - fmin); for (int i = 1; i < cuts.length; ++i) cuts[i] = (i == cuts.length - 1) ? (fmax + 0.001 * (fmax - fmin)) : (fmin + i * width); } // width = (fmax - fmin)/nbins; // if(width == 0) throw new IllegalArgumentException("Data vector is constant!"); if (labels != null && labels.length != nbins) throw new IllegalArgumentException("`labels` vector does not match the number of cuts."); // Construct domain names from _labels or bin intervals if _labels is null final double cutz[] = cuts; // first round _cuts to dig.lab decimals: example floor(2.676*100 + 0.5) / 100 for (int i = 0; i < cuts.length; ++i) cuts[i] = Math.floor(cuts[i] * Math.pow(10, digits) + 0.5) / Math.pow(10, digits); String[][] domains = new String[1][nbins]; if (labels == null) { domains[0][0] = (lowest ? "[" : left(rite)) + cuts[0] + "," + cuts[1] + rite(rite); for (int i = 1; i < (cuts.length - 1); ++i) domains[0][i] = left(rite) + cuts[i] + "," + cuts[i + 1] + rite(rite); } else domains[0] = labels; Frame fr2 = new MRTask() { @Override public void map(Chunk c, NewChunk nc) { int rows = c._len; for (int r = 0; r < rows; ++r) { double x = c.atd(r); if (Double.isNaN(x) || (lowest && x < cutz[0]) || (!lowest && (x < cutz[0] || MathUtils.equalsWithinOneSmallUlp(x, cutz[0]))) || (rite && x > cutz[cutz.length - 1]) || (!rite && (x > cutz[cutz.length - 1] || MathUtils.equalsWithinOneSmallUlp(x, cutz[cutz.length - 1])))) nc.addNum(Double.NaN); else { for (int i = 1; i < cutz.length; ++i) { if (rite) { if (x <= cutz[i]) { nc.addNum(i - 1); break; } } else if (x < cutz[i]) { nc.addNum(i - 1); break; } } } } } }.doAll(1, Vec.T_NUM, fr).outputFrame(fr.names(), domains); return new ValFrame(fr2); } private String left(boolean rite) { return rite ? "(" : "["; } private String rite(boolean rite) { return rite ? "]" : ")"; } private double[] check(AstRoot ast) { double[] n; if (ast instanceof AstNumList) n = ((AstNumList) ast).expand(); else if (ast instanceof AstNum) n = new double[]{((AstNum) ast).getNum()}; // this is the number of breaks wanted... else throw new IllegalArgumentException("Requires a number-list, but found a " + ast.getClass()); return n; } private String[] check2(AstRoot ast) { String[] s = null; if (ast instanceof AstStrList) s = ((AstStrList) ast)._strs; else if (ast instanceof AstStr) s = new String[]{ast.str()}; return s; } }