package hex; import hex.deeplearning.DeepLearning; import hex.deeplearning.DeepLearningModel; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; import water.*; import water.api.AUC; import water.api.AUCData; import water.exec.Env; import water.exec.Exec2; import water.fvec.Frame; import water.fvec.NFSFileVec; import water.fvec.ParseDataset2; import water.util.Log; import java.util.Random; public class DeepLearningProstateTest extends TestUtil { @BeforeClass public static void stall() { stall_till_cloudsize(JUnitRunnerDebug.NODES); } public void runFraction(float fraction) { long seed = 0xDECAF; Random rng = new Random(seed); String[] datasets = new String[2]; int[][] responses = new int[datasets.length][]; datasets[0] = "smalldata/./logreg/prostate.csv"; responses[0] = new int[]{1,2,8}; //CAPSULE (binomial), AGE (regression), GLEASON (multi-class) datasets[1] = "smalldata/iris/iris.csv"; responses[1] = new int[]{4}; //Iris-type (multi-class) int testcount = 0; int count = 0; for (int i =0;i<datasets.length;++i) { String dataset = datasets[i]; Key file = NFSFileVec.make(find_test_file(dataset)); Frame frame = ParseDataset2.parse(Key.make(), new Key[]{file}); Key vfile = NFSFileVec.make(find_test_file(dataset)); Frame vframe = ParseDataset2.parse(Key.make(), new Key[]{vfile}); try { for (boolean replicate : new boolean[]{ true, false, }) { for (boolean load_balance : new boolean[]{ true, false, }) { for (boolean shuffle : new boolean[]{ true, false, }) { for (boolean balance_classes : new boolean[]{ true, false, }) { for (int resp : responses[i]) { for (DeepLearning.ClassSamplingMethod csm : new DeepLearning.ClassSamplingMethod[]{ DeepLearning.ClassSamplingMethod.Stratified, DeepLearning.ClassSamplingMethod.Uniform }) { for (int scoretraining : new int[]{ 200, 20, 0, }) { for (int scorevalidation : new int[]{ 200, 20, 0, }) { for (int vf : new int[]{ 0, //no validation 1, //same as source -1, //different validation frame }) { for (int n_folds : new int[]{ 0, 2, }) { if (n_folds != 0 && vf != 0) continue; for (boolean keep_cv_splits : new boolean[]{false}) { //otherwise it leaks for (boolean override_with_best_model : new boolean[]{false, true}) { for (int train_samples_per_iteration : new int[]{ -2, //auto-tune -1, //N epochs per iteration 0, //1 epoch per iteration rng.nextInt(200), // <1 epoch per iteration 500, //>1 epoch per iteration }) { DeepLearningModel model1 = null, model2 = null; Key dest = null, dest_tmp = null; count++; if (fraction < rng.nextFloat()) continue; try { Log.info("**************************)"); Log.info("Starting test #" + count); Log.info("**************************)"); final double epochs = 7 + rng.nextDouble() + rng.nextInt(4); final int[] hidden = new int[]{1 + rng.nextInt(4), 1 + rng.nextInt(6)}; Frame valid = null; //no validation if (vf == 1) valid = frame; //use the same frame for validation else if (vf == -1) valid = vframe; //different validation frame (here: from the same file) // build the model, with all kinds of shuffling/rebalancing/sampling dest_tmp = Key.make("first"); { Log.info("Using seed: " + seed); DeepLearning p = new DeepLearning(); p.checkpoint = null; p.destination_key = dest_tmp; p.source = frame; p.response = frame.vecs()[resp]; p.validation = valid; p.hidden = hidden; if (i == 0 && resp == 2) p.classification = false; // p.best_model_key = best_model_key; p.override_with_best_model = override_with_best_model; p.epochs = epochs; p.n_folds = n_folds; p.keep_cross_validation_splits = keep_cv_splits; p.seed = seed; p.train_samples_per_iteration = train_samples_per_iteration; p.force_load_balance = load_balance; p.replicate_training_data = replicate; p.shuffle_training_data = shuffle; p.score_training_samples = scoretraining; p.score_validation_samples = scorevalidation; p.classification_stop = -1; p.regression_stop = -1; p.balance_classes = balance_classes; p.quiet_mode = true; p.score_validation_sampling = csm; try { p.invoke(); } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } finally { p.delete(); } model1 = UKV.get(dest_tmp); assert( ((p.train_samples_per_iteration <= 0 || p.train_samples_per_iteration >= frame.numRows()) && model1.epoch_counter > epochs) || Math.abs(model1.epoch_counter - epochs)/epochs < 0.20 ); if (n_folds != 0) // test HTML of cv models { for (Key k : model1.get_params().xval_models) { DeepLearningModel cv_model = UKV.get(k); StringBuilder sb = new StringBuilder(); cv_model.generateHTML("cv", sb); cv_model.delete_best_model(); cv_model.delete(); } } } // Do some more training via checkpoint restart // For n_folds, continue without n_folds (not yet implemented) - from now on, model2 will have n_folds=0... dest = Key.make("restart"); DeepLearning p = new DeepLearning(); final DeepLearningModel tmp_model = UKV.get(dest_tmp); //this actually *requires* frame to also still be in UKV (because of DataInfo...) Assert.assertTrue(tmp_model.get_params().state == Job.JobState.DONE); //HEX-1817 Assert.assertTrue(tmp_model.model_info().get_processed_total() >= frame.numRows() * epochs); assert (tmp_model != null); p.checkpoint = dest_tmp; p.destination_key = dest; p.n_folds = 0; p.source = frame; p.validation = valid; p.response = frame.vecs()[resp]; if (i == 0 && resp == 2) p.classification = false; p.override_with_best_model = override_with_best_model; p.epochs = epochs; p.seed = seed; p.train_samples_per_iteration = train_samples_per_iteration; try { p.invoke(); } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } finally { p.delete(); } // score and check result (on full data) model2 = UKV.get(dest); //this actually *requires* frame to also still be in UKV (because of DataInfo...) Assert.assertTrue(model2.get_params().state == Job.JobState.DONE); //HEX-1817 // test HTML { StringBuilder sb = new StringBuilder(); model2.generateHTML("test", sb); } // score and check result of the best_model if (model2.actual_best_model_key != null) { final DeepLearningModel best_model = UKV.get(model2.actual_best_model_key); Assert.assertTrue(best_model.get_params().state == Job.JobState.DONE); //HEX-1817 // test HTML { StringBuilder sb = new StringBuilder(); best_model.generateHTML("test", sb); } if (override_with_best_model) { Assert.assertEquals(best_model.error(), model2.error(), 0); } } if (valid == null) valid = frame; double threshold = 0; if (model2.isClassifier()) { Frame pred = null, pred2 = null; try { pred = model2.score(valid); StringBuilder sb = new StringBuilder(); AUC auc = new AUC(); double error = 0; // binary if (model2.nclasses() == 2) { auc.actual = valid; assert (resp == 1); auc.vactual = valid.vecs()[resp]; auc.predict = pred; auc.vpredict = pred.vecs()[2]; auc.invoke(); auc.toASCII(sb); AUCData aucd = auc.data(); threshold = aucd.threshold(); error = aucd.err(); Log.info(sb); // check that auc.cm() is the right CM Assert.assertEquals(new ConfusionMatrix(aucd.cm()).err(), error, 1e-15); // check that calcError() is consistent as well (for CM=null, AUC!=null) Assert.assertEquals(model2.calcError(valid, auc.vactual, pred, pred, "training", false, 0, null, auc, null), error, 1e-15); } // Compute CM double CMerrorOrig; { sb = new StringBuilder(); water.api.ConfusionMatrix CM = new water.api.ConfusionMatrix(); CM.actual = valid; CM.vactual = valid.vecs()[resp]; CM.predict = pred; CM.vpredict = pred.vecs()[0]; CM.invoke(); sb.append("\n"); sb.append("Threshold: " + "default\n"); CM.toASCII(sb); Log.info(sb); CMerrorOrig = new ConfusionMatrix(CM.cm).err(); } // confirm that orig CM was made with threshold 0.5 // put pred2 into UKV, and allow access pred2 = new Frame(Key.make("pred2"), pred.names(), pred.vecs()); pred2.delete_and_lock(null); pred2.unlock(null); if (model2.nclasses() == 2) { // make labels with 0.5 threshold for binary classifier Env ev = Exec2.exec("pred2[,1]=pred2[,3]>=" + 0.5); try { pred2 = ev.popAry(); String skey = ev.key(); ev.subRef(pred2, skey); } finally { if (ev!=null) ev.remove_and_unlock(); } water.api.ConfusionMatrix CM = new water.api.ConfusionMatrix(); CM.actual = valid; CM.vactual = valid.vecs()[1]; CM.predict = pred2; CM.vpredict = pred2.vecs()[0]; CM.invoke(); sb = new StringBuilder(); sb.append("\n"); sb.append("Threshold: " + 0.5 + "\n"); CM.toASCII(sb); Log.info(sb); double threshErr = new ConfusionMatrix(CM.cm).err(); Assert.assertEquals(threshErr, CMerrorOrig, 1e-15); // make labels with AUC-given threshold for best F1 ev = Exec2.exec("pred2[,1]=pred2[,3]>=" + threshold); try { pred2 = ev.popAry(); String skey = ev.key(); ev.subRef(pred2, skey); } finally { if (ev != null) ev.remove_and_unlock(); } CM = new water.api.ConfusionMatrix(); CM.actual = valid; CM.vactual = valid.vecs()[1]; CM.predict = pred2; CM.vpredict = pred2.vecs()[0]; CM.invoke(); sb = new StringBuilder(); sb.append("\n"); sb.append("Threshold: ").append(threshold).append("\n"); CM.toASCII(sb); Log.info(sb); double threshErr2 = new ConfusionMatrix(CM.cm).err(); Assert.assertEquals(threshErr2, error, 1e-15); } } finally { if (pred != null) pred.delete(); if (pred2 != null) pred2.delete(); } } //classifier Log.info("Parameters combination " + count + ": PASS"); testcount++; } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } finally { if (model1 != null) { model1.delete_xval_models(); model1.delete_best_model(); model1.delete(); } if (model2 != null) { model2.delete_xval_models(); model2.delete_best_model(); model2.delete(); } } } } } } } } } } } } } } } } finally { frame.delete(); vframe.delete(); } } Log.info("\n\n============================================="); Log.info("Tested " + testcount + " out of " + count + " parameter combinations."); Log.info("============================================="); } public static class Long extends DeepLearningProstateTest { @Test @Ignore public void run() throws Exception { runFraction(1f); } } public static class Mid extends DeepLearningProstateTest { @Test public void run() throws Exception { runFraction(0.01f); } //for nightly tests } public static class Short extends DeepLearningProstateTest { @Test public void run() throws Exception { runFraction(0.001f); } } }