package water.rapids.ast.prims.string;
import org.apache.commons.lang.StringUtils;
import water.MRTask;
import water.fvec.*;
import water.parser.BufferedString;
import water.rapids.Env;
import water.rapids.vals.ValFrame;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.ast.params.AstStrList;
/**
* Accepts a frame with a single string column, and a substring to look for in the target.
* Returns a new integer column containing the countMatches result for each string in the
* target column.
* <p/>
* countMatches - Counts how many times the substring appears in the larger string.
* If either the target string or substring are empty (""), 0 is returned.
*/
public class AstCountMatches extends AstPrimitive {
@Override
public String[] args() {
return new String[]{"ary", "pattern"};
}
@Override
public int nargs() {
return 1 + 2;
} // (countmatches x pattern)
@Override
public String str() {
return "countmatches";
}
@Override
public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) {
Frame fr = stk.track(asts[1].exec(env)).getFrame();
final String[] pattern = asts[2] instanceof AstStrList
? ((AstStrList) asts[2])._strs
: new String[]{asts[2].exec(env).getStr()};
// Type check
for (Vec v : fr.vecs())
if (!(v.isCategorical() || v.isString()))
throw new IllegalArgumentException("countmatches() requires a string or categorical column. "
+ "Received " + fr.anyVec().get_type_str()
+ ". Please convert column to a string or categorical first.");
// Transform each vec
Vec nvs[] = new Vec[fr.numCols()];
int i = 0;
for (Vec v : fr.vecs()) {
if (v.isCategorical())
nvs[i] = countMatchesCategoricalCol(v, pattern);
else
nvs[i] = countMatchesStringCol(v, pattern);
i++;
}
return new ValFrame(new Frame(nvs));
}
private Vec countMatchesCategoricalCol(Vec vec, String[] pattern) {
final int[] matchCounts = countDomainMatches(vec.domain(), pattern);
return new MRTask() {
@Override
public void map(Chunk[] cs, NewChunk[] ncs) {
Chunk c = cs[0];
for (int i = 0; i < c._len; ++i) {
if (!c.isNA(i)) {
int idx = (int) c.at8(i);
ncs[0].addNum(matchCounts[idx]);
} else ncs[0].addNA();
}
}
}.doAll(1, Vec.T_NUM, new Frame(vec)).outputFrame().anyVec();
}
int[] countDomainMatches(String[] domain, String[] pattern) {
int[] res = new int[domain.length];
for (int i = 0; i < domain.length; i++)
for (String aPattern : pattern)
res[i] += StringUtils.countMatches(domain[i], aPattern);
return res;
}
private Vec countMatchesStringCol(Vec vec, String[] pat) {
final String[] pattern = pat;
return new MRTask() {
@Override
public void map(Chunk chk, NewChunk newChk) {
if (chk instanceof C0DChunk) // all NAs
for (int i = 0; i < chk.len(); i++)
newChk.addNA();
else {
BufferedString tmpStr = new BufferedString();
for (int i = 0; i < chk._len; ++i) {
if (chk.isNA(i)) newChk.addNA();
else {
int cnt = 0;
for (String aPattern : pattern)
cnt += StringUtils.countMatches(chk.atStr(tmpStr, i).toString(), aPattern);
newChk.addNum(cnt, 0);
}
}
}
}
}.doAll(Vec.T_NUM, new Frame(vec)).outputFrame().anyVec();
}
}