package hex.gbm;
import java.io.File;
import org.junit.BeforeClass;
import org.junit.Test;
import water.*;
import water.fvec.*;
public class GBMDomainTest extends TestUtil {
private abstract class PrepData { abstract Vec prep(Frame fr); }
@BeforeClass public static void stall() { stall_till_cloudsize(3); }
/**
* The scenario:
* - test data contains an input column which contains less enum values than the same column in train data.
* In this case we should provide correct values mapping:
* A - 0
* B - 1 B - 0 B - 1
* C - 2 D - 1 mapping should remap it into: D - 3
* D - 3
*/
@Test public void testModelAdapt() {
runAndScoreGBM(
"./smalldata/test/classifier/coldom_train_1.csv",
"./smalldata/test/classifier/coldom_test_1.csv",
new PrepData() { @Override Vec prep(Frame fr) { return fr.vecs()[fr.numCols()-1]; } });
}
/**
* The scenario:
* - test data contains an input column which contains more enum values than the same column in train data.
* A - 0
* B - 1 B - 0 B - 1
* C - 2 X - 1 mapping should remap it into: X - NA
* D - 3
*/
@Test public void testModelAdapt2() {
runAndScoreGBM(
"./smalldata/test/classifier/coldom_train_1.csv",
"./smalldata/test/classifier/coldom_test_1_2.csv",
new PrepData() { @Override Vec prep(Frame fr) { return fr.vecs()[fr.numCols()-1]; } });
}
// Adapt a trained model to a test dataset with different enums
void runAndScoreGBM(String train, String test, PrepData prepData) {
File file1 = TestUtil.find_test_file(train);
Key fkey1 = NFSFileVec.make(file1);
Key dest1 = Key.make("train.hex");
File file2 = TestUtil.find_test_file(test);
Key fkey2 = NFSFileVec.make(file2);
Key dest2 = Key.make("test.hex");
GBM gbm = null;
GBM.GBMModel model = null;
Frame preds = null;
try {
gbm = new GBM();
gbm.source = ParseDataset2.parse(dest1,new Key[]{fkey1});
gbm.response = prepData.prep(gbm.source);
gbm.ntrees = 2;
gbm.max_depth = 3;
gbm.learn_rate = 0.2f;
gbm.min_rows = 10;
gbm.nbins = 1024;
gbm.cols = new int[] {0,1,2};
gbm.invoke();
model = UKV.get(gbm.dest());
// The test data set has a few more enums than the train
Frame ftest = ParseDataset2.parse(dest2,new Key[]{fkey2});
preds = gbm.score(ftest);
// Delete test frame
ftest.delete();
} catch (Throwable t) {
t.printStackTrace();
} finally {
gbm.source.delete(); // Remove original hex frame key
if( preds != null ) preds.delete();
if( model != null ) model.delete();
UKV.remove(gbm.response._key);
gbm.remove(); // Remove GBM Job
}
}
}