package water.rapids.ast.prims.advmath;
import hex.quantile.QuantileModel;
import water.Freezable;
import water.H2O;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.*;
import water.rapids.ast.AstFrame;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.ast.params.AstNum;
import water.rapids.ast.params.AstNumList;
import water.rapids.ast.params.AstStr;
import water.rapids.ast.params.AstStrList;
import water.rapids.ast.prims.mungers.AstGroup;
import water.rapids.ast.prims.reducers.AstMean;
import water.rapids.ast.prims.reducers.AstMedian;
import water.rapids.vals.ValFrame;
import water.rapids.vals.ValNums;
import water.util.ArrayUtils;
import water.util.IcedDouble;
import water.util.IcedHashMap;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
/**
* Impute columns of a data frame in place.
* <p/>
* This impute can impute whole Frames or a specific Vec within the Frame. Imputation
* will be by the default mean (for numeric columns) or mode (for categorical columns).
* String, date, and UUID columns are never imputed.
* <p/>
* When a Vec is specified to be imputed, it can alternatively be imputed by grouping on
* some other columns in the Frame. If groupByCols is specified, but the user does not
* supply a column to be imputed then an IllegalArgumentException will be raised. Further,
* if the user specifies the column to impute within the groupByCols, exceptions will be
* raised.
* <p/>
* The methods that a user may impute by are as follows:
* - mean: Vec.T_NUM
* - median: Vec.T_NUM
* - mode: Vec.T_CAT
* - bfill: Any valid Vec type
* - ffill: Any valid Vec type
* <p/>
* All methods of imputation are done in place! The first three methods (mean, median,
* mode) are self-explanatory. The bfill and ffill methods will attempt to fill NAs using
* adjacent cell value (either before or forward):
* <p/>
* Vec = [ bfill_value, NA, ffill_value]
* | ^^ |
* -> || <-
* impute
* <p/>
* If the impute method is median then the combineMethod can be one of the Enum variants
* of QuantileModel.CombineMethod = { INTERPOLATE, AVERAGE, LOW, HIGH }. The Enum
* specifies how to combine quantiles on even sample sizes. This parameter is ignored in
* all other cases.
* <p/>
* Finally, the groupByFrame can be used to impute a column with a pre-computed groupby
* result.
* <p/>
* Other notes:
* <p/>
* If col is -1, then the entire Frame will be imputed using mean/mode where appropriate.
*/
public class AstImpute extends AstPrimitive {
@Override
public String[] args() {
return new String[]{"ary", "col", "method", "combineMethod", "groupByCols", "groupByFrame", "values"};
}
@Override
public String str() {
return "h2o.impute";
}
@Override
public int nargs() {
return 1 + 7;
} // (h2o.impute data col method combine_method groupby groupByFrame values)
@Override
public Val apply(Env env, Env.StackHelp stk, AstRoot asts[]) {
// Argument parsing and sanity checking
// Whole frame being imputed
Frame fr = stk.track(asts[1].exec(env)).getFrame();
// Column within frame being imputed
final int col = (int) asts[2].exec(env).getNum();
if (col >= fr.numCols())
throw new IllegalArgumentException("Column not -1 or in range 0 to " + fr.numCols());
final boolean doAllVecs = col == -1;
final Vec vec = doAllVecs ? null : fr.vec(col);
// Technique used for imputation
AstRoot method = null;
boolean ffill0 = false, bfill0 = false;
switch (asts[3].exec(env).getStr().toUpperCase()) {
case "MEAN":
method = new AstMean();
break;
case "MEDIAN":
method = new AstMedian();
break;
case "MODE":
method = new AstMode();
break;
case "FFILL":
ffill0 = true;
break;
case "BFILL":
bfill0 = true;
break;
default:
throw new IllegalArgumentException("Method must be one of mean, median or mode");
}
// Only for median, how is the median computed on even sample sizes?
QuantileModel.CombineMethod combine = QuantileModel.CombineMethod.valueOf(asts[4].exec(env).getStr().toUpperCase());
// Group-by columns. Empty is allowed, and perfectly normal.
AstRoot ast = asts[5];
AstNumList by2;
if (ast instanceof AstNumList) by2 = (AstNumList) ast;
else if (ast instanceof AstNum) by2 = new AstNumList(((AstNum) ast).getNum());
else if (ast instanceof AstStrList) {
String[] names = ((AstStrList) ast)._strs;
double[] list = new double[names.length];
int i = 0;
for (String name : ((AstStrList) ast)._strs)
list[i++] = fr.find(name);
Arrays.sort(list);
by2 = new AstNumList(list);
} else throw new IllegalArgumentException("Requires a number-list, but found a " + ast.getClass());
Frame groupByFrame = asts[6].str().equals("_") ? null : stk.track(asts[6].exec(env)).getFrame();
AstRoot vals = asts[7];
AstNumList values;
if (vals instanceof AstNumList) values = (AstNumList) vals;
else if (vals instanceof AstNum) values = new AstNumList(((AstNum) vals).getNum());
else values = null;
boolean doGrpBy = !by2.isEmpty() || groupByFrame != null;
// Compute the imputed value per-group. Empty groups are allowed and OK.
IcedHashMap<AstGroup.G, Freezable[]> group_impute_map;
if (!doGrpBy) { // Skip the grouping work
if (ffill0 || bfill0) { // do a forward/backward fill on the NA
// TODO: requires chk.previousNonNA and chk.nextNonNA style methods (which may go across chk boundaries)s
final boolean ffill = ffill0;
final boolean bfill = bfill0;
throw H2O.unimpl("No ffill or bfill imputation supported");
// new MRTask() {
// @Override public void map(Chunk[] cs) {
// int len=cs[0]._len; // end of this chk
// long start=cs[0].start(); // absolute beginning of chk s.t. start-1 bleeds into previous chk
// long absEnd = start+len; // absolute end of the chk s.t. absEnd+1 bleeds into next chk
// for(int c=0;c<cs.length;++c )
// for(int r=0;r<cs[0]._len;++r ) {
// if( cs[c].isNA(r) ) {
// if( r > 0 && r < len-1 ) {
// cs[c].set(r,ffill?)
// }
// }
// }
// }
// }.doAll(doAllVecs?fr:new Frame(vec));
// return new ValNum(Double.NaN);
} else {
final double[] res = values == null ? new double[fr.numCols()] : values.expand();
if (values == null) { // fill up res if no values supplied user, common case
if (doAllVecs) {
for (int i = 0; i < res.length; ++i)
if (fr.vec(i).isNumeric() || fr.vec(i).isCategorical())
res[i] = fr.vec(i).isNumeric() ? fr.vec(i).mean() : ArrayUtils.maxIndex(fr.vec(i).bins());
} else {
Arrays.fill(res, Double.NaN);
if (method instanceof AstMean) res[col] = vec.mean();
if (method instanceof AstMedian)
res[col] = AstMedian.median(new Frame(vec), combine);
if (method instanceof AstMode) res[col] = AstMode.mode(vec);
}
}
new MRTask() {
@Override
public void map(Chunk[] cs) {
int len = cs[0]._len;
// run down each chk
for (int c = 0; c < cs.length; ++c)
if (!Double.isNaN(res[c]))
for (int row = 0; row < len; ++row)
if (cs[c].isNA(row))
cs[c].set(row, res[c]);
}
}.doAll(fr);
return new ValNums(res);
}
} else {
if (col >= fr.numCols())
throw new IllegalArgumentException("Column not -1 or in range 0 to " + fr.numCols());
Frame imputes = groupByFrame;
if (imputes == null) {
// Build and run a GroupBy command
AstGroup ast_grp = new AstGroup();
// simple case where user specified a column... col == -1 means do all columns
if (doAllVecs) {
AstRoot[] aggs = new AstRoot[(int) (3 + 3 * (fr.numCols() - by2.cnt()))];
aggs[0] = ast_grp;
aggs[1] = new AstFrame(fr);
aggs[2] = by2;
int c = 3;
for (int i = 0; i < fr.numCols(); ++i) {
if (!by2.has(i) && (fr.vec(i).isCategorical() || fr.vec(i).isNumeric())) {
aggs[c] = fr.vec(i).isNumeric() ? new AstMean() : new AstMode();
aggs[c + 1] = new AstNumList(i, i + 1);
aggs[c + 2] = new AstStr("rm");
c += 3;
}
}
imputes = ast_grp.apply(env, stk, aggs).getFrame();
} else
imputes = ast_grp.apply(env, stk, new AstRoot[]{ast_grp, new AstFrame(fr), by2, /**/method, new AstNumList(col, col + 1), new AstStr("rm") /**/}).getFrame();
}
if (by2.isEmpty() && imputes.numCols() > 2) // >2 makes it ambiguous which columns are groupby cols and which are aggs, throw IAE
throw new IllegalArgumentException("Ambiguous group-by frame. Supply the `by` columns to proceed.");
final int[] bycols0 = ArrayUtils.seq(0, Math.max((int) by2.cnt(), 1 /* imputes.numCols()-1 */));
group_impute_map = new Gather(by2.expand4(), bycols0, fr.numCols(), col).doAll(imputes)._group_impute_map;
// Now walk over the data, replace NAs with the imputed results
final IcedHashMap<AstGroup.G, Freezable[]> final_group_impute_map = group_impute_map;
if (by2.isEmpty()) {
int[] byCols = new int[imputes.numCols() - 1];
for (int i = 0; i < byCols.length; ++i)
byCols[i] = fr.find(imputes.name(i));
by2 = new AstNumList(byCols);
}
final int[] bycols = by2.expand4();
new MRTask() {
@Override
public void map(Chunk cs[]) {
Set<Integer> _bycolz = new HashSet<>();
for (int b : bycols) _bycolz.add(b);
AstGroup.G g = new AstGroup.G(bycols.length, null);
for (int row = 0; row < cs[0]._len; row++)
for (int c = 0; c < cs.length; ++c)
if (!_bycolz.contains(c))
if (cs[c].isNA(row))
cs[c].set(row, ((IcedDouble) final_group_impute_map.get(g.fill(row, cs, bycols))[c])._val);
}
}.doAll(fr);
return new ValFrame(imputes);
}
}
// flatten the GroupBy result Frame back into a IcedHashMap
private static class Gather extends MRTask<Gather> {
private final int _imputedCol;
private final int _ncol;
private final int[] _byCols0; // actual group-by indexes
private final int[] _byCols; // index into the grouped-by frame result
private IcedHashMap<AstGroup.G, Freezable[]> _group_impute_map;
private transient Set<Integer> _localbyColzSet;
Gather(int[] byCols0, int[] byCols, int ncol, int imputeCol) {
_byCols = byCols;
_byCols0 = byCols0;
_ncol = ncol;
_imputedCol = imputeCol;
}
@Override
public void setupLocal() {
_localbyColzSet = new HashSet<>();
for (int by : _byCols0) _localbyColzSet.add(by);
}
@Override
public void map(Chunk cs[]) {
_group_impute_map = new IcedHashMap<>();
for (int row = 0; row < cs[0]._len; ++row) {
IcedDouble[] imputes = new IcedDouble[_ncol];
for (int c = 0, z = _byCols.length; c < imputes.length; ++c, ++z) { // z used to skip over the gby cols into the columns containing the aggregated columns
if (_imputedCol != -1)
imputes[c] = c == _imputedCol ? new IcedDouble(cs[cs.length - 1].atd(row)) : new IcedDouble(Double.NaN);
else imputes[c] = _localbyColzSet.contains(c) ? new IcedDouble(Double.NaN) : new IcedDouble(cs[z].atd(row));
}
_group_impute_map.put(new AstGroup.G(_byCols.length, null).fill(row, cs, _byCols), imputes);
}
}
@Override
public void reduce(Gather mrt) {
_group_impute_map.putAll(mrt._group_impute_map);
}
}
}