package hex.singlenoderf; import org.junit.Assert; import static org.junit.Assert.assertEquals; import org.junit.BeforeClass; import water.*; import water.fvec.Frame; import water.fvec.Vec; import water.util.Log; public class SpeeDRFTest extends TestUtil { private final void testHTML(SpeeDRFModel m) { StringBuilder sb = new StringBuilder(); SpeeDRFModelView drfv = new SpeeDRFModelView(); drfv.speedrf_model = m; drfv.toHTML(sb); assert(sb.length() > 0); } @BeforeClass public static void stall() { stall_till_cloudsize(1); } // Test kaggle/creditsample-test data @org.junit.Test public void kaggle_credit() { Key destTrain = Key.make("credit"); Frame fr = parseFrame(destTrain, "smalldata/kaggle/creditsample-training.csv.gz"); // Check parsed dataset final int n = 1; assertEquals("Number of chunks", n, fr.anyVec().nChunks()); assertEquals("Number of rows", 150000, fr.numRows()); assertEquals("Number of cols", 12, fr.numCols()); // setup DRF values Vec response = fr.vecs()[1]; int[] ignored_cols = new int[]{6}; SpeeDRF spdrf = new SpeeDRF(); spdrf.source = fr; spdrf.response = response; spdrf.ignored_cols = ignored_cols; spdrf.ntrees = 3; spdrf.max_depth = 30; spdrf.select_stat_type = Tree.SelectStatType.GINI; spdrf.seed = 42; Log.info("Invoking the SpeeDRF task."); spdrf.invoke(); SpeeDRFModel m = UKV.get(spdrf.dest()); Assert.assertTrue(m.get_params().state == Job.JobState.DONE); //HEX-1817 testHTML(m); assertEquals("Number of classes", 2, m.classes()); assertEquals("Number of trees", 3, m.size()); m.delete(); fr.delete(); } @org.junit.Test public void covtype() { Frame fr = parseFrame(Key.make("covtype.hex"), "smalldata/covtype/covtype.20k.data"); //Key okey = loadAndParseFile("covtype.hex", "../datasets/UCI/UCI-large/covtype/covtype.data"); //Key okey = loadAndParseFile("covtype.hex", "/home/0xdiag/datasets/standard/covtype.data"); //Key okey = loadAndParseFile("mnist.hex", "/home/0xdiag/datasets/mnist/mnist8m.csv"); // setup default values for DRF Vec response = fr.vecs()[54]; SpeeDRF spdrf = new SpeeDRF(); spdrf.source = fr; spdrf.response = response; spdrf.ntrees = 8; spdrf.max_depth = 999; spdrf.select_stat_type = Tree.SelectStatType.ENTROPY; spdrf.seed = 42; spdrf.invoke(); SpeeDRFModel m = UKV.get(spdrf.dest()); Assert.assertTrue(m.get_params().state == Job.JobState.DONE); //HEX-1817 testHTML(m); assertEquals("Number of classes", 7, m.classes()); assertEquals("Number of trees", 8, m.size()); m.delete(); fr.delete(); } // public static void main(String[] Args) { // kaggle_credit(); // covtype(); // } }