package hex.deeplearning; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; import water.*; import water.fvec.Frame; import water.fvec.NFSFileVec; import water.parser.ParseDataset; import water.util.FileUtils; import water.util.Log; import hex.deeplearning.DeepLearningModel.DeepLearningParameters; import java.io.File; /** * Simple Deep Neural Network on MNIST * Note: requires './gradlew syncBigDataLaptop' * * 7 hours on i7 5820k to get to 0.91% test set error (or wait longer to get to world-record test set error: 0.83%) * * Duration Training Speed Epochs Samples Training MSE Training R^2 Training LogLoss Training Classification Error Validation MSE Validation R^2 Validation LogLoss Validation Classification Error * 6:59:29.885 2384.428 rows/sec 1000.26288 60015771 0.00005 0.99999 0.00016 0.00010 0.00823 0.99902 0.05626 0.00910 * INFO: Confusion Matrix (vertical: actual; across: predicted): * INFO: 0 1 2 3 4 5 6 7 8 9 Error Rate * INFO: 0 975 0 1 0 0 0 1 1 2 0 0.0051 = 5 / 980 * INFO: 1 0 1131 0 1 0 0 2 0 1 0 0.0035 = 4 / 1,135 * INFO: 2 0 0 1025 0 1 0 1 3 2 0 0.0068 = 7 / 1,032 * INFO: 3 0 0 1 1004 0 0 0 3 2 0 0.0059 = 6 / 1,010 * INFO: 4 1 0 0 0 970 0 5 0 0 6 0.0122 = 12 / 982 * INFO: 5 2 0 0 2 0 885 2 0 1 0 0.0078 = 7 / 892 * INFO: 6 3 3 0 0 1 2 948 0 1 0 0.0104 = 10 / 958 * INFO: 7 1 2 5 0 0 0 0 1017 1 2 0.0107 = 11 / 1,028 * INFO: 8 1 0 0 2 0 3 0 1 963 4 0.0113 = 11 / 974 * INFO: 9 1 2 0 3 6 2 0 4 0 991 0.0178 = 18 / 1,009 * INFO: Totals 984 1138 1032 1012 978 892 959 1029 973 1003 0.0091 = 91 / 10,000 * INFO: Top-10 Hit Ratios: * INFO: K Hit Ratio * INFO: 1 0.990900 * INFO: 2 0.998100 * INFO: 3 0.999200 * INFO: 4 0.999500 * INFO: 5 0.999900 * INFO: 6 0.999900 * INFO: 7 1.000000 * INFO: 8 1.000000 * INFO: 9 1.000000 * INFO: 10 1.000000 */ public class DeepLearningMNIST extends TestUtil { @BeforeClass() public static void setup() { stall_till_cloudsize(1); } @Test @Ignore public void run() { Scope.enter(); Frame frame=null; Frame vframe=null; try { File file = FileUtils.locateFile("bigdata/laptop/mnist/train.csv.gz"); File valid = FileUtils.locateFile("bigdata/laptop/mnist/test.csv.gz"); if (file != null) { NFSFileVec trainfv = NFSFileVec.make(file); frame = ParseDataset.parse(Key.make(), trainfv._key); NFSFileVec validfv = NFSFileVec.make(valid); vframe = ParseDataset.parse(Key.make(), validfv._key); DeepLearningParameters p = new DeepLearningParameters(); // populate model parameters p._train = frame._key; p._valid = vframe._key; p._response_column = "C785"; // last column is the response p._activation = DeepLearningParameters.Activation.RectifierWithDropout; // p._activation = DeepLearningParameters.Activation.MaxoutWithDropout; p._hidden = new int[]{128,128,128}; p._input_dropout_ratio = 0.0; p._score_training_samples = 0; p._adaptive_rate = false; p._rate = 0.005; p._rate_annealing = 0; p._momentum_start = 0; p._momentum_stable = 0; p._mini_batch_size = 1; p._train_samples_per_iteration = -1; // p._score_duty_cycle = 0.1; p._shuffle_training_data = true; // p._reproducible = true; // p._l1= 1e-5; p._max_w2= 1; p._epochs = 20; //1000*10*5./6; p._sparse = true; //faster as activations remain sparse // Convert response 'C785' to categorical (digits 1 to 10) int ci = frame.find("C785"); Scope.track(frame.replace(ci, frame.vecs()[ci].toCategoricalVec())); Scope.track(vframe.replace(ci, vframe.vecs()[ci].toCategoricalVec())); DKV.put(frame); DKV.put(vframe); // speed up training // p._adaptive_rate = true; //disable adaptive per-weight learning rate -> default settings for learning rate and momentum are probably not ideal (slow convergence) p._replicate_training_data = true; //avoid extra communication cost upfront, got enough data on each node for load balancing p._overwrite_with_best_model = true; //no need to keep the best model around p._classification_stop = -1; // p._score_interval = 5; //score and print progress report (only) every 20 seconds p._score_training_samples = 10000; //only score on a small sample of the training set -> don't want to spend too much time scoring (note: there will be at least 1 row per chunk) DeepLearning dl = new DeepLearning(p,Key.<DeepLearningModel>make("dl_mnist_model")); DeepLearningModel model = dl.trainModel().get(); if (model != null) model.delete(); } else { Log.info("Please run ./gradlew syncBigDataLaptop in the top-level directory of h2o-3."); } } finally { Scope.exit(); if (vframe!=null) vframe.remove(); if (frame!=null) frame.remove(); } } }