package hex.pdp;
import hex.PartialDependence;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.TwoDimTable;
public class PartialDependenceTest extends TestUtil {
@BeforeClass() public static void setup() { stall_till_cloudsize(1); }
@Test public void prostateBinary() {
Frame fr=null;
GBMModel model=null;
PartialDependence partialDependence = null;
try {
// Frame
fr = parse_test_file("smalldata/prostate/prostate.csv");
for (String s : new String[]{"RACE","GLEASON","DPROS","DCAPS","CAPSULE"}) {
Vec v = fr.remove(s);
fr.add(s, v.toCategoricalVec());
v.remove();
}
DKV.put(fr);
// Model
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = fr._key;
parms._ignored_columns = new String[]{"ID"};
parms._response_column = "CAPSULE";
model = new GBM(parms).trainModel().get();
// PartialDependence
partialDependence = new PartialDependence(Key.<PartialDependence>make());
// partialDependence._cols = model._output._names;
partialDependence._nbins = 10;
partialDependence._model_id = (Key) model._key;
partialDependence._frame_id = fr._key;
partialDependence.execImpl().get();
for (TwoDimTable t : partialDependence._partial_dependence_data)
Log.info(t);
} finally {
if (fr!=null) fr.remove();
if (model!=null) model.remove();
if (partialDependence !=null) partialDependence.remove();
}
}
@Test public void prostateBinaryPickCols() {
Frame fr=null;
GBMModel model=null;
PartialDependence partialDependence = null;
try {
// Frame
fr = parse_test_file("smalldata/prostate/prostate.csv");
for (String s : new String[]{"RACE","GLEASON","DPROS","DCAPS","CAPSULE"}) {
Vec v = fr.remove(s);
fr.add(s, v.toCategoricalVec());
v.remove();
}
DKV.put(fr);
// Model
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = fr._key;
parms._ignored_columns = new String[]{"ID"};
parms._response_column = "CAPSULE";
model = new GBM(parms).trainModel().get();
// PartialDependence
partialDependence = new PartialDependence(Key.<PartialDependence>make());
partialDependence._cols = new String[]{"DPROS", "GLEASON"}; //pick columns manually
partialDependence._nbins = 10;
partialDependence._model_id = (Key) model._key;
partialDependence._frame_id = fr._key;
partialDependence.execImpl().get();
for (TwoDimTable t : partialDependence._partial_dependence_data)
Log.info(t);
Assert.assertTrue(partialDependence._partial_dependence_data.length == 2);
} finally {
if (fr!=null) fr.remove();
if (model!=null) model.remove();
if (partialDependence !=null) partialDependence.remove();
}
}
@Test public void prostateRegression() {
Frame fr=null;
GBMModel model=null;
PartialDependence partialDependence = null;
try {
// Frame
fr = parse_test_file("smalldata/prostate/prostate.csv");
for (String s : new String[]{"RACE","GLEASON","DPROS","DCAPS","CAPSULE"}) {
Vec v = fr.remove(s);
fr.add(s, v.toCategoricalVec());
v.remove();
}
DKV.put(fr);
// Model
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = fr._key;
parms._ignored_columns = new String[]{"ID"};
parms._response_column = "AGE";
model = new GBM(parms).trainModel().get();
// PartialDependence
partialDependence = new PartialDependence(Key.<PartialDependence>make());
partialDependence._nbins = 10;
partialDependence._model_id = (Key) model._key;
partialDependence._frame_id = fr._key;
partialDependence.execImpl().get();
for (TwoDimTable t : partialDependence._partial_dependence_data)
Log.info(t);
} finally {
if (fr!=null) fr.remove();
if (model!=null) model.remove();
if (partialDependence !=null) partialDependence.remove();
}
}
@Test public void weatherBinary() {
Frame fr=null;
GBMModel model=null;
PartialDependence partialDependence = null;
try {
// Frame
fr = parse_test_file("smalldata/junit/weather.csv");
// Model
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = fr._key;
parms._ignored_columns = new String[]{"Date","RISK_MM", "EvapMM",};
parms._response_column = "RainTomorrow";
model = new GBM(parms).trainModel().get();
// PartialDependence
partialDependence = new PartialDependence(Key.<PartialDependence>make());
partialDependence._nbins = 33;
partialDependence._cols = new String[]{"Sunshine","MaxWindPeriod","WindSpeed9am"};
partialDependence._model_id = (Key) model._key;
partialDependence._frame_id = fr._key;
partialDependence.execImpl().get();
for (TwoDimTable t : partialDependence._partial_dependence_data)
Log.info(t);
} finally {
if (fr!=null) fr.remove();
if (model!=null) model.remove();
if (partialDependence !=null) partialDependence.remove();
}
}
}