package hex.splitframe;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.MRTask;
import water.TestUtil;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import java.util.Arrays;
import static water.fvec.FrameTestUtil.createFrame;
import static water.fvec.FrameTestUtil.collectS;
import static water.util.ArrayUtils.flat;
import static water.util.ArrayUtils.append;
/**
* Tests for shuffle split frame.
*/
public class ShuffleSplitFrameTest extends TestUtil {
@BeforeClass() public static void setup() { stall_till_cloudsize(1); }
/** Reported as PUBDEV-452 */
@Test
public void testShuffleSplitOnStringColumn() {
long[] chunkLayout = ar(2L, 2L, 3L);
String[][] data = ar(ar("A", "B"), ar(null, "C"), ar("D", "E", "F"));
Frame f = createFrame("ShuffleSplitTest1.hex", chunkLayout, data);
testScenario(f, flat(data));
chunkLayout = ar(3L, 3L);
data = ar(ar("A", null, "B"), ar("C", "D", "E"));
f = createFrame("test2.hex", chunkLayout, data);
testScenario(f, flat(data));
}
@Test /* this test makes sure that the rows of the split frames are preserved (including UUID) */
public void testShuffleSplitWithMultipleColumns() {
long[] chunkLayout = ar(2L, 2L, 3L);
String[][] data = ar(ar("1", "2"), ar(null, "3"), ar("4", "5", "6"));
Frame f = null;
Frame tmpFrm = createFrame("ShuffleSplitMCTest1.hex", chunkLayout, data);
try {
f = new MRTask() {
@Override
public void map(Chunk[] cs, NewChunk[] ncs) {
for (int i = 0; i < cs[0]._len; i++) {
BufferedString bs = cs[0].atStr(new BufferedString(), i);
int val = bs == null ? 0 : Integer.parseInt(bs.toString());
ncs[0].addStr(bs);
ncs[1].addNum(val);
ncs[2].addNum(i);
ncs[3].addUUID(i, val);
}
}
}.doAll(new byte[]{Vec.T_STR, Vec.T_NUM, Vec.T_NUM, Vec.T_UUID}, tmpFrm).outputFrame();
} finally {
tmpFrm.delete();
}
testScenario(f, flat(data), new MRTask() {
@Override
public void map(Chunk[] cs) {
for (int i = 0; i < cs[0]._len; i++) {
BufferedString bs = cs[0].atStr(new BufferedString(), i);
int expectedVal = bs == null ? 0 : Integer.parseInt(bs.toString());
int expectedIndex = (int) cs[2].atd(i);
Assert.assertEquals((double) expectedVal, cs[1].atd(i), 0.00001);
Assert.assertEquals(expectedIndex, (int) cs[3].at16l(i));
Assert.assertEquals(expectedVal, (int) cs[3].at16h(i));
}
}
});
}
static void testScenario(Frame f, String[] expValues) { testScenario(f, expValues, null); }
/** Simple testing scenario, splitting frame in the middle and comparing the values */
static void testScenario(Frame f, String[] expValues, MRTask chunkAssertions) {
double[] ratios = ard(0.5, 0.5);
Key<Frame>[] keys = aro(Key.<Frame>make("test.hex"), Key.<Frame>make("train.hex"));
Frame[] splits = null;
try {
splits = ShuffleSplitFrame.shuffleSplitFrame(f, keys, ratios, 42);
Assert.assertEquals("Expecting 2 splits", 2, splits.length);
// Collect values from both splits
String[] values = append(
collectS(splits[0].vec(0)),
collectS(splits[1].vec(0)));
// Sort values, but first replace all nulls by unique value
Arrays.sort(replaceNulls(expValues));
Arrays.sort(replaceNulls(values));
Assert.assertArrayEquals("Values should match", expValues, values);
if (chunkAssertions != null) {
for (Frame s: splits) chunkAssertions.doAll(s).getResult();
}
} finally {
f.delete();
if (splits!=null) for(Frame s: splits) s.delete();
}
}
private static String[] replaceNulls(String[] ary) {
return replaceNulls(ary, "_NA_#");
}
private static String[] replaceNulls(String[] ary, String replacement) {
for (int i = 0; i < ary.length; i++ ) {
if (ary[i] == null) ary[i] = replacement;
}
return ary;
}
}