package hex.deeplearning;
import hex.FrameSplitter;
import water.TestUtil;
import hex.deeplearning.DeepLearningModel.DeepLearningParameters;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.*;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.parser.ParseDataset;
import water.util.FileUtils;
import water.util.FrameUtils;
import water.util.Log;
import java.util.*;
import static water.util.FrameUtils.generateNumKeys;
public class DeepLearningMissingTest extends TestUtil {
@BeforeClass() public static void setup() { stall_till_cloudsize(1); }
@Test public void run() {
long seed = 1234;
DeepLearningModel mymodel = null;
Frame train = null;
Frame test = null;
Frame data = null;
DeepLearningParameters p;
Log.info("");
Log.info("STARTING.");
Log.info("Using seed " + seed);
Map<DeepLearningParameters.MissingValuesHandling,Double> sumErr = new TreeMap<>();
StringBuilder sb = new StringBuilder();
for (DeepLearningParameters.MissingValuesHandling mvh :
new DeepLearningParameters.MissingValuesHandling[]{
DeepLearningParameters.MissingValuesHandling.MeanImputation,
DeepLearningParameters.MissingValuesHandling.Skip
})
{
double sumloss = 0;
Map<Double,Double> map = new TreeMap<>();
for (double missing_fraction : new double[]{0, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99}) {
double loss =0;
try {
Scope.enter();
NFSFileVec nfs = NFSFileVec.make("smalldata/junit/weather.csv");
data = ParseDataset.parse(Key.make("data.hex"), nfs._key);
Log.info("FrameSplitting");
// Create holdout test data on clean data (before adding missing values)
FrameSplitter fs = new FrameSplitter(data, new double[]{0.75f}, generateNumKeys(data._key,2), null);
H2O.submitTask(fs);//.join();
Frame[] train_test = fs.getResult();
train = train_test[0];
test = train_test[1];
Log.info("Done...");
// add missing values to the training data (excluding the response)
if (missing_fraction > 0) {
Frame frtmp = new Frame(Key.<Frame>make(), train.names(), train.vecs());
frtmp.remove(frtmp.numCols() - 1); //exclude the response
DKV.put(frtmp._key, frtmp); //need to put the frame (to be modified) into DKV for MissingInserter to pick up
FrameUtils.MissingInserter j = new FrameUtils.MissingInserter(frtmp._key, seed, missing_fraction);
j.execImpl().get(); //MissingInserter is non-blocking, must block here explicitly
DKV.remove(frtmp._key); //Delete the frame header (not the data)
}
// Build a regularized DL model with polluted training data, score on clean validation set
p = new DeepLearningParameters();
p._train = train._key;
p._valid = test._key;
p._response_column = train._names[train.numCols()-1];
p._ignored_columns = new String[]{train._names[1],train._names[22]}; //only for weather data
p._missing_values_handling = mvh;
p._loss = DeepLearningParameters.Loss.CrossEntropy; // DeepLearningParameters.Loss.ModifiedHuber;
p._activation = DeepLearningParameters.Activation.Rectifier;
p._hidden = new int[]{50,50};
p._l1 = 1e-5;
p._input_dropout_ratio = 0.2;
p._epochs = 3;
p._reproducible = true;
p._seed = seed;
p._elastic_averaging = false;
// Convert response to categorical
int ri = train.numCols()-1;
int ci = test.find(p._response_column);
Scope.track(train.replace(ri, train.vecs()[ri].toCategoricalVec()));
Scope.track(test .replace(ci, test.vecs()[ci].toCategoricalVec()));
DKV.put(train);
DKV.put(test);
DeepLearning dl = new DeepLearning(p);
Log.info("Starting with " + missing_fraction * 100 + "% missing values added.");
mymodel = dl.trainModel().get();
// Extract the scoring on validation set from the model
loss = mymodel.loss();
Log.info("Missing " + missing_fraction * 100 + "% -> logloss: " + loss);
} catch(Throwable t) {
t.printStackTrace();
loss = 100;
} finally {
Scope.exit();
// cleanup
if (mymodel != null) {
mymodel.delete();
}
if (train != null) train.delete();
if (test != null) test.delete();
if (data != null) data.delete();
}
map.put(missing_fraction, loss);
sumloss += loss;
}
sb.append("\nMethod: ").append(mvh.toString()).append("\n");
sb.append("missing fraction --> loss\n");
for (String s : Arrays.toString(map.entrySet().toArray()).split(",")) sb.append(s.replace("=", " --> ")).append("\n");
sb.append('\n');
sb.append("sum loss: ").append(sumloss).append("\n");
sumErr.put(mvh, sumloss);
}
Log.info(sb.toString());
Assert.assertEquals(405.5017, sumErr.get(DeepLearningParameters.MissingValuesHandling.Skip), 1e-2);
Assert.assertEquals(3.914915, sumErr.get(DeepLearningParameters.MissingValuesHandling.MeanImputation), 1e-3);
}
}