package hex.quantile; import hex.ModelBuilder; import hex.ModelCategory; import water.*; import water.fvec.*; import water.util.ArrayUtils; import water.util.Log; import java.util.Arrays; /** * Quantile model builder... building a simple QuantileModel */ public class Quantile extends ModelBuilder<QuantileModel,QuantileModel.QuantileParameters,QuantileModel.QuantileOutput> { private int _ncols; @Override protected boolean logMe() { return false; } @Override public boolean isSupervised() { return false; } // Called from Nano thread; start the Quantile Job on a F/J thread public Quantile( QuantileModel.QuantileParameters parms ) { super(parms); init(false); } public Quantile( QuantileModel.QuantileParameters parms, Job job ) { super(parms, job); init(false); } @Override public Driver trainModelImpl() { return new QuantileDriver(); } @Override public ModelCategory[] can_build() { return new ModelCategory[]{ModelCategory.Unknown}; } // any number of chunks is fine - don't rebalance - it's not worth it for a few passes over the data (at most) @Override protected int desiredChunks(final Frame original_fr, boolean local) { return 1; } /** Initialize the ModelBuilder, validating all arguments and preparing the * training frame. This call is expected to be overridden in the subclasses * and each subclass will start with "super.init();". This call is made * by the front-end whenever the GUI is clicked, and needs to be fast; * heavy-weight prep needs to wait for the trainModel() call. * * Validate the probs. */ @Override public void init(boolean expensive) { super.init(expensive); for( double p : _parms._probs ) if( p < 0.0 || p > 1.0 ) error("_probs","Probabilities must be between 0 and 1"); _ncols = train().numCols()-numSpecialCols(); //offset/weights/nfold - should only ever be weights if ( numSpecialCols() == 1 && _weights == null) throw new IllegalArgumentException("The only special Vec that is supported for Quantiles is observation weights."); if ( numSpecialCols() >1 ) throw new IllegalArgumentException("Cannot handle more than 1 special vec (weights)"); } private static class SumWeights extends MRTask<SumWeights> { double sum; @Override public void map(Chunk c, Chunk w) { for (int i=0;i<c.len();++i) if (!c.isNA(i)) { double wt = w.atd(i); // For now: let the user give small weights, results are probably not very good (same as for wtd.quantile in R) // if (wt > 0 && wt < 1) throw new H2OIllegalArgumentException("Quantiles only accepts weights that are either 0 or >= 1."); sum += wt; } } @Override public void reduce(SumWeights mrt) { sum+=mrt.sum; } } // ---------------------- private class QuantileDriver extends Driver { @Override public void computeImpl() { QuantileModel model = null; try { init(true); // The model to be built model = new QuantileModel(dest(), _parms, new QuantileModel.QuantileOutput(Quantile.this)); model._output._parameters = _parms; model._output._quantiles = new double[_ncols][_parms._probs.length]; model.delete_and_lock(_job); // --- // Run the main Quantile Loop Vec vecs[] = train().vecs(); for( int n=0; n<_ncols; n++ ) { if( stop_requested() ) return; // Stopped/cancelled Vec vec = vecs[n]; if (vec.isBad() || vec.isCategorical() || vec.isString() || vec.isTime() || vec.isUUID()) { model._output._quantiles[n] = new double[_parms._probs.length]; Arrays.fill(model._output._quantiles[n], Double.NaN); continue; } double sumRows=_weights == null ? vec.length()-vec.naCnt() : new SumWeights().doAll(vec, _weights).sum; // Compute top-level histogram Histo h1 = new Histo(vec.min(),vec.max(),0,sumRows,vec.isInt()); h1 = _weights==null ? h1.doAll(vec) : h1.doAll(vec, _weights); // For each probability, see if we have it exactly - or else run // passes until we do. for( int p = 0; p < _parms._probs.length; p++ ) { double prob = _parms._probs[p]; Histo h = h1; // Start from the first global histogram model._output._iterations++; // At least one iter per-prob-per-column while( Double.isNaN(model._output._quantiles[n][p] = h.findQuantile(prob,_parms._combine_method)) ) { h = _weights == null ? h.refinePass(prob).doAll(vec) : h.refinePass(prob).doAll(vec, _weights); // Full pass at higher resolution model._output._iterations++; // also count refinement iterations } // Update the model model.update(_job); // Update model in K/V store _job.update(0); // One unit of work } StringBuilder sb = new StringBuilder(); sb.append("Quantile: iter: ").append(model._output._iterations).append(" Qs=").append(Arrays.toString(model._output._quantiles[n])); Log.debug(sb); } } finally { if( model != null ) model.unlock(_job); } } } public static class StratifiedQuantilesTask extends H2O.H2OCountedCompleter<StratifiedQuantilesTask> { // INPUT final double _prob; final Vec _response; //vec to compute quantile for final Vec _weights; //obs weights final Vec _strata; //continuous integer range mapping into the _quantiles[id][] final QuantileModel.CombineMethod _combine_method; // OUTPUT public double[/*strata*/] _quantiles; public int[] _nids; public StratifiedQuantilesTask(H2O.H2OCountedCompleter cc, double prob, Vec response, // response Vec weights, // obs weights Vec strata, // stratification (can be null) QuantileModel.CombineMethod combine_method) { super(cc); _response = response; _prob=prob; _combine_method=combine_method; _weights=weights; _strata=strata; } @Override public void compute2() { final int strataMin = (int)_strata.min(); final int strataMax = (int)_strata.max(); if (strataMin < 0 && strataMax < 0) { Log.warn("No quantiles can be computed since there are no non-OOB rows."); tryComplete(); return; } final int nstrata = strataMax - strataMin + 1; Log.info("Computing quantiles for (up to) " + nstrata + " different strata."); _quantiles = new double[nstrata]; _nids = new int[nstrata]; Arrays.fill(_quantiles,Double.NaN); Vec weights = _weights != null ? _weights : _response.makeCon(1); for (int i=0;i<nstrata;++i) { //loop over nodes Vec newWeights = weights.makeCopy(); //only keep weights for this stratum (node), set rest to 0 if (_strata!=null) { _nids[i] = strataMin+i; new KeepOnlyOneStrata(_nids[i]).doAll(_strata, newWeights); } double sumRows = new SumWeights().doAll(_response, newWeights).sum; if (sumRows>0) { Histo h = new Histo(_response.min(), _response.max(), 0, sumRows, _response.isInt()); h.doAll(_response, newWeights); while (Double.isNaN(_quantiles[i] = h.findQuantile(_prob, _combine_method))) h = h.refinePass(_prob).doAll(_response, newWeights); newWeights.remove(); //sanity check quantiles assert (_quantiles[i] <= _response.max() + 1e-6); assert (_quantiles[i] >= _response.min() - 1e-6); } } if (_weights != weights) weights.remove(); tryComplete(); } private static class KeepOnlyOneStrata extends MRTask<KeepOnlyOneStrata> { KeepOnlyOneStrata(int stratumToKeep) { this.stratumToKeep = stratumToKeep; } int stratumToKeep; @Override public void map(Chunk strata, Chunk newW) { for (int i=0; i<strata._len; ++i) { // Log.info("NID:" + ((int) strata.at8(i))); if ((int) strata.at8(i) != stratumToKeep) newW.set(i, 0); } } } } // ------------------------------------------------------------------------- private final static class Histo extends MRTask<Histo> { private static final int NBINS = 1024; // Default bin count private final int _nbins; // Actual bin count private final double _lb; // Lower bound of bin[0] private final double _step; // Step-size per-bin private final double _start_row; // Starting cumulative count of weighted rows for this lower-bound private final double _nrows; // Total datasets (weighted) rows private final boolean _isInt; // Column only holds ints // Big Data output result double _bins[/*nbins*/]; // Weighted count of rows in each bin double _mins[/*nbins*/]; // Smallest element in bin double _maxs[/*nbins*/]; // Largest element in bin private Histo(double lb, double ub, double start_row, double nrows, boolean isInt) { boolean is_int = (isInt && (ub - lb < NBINS)); _nbins = is_int ? (int) (ub - lb + 1) : NBINS; _lb = lb; double ulp = Math.ulp(Math.max(Math.abs(lb), Math.abs(ub))); _step = is_int ? 1 : (ub + ulp - lb) / _nbins; _start_row = start_row; _nrows = nrows; _isInt = isInt; } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append("range : " + _lb + " ... " + (_lb + _nbins * _step)); sb.append("\npsum0 : " + _start_row); sb.append("\ncounts: " + Arrays.toString(_bins)); sb.append("\nmaxs : " + Arrays.toString(_maxs)); sb.append("\nmins : " + Arrays.toString(_mins)); sb.append("\n"); return sb.toString(); } @Override public void map(Chunk chk, Chunk weight) { _bins = new double[_nbins]; _mins = new double[_nbins]; _maxs = new double[_nbins]; Arrays.fill(_mins, Double.MAX_VALUE); Arrays.fill(_maxs, -Double.MAX_VALUE); double d; for (int row = 0; row < chk._len; row++) { double w = weight.atd(row); if (w == 0) continue; if (!Double.isNaN(d = chk.atd(row))) { // na.rm=true double idx = (d - _lb) / _step; if (!(0.0 <= idx && idx < _bins.length)) continue; int i = (int) idx; if (_bins[i] == 0) _mins[i] = _maxs[i] = d; // Capture unique value else { if (d < _mins[i]) _mins[i] = d; if (d > _maxs[i]) _maxs[i] = d; } _bins[i] += w; // Bump row counts by row weight } } } @Override public void map(Chunk chk) { map(chk, new C0DChunk(1, chk.len())); } @Override public void reduce(Histo h) { for (int i = 0; i < _nbins; i++) { // Keep min/max if (_mins[i] > h._mins[i]) _mins[i] = h._mins[i]; if (_maxs[i] < h._maxs[i]) _maxs[i] = h._maxs[i]; } ArrayUtils.add(_bins, h._bins); } /** @return Quantile for probability prob, or NaN if another pass is needed. */ double findQuantile( double prob, QuantileModel.CombineMethod method ) { double p2 = prob*(_nrows-1); // Desired fractional row number for this probability long r2 = (long)p2; int loidx = findBin(r2); // Find bin holding low value double lo = (loidx == _nbins) ? binEdge(_nbins) : _maxs[loidx]; if( loidx<_nbins && r2==p2 && _mins[loidx]==lo ) return lo; // Exact row number, exact bin? Then quantile is exact long r3 = r2+1; int hiidx = findBin(r3); // Find bin holding high value double hi = (hiidx == _nbins) ? binEdge(_nbins) : _mins[hiidx]; if( loidx==hiidx ) // Somewhere in the same bin? return (lo==hi) ? lo : Double.NaN; // Only if bin is constant, otherwise must refine the bin // Split across bins - the interpolate between the hi of the lo bin, and // the lo of the hi bin return computeQuantile(lo,hi,r2,_nrows,prob,method); } private double binEdge( int idx ) { return _lb+_step*idx; } // bin for row; can be _nbins if just off the end (normally expect 0 to nbins-1) // row == position in (weighted) population private int findBin( double row ) { long sum = (long)_start_row; for( int i=0; i<_nbins; i++ ) if( (long)row < (sum += _bins[i]) ) return i; return _nbins; } // Run another pass over the data, with refined endpoints, to home in on // the exact elements for this probability. Histo refinePass( double prob ) { double prow = prob*(_nrows-1); // Desired fractional row number for this probability long lorow = (long)prow; // Lower integral row number int loidx = findBin(lorow); // Find bin holding low value // If loidx is the last bin, then high must be also the last bin - and we // have an exact quantile (equal to the high bin) and we didn't need // another refinement pass assert loidx < _nbins; double lo = _mins[loidx]; // Lower end of range to explore // If probability does not hit an exact row, we need the elements on // either side - so the next row up from the low row long hirow = lorow==prow ? lorow : lorow+1; int hiidx = findBin(hirow); // Find bin holding high value // Upper end of range to explore - except at the very high end cap double hi = hiidx==_nbins ? binEdge(_nbins) : _maxs[hiidx]; long sum = (long)_start_row; for( int i=0; i<loidx; i++ ) sum += _bins[i]; return new Histo(lo,hi,sum,_nrows,_isInt); } } /** Compute the correct final quantile from these 4 values. If the lo and hi * elements are equal, use them. However if they differ, then there is no * single value which exactly matches the desired quantile. There are * several well-accepted definitions in this case - including picking either * the lo or the hi, or averaging them, or doing a linear interpolation. * @param lo the highest element less than or equal to the desired quantile * @param hi the lowest element greater than or equal to the desired quantile * @param row row number (zero based) of the lo element; high element is +1 * @return desired quantile. */ static double computeQuantile( double lo, double hi, double row, double nrows, double prob, QuantileModel.CombineMethod method ) { if( lo==hi ) return lo; // Equal; pick either if( method == null ) method= QuantileModel.CombineMethod.INTERPOLATE; switch( method ) { case INTERPOLATE: return linearInterpolate(lo,hi,row,nrows,prob); case AVERAGE: return 0.5*(hi+lo); case LOW: return lo; case HIGH: return hi; default: Log.info("Unknown even sample size quantile combination type: " + method + ". Doing linear interpolation."); return linearInterpolate(lo,hi,row,nrows,prob); } } private static double linearInterpolate(double lo, double hi, double row, double nrows, double prob) { // Unequal, linear interpolation double plo = (row+0)/(nrows-1); // Note that row numbers are inclusive on the end point, means we need a -1 double phi = (row+1)/(nrows-1); // Passed in the row number for the low value, high is the next row, so +1 assert plo <= prob && prob <= phi; return lo + (hi-lo)*(prob-plo)/(phi-plo); // Classic linear interpolation } }