package water.rapids.ast.prims.mungers;
import water.DKV;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.params.AstStrList;
import water.util.VecUtils;
import java.util.Arrays;
/**
*/
public class AstSetDomain extends AstPrimitive {
@Override
public String[] args() {
return new String[]{"ary", "newDomains"};
}
@Override
public int nargs() {
return 1 + 2;
} // (setDomain x [list of strings])
@Override
public String str() {
return "setDomain";
}
@Override
public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) {
Frame f = stk.track(asts[1].exec(env)).getFrame();
String[] _domains = ((AstStrList) asts[2])._strs;
if (f.numCols() != 1)
throw new IllegalArgumentException("Must be a single column. Got: " + f.numCols() + " columns.");
Vec v = f.anyVec();
if (!v.isCategorical())
throw new IllegalArgumentException("Vector must be a factor column. Got: " + v.get_type_str());
if (_domains != null && _domains.length != v.domain().length) {
// in this case we want to recollect the domain and check that number of levels matches _domains
VecUtils.CollectDomainFast t = new VecUtils.CollectDomainFast((int) v.max());
t.doAll(v);
final long[] dom = t.domain();
if (dom.length != _domains.length)
throw new IllegalArgumentException("Number of replacement factors must equal current number of levels. Current number of levels: " + dom.length + " != " + _domains.length);
new MRTask() {
@Override
public void map(Chunk c) {
for (int i = 0; i < c._len; ++i) {
if (!c.isNA(i)) {
long num = Arrays.binarySearch(dom, c.at8(i));
if (num < 0) throw new IllegalArgumentException("Could not find the categorical value!");
c.set(i, num);
}
}
}
}.doAll(v);
}
v.setDomain(_domains);
DKV.put(v);
return new ValFrame(f);
}
}