package water.util;
import water.*;
import water.H2O.H2OCallback;
import water.H2O.H2OCountedCompleter;
import water.fvec.*;
import water.nbhm.NonBlockingHashMap;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import water.parser.BufferedString;
import static water.util.RandomUtils.getRNG;
public class MRUtils {
/**
* Sample rows from a frame.
* Can be unlucky for small sampling fractions - will continue calling itself until at least 1 row is returned.
* @param fr Input frame
* @param rows Approximate number of rows to sample (across all chunks)
* @param seed Seed for RNG
* @return Sampled frame
*/
public static Frame sampleFrame(Frame fr, final long rows, final long seed) {
if (fr == null) return null;
final float fraction = rows > 0 ? (float)rows / fr.numRows() : 1.f;
if (fraction >= 1.f) return fr;
Key newKey = fr._key != null ? Key.make(fr._key.toString() + (fr._key.toString().contains("temporary") ? ".sample." : ".temporary.sample.") + PrettyPrint.formatPct(fraction).replace(" ","")) : null;
Frame r = new MRTask() {
@Override
public void map(Chunk[] cs, NewChunk[] ncs) {
final Random rng = getRNG(0);
final BufferedString bStr = new BufferedString();
int count = 0;
for (int r = 0; r < cs[0]._len; r++) {
rng.setSeed(seed+r+cs[0].start());
if (rng.nextFloat() < fraction || (count == 0 && r == cs[0]._len-1) ) {
count++;
for (int i = 0; i < ncs.length; i++) {
if (cs[i].isNA(r)) ncs[i].addNA();
else if (cs[i] instanceof CStrChunk)
ncs[i].addStr(cs[i].atStr(bStr,r));
else if (cs[i] instanceof C16Chunk)
ncs[i].addUUID(cs[i].at16l(r),cs[i].at16h(r));
else
ncs[i].addNum(cs[i].atd(r));
}
}
}
}
}.doAll(fr.types(), fr).outputFrame(newKey, fr.names(), fr.domains());
if (r.numRows() == 0) {
Log.warn("You asked for " + rows + " rows (out of " + fr.numRows() + "), but you got none (seed=" + seed + ").");
Log.warn("Let's try again. You've gotta ask yourself a question: \"Do I feel lucky?\"");
return sampleFrame(fr, rows, seed+1);
}
return r;
}
/**
* Row-wise shuffle of a frame (only shuffles rows inside of each chunk)
* @param fr Input frame
* @return Shuffled frame
*/
public static Frame shuffleFramePerChunk(Frame fr, final long seed) {
return new MRTask() {
@Override
public void map(Chunk[] cs, NewChunk[] ncs) {
int[] idx = new int[cs[0]._len];
for (int r=0; r<idx.length; ++r) idx[r] = r;
ArrayUtils.shuffleArray(idx, getRNG(seed));
for (long anIdx : idx) {
for (int i = 0; i < ncs.length; i++) {
if (cs[i] instanceof CStrChunk) {
ncs[i].addStr(cs[i],cs[i].start()+anIdx);
} else {
ncs[i].addNum(cs[i].atd((int) anIdx));
}
}
}
}
}.doAll(fr.types(), fr).outputFrame(fr.names(), fr.domains());
}
/**
* Compute the class distribution from a class label vector
* (not counting missing values)
*
* Usage 1: Label vector is categorical
* ------------------------------------
* Vec label = ...;
* assert(label.isCategorical());
* double[] dist = new ClassDist(label).doAll(label).dist();
*
* Usage 2: Label vector is numerical
* ----------------------------------
* Vec label = ...;
* int num_classes = ...;
* assert(label.isInt());
* double[] dist = new ClassDist(num_classes).doAll(label).dist();
*
*/
public static class ClassDist extends MRTask<ClassDist> {
final int _nclass;
protected double[] _ys;
public ClassDist(final Vec label) { _nclass = label.domain().length; }
public ClassDist(int n) { _nclass = n; }
public final double[] dist() { return _ys; }
public final double[] rel_dist() {
final double sum = ArrayUtils.sum(_ys);
return ArrayUtils.div(Arrays.copyOf(_ys, _ys.length), sum);
}
@Override public void map(Chunk ys) {
_ys = new double[_nclass];
for( int i=0; i<ys._len; i++ )
if (!ys.isNA(i))
_ys[(int) ys.at8(i)]++;
}
@Override public void map(Chunk ys, Chunk ws) {
_ys = new double[_nclass];
for( int i=0; i<ys._len; i++ )
if (!ys.isNA(i))
_ys[(int) ys.at8(i)] += ws.atd(i);
}
@Override public void reduce( ClassDist that ) { ArrayUtils.add(_ys,that._ys); }
}
public static class Dist extends MRTask<Dist> {
private IcedHashMap<IcedDouble,IcedAtomicInt> _dist;
@Override public void map(Chunk ys) {
_dist = new IcedHashMap<>();
IcedDouble d = new IcedDouble(0);
for( int row=0; row< ys._len; row++ )
if( !ys.isNA(row) ) {
d._val = ys.atd(row);
IcedAtomicInt oldV = _dist.get(d);
if(oldV == null)
oldV = _dist.putIfAbsent(new IcedDouble(d._val), new IcedAtomicInt(1));
if(oldV != null)
oldV.incrementAndGet();
}
}
@Override public void reduce(Dist mrt) {
if( _dist != mrt._dist ) {
IcedHashMap<IcedDouble,IcedAtomicInt> l = _dist;
IcedHashMap<IcedDouble,IcedAtomicInt> r = mrt._dist;
if( l.size() < r.size() ) { l=r; r=_dist; }
for( IcedDouble v: r.keySet() ) {
IcedAtomicInt oldVal = l.putIfAbsent(v, r.get(v));
if( oldVal!=null ) oldVal.addAndGet(r.get(v).get());
}
_dist=l;
mrt._dist=null;
}
}
public double[] dist() {
int i=0;
double[] dist = new double[_dist.size()];
for( IcedAtomicInt v: _dist.values() ) dist[i++] = v.get();
return dist;
}
public double[] keys() {
int i=0;
double[] keys = new double[_dist.size()];
for( IcedDouble k: _dist.keySet() ) keys[i++] = k._val;
return keys;
}
}
/**
* Stratified sampling for classifiers - FIXME: For weights, this is not accurate, as the sampling is done with uniform weights
* @param fr Input frame
* @param label Label vector (must be categorical)
* @param weights Weights vector, can be null
* @param sampling_ratios Optional: array containing the requested sampling ratios per class (in order of domains), will be overwritten if it contains all 0s
* @param maxrows Maximum number of rows in the returned frame
* @param seed RNG seed for sampling
* @param allowOversampling Allow oversampling of minority classes
* @param verbose Whether to print verbose info
* @return Sampled frame, with approximately the same number of samples from each class (or given by the requested sampling ratios)
*/
public static Frame sampleFrameStratified(final Frame fr, Vec label, Vec weights, float[] sampling_ratios, long maxrows, final long seed, final boolean allowOversampling, final boolean verbose) {
if (fr == null) return null;
assert(label.isCategorical());
if (maxrows < label.domain().length) {
Log.warn("Attempting to do stratified sampling to fewer samples than there are class labels - automatically increasing to #rows == #labels (" + label.domain().length + ").");
maxrows = label.domain().length;
}
ClassDist cd = new ClassDist(label);
double[] dist = weights != null ? cd.doAll(label, weights).dist() : cd.doAll(label).dist();
assert(dist.length > 0);
Log.info("Doing stratified sampling for data set containing " + fr.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (allowOversampling ? "on" : "off"));
if (verbose)
for (int i=0; i<dist.length;++i)
Log.info("Class " + label.factor(i) + ": count: " + dist[i] + " prior: " + (float)dist[i]/fr.numRows());
// create sampling_ratios for class balance with max. maxrows rows (fill
// existing array if not null). Make a defensive copy.
sampling_ratios = sampling_ratios == null ? new float[dist.length] : sampling_ratios.clone();
assert sampling_ratios.length == dist.length;
if( ArrayUtils.minValue(sampling_ratios) == 0 && ArrayUtils.maxValue(sampling_ratios) == 0 ) {
// compute sampling ratios to achieve class balance
for (int i=0; i<dist.length;++i)
sampling_ratios[i] = ((float)fr.numRows() / label.domain().length) / (float)dist[i]; // prior^-1 / num_classes
final float inv_scale = ArrayUtils.minValue(sampling_ratios); //majority class has lowest required oversampling factor to achieve balance
if (!Float.isNaN(inv_scale) && !Float.isInfinite(inv_scale))
ArrayUtils.div(sampling_ratios, inv_scale); //want sampling_ratio 1.0 for majority class (no downsampling)
}
if (!allowOversampling)
for (int i=0; i<sampling_ratios.length; ++i)
sampling_ratios[i] = Math.min(1.0f, sampling_ratios[i]);
// given these sampling ratios, and the original class distribution, this is the expected number of resulting rows
float numrows = 0;
for (int i=0; i<sampling_ratios.length; ++i) {
numrows += sampling_ratios[i] * dist[i];
}
if (Float.isNaN(numrows)) {
throw new IllegalArgumentException("Error during sampling - too few points?");
}
final long actualnumrows = Math.min(maxrows, Math.round(numrows)); //cap #rows at maxrows
assert(actualnumrows >= 0); //can have no matching rows in case of sparse data where we had to fill in a makeZero() vector
Log.info("Stratified sampling to a total of " + String.format("%,d", actualnumrows) + " rows" + (actualnumrows < numrows ? " (limited by max_after_balance_size).":"."));
if (actualnumrows != numrows) {
ArrayUtils.mult(sampling_ratios, (float)actualnumrows/numrows); //adjust the sampling_ratios by the global rescaling factor
if (verbose)
Log.info("Downsampling majority class by " + (float)actualnumrows/numrows
+ " to limit number of rows to " + String.format("%,d", maxrows));
}
for (int i=0;i<label.domain().length;++i) {
Log.info("Class '" + label.domain()[i] + "' sampling ratio: " + sampling_ratios[i]);
}
return sampleFrameStratified(fr, label, weights, sampling_ratios, seed, verbose);
}
/**
* Stratified sampling
* @param fr Input frame
* @param label Label vector (from the input frame)
* @param weights Weight vector (from the input frame), can be null
* @param sampling_ratios Given sampling ratios for each class, in order of domains
* @param seed RNG seed
* @param debug Whether to print debug info
* @return Stratified frame
*/
public static Frame sampleFrameStratified(final Frame fr, Vec label, Vec weights, final float[] sampling_ratios, final long seed, final boolean debug) {
return sampleFrameStratified(fr, label, weights, sampling_ratios, seed, debug, 0);
}
// internal version with repeat counter
// currently hardcoded to do up to 10 tries to get a row from each class, which can be impossible for certain wrong sampling ratios
private static Frame sampleFrameStratified(final Frame fr, Vec label, Vec weights, final float[] sampling_ratios, final long seed, final boolean debug, int count) {
if (fr == null) return null;
assert(label.isCategorical());
assert(sampling_ratios != null && sampling_ratios.length == label.domain().length);
final int labelidx = fr.find(label); //which column is the label?
assert(labelidx >= 0);
final int weightsidx = fr.find(weights); //which column is the weight?
final boolean poisson = false; //beta feature
//FIXME - this is doing uniform sampling, even if the weights are given
Frame r = new MRTask() {
@Override
public void map(Chunk[] cs, NewChunk[] ncs) {
final Random rng = getRNG(seed);
for (int r = 0; r < cs[0]._len; r++) {
if (cs[labelidx].isNA(r)) continue; //skip missing labels
rng.setSeed(cs[0].start()+r+seed);
final int label = (int)cs[labelidx].at8(r);
assert(sampling_ratios.length > label && label >= 0);
int sampling_reps;
if (poisson) {
throw H2O.unimpl();
// sampling_reps = ArrayUtils.getPoisson(sampling_ratios[label], rng);
} else {
final float remainder = sampling_ratios[label] - (int)sampling_ratios[label];
sampling_reps = (int)sampling_ratios[label] + (rng.nextFloat() < remainder ? 1 : 0);
}
for (int i = 0; i < ncs.length; i++) {
if (cs[i] instanceof CStrChunk) {
for (int j = 0; j < sampling_reps; ++j) {
ncs[i].addStr(cs[i],cs[0].start()+r);
}
} else {
for (int j = 0; j < sampling_reps; ++j) {
ncs[i].addNum(cs[i].atd(r));
}
}
}
}
}
}.doAll(fr.types(), fr).outputFrame(fr.names(), fr.domains());
// Confirm the validity of the distribution
Vec lab = r.vecs()[labelidx];
Vec wei = weightsidx != -1 ? r.vecs()[weightsidx] : null;
double[] dist = wei != null ? new ClassDist(lab).doAll(lab, wei).dist() : new ClassDist(lab).doAll(lab).dist();
// if there are no training labels in the test set, then there is no point in sampling the test set
if (dist == null) return fr;
if (debug) {
double sumdist = ArrayUtils.sum(dist);
Log.info("After stratified sampling: " + sumdist + " rows.");
for (int i=0; i<dist.length;++i) {
Log.info("Class " + r.vecs()[labelidx].factor(i) + ": count: " + dist[i]
+ " sampling ratio: " + sampling_ratios[i] + " actual relative frequency: " + (float)dist[i] / sumdist * dist.length);
}
}
// Re-try if we didn't get at least one example from each class
if (ArrayUtils.minValue(dist) == 0 && count < 10) {
Log.info("Re-doing stratified sampling because not all classes were represented (unlucky draw).");
r.remove();
return sampleFrameStratified(fr, label, weights, sampling_ratios, seed+1, debug, ++count);
}
// shuffle intra-chunk
Frame shuffled = shuffleFramePerChunk(r, seed+0x580FF13);
r.remove();
return shuffled;
}
}