package hex.util; import hex.CreateFrame; import hex.DataInfo; import hex.Model; import hex.aggregator.Aggregator; import hex.aggregator.AggregatorModel; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; import water.*; import water.fvec.Frame; import water.fvec.RebalanceDataSet; import water.fvec.Vec; import water.util.Log; public class AggregatorTest extends TestUtil { @BeforeClass() public static void setup() { stall_till_cloudsize(1); } @Test public void testAggregator100() { testAggregator(100); } @Test public void testAggregator1k() { testAggregator(1000); } @Test public void testAggregator13() { testAggregator(13); } @Test public void testAggregator10k() { testAggregator(10000); } public void testAggregator(int max) { CreateFrame cf = new CreateFrame(); cf.rows = 100000; cf.cols = 2; cf.categorical_fraction = 0.1; cf.integer_fraction = 0.3; cf.real_range = 100; cf.integer_range = 100; cf.seed = 1234; Frame frame = cf.execImpl().get(); AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters(); parms._train = frame._key; parms._target_num_exemplars = max; long start = System.currentTimeMillis(); AggregatorModel agg = new Aggregator(parms).trainModel().get(); System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds"); agg.checkConsistency(); Frame output = agg._output._output_frame.get(); System.out.println(output.toTwoDimTable(0,10)); frame.delete(); checkNumExemplars(agg); output.remove(); agg.remove(); } @Test public void testAggregatorEigen() { CreateFrame cf = new CreateFrame(); cf.rows = 1000; cf.cols = 10; cf.categorical_fraction = 0.6; cf.integer_fraction = 0.0; cf.binary_fraction = 0.0; cf.real_range = 100; cf.integer_range = 100; cf.missing_fraction = 0; cf.factors = 5; cf.seed = 1234; Frame frame = cf.execImpl().get(); AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters(); parms._train = frame._key; parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.Eigen; long start = System.currentTimeMillis(); AggregatorModel agg = new Aggregator(parms).trainModel().get(); // 0.905 System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds"); agg.checkConsistency(); Frame output = agg._output._output_frame.get(); System.out.println(output.toTwoDimTable(0,10)); Log.info("Number of exemplars: " + agg._exemplars.length); output.remove(); frame.remove(); agg.remove(); } @Test public void testAggregatorBinary() { CreateFrame cf = new CreateFrame(); cf.rows = 1000; cf.cols = 10; cf.categorical_fraction = 0.6; cf.integer_fraction = 0.0; cf.binary_fraction = 0.0; cf.real_range = 100; cf.integer_range = 100; cf.missing_fraction = 0.1; cf.factors = 5; cf.seed = 1234; Frame frame = cf.execImpl().get(); AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters(); parms._train = frame._key; parms._transform = DataInfo.TransformType.NORMALIZE; parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.Binary; long start = System.currentTimeMillis(); AggregatorModel agg = new Aggregator(parms).trainModel().get(); // 0.905 System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds"); agg.checkConsistency(); Frame output = agg._output._output_frame.get(); System.out.println(output.toTwoDimTable(0,10)); Log.info("Number of exemplars: " + agg._exemplars.length); Assert.assertTrue(agg._exemplars.length==1000); output.remove(); frame.remove(); agg.remove(); } @Test public void testAggregatorOneHot() { Scope.enter(); CreateFrame cf = new CreateFrame(); cf.rows = 1000; cf.cols = 10; cf.categorical_fraction = 0.6; cf.integer_fraction = 0.0; cf.binary_fraction = 0.0; cf.real_range = 100; cf.integer_range = 100; cf.missing_fraction = 0.1; cf.factors = 5; cf.seed = 1234; Frame frame = cf.execImpl().get(); AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters(); parms._train = frame._key; parms._target_num_exemplars = 278; parms._transform = DataInfo.TransformType.NORMALIZE; parms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.OneHotExplicit; long start = System.currentTimeMillis(); AggregatorModel agg = new Aggregator(parms).trainModel().get(); // 0.905 System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds"); agg.checkConsistency(); Frame output = agg._output._output_frame.get(); System.out.println(output.toTwoDimTable(0,10)); checkNumExemplars(agg); output.remove(); frame.remove(); agg.remove(); Scope.exit(); } @Ignore @Test public void testAirlines() { Frame frame = parse_test_file("smalldata/airlines/allyears2k_headers.zip"); AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters(); parms._train = frame._key; parms._target_num_exemplars = 500; parms._rel_tol_num_exemplars = 0.05; long start = System.currentTimeMillis(); AggregatorModel agg = new Aggregator(parms).trainModel().get(); // 0.179 System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds"); agg.checkConsistency(); frame.delete(); Frame output = agg._output._output_frame.get(); output.remove(); checkNumExemplars(agg); agg.remove(); } @Test public void testCovtype() { Frame frame = parse_test_file("smalldata/covtype/covtype.20k.data"); AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters(); parms._train = frame._key; parms._target_num_exemplars = 500; parms._rel_tol_num_exemplars = 0.05; long start = System.currentTimeMillis(); AggregatorModel agg = new Aggregator(parms).trainModel().get(); // 0.179 System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds"); agg.checkConsistency(); frame.delete(); Frame output = agg._output._output_frame.get(); Log.info("Exemplars: " + output.toString()); output.remove(); checkNumExemplars(agg); agg.remove(); } public void checkNumExemplars(AggregatorModel m) { Log.info("Number of exemplars: " + m._exemplars.length); Assert.assertTrue(m._exemplars.length >= (1.-m._parms._rel_tol_num_exemplars)*m._parms._target_num_exemplars); Assert.assertTrue(m._exemplars.length <= (1.+m._parms._rel_tol_num_exemplars)*m._parms._target_num_exemplars); } @Test public void testChunks() { Frame frame = parse_test_file("smalldata/covtype/covtype.20k.data"); AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters(); parms._train = frame._key; parms._target_num_exemplars = 137; parms._rel_tol_num_exemplars = 0.05; long start = System.currentTimeMillis(); AggregatorModel agg = new Aggregator(parms).trainModel().get(); // 0.418 System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds"); agg.checkConsistency(); Frame output = agg._output._output_frame.get(); checkNumExemplars(agg); output.remove(); agg.remove(); for (int i : new int[]{1,2,5,10,50,100}) { Key key = Key.make(); RebalanceDataSet rb = new RebalanceDataSet(frame, key, i); H2O.submitTask(rb); rb.join(); Frame rebalanced = DKV.get(key).get(); parms = new AggregatorModel.AggregatorParameters(); parms._train = frame._key; parms._target_num_exemplars = 137; parms._rel_tol_num_exemplars = 0.05; start = System.currentTimeMillis(); AggregatorModel agg2 = new Aggregator(parms).trainModel().get(); // 0.373 0.504 0.357 0.454 0.368 0.355 System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds"); agg2.checkConsistency(); Log.info("Number of exemplars for " + i + " chunks: " + agg2._exemplars.length); rebalanced.delete(); Assert.assertTrue(Math.abs(agg._exemplars.length - agg2._exemplars.length) == 0); output = agg2._output._output_frame.get(); output.remove(); checkNumExemplars(agg); agg2.remove(); } frame.delete(); } @Ignore @Test public void testCovtypeMemberIndices() { Frame frame = parse_test_file("smalldata/covtype/covtype.20k.data"); AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters(); parms._train = frame._key; parms._target_num_exemplars = 117; long start = System.currentTimeMillis(); AggregatorModel agg = new Aggregator(parms).trainModel().get(); // 1.489 System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds"); agg.checkConsistency(); // Frame assignment = new Frame(new Vec[]{(Vec)agg._exemplar_assignment_vec_key.get()}); // Frame.export(assignment, "/tmp/assignment", "yada", true); // Log.info("Exemplars: " + new Frame(new Vec[]{(Vec)agg._exemplar_assignment_vec_key.get()}).toString(0,20000)); Log.info("Number of exemplars: " + agg._exemplars.length); Key<Frame> memberKey = Key.make(); for (int i=0; i<agg._exemplars.length; ++i) { Frame members = agg.scoreExemplarMembers(memberKey, i); assert (members.numRows() == agg._counts[i]); // Log.info(members); members.delete(); } Frame output = agg._output._output_frame.get(); output.remove(); checkNumExemplars(agg); frame.delete(); agg.remove(); } @Test public void testDomains() { Frame frame = parse_test_file("smalldata/junit/weather.csv"); for (String s : new String[]{"MaxWindSpeed", "RelHumid9am", "Cloud9am"}) { Vec v = frame.vec(s); Vec newV = v.toCategoricalVec(); frame.remove(s); frame.add(s,newV); v.remove(); } DKV.put(frame); AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters(); parms._train = frame._key; parms._target_num_exemplars = 17; AggregatorModel agg = new Aggregator(parms).trainModel().get(); Frame output = agg._output._output_frame.get(); Assert.assertTrue(output.numRows() <= 17); boolean same = true; for (int i=0;i<frame.numCols();++i) { if (frame.vec(i).isCategorical()) { same = (frame.domains()[i].length == output.domains()[i].length); if (!same) break; } } frame.remove(); output.remove(); agg.remove(); Assert.assertFalse(same); } @Ignore @Test public void testMNIST() { Frame frame = parse_test_file("bigdata/laptop/mnist/train.csv.gz"); AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters(); parms._train = frame._key; long start = System.currentTimeMillis(); AggregatorModel agg = new Aggregator(parms).trainModel().get(); System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds"); agg.checkConsistency(); frame.delete(); Frame output = agg._output._output_frame.get(); // Log.info("Exemplars: " + output); output.remove(); Log.info("Number of exemplars: " + agg._exemplars.length); checkNumExemplars(agg); agg.remove(); } }