package water.api;
import hex.FrameSplitter;
import java.util.Arrays;
import java.util.Random;
import water.*;
import water.fvec.Frame;
import water.util.MRUtils;
import water.util.Utils;
/** Small utility page to split frame
* into n-parts parts based on given ratios.
*
* <p>User specifies n-split ratios, which expose parts of resulting
* datasets and produces (n+1)-datasets based on random selection of rows
* from original dataset.</p>
*
* <p>Keep original chunk distribution.</p>
*
* @see FrameSplitter
*/
public class FrameSplitPage extends Func {
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.
@API(help = "Data frame", required = true, filter = Default.class)
public Frame source;
@API(help = "Split ratio - can be an array of split ratios", required = true, filter = Default.class)
public float[] ratios = new float[] {0.75f}; // n-values => n+1 output datasets
@API(help = "Shuffle rows before splitting", required = false, filter = Default.class)
public boolean shuffle = false;
@API(help = "Seed for reproducible shuffling.", required = false, filter = Default.class)
public long seed = new Random().nextLong();
@API(help = "Keys for each split partition.")
public Key[] split_keys;
@API(help = "Holds a number of rows per each output partition.")
public long[] split_rows;
@API(help = "Holds a number of split ratios per partition.")
public float[] split_ratios;
// Check parameters
@Override protected void init() throws IllegalArgumentException {
super.init();
/* Check input parameters */
float sum = 0;
long nrows = source.numRows();
if (nrows <= ratios.length) throw new IllegalArgumentException("Dataset does not have enough row to be split!");
for (int i=0; i<ratios.length; i++) {
if (!(ratios[i] > 0 && ratios[i] < 1)) throw new IllegalArgumentException("Split ration has to be in (0,1) interval!");
if (ratios[i] * nrows <= 1) throw new IllegalArgumentException("Ratio " + ratios[i] + " produces empty frame since the source frame has only " + nrows + "!");
sum += ratios[i];
}
if (!(sum<1f)) throw new IllegalArgumentException("Sum of split ratios has to be less than 1!");
}
// Run the function
@Override protected void execImpl() {
Frame frame = source;
if (shuffle) {
// FIXME: switch to global shuffle
frame = MRUtils.shuffleFramePerChunk(Utils.generateShuffledKey(frame._key), frame, seed);
frame.delete_and_lock(null).unlock(null); // save frame to DKV
// delete frame on the end
gtrash(frame);
}
FrameSplitter fs = new FrameSplitter(frame, ratios);
H2O.submitTask(fs);
Frame[] splits = fs.getResult();
split_keys = new Key [splits.length];
split_rows = new long[splits.length];
float rsum = Utils.sum(ratios);
split_ratios = Arrays.copyOf(ratios, splits.length);
split_ratios[splits.length-1] = 1f-rsum;
long sum = 0;
for(int i=0; i<splits.length; i++) {
sum += splits[i].numRows();
split_keys[i] = splits[i]._key;
split_rows[i] = splits[i].numRows();
}
assert sum == source.numRows() : "Frame split produced wrong number of rows: nrows(source) != sum(nrows(splits))";
}
@Override public boolean toHTML(StringBuilder sb) {
int nsplits = split_keys.length;
String [] headers = new String[nsplits+2];
headers[0] = "";
for(int i=0; i<nsplits; i++) headers[i+1] = "Split #"+i;
headers[nsplits+1] = "Total";
DocGen.HTML.arrayHead(sb, headers);
// Key table row
sb.append("<tr><td>").append(DocGen.HTML.bold("Keys")).append("</td>");
for (int i=0; i<nsplits; i++) {
Key k = split_keys[i];
sb.append("<td>").append(Inspect2.link(k)).append("</td>");
}
sb.append("<td>").append(Inspect2.link(source._key)).append("</td>");
sb.append("</tr>");
// Number of rows row
sb.append("<tr><td>").append(DocGen.HTML.bold("Rows")).append("</td>");
for (int i=0; i<nsplits; i++) {
long r = split_rows[i];
sb.append("<td>").append(String.format("%,d", r)).append("</td>");
}
sb.append("<td>").append(String.format("%,d", Utils.sum(split_rows))).append("</td>");
sb.append("</tr>");
// Split ratios
sb.append("<tr><td>").append(DocGen.HTML.bold("Ratios")).append("</td>");
for (int i=0; i<nsplits; i++) {
float r = 100*split_ratios[i];
sb.append("<td>").append(String.format("%.2f %%", r)).append("</td>");
}
sb.append("<td>").append(String.format("%.2f %%", 100*Utils.sum(split_ratios))).append("</td>");
sb.append("</tr>");
DocGen.HTML.arrayTail(sb);
return true;
}
}