package hex;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import junit.framework.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.*;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.ParseDataset2;
import water.util.Log;
import java.util.Arrays;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
public class DeepLearningReproducibilityTest extends TestUtil {
@BeforeClass public static void stall() {
stall_till_cloudsize(JUnitRunnerDebug.NODES);
}
@Test
public void run() {
long seed = new Random().nextLong();
DeepLearningModel mymodel = null;
Frame train = null;
Frame test = null;
Frame data = null;
Log.info("");
Log.info("STARTING.");
Log.info("Using seed " + seed);
Map<Integer,Float> repeatErrs = new TreeMap<Integer,Float>();
int N = 6;
StringBuilder sb = new StringBuilder();
float repro_error = 0;
for (boolean repro : new boolean[]{true, false}) {
Frame[] preds = new Frame[N];
for (int repeat = 0; repeat < N; ++repeat) {
try {
Key file = NFSFileVec.make(find_test_file("smalldata/weather.csv"));
// Key file = NFSFileVec.make(find_test_file("smalldata/mnist/test.csv.gz"));
data = ParseDataset2.parse(Key.make("data.hex"), new Key[]{file});
// Create holdout test data on clean data (before adding missing values)
FrameSplitter fs = new FrameSplitter(data, new float[]{0.75f});
H2O.submitTask(fs).join();
Frame[] train_test = fs.getResult();
train = train_test[0];
test = train_test[1];
// Build a regularized DL model with polluted training data, score on clean validation set
DeepLearning p;
p = new DeepLearning();
p.source = train;
p.validation = test;
p.response = train.lastVec();
p.ignored_cols = new int[]{1, 22}; //for weather data
p.activation = DeepLearning.Activation.RectifierWithDropout;
p.hidden = new int[]{32, 58};
p.l1 = 1e-5;
p.l2 = 3e-5;
p.seed = 0xbebe;
p.input_dropout_ratio = 0.2;
p.hidden_dropout_ratios = new double[]{0.4, 0.1};
p.epochs = 3.32;
p.quiet_mode = true;
p.reproducible = repro;
try {
Log.info("Starting with #" + repeat);
p.invoke();
} catch (Throwable t) {
t.printStackTrace();
throw new RuntimeException(t);
} finally {
p.delete();
}
// Extract the scoring on validation set from the model
mymodel = UKV.get(p.dest());
preds[repeat] = mymodel.score(test);
repeatErrs.put(repeat, mymodel.error());
} catch (Throwable t) {
t.printStackTrace();
throw new RuntimeException(t);
} finally {
// cleanup
if (mymodel != null) {
mymodel.delete_xval_models();
mymodel.delete_best_model();
mymodel.delete();
}
if (train != null) train.delete();
if (test != null) test.delete();
if (data != null) data.delete();
}
}
sb.append("Reproducibility: " + (repro ? "on" : "off") + "\n");
sb.append("Repeat # --> Validation Error\n");
for (String s : Arrays.toString(repeatErrs.entrySet().toArray()).split(","))
sb.append(s.replace("=", " --> ")).append("\n");
sb.append('\n');
Log.info(sb.toString());
try {
if (repro) {
// check reproducibility
for (Float error : repeatErrs.values()) {
Assert.assertTrue(error.equals(repeatErrs.get(0)));
}
for (Frame f : preds) {
Assert.assertTrue(f.isIdentical(preds[0]));
}
repro_error = repeatErrs.get(0);
} else {
// check standard deviation of non-reproducible mode
double mean = 0;
for (Float error : repeatErrs.values()) {
mean += error;
}
mean /= N;
Log.info("mean error: " + mean);
double stddev = 0;
for (Float error : repeatErrs.values()) {
stddev += (error - mean) * (error - mean);
}
stddev /= N;
stddev = Math.sqrt(stddev);
Log.info("standard deviation: " + stddev);
Assert.assertTrue(stddev < 0.1 / Math.sqrt(N));
Log.info("difference to reproducible mode: " + Math.abs(mean - repro_error) / stddev + " standard deviations");
}
} finally {
for (Frame f : preds) if (f != null) f.delete();
}
}
}
}