package hex.splitframe;
import java.util.Random;
import water.*;
import water.fvec.*;
/** Frame splitter function to divide given frame into multiple partitions
* based on given ratios.
*
* <p>The task creates <code>ratios.length+1</code> output frame each
* containing a demanded fraction of rows from source dataset</p>
*
* Rows are selected at random for each split, but remain ordered.
*/
public class ShuffleSplitFrame {
public static Frame[] shuffleSplitFrame( Frame fr, Key<Frame>[] keys, final double ratios[], final long seed ) {
// Sanity check the ratios
assert keys.length == ratios.length;
double sum = ratios[0];
for( int i = 1; i<ratios.length; i++ ) {
sum += ratios[i];
ratios[i] = sum;
}
assert water.util.MathUtils.equalsWithinOneSmallUlp(sum,1.0);
byte[] types = fr.types();
final int ncols = fr.numCols();
byte[] alltypes = new byte[ncols*ratios.length];
for( int i = 0; i<ratios.length; i++ )
System.arraycopy(types,0,alltypes,i*ncols,ncols);
// Do the split, into ratios.length groupings of NewChunks
MRTask mr = new MRTask() {
@Override public void map( Chunk cs[], NewChunk ncs[] ) {
Random rng = new Random(seed*cs[0].cidx());
int nrows = cs[0]._len;
for( int i=0; i<nrows; i++ ) {
double r = rng.nextDouble();
int x=0; // Pick the NewChunk split
for( ; x<ratios.length-1; x++ ) if( r<ratios[x] ) break;
x *= ncols;
// Copy row to correct set of NewChunks
for( int j=0; j<ncols; j++ ) {
byte colType = cs[j].vec().get_type();
switch (colType) {
case Vec.T_BAD : break; /* NOP */
case Vec.T_STR : ncs[x + j].addStr(cs[j], i); break;
case Vec.T_UUID: ncs[x + j].addUUID(cs[j], i); break;
case Vec.T_NUM : /* fallthrough */
case Vec.T_CAT :
case Vec.T_TIME:
ncs[x + j].addNum(cs[j].atd(i));
break;
default:
throw new IllegalArgumentException("Unsupported vector type: " + colType);
}
}
}
}
}.doAll(alltypes,fr);
// Build output frames
Frame frames[] = new Frame[ratios.length];
Vec[] vecs = fr.vecs();
String[] names = fr.names();
Futures fs = new Futures();
for( int i=0; i<ratios.length; i++ ) {
Vec[] nvecs = new Vec[ncols];
final int rowLayout = mr.appendables()[i*ncols].compute_rowLayout();
for( int c=0; c<ncols; c++ ) {
AppendableVec av = mr.appendables()[i*ncols + c];
av.setDomain(vecs[c].domain());
nvecs[c] = av.close(rowLayout,fs);
}
frames[i] = new Frame(keys[i],fr.names(),nvecs);
DKV.put(frames[i],fs);
}
fs.blockForPending();
return frames;
}
}