package water.rapids.ast.prims.advmath; import hex.CreateFrame; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import water.DKV; import water.Key; import water.TestUtil; import water.fvec.Frame; import water.parser.BufferedString; import water.rapids.Rapids; import water.rapids.Val; import water.util.ArrayUtils; public class StratifiedSplitTest extends TestUtil{ private static Frame f = null, fr1 = null, fanimal = null, fr2 = null; @BeforeClass public static void setup() { stall_till_cloudsize(1); } @AfterClass public static void teardown() { f.delete(); fr1.delete(); fanimal.delete(); fr2.delete(); } @Test public void testStratifiedSampling() { f = ArrayUtils.frame("response" ,vec(ari(1,0,0,0,0,0,0,0,0,0,0,1))); fanimal = ArrayUtils.frame("response" ,vec(ar("dog","cat"),ari(1,0,0,0,0,0,0,0,0,0,0,1))); f = new Frame(f); fanimal = new Frame(fanimal); f._key = Key.make(); fanimal._key = Key.make(); DKV.put(f); DKV.put(fanimal); Val res1 = Rapids.exec("(h2o.random_stratified_split (cols_py " + f._key + " 0) 0.3333333 123)"); // fr1 = res1.getFrame(); Assert.assertEquals(fr1.vec(0).at8(0),1); // minority class should be in the test split Assert.assertEquals(fr1.vec(0).at8(11),0); // minority class should be in the train split Assert.assertEquals(fr1.vec(0).mean(),1.0/3.0,1e-5); // minority class should be in the train split //test categorical Val res2 = Rapids.exec("(h2o.random_stratified_split (cols_py " + fanimal._key + " 0) 0.3333333 123)"); // fr2 = res2.getFrame(); Assert.assertEquals(fr2.vec(0).at8(0),1); // minority class should be in the test split Assert.assertEquals(fr2.vec(0).at8(11),0); // minority class should be in the test split Assert.assertEquals(fr2.vec(0).mean(),1.0/3.0,1e-5); // minority class should be in the test split } }