package water.rapids; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import water.DKV; import water.Key; import water.Keyed; import water.TestUtil; import water.fvec.Frame; import water.rapids.vals.ValFrame; public class GroupByTest extends TestUtil { @BeforeClass public static void setup() { stall_till_cloudsize(5); } @Test public void testBasic() { Frame fr = null; String tree = "(GB hex [1] mean 2 \"all\")"; // Group-By on col 1 (not 0), no order-by, mean of col 2 try { fr = chkTree(tree,"smalldata/iris/iris_wheader.csv"); chkDim(fr,2,23); chkFr(fr,0,0,2.0); // Group 2.0, mean is 3.5 chkFr(fr,1,0,3.5); chkFr(fr,0,1,2.2); // Group 2.2, mean is 4.5 chkFr(fr,1,1,4.5); chkFr(fr,0,7,2.8); // Group 2.8, mean is 5.043, largest group chkFr(fr,1,7,5.042857142857143); chkFr(fr,0,22,4.4); // Group 4.4, mean is 1.5, last group chkFr(fr,1,22,1.5); } finally { if( fr != null ) fr.delete(); Keyed.remove(Key.make("hex")); } } @Test public void testCatGroup() { Frame fr = null; String tree = "(GB hex [4] nrow 0 \"all\" mean 2 \"all\")"; // Group-By on col 4, no order-by, nrow and mean of col 2 try { fr = chkTree(tree,"smalldata/iris/iris_wheader.csv"); chkDim(fr,3,3); chkFr(fr,0,0,"Iris-setosa"); chkFr(fr,1,0,50); chkFr(fr,2,0,1.464); chkFr(fr,0,1,"Iris-versicolor"); chkFr(fr,1,1,50); chkFr(fr,2,1,4.26 ); chkFr(fr,0,2,"Iris-virginica"); chkFr(fr,1,2,50); chkFr(fr,2,2,5.552); fr.delete(); fr = chkTree("(GB hex [1] mode 4 \"all\" )","smalldata/iris/iris_wheader.csv"); chkDim(fr,2,23); } finally { if( fr != null ) fr.delete(); Keyed.remove(Key.make("hex")); } } @Test public void testNAHandle() { Frame fr = null; try { String tree = "(GB hex [7] nrow 0 \"all\" mean 1 \"all\")"; // Group-By on year, no order-by, mean of economy fr = chkTree(tree,"smalldata/junit/cars.csv"); chkDim(fr,3,13); chkFr(fr,0,0,70); // 1970, 35 cars, NA in economy chkFr(fr,1,0,35); chkFr(fr,2,0,Double.NaN); chkFr(fr,0,2,72); // 1972, 28 cars, 18.714 in economy chkFr(fr,1,2,28); chkFr(fr,2,2,18.714,1e-1); fr.delete(); tree = "(GB hex [7] nrow 1 \"all\" nrow 1 \"rm\" nrow 1 \"ignore\")"; // Group-By on year, no order-by, nrow of economy fr = chkTree(tree,"smalldata/junit/cars.csv"); chkDim(fr,4,13); chkFr(fr,0,0,70); // 1970, 35 cars, 29 have economy chkFr(fr,1,0,35); // ALL chkFr(fr,2,0,29); // RM chkFr(fr,3,0,29); // IGNORE fr.delete(); tree = "(GB hex [7] mean 1 \"all\" mean 1 \"rm\" mean 1 \"ignore\")"; // Group-By on year, no order-by, mean of economy fr = chkTree(tree,"smalldata/junit/cars.csv"); chkDim(fr,4,13); chkFr(fr,0,0,70); // 1970, 35 cars, 29 have economy chkFr(fr,1,0,Double.NaN); // ALL chkFr(fr,2,0,17.69, 1e-1); // RM chkFr(fr,3,0,14.66, 1e-1); // IGNORE } finally { if( fr != null ) fr.delete(); Keyed.remove(Key.make("hex")); } } @Test public void testAllAggs() { Frame fr = null; try { String tree = "(GB hex [4] nrow 0 \"rm\" mean 1 \"rm\" sum 1 \"rm\" min 1 \"rm\" max 1 \"rm\" )"; fr = chkTree(tree,"smalldata/iris/iris_wheader.csv"); chkDim(fr,6,3); chkFr(fr,0,0,"Iris-setosa"); chkFr(fr,1,0,50); // nrow chkFr(fr,2,0,3.418); // mean chkFr(fr,3,0,170.9); // sum chkFr(fr,4,0, 2.3); // min chkFr(fr,5,0, 4.4); // max chkFr(fr,0,1,"Iris-versicolor"); chkFr(fr,1,1,50); // nrow chkFr(fr,2,1,2.770); // mean chkFr(fr,3,1,138.5); // sum chkFr(fr,4,1, 2.0); // min chkFr(fr,5,1, 3.4); // max chkFr(fr,0,2,"Iris-virginica"); chkFr(fr,1,2,50); // nrow chkFr(fr,2,2,2.974); // mean chkFr(fr,3,2,148.7); // sum chkFr(fr,4,2, 2.2); // min chkFr(fr,5,2, 3.8); // max } finally { if( fr != null ) fr.delete(); Keyed.remove(Key.make("hex")); } } @Test public void testImpute() { Frame fr = null; Frame fr2 =null; try { // Impute fuel economy via the "mean" method, no. String tree = "(h2o.impute hex 1 \"mean\" \"low\" [] _ _)"; // (h2o.impute data col method combine_method groupby groupByFrame values) chkTree(tree,"smalldata/junit/cars.csv",1f); fr = DKV.getGet("hex"); chkDim(fr,8,406); Assert.assertEquals(0,fr.vec(1).naCnt()); // No NAs anymore Assert.assertEquals(23.51,fr.vec(1).at(26),1e-1); // Row 26 was an NA, now as mean economy fr.delete(); // Impute fuel economy via the "mean" method, after grouping by year. Update in place. tree = "(h2o.impute hex 1 \"mean\" \"low\" [7] _ _)"; fr2 = chkTree(tree,"smalldata/junit/cars.csv",1f); fr = DKV.getGet("hex"); chkDim(fr,8,406); Assert.assertEquals(0,fr.vec(1).naCnt()); // No NAs anymore Assert.assertEquals(17.69,fr.vec(1).at(26),1e-1); // Row 26 was an NA, now as 1970 mean economy } finally { if( fr != null ) fr.delete(); if( fr2!=null ) fr2.delete(); Keyed.remove(Key.make("hex")); } } @Test public void testBasicDdply() { Frame fr = null; String tree = "(ddply hex [1] {x . (flatten (mean (cols x 2) TRUE))})"; // Group-By on col 1 (not 0) mean of col 2 try { fr = chkTree(tree,"smalldata/iris/iris_wheader.csv"); chkDim(fr,2,23); chkFr(fr,0,0,2.0); // Group 2.0, mean is 3.5 chkFr(fr,1,0,3.5); chkFr(fr,0,1,2.2); // Group 2.2, mean is 4.5 chkFr(fr,1,1,4.5); chkFr(fr,0,7,2.8); // Group 2.8, mean is 5.043, largest group chkFr(fr,1,7,5.042857142857143); chkFr(fr,0,22,4.4); // Group 4.4, mean is 1.5, last group chkFr(fr,1,22,1.5); fr.delete(); fr = chkTree("(ddply hex [1] {x . (sum (* (cols x 2) (cols x 3)))})","smalldata/iris/iris_wheader.csv"); chkDim(fr,2,23); } finally { if( fr != null ) fr.delete(); Keyed.remove(Key.make("hex")); } } // covtype.altered response column has this distribution: // -1 20510 // 1 211840 // 2 283301 // 3 35754 // 4 2747 // 6 17367 // 10000 9493 @Test public void testSplitCats() { Frame cov = parse_test_file(Key.make("cov"),"smalldata/covtype/covtype.altered.gz"); System.out.println(cov.toString(0,10)); Val v_ddply = Rapids.exec("(ddply cov [54] nrow)"); System.out.println(v_ddply.toString()); v_ddply.getFrame().delete(); Val v_groupby = Rapids.exec("(GB cov [54] nrow 54 \"all\")"); System.out.println(v_groupby.toString()); v_groupby.getFrame().delete(); cov.delete(); } @Test public void testGroupbyTableSpeed() { Frame ids = parse_test_file(Key.make("cov"),"smalldata/junit/id_cols.csv"); ids.replace(0,ids.anyVec().toCategoricalVec()).remove(); System.out.println(ids.toString(0,10)); long start = System.currentTimeMillis(); Val v_gb = Rapids.exec("(GB cov [0] nrow 0 \"all\")"); System.out.println("GB Time= "+(System.currentTimeMillis()-start)+"msec"); System.out.println(v_gb.toString()); v_gb.getFrame().delete(); long start2 = System.currentTimeMillis(); Val v_tb = Rapids.exec("(table cov FALSE)"); System.out.println("Table Time= "+(System.currentTimeMillis()-start2)+"msec"); System.out.println(v_tb.toString()); v_tb.getFrame().delete(); ids.delete(); } private void chkDim( Frame fr, int col, int row ) { Assert.assertEquals(col,fr.numCols()); Assert.assertEquals(row,fr.numRows()); } private void chkFr( Frame fr, int col, int row, double exp ) { chkFr(fr,col,row,exp,Math.ulp(1)); } private void chkFr( Frame fr, int col, int row, double exp, double tol ) { if( Double.isNaN(exp) ) Assert.assertTrue(fr.vec(col).isNA(row)); else Assert.assertEquals(exp, fr.vec(col).at(row),tol); } private void chkFr( Frame fr, int col, int row, String exp ) { String[] dom = fr.vec(col).domain(); Assert.assertEquals(exp, dom[(int)fr.vec(col).at8(row)]); } private Frame chkTree(String tree, String fname, float d) { Frame fr = parse_test_file(Key.make("hex"),fname); Val val = Rapids.exec(tree); System.out.println(val.toString()); if( val instanceof ValFrame ) return val.getFrame(); return null; } private Frame chkTree(String tree, String fname) { return chkTree(tree,fname,false); } private Frame chkTree(String tree, String fname, boolean expectThrow) { Frame fr = parse_test_file(Key.make("hex"),fname); try { Val val = Rapids.exec(tree); System.out.println(val.toString()); if( val instanceof ValFrame ) return val.getFrame(); throw new IllegalArgumentException("expected a frame return"); } catch( IllegalArgumentException iae ) { if( !expectThrow ) throw iae; // If not expecting a throw, then throw which fails the junit fr.delete(); // If expecting, then cleanup return null; } } }