package hex.deeplearning; import java.util.*; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import water.*; import water.fvec.Frame; import water.fvec.NFSFileVec; import water.parser.ParseDataset; import water.util.FileUtils; import water.util.FrameUtils; import water.util.Log; import hex.deeplearning.DeepLearningModel.DeepLearningParameters; import static org.junit.Assert.assertTrue; public class DeepLearningReproducibilityTest extends TestUtil { @BeforeClass() public static void setup() { stall_till_cloudsize(1); } @Test public void run() { NFSFileVec ff = TestUtil.makeNfsFileVec("smalldata/junit/weather.csv"); Frame golden = ParseDataset.parse(Key.make("golden.hex"), ff._key); DeepLearningModel mymodel = null; Frame train = null; Frame test = null; Frame data = null; Map<Integer,Float> repeatErrs = new TreeMap<>(); int N = 3; StringBuilder sb = new StringBuilder(); float repro_error = 0; for (boolean repro : new boolean[]{true, false}) { Scope.enter(); Frame[] preds = new Frame[N]; long[] checksums = new long[N]; double[] numbers = new double[N]; for (int repeat = 0; repeat < N; ++repeat) { try { NFSFileVec file = TestUtil.makeNfsFileVec("smalldata/junit/weather.csv"); data = ParseDataset.parse(Key.make("data.hex"), file._key); Assert.assertTrue(TestUtil.isBitIdentical(data, golden)); //test parser consistency // Create holdout test data on clean data (before adding missing values) train = data; test = data; // Build a regularized DL model with polluted training data, score on clean validation set DeepLearningParameters p = new DeepLearningParameters(); p._train = train._key; p._valid = test._key; p._response_column = train.names()[train.names().length-1]; int ci = train.names().length-1; Scope.track(train.replace(ci, train.vecs()[ci].toCategoricalVec())); DKV.put(train); p._ignored_columns = new String[]{"EvapMM", "RISK_MM"}; //for weather data p._activation = DeepLearningParameters.Activation.RectifierWithDropout; p._hidden = new int[]{32, 58}; p._l1 = 1e-5; p._l2 = 3e-5; p._seed = 0xbebe; p._loss = DeepLearningParameters.Loss.CrossEntropy; p._input_dropout_ratio = 0.2; p._train_samples_per_iteration = 3; p._hidden_dropout_ratios = new double[]{0.4, 0.1}; p._epochs = 1.32; // p._nfolds = 2; p._quiet_mode = true; p._reproducible = repro; DeepLearning dl = new DeepLearning(p); mymodel = dl.trainModel().get(); // Extract the scoring on validation set from the model preds[repeat] = mymodel.score(test); for (int i=0; i<5; ++i) { Frame tmp = mymodel.score(test); Assert.assertTrue("Prediction #" + i + " for repeat #" + repeat + " differs!", TestUtil.isBitIdentical(preds[repeat],tmp)); tmp.delete(); } Log.info("Prediction:\n" + FrameUtils.chunkSummary(preds[repeat]).toString()); numbers[repeat] = mymodel.model_info().get_weights(0).get(23,4); checksums[repeat] = mymodel.model_info().checksum_impl(); //check that the model state is consistent repeatErrs.put(repeat, mymodel.loss()); } finally { // cleanup if (mymodel != null) { mymodel.delete(); } if (train != null) train.delete(); if (test != null) test.delete(); if (data != null) data.delete(); } } sb.append("Reproducibility: ").append(repro ? "on" : "off").append("\n"); sb.append("Repeat # --> Validation Loss\n"); for (String s : Arrays.toString(repeatErrs.entrySet().toArray()).split(",")) sb.append(s.replace("=", " --> ")).append("\n"); sb.append('\n'); Log.info(sb.toString()); try { if (repro) { // check reproducibility for (double error : numbers) assertTrue(Arrays.toString(numbers), error == numbers[0]); for (Float error : repeatErrs.values()) assertTrue(error.equals(repeatErrs.get(0))); for (long cs : checksums) assertTrue(cs == checksums[0]); for (Frame f : preds) { // assertTrue(TestUtil.isBitIdentical(f, preds[0])); // PUBDEV-892: This should have passed all the time for (int i=0; i<f.vecs().length; ++i) { TestUtil.assertVecEquals(f.vecs()[i], preds[0].vecs()[i], 1e-5); //PUBDEV-892: This tolerance should be 1e-15 } } repro_error = repeatErrs.get(0); } else { // check standard deviation of non-reproducible mode double mean = 0; for (Float error : repeatErrs.values()) { mean += error; } mean /= N; // check non-reproducibility (Hogwild! will never reproduce) for (int i=1; i<N; ++i) assertTrue(repeatErrs.get(i) != repeatErrs.get(0)); Log.info("mean error: " + mean); double stddev = 0; for (Float error : repeatErrs.values()) { stddev += (error - mean) * (error - mean); } stddev /= N; stddev = Math.sqrt(stddev); Log.info("standard deviation: " + stddev); // assertTrue(stddev < 0.3 / Math.sqrt(N)); Log.info("difference to reproducible mode: " + Math.abs(mean - repro_error) / stddev + " standard deviations"); } } finally { for (Frame f : preds) if (f != null) f.delete(); } Scope.exit(); } golden.delete(); } }