package water.rapids.ast.prims.string; import water.MRTask; import water.fvec.*; import water.parser.BufferedString; import water.rapids.Env; import water.rapids.Val; import water.rapids.vals.ValFrame; import water.rapids.ast.AstPrimitive; import water.rapids.ast.AstRoot; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; /** */ public class AstStrSplit extends AstPrimitive { @Override public String[] args() { return new String[]{"ary", "split"}; } @Override public int nargs() { return 1 + 2; } // (strsplit x split) @Override public String str() { return "strsplit"; } @Override public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) { Frame fr = stk.track(asts[1].exec(env)).getFrame(); String splitRegEx = asts[2].exec(env).getStr(); // Type check for (Vec v : fr.vecs()) if (!(v.isCategorical() || v.isString())) throw new IllegalArgumentException("strsplit() requires a string or categorical column. " + "Received " + fr.anyVec().get_type_str() + ". Please convert column to a string or categorical first."); // Transform each vec ArrayList<Vec> vs = new ArrayList<>(fr.numCols()); for (Vec v : fr.vecs()) { Vec[] splits; if (v.isCategorical()) { splits = strSplitCategoricalCol(v, splitRegEx); for (Vec split : splits) vs.add(split); } else { splits = strSplitStringCol(v, splitRegEx); for (Vec split : splits) vs.add(split); } } return new ValFrame(new Frame(vs.toArray(new Vec[vs.size()]))); } private Vec[] strSplitCategoricalCol(Vec vec, String splitRegEx) { final String[] old_domains = vec.domain(); final String[][] new_domains = newDomains(old_domains, splitRegEx); final String regex = splitRegEx; return new MRTask() { @Override public void map(Chunk[] cs, NewChunk[] ncs) { Chunk c = cs[0]; for (int i = 0; i < c._len; ++i) { int cnt = 0; if (!c.isNA(i)) { int idx = (int) c.at8(i); String s = old_domains[idx]; String[] ss = s.split(regex); for (String s1 : ss) { int n_idx = Arrays.asList(new_domains[cnt]).indexOf(s1); if (n_idx == -1) ncs[cnt++].addNA(); else ncs[cnt++].addNum(n_idx); } } if (cnt < ncs.length) for (; cnt < ncs.length; ++cnt) ncs[cnt].addNA(); } } }.doAll(new_domains.length, Vec.T_CAT, new Frame(vec)).outputFrame(null, null, new_domains).vecs(); } // each domain level may split in its own uniq way. // hold onto a hashset of domain levels for each "new" column private String[][] newDomains(String[] domains, String regex) { ArrayList<HashSet<String>> strs = new ArrayList<>(); // loop over each level in the domain HashSet<String> x; for (String domain : domains) { String[] news = domain.split(regex); for (int i = 0; i < news.length; ++i) { // we have a "new" column, must add a new HashSet to the array // list and start tracking levels for this "i" if (strs.size() == i) { x = new HashSet<>(); x.add(news[i]); strs.add(x); } else { // ok not a new column // whip out the current set of levels and add the new one strs.get(i).add(news[i]); } } } return listToArray(strs); } private String[][] listToArray(ArrayList<HashSet<String>> strs) { String[][] doms = new String[strs.size()][]; int i = 0; for (HashSet<String> h : strs) doms[i++] = h.toArray(new String[h.size()]); return doms; } private Vec[] strSplitStringCol(Vec vec, final String splitRegEx) { final int newColCnt = (new AstStrSplit.CountSplits(splitRegEx)).doAll(vec)._maxSplits; return new MRTask() { @Override public void map(Chunk[] cs, NewChunk[] ncs) { Chunk chk = cs[0]; if (chk instanceof C0DChunk) // all NAs for (int row = 0; row < chk.len(); row++) for (int col = 0; col < ncs.length; col++) ncs[col].addNA(); else { BufferedString tmpStr = new BufferedString(); for (int row = 0; row < chk._len; ++row) { int col = 0; if (!chk.isNA(row)) { String[] ss = chk.atStr(tmpStr, row).toString().split(splitRegEx); for (String s : ss) // distribute strings among new cols ncs[col++].addStr(s); } if (col < ncs.length) // fill remaining cols w/ NA for (; col < ncs.length; col++) ncs[col].addNA(); } } } }.doAll(newColCnt, Vec.T_STR, new Frame(vec)).outputFrame().vecs(); } /** * Run through column to figure out the maximum split that * any string in the column will need. */ private static class CountSplits extends MRTask<AstStrSplit.CountSplits> { // IN private final String _regex; // OUT int _maxSplits = 0; CountSplits(String regex) { _regex = regex; } @Override public void map(Chunk chk) { BufferedString tmpStr = new BufferedString(); for (int row = 0; row < chk._len; row++) { if (!chk.isNA(row)) { int split = chk.atStr(tmpStr, row).toString().split(_regex).length; if (split > _maxSplits) _maxSplits = split; } } } @Override public void reduce(AstStrSplit.CountSplits that) { if (this._maxSplits < that._maxSplits) this._maxSplits = that._maxSplits; } } }