package hex; import hex.deeplearning.*; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import water.*; import water.fvec.*; import water.util.Log; import java.util.HashSet; public class DeepLearningAutoEncoderTest extends TestUtil { /* Visualize outliers with the following R code (from smalldata/anomaly dir): train <- scan("ecg_discord_train.csv", sep=",") test <- scan("ecg_discord_test.csv", sep=",") plot.ts(train) plot.ts(test) */ static final String PATH = "smalldata/anomaly/ecg_discord_train.csv"; //first 20 points static final String PATH2 = "smalldata/anomaly/ecg_discord_test.csv"; //first 22 points @BeforeClass public static void stall() { stall_till_cloudsize(JUnitRunnerDebug.NODES); } @Test public void run() { long seed = 0xDECAF; Frame train=null, test=null; try { Key file_train = NFSFileVec.make(find_test_file(PATH)); train = ParseDataset2.parse(Key.make(), new Key[]{file_train}); Key file_test = NFSFileVec.make(find_test_file(PATH2)); test = ParseDataset2.parse(Key.make(), new Key[]{file_test}); for (float sparsity_beta : new float[]{0, 0.1f}) { DeepLearning p = new DeepLearning(); p.source = train; p.autoencoder = true; p.response = train.lastVec(); p.classification = false; p.seed = seed; p.hidden = new int[]{100, 100}; p.adaptive_rate = true; p.train_samples_per_iteration = -1; p.sparsity_beta = sparsity_beta; p.average_activation = -0.7; p.l1 = 1e-4; // p.l2 = 1e-4; // p.rate = 1e-5; p.activation = DeepLearning.Activation.Tanh; p.loss = DeepLearning.Loss.MeanSquare; // p.initial_weight_distribution = DeepLearning.InitialWeightDistribution.Normal; // p.initial_weight_scale = 1e-3; p.epochs = 500; // p.shuffle_training_data = true; p.force_load_balance = false; //if enabled, Hogwild gets ugly on many cores p.invoke(); DeepLearningModel mymodel = UKV.get(p.dest()); Frame l2_frame_train=null, l2_frame_test=null; // Verification of results StringBuilder sb = new StringBuilder(); try { sb.append("Verifying results.\n"); // Training data // Reconstruct data using the same helper functions and verify that self-reported MSE agrees double quantile = 0.95; l2_frame_train = mymodel.scoreAutoEncoder(train); final Vec l2_train = l2_frame_train.anyVec(); sb.append("Mean reconstruction error: " + l2_train.mean() + "\n"); Assert.assertEquals(mymodel.mse(), l2_train.mean(), 1e-7); Assert.assertTrue("too big a reconstruction error: " + l2_train.mean(), l2_train.mean() < 0.06); // manually compute L2 Frame reconstr = mymodel.score(train); //this creates real values in original space double mean_l2 = 0; for (int r = 0; r < reconstr.numRows(); ++r) { double my_l2 = 0; for (int c = 0; c < reconstr.numCols(); ++c) { my_l2 += Math.pow((reconstr.vec(c).at(r) - train.vec(c).at(r)) * mymodel.model_info().data_info()._normMul[c], 2); //undo normalization here } my_l2 /= reconstr.numCols(); mean_l2 += my_l2; } mean_l2 /= reconstr.numRows(); reconstr.delete(); sb.append("Mean reconstruction error (train): " + l2_train.mean() + "\n"); Assert.assertEquals(mymodel.mse(), mean_l2, 1e-7); // print stats and potential outliers sb.append("The following training points are reconstructed with an error above the " + quantile * 100 + "-th percentile - check for \"goodness\" of training data.\n"); double thresh_train = mymodel.calcOutlierThreshold(l2_train, quantile); for (long i = 0; i < l2_train.length(); i++) { if (l2_train.at(i) > thresh_train) { sb.append(String.format("row %d : l2_train error = %5f\n", i, l2_train.at(i))); } } // Test data // Reconstruct data using the same helper functions and verify that self-reported MSE agrees l2_frame_test = mymodel.scoreAutoEncoder(test); final Vec l2_test = l2_frame_test.anyVec(); double mult = 10; double thresh_test = mult * thresh_train; sb.append("\nFinding outliers.\n"); sb.append("Mean reconstruction error (test): " + l2_test.mean() + "\n"); // print stats and potential outliers sb.append("The following test points are reconstructed with an error greater than " + mult + " times the mean reconstruction error of the training data:\n"); HashSet<Long> outliers = new HashSet<Long>(); for (long i = 0; i < l2_test.length(); i++) { if (l2_test.at(i) > thresh_test) { outliers.add(i); sb.append(String.format("row %d : l2 error = %5f\n", i, l2_test.at(i))); } } // check that the all outliers are found (and nothing else) Assert.assertTrue(outliers.contains(new Long(20))); Assert.assertTrue(outliers.contains(new Long(21))); Assert.assertTrue(outliers.contains(new Long(22))); Assert.assertTrue(outliers.size() == 3); } finally { Log.info(sb); // cleanup if (p!=null) p.delete(); if (mymodel!=null) mymodel.delete(); if (l2_frame_train!=null) l2_frame_train.delete(); if (l2_frame_test!=null) l2_frame_test.delete(); } } } finally { if (train!=null) train.delete(); if (test!=null) test.delete(); } } }