package hex.mli; import hex.genmodel.utils.DistributionFamily; import hex.mli.loco.LeaveOneCovarOut; import hex.tree.gbm.GBM; import org.junit.BeforeClass; import org.junit.Test; import water.DKV; import water.Scope; import water.TestUtil; import water.fvec.Frame; import hex.tree.gbm.GBMModel; import static hex.genmodel.utils.DistributionFamily.gaussian; import static hex.genmodel.utils.DistributionFamily.multinomial; import static hex.genmodel.utils.DistributionFamily.bernoulli; /** * This Junit is mainly used to detect leaks in Leave One Covariate Out (LOCO) */ public class LeaveOneCovarOutTest extends TestUtil { @BeforeClass() public static void setup() { stall_till_cloudsize(1); } @Test public void testLocoRegressionDefault() { //Regression case locoRun("./smalldata/junit/cars.csv", "economy (mpg)", gaussian,null); } @Test public void testLocoBernoulliDefault() { //Bernoulli case locoRun("./smalldata/logreg/prostate.csv", "CAPSULE", bernoulli,null); } @Test public void testLocoMultinomialDefault(){ //Multinomial case locoRun("./smalldata/junit/cars.csv", "cylinders", multinomial,null); } @Test public void testLocoRegressionMean() { //Regression case locoRun("./smalldata/junit/cars.csv", "economy (mpg)", gaussian,"mean"); } @Test public void testLocoBernoulliMean() { //Bernoulli case locoRun("./smalldata/logreg/prostate.csv", "CAPSULE", bernoulli,"mean"); } @Test public void testLocoMultinomialMean(){ //Multinomial case locoRun("./smalldata/junit/cars.csv", "cylinders", multinomial,"mean"); } @Test public void testLocoRegressionMedian() { //Regression case locoRun("./smalldata/junit/cars.csv", "economy (mpg)", gaussian,"median"); } @Test public void testLocoBernoulliMedian() { //Bernoulli case locoRun("./smalldata/logreg/prostate.csv", "CAPSULE", bernoulli,"median"); } @Test public void testLocoMultinomialMedian(){ //Multinomial case locoRun("./smalldata/junit/cars.csv", "cylinders", multinomial,"median"); } public Frame locoRun(String fname, String response, DistributionFamily family, String method) { GBMModel gbm = null; Frame fr = null; Frame loco=null; try { Scope.enter(); fr = parse_test_file(fname); int idx = fr.find(response); if (family == DistributionFamily.bernoulli || family == DistributionFamily.multinomial || family == DistributionFamily.modified_huber) { if (!fr.vecs()[idx].isCategorical()) { Scope.track(fr.replace(idx, fr.vecs()[idx].toCategoricalVec())); } } DKV.put(fr); // Update frame after hacking it GBMModel.GBMParameters parms = new GBMModel.GBMParameters(); if( idx < 0 ) idx = ~idx; parms._train = fr._key; parms._response_column = fr._names[idx]; parms._ntrees = 5; parms._distribution = family; parms._max_depth = 4; parms._min_rows = 1; parms._nbins = 50; parms._learn_rate = .2f; parms._score_each_iteration = true; GBM job = new GBM(parms); gbm = job.trainModel().get(); if(method == null) { loco = LeaveOneCovarOut.leaveOneCovarOut(gbm, fr, job._job, null,null); assert DKV.get(loco._key) != null : "LOCO frame with default transform is not in DKV!"; } else if(method == "mean"){ loco = LeaveOneCovarOut.leaveOneCovarOut(gbm, fr, job._job, "mean",null); assert DKV.get(loco._key) != null : "LOCO frame with mean transform is not in DKV!"; } else{ loco = LeaveOneCovarOut.leaveOneCovarOut(gbm, fr, job._job, "median",null); assert DKV.get(loco._key) != null : "LOCO frame with median transform is not in DKV!"; } return loco; } finally { if( fr != null ) fr.remove(); if( gbm != null ) gbm.delete(); if( loco != null ) loco.remove(); Scope.exit(); } } }