package water.rapids.ast.prims.search;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.rapids.Env;
import water.rapids.vals.ValFrame;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.ast.params.AstNum;
import water.rapids.ast.params.AstNumList;
import water.rapids.ast.params.AstStr;
import water.rapids.ast.params.AstStrList;
import water.util.MathUtils;
import java.util.Arrays;
/**
*/
public class AstMatch extends AstPrimitive {
@Override
public String[] args() {
return new String[]{"ary", "table", "nomatch", "incomparables"};
}
@Override
public int nargs() {
return 1 + 4;
} // (match fr table nomatch incomps)
@Override
public String str() {
return "match";
}
@Override
public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) {
Frame fr = stk.track(asts[1].exec(env)).getFrame();
if ((fr.numCols() != 1) || ! (fr.anyVec().isCategorical() || fr.anyVec().isString()))
throw new IllegalArgumentException("can only match on a single categorical/string column.");
final MRTask<?> matchTask;
double noMatch = asts[3].exec(env).getNum();
if (asts[2] instanceof AstNumList) {
matchTask = new NumMatchTask(((AstNumList) asts[2]).sort().expand(), noMatch);
} else if (asts[2] instanceof AstNum) {
matchTask = new NumMatchTask(new double[]{asts[2].exec(env).getNum()}, noMatch);
} else if (asts[2] instanceof AstStrList) {
String[] values = ((AstStrList) asts[2])._strs;
Arrays.sort(values);
matchTask = fr.anyVec().isString() ? new StrMatchTask(values, noMatch) : new CatMatchTask(values, noMatch);
} else if (asts[2] instanceof AstStr) {
String[] values = new String[]{asts[2].exec(env).getStr()};
matchTask = fr.anyVec().isString() ? new StrMatchTask(values, noMatch) : new CatMatchTask(values, noMatch);
} else
throw new IllegalArgumentException("Expected numbers/strings. Got: " + asts[2].getClass());
Frame result = matchTask.doAll(Vec.T_NUM, fr.anyVec()).outputFrame();
return new ValFrame(result);
}
private static class StrMatchTask extends MRTask<CatMatchTask> {
String[] _values;
double _noMatch;
StrMatchTask(String[] values, double noMatch) {
_values = values;
_noMatch = noMatch;
}
@Override
public void map(Chunk c, NewChunk nc) {
BufferedString bs = new BufferedString();
int rows = c._len;
for (int r = 0; r < rows; r++) {
double x = c.isNA(r) ? _noMatch : in(_values, c.atStr(bs, r).toString(), _noMatch);
nc.addNum(x);
}
}
}
private static class CatMatchTask extends MRTask<CatMatchTask> {
String[] _values;
double _noMatch;
CatMatchTask(String[] values, double noMatch) {
_values = values;
_noMatch = noMatch;
}
@Override
public void map(Chunk c, NewChunk nc) {
String[] domain = c.vec().domain();
int rows = c._len;
for (int r = 0; r < rows; r++) {
double x = c.isNA(r) ? _noMatch : in(_values, domain[(int) c.at8(r)], _noMatch);
nc.addNum(x);
}
}
}
private static class NumMatchTask extends MRTask<CatMatchTask> {
double[] _values;
double _noMatch;
NumMatchTask(double[] values, double noMatch) {
_values = values;
_noMatch = noMatch;
}
@Override
public void map(Chunk c, NewChunk nc) {
int rows = c._len;
for (int r = 0; r < rows; r++) {
double x = c.isNA(r) ? _noMatch : in(_values, c.atd(r), _noMatch);
nc.addNum(x);
}
}
}
private static double in(String[] matches, String s, double nomatch) {
return Arrays.binarySearch(matches, s) >= 0 ? 1 : nomatch;
}
private static double in(double[] matches, double d, double nomatch) {
return binarySearchDoublesUlp(matches, 0, matches.length, d) >= 0 ? 1 : nomatch;
}
private static int binarySearchDoublesUlp(double[] a, int from, int to, double key) {
int lo = from;
int hi = to - 1;
while (lo <= hi) {
int mid = (lo + hi) >>> 1;
double midVal = a[mid];
if (MathUtils.equalsWithinOneSmallUlp(midVal, key)) return mid;
if (midVal < key) lo = mid + 1;
else if (midVal > key) hi = mid - 1;
else {
long midBits = Double.doubleToLongBits(midVal);
long keyBits = Double.doubleToLongBits(key);
if (midBits == keyBits) return mid;
else if (midBits < keyBits) lo = mid + 1;
else hi = mid - 1;
}
}
return -(lo + 1); // key not found.
}
}