package water.rapids.ast.prims.string; import org.apache.commons.io.FileUtils; import water.MRTask; import water.fvec.*; import water.parser.BufferedString; import water.rapids.Env; import water.rapids.vals.ValFrame; import water.rapids.ast.AstPrimitive; import water.rapids.ast.AstRoot; import java.io.File; import java.io.IOException; import java.util.HashSet; /** */ public class AstCountSubstringsWords extends AstPrimitive { @Override public String[] args() { return new String[]{"ary", "words"}; } @Override public int nargs() { return 1 + 2; } // (num_valid_substrings x words) @Override public String str() { return "num_valid_substrings"; } @Override public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) { Frame fr = stk.track(asts[1].exec(env)).getFrame(); String wordsPath = asts[2].exec(env).getStr(); //Type check for (Vec v : fr.vecs()) if (!(v.isCategorical() || v.isString())) throw new IllegalArgumentException("num_valid_substrings() requires a string or categorical column. " + "Received " + fr.anyVec().get_type_str() + ". Please convert column to a string or categorical first."); HashSet<String> words = null; try { words = new HashSet<>(FileUtils.readLines(new File(wordsPath))); } catch (IOException e) { e.printStackTrace(); } //Transform each vec Vec nvs[] = new Vec[fr.numCols()]; int i = 0; for (Vec v : fr.vecs()) { if (v.isCategorical()) nvs[i] = countSubstringsWordsCategoricalCol(v, words); else nvs[i] = countSubstringsWordsStringCol(v, words); i++; } return new ValFrame(new Frame(nvs)); } private Vec countSubstringsWordsCategoricalCol(Vec vec, final HashSet<String> words) { Vec res = new MRTask() { transient double[] catCounts; @Override public void setupLocal() { String[] doms = _fr.anyVec().domain(); catCounts = new double[doms.length]; for (int i = 0; i < doms.length; i++) catCounts[i] = calcCountSubstringsWords(doms[i], words); } @Override public void map(Chunk chk, NewChunk newChk) { //pre-allocate since the size is known newChk.alloc_doubles(chk._len); for (int i = 0; i < chk._len; i++) if (chk.isNA(i)) newChk.addNA(); else newChk.addNum(catCounts[(int) chk.atd(i)]); } }.doAll(1, Vec.T_NUM, new Frame(vec)).outputFrame().anyVec(); return res; } private Vec countSubstringsWordsStringCol(Vec vec, final HashSet<String> words) { return new MRTask() { @Override public void map(Chunk chk, NewChunk newChk) { if (chk instanceof C0DChunk) //all NAs newChk.addNAs(chk.len()); else { //UTF requires Java string methods BufferedString tmpStr = new BufferedString(); for (int i = 0; i < chk._len; i++) { if (chk.isNA(i)) newChk.addNA(); else { String str = chk.atStr(tmpStr, i).toString(); newChk.addNum(calcCountSubstringsWords(str, words)); } } } } }.doAll(new byte[]{Vec.T_NUM}, vec).outputFrame().anyVec(); } // count all substrings >= 2 chars that are in words private int calcCountSubstringsWords(String str, HashSet<String> words) { int wordCount = 0; int N = str.length(); for (int i = 0; i < N - 1; i++) for (int j = i + 2; j < N + 1; j++) { if (words.contains(str.substring(i, j))) wordCount += 1; } return wordCount; } }