package hex;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.*;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.ParseDataset2;
import water.util.Log;
import java.util.*;
public class DeepLearningMissingTest extends TestUtil {
@BeforeClass public static void stall() {
stall_till_cloudsize(JUnitRunnerDebug.NODES);
}
@Test
public void run() {
long seed = new Random().nextLong();
DeepLearningModel mymodel = null;
Frame train = null;
Frame test = null;
Frame data = null;
DeepLearning p;
Log.info("");
Log.info("STARTING.");
Log.info("Using seed " + seed);
Map<DeepLearning.MissingValuesHandling,Double> sumErr = new TreeMap<DeepLearning.MissingValuesHandling,Double>();
StringBuilder sb = new StringBuilder();
for (DeepLearning.MissingValuesHandling mvh : new DeepLearning.MissingValuesHandling[]{
DeepLearning.MissingValuesHandling.Skip, DeepLearning.MissingValuesHandling.MeanImputation
}) {
double sumerr = 0;
Map<Double,Double> map = new TreeMap<Double, Double>();
for (double missing_fraction : new double[]{0, 0.1, 0.25, 0.5, 0.75, 1}) {
try {
Key file = NFSFileVec.make(find_test_file("smalldata/weather.csv"));
// Key file = NFSFileVec.make(find_test_file("smalldata/mnist/test.csv.gz"));
data = ParseDataset2.parse(Key.make("data.hex"), new Key[]{file});
// Create holdout test data on clean data (before adding missing values)
FrameSplitter fs = new FrameSplitter(data, new float[]{0.75f});
H2O.submitTask(fs).join();
Frame[] train_test = fs.getResult();
train = train_test[0];
test = train_test[1];
// add missing values to the training data (excluding the response)
if (missing_fraction > 0) {
Frame frtmp = new Frame(Key.make(), train.names(), train.vecs());
frtmp.remove(frtmp.numCols() - 1); //exclude the response
DKV.put(frtmp._key, frtmp);
InsertMissingValues imv = new InsertMissingValues();
imv.missing_fraction = missing_fraction;
imv.seed = seed; //use the same seed for Skip and MeanImputation!
imv.key = frtmp._key;
imv.serve();
DKV.remove(frtmp._key); //just remove the Frame header (not the chunks)
}
// Build a regularized DL model with polluted training data, score on clean validation set
p = new DeepLearning();
p.source = train;
p.validation = test;
p.response = train.lastVec();
p.ignored_cols = new int[]{1,22}; //only for weather data
p.missing_values_handling = mvh;
p.activation = DeepLearning.Activation.RectifierWithDropout;
p.hidden = new int[]{200,200};
p.l1 = 1e-5;
p.input_dropout_ratio = 0.2;
p.epochs = 10;
p.quiet_mode = true;
try {
Log.info("Starting with " + missing_fraction * 100 + "% missing values added.");
p.invoke();
} catch(Throwable t) {
t.printStackTrace();
throw new RuntimeException(t);
} finally {
p.delete();
}
// Extract the scoring on validation set from the model
mymodel = UKV.get(p.dest());
DeepLearningModel.Errors[] errs = mymodel.scoring_history();
DeepLearningModel.Errors lasterr = errs[errs.length-1];
double err = lasterr.valid_err;
Log.info("Missing " + missing_fraction * 100 + "% -> Err: " + err);
map.put(missing_fraction, err);
sumerr += err;
} catch(Throwable t) {
t.printStackTrace();
throw new RuntimeException(t);
} finally {
// cleanup
if (mymodel != null) {
mymodel.delete_xval_models();
mymodel.delete_best_model();
mymodel.delete();
}
if (train != null) train.delete();
if (test != null) test.delete();
if (data != null) data.delete();
}
}
sb.append("\nMethod: " + mvh.toString() + "\n");
sb.append("missing fraction --> Error\n");
for (String s : Arrays.toString(map.entrySet().toArray()).split(",")) sb.append(s.replace("=", " --> ")).append("\n");
sb.append('\n');
sb.append("Sum Err: " + sumerr + "\n");
sumErr.put(mvh, sumerr);
}
Log.info(sb.toString());
Assert.assertTrue(sumErr.get(DeepLearning.MissingValuesHandling.Skip) > sumErr.get(DeepLearning.MissingValuesHandling.MeanImputation));
Assert.assertTrue(sumErr.get(DeepLearning.MissingValuesHandling.MeanImputation) < 2); //this holds true for both datasets
}
}