package hex;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.JUnitRunnerDebug;
import water.Key;
import water.TestUtil;
import water.UKV;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.ParseDataset2;
import water.fvec.Vec;
import water.util.Log;
public class DeepLearningAutoEncoderCategoricalTest extends TestUtil {
static final String PATH = "smalldata/airlines/AirlinesTrain.csv.zip";
@BeforeClass public static void stall() {
stall_till_cloudsize(JUnitRunnerDebug.NODES);
}
@Test
public void run() {
long seed = 0xDECAF;
Key file_train = NFSFileVec.make(find_test_file(PATH));
Frame train = ParseDataset2.parse(Key.make(), new Key[]{file_train});
DeepLearning p = new DeepLearning();
p.source = train;
p.autoencoder = true;
p.response = train.lastVec();
p.seed = seed;
p.hidden = new int[]{100, 50, 20};
// p.ignored_cols = new int[]{0,1,2,3,6,7,8,10}; //Optional: ignore all categoricals
// p.ignored_cols = new int[]{4,5,9}; //Optional: ignore all numericals
p.adaptive_rate = true;
p.l1 = 1e-4;
p.activation = DeepLearning.Activation.Tanh;
p.train_samples_per_iteration = -1;
p.loss = DeepLearning.Loss.MeanSquare;
p.epochs = 2;
// p.shuffle_training_data = true;
p.force_load_balance = true;
p.score_training_samples = 0;
p.score_validation_samples = 0;
// p.reproducible = true;
p.invoke();
// Verification of results
StringBuilder sb = new StringBuilder();
sb.append("Verifying results.\n");
DeepLearningModel mymodel = UKV.get(p.dest());
sb.append("Reported mean reconstruction error: " + mymodel.mse() + "\n");
// Training data
// Reconstruct data using the same helper functions and verify that self-reported MSE agrees
final Frame l2 = mymodel.scoreAutoEncoder(train);
final Vec l2vec = l2.anyVec();
sb.append("Actual mean reconstruction error: " + l2vec.mean() + "\n");
// print stats and potential outliers
double quantile = 1 - 5. / train.numRows();
sb.append("The following training points are reconstructed with an error above the "
+ quantile * 100 + "-th percentile - potential \"outliers\" in testing data.\n");
double thresh = mymodel.calcOutlierThreshold(l2vec, quantile);
for (long i = 0; i < l2vec.length(); i++) {
if (l2vec.at(i) > thresh) {
sb.append(String.format("row %d : l2vec error = %5f\n", i, l2vec.at(i)));
}
}
Log.info(sb.toString());
Assert.assertEquals(mymodel.mse(), l2vec.mean(), 1e-8);
// Create reconstruction
Log.info("Creating full reconstruction.");
final Frame recon_train = mymodel.score(train);
// cleanup
recon_train.delete();
train.delete();
p.delete();
mymodel.delete();
l2.delete();
}
}