package water.api; import water.*; import water.exec.ASTTable; import water.exec.ASTddply.Group; import water.exec.Env; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.Vec; import water.nbhm.NonBlockingHashMap; import water.util.Log; import java.util.Arrays; public class Impute extends Request2 { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. public static final String DOC_GET = "Impute"; @API(help = "Data Frame containing columns to be imputed.", required = true, filter = Default.class, json=true) public Frame source; @API(help="Column which to impute.", required=true, filter=columnVecSelect.class, json=true) public Vec column; class columnVecSelect extends VecClassSelect { columnVecSelect() { super("source"); } } @API(help = "Method of impute: Mean, Median, Most Common", required = true, filter = Default.class, json=true) //, Regression, RandomForest public Method method = Method.mean; class colsFilter1 extends MultiVecSelect { public colsFilter1() { super("source");} } @API(help = "Columns to Select for Grouping", filter=colsFilter1.class) int[] group_by; public enum Method { mean, median, mode // regression, // randomForest } public Impute() {} protected boolean init() throws IllegalArgumentException { // Input handling if (source == null || column == null) throw new IllegalArgumentException("Missing data or input column!"); if (column.isBad()) { Log.info("Column is 100% NAs, nothing to do."); return true; } if (method != Method.mean && method != Method.median && method != Method.mode) // || method != Method.regression || method != Method.randomForest throw new IllegalArgumentException("method must be one of (mean, median, mode)"); // regression, randomForest)"); if ( !(column.isEnum()) && column.naCnt() <= 0) { Log.info("No NAs in the column, nothing to do."); return true; } if (column.isEnum() && !Arrays.asList(column._domain).contains("NA") && column.naCnt() <= 0 ) { Log.info("No NAs in the column, nothing to do."); return true; } // if (method == Method.regression && (column.isEnum() || column.isUUID() || column.isTime())) // throw new IllegalArgumentException("Trying to perform regression on non-numeric column! Please select a different column."); if (method == Method.mode && (!column.isEnum())) throw new IllegalArgumentException("Method `mode` only applicable to factor columns."); if (column.isEnum() && method != Method.mode) { Log.warn("Column to impute is a factor column, changing method to mode."); method = Method.mode; } return false; } @Override protected Response serve() { if (init()) return Inspect2.redirect(this, source._key.toString()); final int col_id = source.find(column); final int[] _cols = group_by; final Key mykey = Key.make(); try { if (group_by == null) { // just use "method" using the input "column" double _replace_val = 0; if (method == Method.mean) { _replace_val = column.mean(); } else if (method == Method.median) { QuantilesPage qp = new QuantilesPage(); qp.source_key = source; qp.column = column; qp.invoke(); _replace_val = qp.result; } else if (method == Method.mode) { String dom[] = column.domain(); long[][] levels = new long[1][]; levels[0] = new Vec.CollectDomain(column).doAll(new Frame(column)).domain(); long[][] counts = new ASTTable.Tabularize(levels).doAll(column)._counts; long maxCounts = -1; int mode = -1; for (int i = 0; i < counts[0].length; ++i) { if (counts[0][i] > maxCounts && !dom[i].equals("NA")) { // check for "NA" in domain -- corner case from R maxCounts = counts[0][i]; mode = i; } } _replace_val = mode != -1 ? (double) mode : (double) Arrays.asList(dom).indexOf("NA"); // could produce -1 if "NA" not in the domain -- that is we don't have the R corner case if (_replace_val == -1) _replace_val = Double.NaN; // OK to replace, since we're in the elif "mode" block } final double rv = _replace_val; new MRTask2() { @Override public void map(Chunk[] cs) { Chunk c = cs[col_id]; int rows = c.len(); for (int r = 0; r < rows; ++r) { if (c.isNA0(r) || (c._vec.isEnum() && c._vec.domain()[(int) c.at0(r)].equals("NA"))) { if (!Double.isNaN(rv)) c.set0(r, rv); // leave as NA if replace value is NA } } } }.doAll(source); } else { // collect the groups HashMap and the frame from the ddply. // create a vec of group IDs (each row is in some group) // MRTask over the rows water.exec.Exec2.exec(Key.make().toString() + " = anonymous <- function(x) \n{\n " + method + "(x[," + (col_id + 1) + "])\n}").remove_and_unlock(); Env env = water.exec.Exec2.exec(mykey.toString() + " = ddply(" + source._key.toString() + ", " + toAryString(_cols) + ", anonymous)"); final Frame grp_replacement = new Frame(env.peekAry()); env.remove_and_unlock(); Log.info("GROUP TASK NUM COLS: "+ grp_replacement.numCols()); final GroupTask grp2val = new GroupTask(grp_replacement.numCols() - 1).doAll(grp_replacement); new MRTask2() { @Override public void map(Chunk[] cs) { Chunk c = cs[col_id]; int rows = cs[0].len(); for (int r = 0; r < rows; ++r) { if (c.isNA0(r) || (c._vec.isEnum() && c._vec.domain()[(int) c.at0(r)].equals("NA"))) { Group g = new Group(_cols.length); g.fill(r, cs, _cols); if (grp2val._grp2val.get(g) == null) continue; double rv = grp2val._grp2val.get(g); c.set0(r, rv); } } } }.doAll(source); } return Inspect2.redirect(this, source._key.toString()); } catch( Throwable t ) { return Response.error(t); } finally { // Delete frames UKV.remove(mykey); } } private String toAryString(int[] c) { String res = "c("; for (int i = 0; i < c.length; ++i) { if (i ==c.length-1) res += String.valueOf(c[i] + 1) + ")"; // + 1 for 0 -> 1 based indexing else res += String.valueOf(c[i]+1)+","; // + 1 for 0 -> 1 based indexing } return res; } @Override public boolean toHTML( StringBuilder sb ) { return super.toHTML(sb); } // Create a table: Group -> Impute value private static class GroupTask extends MRTask2<GroupTask> { protected NonBlockingHashMap<Group, Double> _grp2val = new NonBlockingHashMap<Group, Double>(); int[] _cols; int _ncols; GroupTask(int ncols) { _cols = new int[_ncols=ncols]; for (int i = 0; i < _cols.length; ++i) _cols[i] = i;} @Override public void map(Chunk[] cs) { if (_grp2val == null) _grp2val = new NonBlockingHashMap<Group, Double>(); if (_cols == null) { _cols = new int[cs.length-1]; for (int i = 0; i < _cols.length; ++i) _cols[i] = i; } int rows = cs[0].len(); Chunk vals = cs[cs.length-1]; for (int row = 0; row < rows; ++row) { Group g = new Group(_cols.length); g.fill(row, cs, _cols); double val = vals.at0(row); _grp2val.putIfAbsent(g, val); } } @Override public void reduce( GroupTask gt) { for (Group g : gt._grp2val.keySet()) { Double val = gt._grp2val.get(g); if (g != null && val != null) _grp2val.putIfAbsent(g, val); } } // Custom serialization for NBHM. Much nicer when these are auto-gen'd. // Only sends Groups over the wire, NOT NewChunks with rows. @Override public AutoBuffer write( AutoBuffer ab ) { super.write(ab); if( _grp2val == null ) return ab.put4(0); ab.put4(_grp2val.size()); for( Group g : _grp2val.keySet() ) { ab.put(g); ab.put8d(_grp2val.get(g)); } return ab; } @Override public GroupTask read( AutoBuffer ab ) { super.read(ab); int len = ab.get4(); if( len == 0 ) return this; _grp2val= new NonBlockingHashMap<Group,Double>(); for( int i=0; i<len; i++ ) _grp2val.put(ab.get(Group.class),ab.get8d()); return this; } @Override public void copyOver( Freezable dt ) { GroupTask that = (GroupTask)dt; super.copyOver(that); this._ncols = that._ncols; this._cols = that._cols; this._grp2val = that._grp2val; } } }