package water.util; import static water.util.Utils.getDeterRNG; import water.*; import water.fvec.*; import java.util.Random; public class MRUtils { /** * Sample rows from a frame. * Can be unlucky for small sampling fractions - will continue calling itself until at least 1 row is returned. * @param fr Input frame * @param rows Approximate number of rows to sample (across all chunks) * @param seed Seed for RNG * @return Sampled frame */ public static Frame sampleFrame(Frame fr, final long rows, final long seed) { if (fr == null) return null; final float fraction = rows > 0 ? (float)rows / fr.numRows() : 1.f; if (fraction >= 1.f) return fr; Frame r = new MRTask2() { @Override public void map(Chunk[] cs, NewChunk[] ncs) { final Random rng = getDeterRNG(seed + cs[0].cidx()); int count = 0; for (int r = 0; r < cs[0]._len; r++) if (rng.nextFloat() < fraction || (count == 0 && r == cs[0]._len-1) ) { count++; for (int i = 0; i < ncs.length; i++) { ncs[i].addNum(cs[i].at0(r)); } } } }.doAll(fr.numCols(), fr).outputFrame(fr.names(), fr.domains()); if (r.numRows() == 0) { Log.warn("You asked for " + rows + " rows (out of " + fr.numRows() + "), but you got none (seed=" + seed + ")."); Log.warn("Let's try again. You've gotta ask yourself a question: \"Do I feel lucky?\""); return sampleFrame(fr, rows, seed+1); } return r; } /** * Row-wise shuffle of a frame (only shuffles rows inside of each chunk) * @param fr Input frame * @return Shuffled frame */ public static Frame shuffleFramePerChunk(Frame fr, final long seed) { return shuffleFramePerChunk(null, fr, seed); } public static Frame shuffleFramePerChunk(Key outputFrameKey, Frame fr, final long seed) { Frame r = new MRTask2() { @Override public void map(Chunk[] cs, NewChunk[] ncs) { long[] idx = new long[cs[0]._len]; for (int r=0; r<idx.length; ++r) idx[r] = r; Utils.shuffleArray(idx, seed); for (int r=0; r<idx.length; ++r) { for (int i = 0; i < ncs.length; i++) { ncs[i].addNum(cs[i].at0((int)idx[r])); } } } }.doAll(fr.numCols(), fr).outputFrame(outputFrameKey, fr.names(), fr.domains()); return r; } /** * Global redistribution of a Frame (balancing of chunks), done by calling process (all-to-one + one-to-all) * @param fr Input frame * @param seed RNG seed * @param shuffle whether to shuffle the data globally * @return Shuffled frame */ public static Frame shuffleAndBalance(final Frame fr, int splits, long seed, final boolean local, final boolean shuffle) { if( (fr.vecs()[0].nChunks() < splits || shuffle) && fr.numRows() > splits) { Vec[] vecs = fr.vecs().clone(); Log.info("Load balancing dataset, splitting it into up to " + splits + " chunks."); long[] idx = null; if (shuffle) { idx = new long[splits]; for (int r=0; r<idx.length; ++r) idx[r] = r; Utils.shuffleArray(idx, seed); } Key keys[] = new Vec.VectorGroup().addVecs(vecs.length); final long rows_per_new_chunk = (long)(Math.ceil((double)fr.numRows()/splits)); //loop over cols (same indexing for each column) Futures fs = new Futures(); for(int col=0; col<vecs.length; col++) { AppendableVec vec = new AppendableVec(keys[col]); // create outgoing chunks for this col NewChunk[] outCkg = new NewChunk[splits]; for(int i=0; i<splits; ++i) outCkg[i] = new NewChunk(vec, i); //loop over all incoming chunks for( int ckg = 0; ckg < vecs[col].nChunks(); ckg++ ) { final Chunk inCkg = vecs[col].chunkForChunkIdx(ckg); // loop over local rows of incoming chunks (fast path) for (int row = 0; row < inCkg._len; ++row) { int outCkgIdx = (int)((inCkg._start + row) / rows_per_new_chunk); // destination chunk idx if (shuffle) outCkgIdx = (int)(idx[outCkgIdx]); //shuffle: choose a different output chunk assert(outCkgIdx >= 0 && outCkgIdx < splits); outCkg[outCkgIdx].addNum(inCkg.at0(row)); } } for(int i=0; i<outCkg.length; ++i) outCkg[i].close(i, fs); Vec t = vec.close(fs); t._domain = vecs[col]._domain; vecs[col] = t; } fs.blockForPending(); Log.info("Load balancing done."); return new Frame(fr.names(), vecs); } return fr; } /** * Compute the class distribution from a class label vector * (not counting missing values) * * Usage 1: Label vector is categorical * ------------------------------------ * Vec label = ...; * assert(label.isEnum()); * long[] dist = new ClassDist(label).doAll(label).dist(); * * Usage 2: Label vector is numerical * ---------------------------------- * Vec label = ...; * int num_classes = ...; * assert(label.isInt()); * long[] dist = new ClassDist(num_classes).doAll(label).dist(); * */ public static class ClassDist extends ClassDistHelper { public ClassDist(final Vec label) { super(label.domain().length); } public ClassDist(int n) { super(n); } public final long[] dist() { return _ys; } public final float[] rel_dist() { float[] rel = new float[_ys.length]; for (int i=0; i<_ys.length; ++i) rel[i] = (float)_ys[i]; final float sum = Utils.sum(rel); assert(sum != 0.); Utils.div(rel, sum); return rel; } } private static class ClassDistHelper extends MRTask2<ClassDist> { private ClassDistHelper(int nclass) { _nclass = nclass; } final int _nclass; protected long[] _ys; @Override public void map(Chunk ys) { _ys = new long[_nclass]; for( int i=0; i<ys._len; i++ ) if( !ys.isNA0(i) ) _ys[(int)ys.at80(i)]++; } @Override public void reduce( ClassDist that ) { Utils.add(_ys,that._ys); } } /** * Stratified sampling for classifiers * @param fr Input frame * @param label Label vector (must be enum) * @param sampling_ratios Optional: array containing the requested sampling ratios per class (in order of domains), will be overwritten if it contains all 0s * @param maxrows Maximum number of rows in the returned frame * @param seed RNG seed for sampling * @param allowOversampling Allow oversampling of minority classes * @param verbose Whether to print verbose info * @return Sampled frame, with approximately the same number of samples from each class (or given by the requested sampling ratios) */ public static Frame sampleFrameStratified(final Frame fr, Vec label, float[] sampling_ratios, long maxrows, final long seed, final boolean allowOversampling, final boolean verbose) { if (fr == null) return null; assert(label.isEnum()); assert(maxrows >= label.domain().length); long[] dist = new ClassDist(label).doAll(label).dist(); assert(dist.length > 0); Log.info("Doing stratified sampling for data set containing " + fr.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (allowOversampling ? "on" : "off")); if (verbose) { for (int i=0; i<dist.length;++i) { Log.info("Class " + label.domain(i) + ": count: " + dist[i] + " prior: " + (float)dist[i]/fr.numRows()); } } // create sampling_ratios for class balance with max. maxrows rows (fill existing array if not null) if (sampling_ratios == null || (Utils.minValue(sampling_ratios) == 0 && Utils.maxValue(sampling_ratios) == 0)) { // compute sampling ratios to achieve class balance if (sampling_ratios == null) { sampling_ratios = new float[dist.length]; } assert(sampling_ratios.length == dist.length); for (int i=0; i<dist.length;++i) { if (dist[i] == 0) { Log.warn("No rows of class " + label.domain()[i] + " found."); } sampling_ratios[i] = dist[i] == 0 ? 1 // don't sample if there's no rows of a certain class (avoid division by 0) : ((float)fr.numRows() / label.domain().length) / dist[i]; // prior^-1 / num_classes assert(sampling_ratios[i] >= 0); } final float inv_scale = Utils.minValue(sampling_ratios); //majority class has lowest required oversampling factor to achieve balance if (!Float.isNaN(inv_scale) && !Float.isInfinite(inv_scale)) Utils.div(sampling_ratios, inv_scale); //want sampling_ratio 1.0 for majority class (no downsampling) } for (float s : sampling_ratios) assert(!Float.isNaN(s) && !Float.isInfinite(s)); if (!allowOversampling) { for (int i=0; i<sampling_ratios.length; ++i) { sampling_ratios[i] = Math.min(1.0f, sampling_ratios[i]); } } for (float s : sampling_ratios) assert(!Float.isNaN(s) && !Float.isInfinite(s)); // given these sampling ratios, and the original class distribution, this is the expected number of resulting rows float numrows = 0; for (int i=0; i<sampling_ratios.length; ++i) { numrows += sampling_ratios[i] * dist[i]; } final long actualnumrows = Math.min(maxrows, Math.round(numrows)); //cap #rows at maxrows assert(actualnumrows >= 0); //can have no matching rows in case of sparse data where we had to fill in a makeZero() vector Log.info("Stratified sampling to a total of " + String.format("%,d", actualnumrows) + " rows" + (actualnumrows < numrows ? " (limited by max_after_balance_size).":".")); if (actualnumrows != numrows) { assert(numrows > 0); Utils.mult(sampling_ratios, (float)actualnumrows/numrows); //adjust the sampling_ratios by the global rescaling factor if (verbose) Log.info("Downsampling majority class by " + (float)actualnumrows/numrows + " to limit number of rows to " + String.format("%,d", maxrows)); } if (Utils.minIndex(sampling_ratios) == Utils.maxIndex(sampling_ratios)) { Log.info("All classes are sampled with sampling ratio: " + Utils.minValue(sampling_ratios)); } else { for (int i=0;i<label.domain().length;++i) { Log.info("Class '" + label.domain()[i].toString() + "' sampling ratio: " + sampling_ratios[i]); } } return sampleFrameStratified(fr, label, sampling_ratios, seed, verbose); } /** * Stratified sampling * @param fr Input frame * @param label Label vector (from the input frame) * @param sampling_ratios Given sampling ratios for each class, in order of domains * @param seed RNG seed * @param debug Whether to print debug info * @return Stratified frame */ public static Frame sampleFrameStratified(final Frame fr, Vec label, final float[] sampling_ratios, final long seed, final boolean debug) { return sampleFrameStratified(fr, label, sampling_ratios, seed, debug, 0); } // internal version with repeat counter // currently hardcoded to do up to 10 tries to get a row from each class, which can be impossible for certain wrong sampling ratios private static Frame sampleFrameStratified(final Frame fr, Vec label, final float[] sampling_ratios, final long seed, final boolean debug, int count) { if (fr == null) return null; assert(label.isEnum()); assert(sampling_ratios != null && sampling_ratios.length == label.domain().length); for (float s : sampling_ratios) assert(!Float.isNaN(s)); final int labelidx = fr.find(label); //which column is the label? assert(labelidx >= 0); final boolean poisson = false; //beta feature Frame r = new MRTask2() { @Override public void map(Chunk[] cs, NewChunk[] ncs) { final Random rng = getDeterRNG(seed + cs[0].cidx()); for (int r = 0; r < cs[0]._len; r++) { if (cs[labelidx].isNA0(r)) continue; //skip missing labels final int label = (int)cs[labelidx].at80(r); assert(sampling_ratios.length > label && label >= 0); int sampling_reps; if (poisson) { sampling_reps = Utils.getPoisson(sampling_ratios[label], rng); } else { final float remainder = sampling_ratios[label] - (int)sampling_ratios[label]; sampling_reps = (int)sampling_ratios[label] + (rng.nextFloat() < remainder ? 1 : 0); } for (int i = 0; i < ncs.length; i++) { for (int j = 0; j < sampling_reps; ++j) { ncs[i].addNum(cs[i].at0(r)); } } } } }.doAll(fr.numCols(), fr).outputFrame(fr.names(), fr.domains()); assert(r.numCols() == fr.numCols()); // Confirm the validity of the distribution long[] dist = new ClassDist(r.vecs()[labelidx]).doAll(r.vecs()[labelidx]).dist(); // if there are no training labels in the test set, then there is no point in sampling the test set if (dist == null) { r.delete(); return fr; } if (debug) { long sumdist = Utils.sum(dist); Log.info("After stratified sampling: " + sumdist + " rows."); for (int i=0; i<dist.length;++i) { Log.info("Class " + r.vecs()[labelidx].domain(i) + ": count: " + dist[i] + " sampling ratio: " + sampling_ratios[i] + " actual relative frequency: " + (float)dist[i] / sumdist * dist.length); } } // Re-try if we didn't get at least one example from each class if (Utils.minValue(dist) == 0 && count < 10) { Log.info("Re-doing stratified sampling because not all classes were represented (unlucky draw)."); r.delete(); return sampleFrameStratified(fr, label, sampling_ratios, seed+1, debug, ++count); } // // shuffle intra-chunk // Frame shuffled = shuffleFramePerChunk(r, seed + 0x580FF13); // r.delete(); // return shuffled; return r; } /** * Compute the L2 norm for each row of the frame * @param fr Input frame * @return Vec containing L2 values for each row, is in K-V store */ public static Vec getL2(final Frame fr, final double[] scale) { // add workspace vec at end final int idx = fr.numCols(); assert(scale.length == idx) : "Mismatch for number of columns"; fr.add("L2", fr.anyVec().makeZero()); Vec res; try { new MRTask2() { @Override public void map(Chunk[] cs) { for (int r = 0; r < cs[0]._len; r++) { double norm2 = 0; for (int i = 0; i < idx; i++) norm2 += Math.pow(cs[i].at0(r) * scale[i], 2); cs[idx].set0(r, Math.sqrt(norm2)); } } }.doAll(fr); } finally { res = fr.remove(idx); } res.rollupStats(); return res; } }