package hex.tree.gbm;
import hex.FrameSplitter;
import hex.ModelMetricsBinomial;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.*;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.parser.ParseDataset;
import water.util.FileUtils;
import water.util.FrameUtils;
import static water.util.FrameUtils.generateNumKeys;
import water.util.Log;
import java.util.Arrays;
import java.util.Map;
import java.util.TreeMap;
public class GBMMissingTest extends TestUtil {
@BeforeClass() public static void setup() { stall_till_cloudsize(1); }
@Test public void run() {
long seed = 1234;
GBMModel mymodel = null;
Frame train = null;
Frame test = null;
Frame data = null;
GBMModel.GBMParameters p;
Log.info("");
Log.info("STARTING.");
Log.info("Using seed " + seed);
StringBuilder sb = new StringBuilder();
double sumerr = 0;
Map<Double,Double> map = new TreeMap<>();
for (double missing_fraction : new double[]{0, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99}) {
double err=0;
try {
Scope.enter();
NFSFileVec nfs = TestUtil.makeNfsFileVec("smalldata/junit/weather.csv");
data = ParseDataset.parse(Key.make("data.hex"), nfs._key);
Log.info("FrameSplitting");
// Create holdout test data on clean data (before adding missing values)
FrameSplitter fs = new FrameSplitter(data, new double[]{0.75f}, generateNumKeys(data._key,2), null);
H2O.submitTask(fs);//.join();
Frame[] train_test = fs.getResult();
train = train_test[0];
test = train_test[1];
Log.info("Done...");
// add missing values to the training data (excluding the response)
if (missing_fraction > 0) {
Frame frtmp = new Frame(Key.<Frame>make(), train.names(), train.vecs());
frtmp.remove(frtmp.numCols() - 1); //exclude the response
DKV.put(frtmp._key, frtmp); //need to put the frame (to be modified) into DKV for MissingInserter to pick up
FrameUtils.MissingInserter j = new FrameUtils.MissingInserter(frtmp._key, seed, missing_fraction);
j.execImpl().get(); //MissingInserter is non-blocking, must block here explicitly
DKV.remove(frtmp._key); //Delete the frame header (not the data)
}
// Build a regularized GBM model with polluted training data, score on clean validation set
p = new GBMModel.GBMParameters();
p._train = train._key;
p._valid = test._key;
p._response_column = train._names[train.numCols()-1];
p._ignored_columns = new String[]{train._names[1],train._names[22]}; //only for weather data
p._seed = seed;
// Convert response to categorical
int ri = train.numCols()-1;
int ci = test.find(p._response_column);
Scope.track(train.replace(ri, train.vecs()[ri].toCategoricalVec()));
Scope.track(test .replace(ci, test.vecs()[ci].toCategoricalVec()));
DKV.put(train);
DKV.put(test);
GBM gbm = new GBM(p);
Log.info("Starting with " + missing_fraction * 100 + "% missing values added.");
mymodel = gbm.trainModel().get();
// Extract the scoring on validation set from the model
err = ((ModelMetricsBinomial)mymodel._output._validation_metrics).logloss();
Frame train_preds = mymodel.score(train);
Assert.assertTrue(mymodel.testJavaScoring(train, train_preds, 1e-15));
train_preds.remove();
Log.info("Missing " + missing_fraction * 100 + "% -> logloss: " + err);
} catch(Throwable t) {
t.printStackTrace();
err = 100;
} finally {
Scope.exit();
// cleanup
if (mymodel != null) {
mymodel.delete();
}
if (train != null) train.delete();
if (test != null) test.delete();
if (data != null) data.delete();
}
map.put(missing_fraction, err);
sumerr += err;
}
sb.append("missing fraction --> Error\n");
for (String s : Arrays.toString(map.entrySet().toArray()).split(",")) sb.append(s.replace("=", " --> ")).append("\n");
sb.append('\n');
sb.append("Sum Err: ").append(sumerr).append("\n");
Log.info(sb.toString());
}
}