package hex;
import org.junit.*;
import water.*;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
public class ModelAdaptTest extends TestUtil {
@BeforeClass public static void stall() { stall_till_cloudsize(1); }
// Private junk model class to test Adaption logic
private static class AModel extends Model {
AModel( Key key, Parameters p, Output o ) { super(key,p,o); }
@Override protected double[] score0(double data[/*ncols*/], double preds[/*nclasses+1*/]) { throw H2O.unimpl(); }
@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) { throw H2O.unimpl(); }
static class AParms extends Model.Parameters {
public String algoName() { return "A"; }
public String fullName() { return "A"; }
public String javaName() { return AModel.class.getName(); }
@Override public long progressUnits() { return 0; }
}
static class AOutput extends Model.Output { }
}
@Test public void testModelAdaptMultinomial() {
Frame trn = parse_test_file("smalldata/junit/mixcat_train.csv");
AModel.AParms p = new AModel.AParms();
AModel.AOutput o = new AModel.AOutput();
o.setNames(trn.names());
o._domains = trn.domains();
trn.remove();
AModel am = new AModel(Key.make(),p,o);
Frame tst = parse_test_file("smalldata/junit/mixcat_test.csv");
Frame adapt = new Frame(tst);
String[] warns = am.adaptTestForTrain(adapt,true, true);
Assert.assertTrue(ArrayUtils.find(warns,"Test/Validation dataset column 'Feature_1' has levels not trained on: [D]")!= -1);
Assert.assertTrue(ArrayUtils.find(warns, "Test/Validation dataset is missing column 'Const': substituting in a column of NaN") != -1);
Assert.assertTrue(ArrayUtils.find(warns, "Test/Validation dataset is missing column 'Useless': substituting in a column of NaN") != -1);
Assert.assertTrue(ArrayUtils.find(warns, "Test/Validation dataset column 'Response' has levels not trained on: [W]") != -1);
// Feature_1: merged test & train domains
Assert.assertArrayEquals(adapt.vec("Feature_1").domain(),new String[]{"A","B","C","D"});
// Const: all NAs
Assert.assertTrue(adapt.vec("Const").isBad());
// Useless: all NAs
Assert.assertTrue(adapt.vec("Useless").isBad());
// Response: merged test & train domains
Assert.assertArrayEquals(adapt.vec("Response").domain(),new String[]{"X","Y","Z","W"});
Model.cleanup_adapt(adapt, tst );
tst.remove();
}
// If the train set has a categorical, and the test set column is all missing
// then by-default it is treated as a numeric column (no domain). Verify that
// we make an empty domain mapping
@Test public void testModelAdaptMissing() {
AModel.AParms p = new AModel.AParms();
AModel.AOutput o = new AModel.AOutput();
Vec cat = vec(new String[]{"A","B"},0,1,0,1);
Frame trn = new Frame();
trn.add("cat",cat);
o.setNames(trn.names());
o._domains = trn.domains();
trn.remove();
AModel am = new AModel(Key.make(),p,o);
Frame tst = new Frame();
tst.add("cat", cat.makeCon(Double.NaN)); // All NAN/missing column
Frame adapt = new Frame(tst);
String[] warns = am.adaptTestForTrain(adapt,true, true);
Assert.assertTrue(warns.length == 0); // No errors during adaption
Model.cleanup_adapt(adapt, tst );
tst.remove();
}
// If the train set has a categorical, and the test set column is numeric
// then convert it to a categorical
@Test public void testModelAdaptConvert() {
AModel.AParms p = new AModel.AParms();
AModel.AOutput o = new AModel.AOutput();
Frame trn = new Frame();
trn.add("dog",vec(new String[]{"A","B"},0,1,0,1));
o.setNames(trn.names());
o._domains = trn.domains();
trn.remove();
AModel am = new AModel(Key.make(),p,o);
Frame tst = new Frame();
tst.add("dog",vec(2, 3, 2, 3));
Frame adapt = new Frame(tst);
boolean saw_iae = false;
try { am.adaptTestForTrain(adapt, true, true); }
catch( IllegalArgumentException iae ) { saw_iae = true; }
Assert.assertTrue(saw_iae);
Model.cleanup_adapt(adapt, tst );
tst.remove();
}
}