package water.rapids.ast.prims.string;
import water.Iced;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.rapids.Val;
import water.rapids.ast.AstBuiltin;
import water.rapids.vals.ValFrame;
import java.util.Arrays;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Searches for matches to argument "regex" within each element
* of a string column.
*
* Params:
* - regex regular expression
* - ignore_case if ignore_case == 1, matching is case insensitive
* - invert if invert == 1, identifies elements that do not match the regex
* - output_logical if output_logical == 1, result will be a logical vector, otherwise returns matching positions
*/
public class AstGrep extends AstBuiltin<AstGrep> {
@Override
public String[] args() {
return new String[]{"ary", "regex", "ignore_case", "invert", "output_logical"};
}
@Override
public int nargs() {
return 1 + 5;
} // (grep x regex ignore_case invert output_logical)
@Override
public String str() {
return "grep";
}
@Override
protected Val exec(Val[] args) {
Frame fr = args[1].getFrame();
String regex = args[2].getStr();
boolean ignoreCase = args[3].getNum() == 1;
boolean invert = args[4].getNum() == 1;
boolean outputLogical = args[5].getNum() == 1;
GrepHelper grepHelper = new GrepHelper(regex, ignoreCase, invert, outputLogical);
if ((fr.numCols() != 1) || ! (fr.anyVec().isCategorical() || fr.anyVec().isString()))
throw new IllegalArgumentException("can only grep on a single categorical/string column.");
Vec v = fr.anyVec();
assert v != null;
Frame result;
if (v.isCategorical()) {
int[] filtered = grepDomain(grepHelper, v);
Arrays.sort(filtered);
result = new GrepCatTask(grepHelper, filtered).doAll(Vec.T_NUM, v).outputFrame();
} else {
result = new GrepStrTask(grepHelper).doAll(Vec.T_NUM, v).outputFrame();
}
return new ValFrame(result);
}
private static int[] grepDomain(GrepHelper grepHelper, Vec v) {
Pattern p = grepHelper.compilePattern();
String[] domain = v.domain();
int cnt = 0;
int[] filtered = new int[domain.length];
for (int i = 0; i < domain.length; i++) {
if (p.matcher(domain[i]).find())
filtered[cnt++] = i;
}
int[] result = new int[cnt];
System.arraycopy(filtered, 0, result, 0, cnt);
return result;
}
private static class GrepCatTask extends MRTask<GrepCatTask> {
private final int[] _matchingCats;
private final GrepHelper _gh;
GrepCatTask(GrepHelper gh, int[] matchingCats) {
_matchingCats = matchingCats;
_gh = gh;
}
@Override
public void map(Chunk c, NewChunk n) {
OutputWriter w = OutputWriter.makeWriter(_gh, n, c.start());
int rows = c._len;
for (int r = 0; r < rows; r++) {
if (c.isNA(r)) {
w.addNA(r);
} else {
int cat = (int) c.at8(r);
int pos = Arrays.binarySearch(_matchingCats, cat);
w.addRow(r, pos >= 0);
}
}
}
}
private static class GrepStrTask extends MRTask<GrepStrTask> {
private final GrepHelper _gh;
GrepStrTask(GrepHelper gh) {
_gh = gh;
}
@Override
public void map(Chunk c, NewChunk n) {
OutputWriter w = OutputWriter.makeWriter(_gh, n, c.start());
Pattern p = _gh.compilePattern();
Matcher m = p.matcher("");
BufferedString bs = new BufferedString();
int rows = c._len;
for (int r = 0; r < rows; r++) {
if (c.isNA(r)) {
w.addNA(r);
} else {
m.reset(c.atStr(bs, r).toString());
w.addRow(r, m.find());
}
}
}
}
private static class GrepHelper extends Iced<GrepHelper> {
private String _regex;
private boolean _ignoreCase;
private boolean _invert;
private boolean _outputLogical;
public GrepHelper() {}
GrepHelper(String regex, boolean ignoreCase, boolean invert, boolean outputLogical) {
_regex = regex;
_ignoreCase = ignoreCase;
_invert = invert;
_outputLogical = outputLogical;
}
Pattern compilePattern() {
int flags = _ignoreCase ? Pattern.CASE_INSENSITIVE | Pattern.UNICODE_CASE : 0;
return Pattern.compile(_regex, flags);
}
}
private static abstract class OutputWriter {
static final double MATCH = 1;
static final double NO_MATCH = 0;
NewChunk _nc;
long _start;
boolean _invert;
OutputWriter(NewChunk nc, long start, boolean invert) {
_nc = nc;
_start = start;
_invert = invert;
}
abstract void addNA(int row);
abstract void addRow(int row, boolean matched);
static OutputWriter makeWriter(GrepHelper gh, NewChunk nc, long start) {
return gh._outputLogical ? new IndicatorWriter(nc, start, gh._invert) : new PositionWriter(nc, start, gh._invert);
}
}
private static class IndicatorWriter extends OutputWriter {
IndicatorWriter(NewChunk nc, long start, boolean invert) {
super(nc, start, invert);
}
@Override
void addNA(int row) {
_nc.addNum(_invert ? MATCH : NO_MATCH);
}
@Override
void addRow(int row, boolean matched) {
_nc.addNum(matched != _invert ? MATCH : NO_MATCH);
}
}
private static class PositionWriter extends OutputWriter {
PositionWriter(NewChunk nc, long start, boolean invert) {
super(nc, start, invert);
}
@Override
void addNA(int row) {
if (_invert)
_nc.addNum(_start + row);
}
@Override
void addRow(int row, boolean matched) {
if (matched != _invert)
_nc.addNum(_start + row);
}
}
}