package water.util; import hex.Interaction; import water.*; import water.exceptions.H2OIllegalArgumentException; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.Vec; /** * Simple Co-Occurrence based tabulation of X vs Y, where X and Y are two Vecs in a given dataset * Uses histogram of given resolution in X and Y * Handles numerical/categorical data and missing values * Supports observation weights * * Fills up two double[][] arrays: * _countData[xbin][ybin] contains the sum of observation weights (or 1) for co-occurrences in bins xbin/ybin * _responseData[xbin][2] contains the mean value of Y and the sum of observation weights for a given bin for X */ public class Tabulate extends Keyed<Tabulate> { public final Job<Tabulate> _job; public Frame _dataset; public Key[] _vecs = new Key[2]; public String _predictor; public String _response; public String _weight; int _nbins_predictor = 20; int _nbins_response = 10; // result double[][] _count_data; double[][] _response_data; public TwoDimTable _count_table; public TwoDimTable _response_table; // helper to speed up stuff static private class Stats extends Iced { Stats(Vec v) { _min = v.min(); _max = v.max(); _isCategorical = v.isCategorical(); _isInt = v.isInt(); _cardinality = v.cardinality(); _missing = v.naCnt() > 0 ? 1 : 0; _domain = v.domain(); } final double _min; final double _max; final boolean _isCategorical; final boolean _isInt; final int _cardinality; final int _missing; //0 or 1 final String[] _domain; } final private Stats[] _stats = new Stats[2]; public Tabulate() { _job = new Job(Key.<Tabulate>make(), Tabulate.class.getName(), "Tabulate job"); } private int bins(int v) { return v==1 ? _nbins_response : _nbins_predictor; } private int res(final int v) { final int missing = _stats[v]._missing; if (_stats[v]._isCategorical) return _stats[v]._cardinality + missing; return bins(v) + missing; } private int bin(final int v, final double val) { if (Double.isNaN(val)) { return 0; } int b; int bins = bins(v); if (_stats[v]._isCategorical) { assert((int)val == val); b = (int) val; } else { double d = (_stats[v]._max - _stats[v]._min) / bins; b = (int) ((val - _stats[v]._min) / d); assert(b>=0 && b<= bins); b = Math.min(b, bins -1);//avoid AIOOBE at upper bound } return b+_stats[v]._missing; } private String labelForBin(final int v, int b) { int missing = _stats[v]._missing; if (missing == 1 && b==0) return "missing(NA)"; if (missing == 1) b--; if (_stats[v]._isCategorical) return _stats[v]._domain[b]; int bins = bins(v); if (_stats[v]._isInt && (_stats[v]._max - _stats[v]._min + 1) <= bins) return Integer.toString((int)(_stats[v]._min + b)); double d = (_stats[v]._max - _stats[v]._min)/bins; return String.format("%5f", _stats[v]._min + (b + 0.5) * d); } public Tabulate execImpl() { if (_dataset == null) throw new H2OIllegalArgumentException("Dataset not found"); if (_nbins_predictor < 1) throw new H2OIllegalArgumentException("Number of bins for predictor must be >= 1"); if (_nbins_response < 1) throw new H2OIllegalArgumentException("Number of bins for response must be >= 1"); Vec x = _dataset.vec(_predictor); if (x == null) throw new H2OIllegalArgumentException("Predictor column " + _predictor + " not found"); if (x.cardinality() > _nbins_predictor) { Interaction in = new Interaction(); in._source_frame = _dataset._key; in._factor_columns = new String[]{_predictor}; in._max_factors = _nbins_predictor -1; in.execImpl(null); x = in._job._result.get().anyVec(); } else if (x.isInt() && (x.max() - x.min() + 1) <= _nbins_predictor) { x = x.toCategoricalVec(); } Vec y = _dataset.vec(_response); if (y == null) throw new H2OIllegalArgumentException("Response column " + _response + " not found"); if (y.cardinality() > _nbins_response) { Interaction in = new Interaction(); in._source_frame = _dataset._key; in._factor_columns = new String[]{_response}; in._max_factors = _nbins_response -1; in.execImpl(null); y = in._job._result.get().anyVec(); } else if (y.isInt() && (y.max() - y.min() + 1) <= _nbins_response) { y = y.toCategoricalVec(); } if (y!=null && y.cardinality() > 2) Log.warn("Response column has more than two factor levels - mean response depends on lexicographic order of factors!"); Vec w = _dataset.vec(_weight); //can be null if (w != null && (!w.isNumeric() && w.min() < 0)) throw new H2OIllegalArgumentException("Observation weights must be numeric with values >= 0"); if (x!=null) { _vecs[0] = x._key; _stats[0] = new Stats(x); } if (y!=null) { _vecs[1] = y._key; _stats[1] = new Stats(y); } Tabulate sp = w != null ? new CoOccurrence(this).doAll(x, y, w)._sp : new CoOccurrence(this).doAll(x, y)._sp; _count_table = sp.tabulationTwoDimTable(); _response_table = sp.responseCharTwoDimTable(); Log.info(_count_table.toString(2, false)); Log.info(_response_table.toString(2, false)); return sp; } private static class CoOccurrence extends MRTask<CoOccurrence> { final Tabulate _sp; CoOccurrence(Tabulate sp) {_sp = sp;} @Override protected void setupLocal() { _sp._count_data = new double[_sp.res(0)][_sp.res(1)]; _sp._response_data = new double[_sp.res(0)][2]; } @Override public void map(Chunk x, Chunk y) { map(x, y, (Chunk)null); } @Override public void map(Chunk x, Chunk y, Chunk w) { for (int r=0; r<x.len(); ++r) { int xbin = _sp.bin(0, x.atd(r)); int ybin = _sp.bin(1, y.atd(r)); double weight = w!=null?w.atd(r):1; if (Double.isNaN(weight)) continue; AtomicUtils.DoubleArray.add(_sp._count_data[xbin], ybin, weight); //increment co-occurrence count by w if (!y.isNA(r)) { AtomicUtils.DoubleArray.add(_sp._response_data[xbin], 0, weight * y.atd(r)); //add to mean response for x AtomicUtils.DoubleArray.add(_sp._response_data[xbin], 1, weight); //increment total for x } } } @Override public void reduce(CoOccurrence mrt) { if (_sp._response_data == mrt._sp._response_data) return; ArrayUtils.add(_sp._response_data, mrt._sp._response_data); } @Override protected void postGlobal() { //compute mean response for (int i=0; i<_sp._response_data.length; ++i) { _sp._response_data[i][0] /= _sp._response_data[i][1]; } } } public TwoDimTable tabulationTwoDimTable() { if (_response_data == null) return null; int predN = _count_data.length; int respN = _count_data[0].length; String tableHeader = "(Weighted) co-occurrence counts of '" + _predictor + "' and '" + _response + "'"; String[] rowHeaders = new String[predN * respN]; String[] colHeaders = new String[3]; //predictor response wcount String[] colTypes = new String[colHeaders.length]; String[] colFormats = new String[colHeaders.length]; colHeaders[0] = _predictor; colHeaders[1] = _response; colTypes[0] = "string"; colFormats[0] = "%s"; colTypes[1] = "string"; colFormats[1] = "%s"; colHeaders[2] = "counts"; colTypes[2] = "double"; colFormats[2] = "%f"; TwoDimTable table = new TwoDimTable( tableHeader, null/*tableDescription*/, rowHeaders, colHeaders, colTypes, colFormats, null); for (int p=0; p<predN; ++p) { String plabel = labelForBin(0, p); for (int r=0; r<respN; ++r) { String rlabel = labelForBin(1, r); for (int c=0; c<3; ++c) { table.set(r*predN + p, 0, plabel); table.set(r*predN + p, 1, rlabel); table.set(r*predN + p, 2, _count_data[p][r]); } } } return table; } public TwoDimTable responseCharTwoDimTable() { if (_response_data == null) return null; String tableHeader = "Mean value of '" + _response + "' and (weighted) counts for '" + _predictor + "' values"; int predN = _count_data.length; String[] rowHeaders = new String[predN]; //X String[] colHeaders = new String[3]; //Y String[] colTypes = new String[colHeaders.length]; String[] colFormats = new String[colHeaders.length]; colHeaders[0] = _predictor; colTypes[0] = "string"; colFormats[0] = "%s"; colHeaders[1] = "mean " + _response; colTypes[2] = "double"; colFormats[2] = "%f"; colHeaders[2] = "counts"; colTypes[1] = "double"; colFormats[1] = "%f"; TwoDimTable table = new TwoDimTable( tableHeader, null/*tableDescription*/, rowHeaders, colHeaders, colTypes, colFormats, null); for (int p=0; p<predN; ++p) { String plabel = labelForBin(0, p); table.set(p, 0, plabel); table.set(p, 1, _response_data[p][0]); table.set(p, 2, _response_data[p][1]); } return table; } }