package water.rapids.transforms; import org.apache.commons.lang.ArrayUtils; import water.DKV; import water.H2O; import water.fvec.Frame; import water.rapids.*; import water.rapids.ast.AstExec; import water.rapids.ast.AstParameter; import water.rapids.ast.AstRoot; import water.rapids.ast.params.AstId; public class H2OColOp extends Transform<H2OColOp> { protected final String _fun; private final String _oldCol; private String[] _newCol; private String _newColTypes; boolean _multiColReturn; public H2OColOp(String name, String ast, boolean inplace, String[] newNames) { // (op (cols fr cols) {extra_args}) super(name,ast,inplace,newNames); _fun = _ast._asts[0].str(); _oldCol = ((AstExec)_ast._asts[1])._asts[2].str(); setupParams(); } private void setupParams() { String[] args = _ast.getArgs(); if( args!=null && args.length > 1 ) { // first arg is the frame for(int i=1; i<args.length; ++i) setupParamsImpl(i,args); } } protected void setupParamsImpl(int i, String[] args) { _params.put(args[i], (AstParameter) _ast._asts[i + 1]); } @Override public Transform<H2OColOp> fit(Frame f) { return this; } @Override protected Frame transformImpl(Frame f) { ((AstExec)_ast._asts[1])._asts[1] = new AstId(f); Session ses = new Session(); Frame fr = ses.exec(_ast, null).getFrame(); _newCol = _newNames==null?new String[fr.numCols()]:_newNames; _newColTypes = toJavaPrimitive(fr.anyVec().get_type_str()); if( (_multiColReturn=fr.numCols() > 1) ) { for(int i=0;i<_newCol.length;i++) { if(_newNames==null) _newCol[i] = f.uniquify(i > 0 ? _newCol[i - 1] : _oldCol); f.add(_newCol[i], fr.vec(i)); } if( _inplace ) f.remove(f.find(_oldCol)).remove(); } else { _newCol = _newNames==null?new String[]{_inplace ? _oldCol : f.uniquify(_oldCol)}:_newCol; if( _inplace ) f.replace(f.find(_oldCol), fr.anyVec()).remove(); else f.add(_newNames == null ? _newCol[0] : _newNames[0], fr.anyVec()); } DKV.put(f); return f; } @Override Frame inverseTransform(Frame f) { throw H2O.unimpl(); } @Override public String genClassImpl() { String typeCast = _inTypes[ArrayUtils.indexOf(_inNames, _oldCol)].equals("Numeric")?"Double":"String"; if( _multiColReturn ) { StringBuilder sb = new StringBuilder( " @Override public RowData transform(RowData row) {\n"+ (paramIsRow() ? addRowParam() : "") + " "+_newColTypes+"[] res = GenMunger."+lookup(_fun)+"(("+typeCast+")row.get(\""+_oldCol+"\"), _params);\n"); for(int i=0;i<_newCol.length;i++) sb.append( " row.put(\""+_newCol[i]+"\",("+i+">=res.length)?\"\":res["+i+"]);\n"); sb.append( " return row;\n" + " }\n"); return sb.toString(); } else { return " @Override public RowData transform(RowData row) {\n"+ (paramIsRow() ? addRowParam() : "") + " "+_newColTypes+" res = GenMunger."+lookup(_fun)+"(("+typeCast+")row.get(\""+_oldCol+"\"), _params);\n"+ " row.put(\""+_newCol[0]+"\", res);\n" + " return row;\n" + " }\n"; } } protected boolean paramIsRow() { return false; } protected String addRowParam() { return ""; } protected String lookup(String op) { return op.replaceAll("\\.",""); } }