package hex.naivebayes;
import hex.SplitFrame;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import hex.naivebayes.NaiveBayesModel.NaiveBayesParameters;
import java.util.concurrent.ExecutionException;
public class NaiveBayesTest extends TestUtil {
@BeforeClass public static void setup() { stall_till_cloudsize(1); }
@Test public void testIris() throws InterruptedException, ExecutionException {
NaiveBayesModel model = null;
Frame train = null, score = null;
try {
train = parse_test_file(Key.make("iris_wheader.hex"), "smalldata/iris/iris_wheader.csv");
NaiveBayesParameters parms = new NaiveBayesParameters();
parms._train = train._key;
parms._laplace = 0;
parms._response_column = train._names[4];
parms._compute_metrics = false;
model = new NaiveBayes(parms).trainModel().get();
// Done building model; produce a score column with class assignments
score = model.score(train);
Assert.assertTrue(model.testJavaScoring(train,score,1e-6));
} finally {
if (train != null) train.delete();
if (score != null) score.delete();
if (model != null) model.delete();
}
}
@Test public void testIrisValidation() throws InterruptedException, ExecutionException {
NaiveBayesModel model = null;
Frame fr = null, fr2 = null;
Frame tr = null, te = null;
try {
fr = parse_test_file("smalldata/iris/iris_wheader.csv");
SplitFrame sf = new SplitFrame(fr,new double[] { 0.5, 0.5 },new Key[] { Key.make("train.hex"), Key.make("test.hex") });
// Invoke the job
sf.exec().get();
Key[] ksplits = sf._destination_frames;
tr = DKV.get(ksplits[0]).get();
te = DKV.get(ksplits[1]).get();
NaiveBayesParameters parms = new NaiveBayesParameters();
parms._train = ksplits[0];
parms._valid = ksplits[1];
parms._laplace = 0.01; // Need Laplace smoothing
parms._response_column = fr._names[4];
parms._compute_metrics = true;
model = new NaiveBayes(parms).trainModel().get();
// Done building model; produce a score column with class assignments
fr2 = model.score(te);
Assert.assertTrue(model.testJavaScoring(te,fr2,1e-6));
} finally {
if( fr != null ) fr.delete();
if( fr2 != null ) fr2.delete();
if( tr != null ) tr .delete();
if( te != null ) te .delete();
if( model != null ) model.delete();
}
}
@Test public void testProstate() throws InterruptedException, ExecutionException {
NaiveBayesModel model = null;
Frame train = null, score = null;
final int[] cats = new int[]{1,3,4,5}; // Categoricals: CAPSULE, RACE, DPROS, DCAPS
try {
Scope.enter();
train = parse_test_file(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
for(int i = 0; i < cats.length; i++)
Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec()));
train.remove("ID").remove();
DKV.put(train._key, train);
NaiveBayesParameters parms = new NaiveBayesParameters();
parms._train = train._key;
parms._laplace = 0;
parms._response_column = train._names[0];
parms._compute_metrics = true;
model = new NaiveBayes(parms).trainModel().get();
// Done building model; produce a score column with class assignments
score = model.score(train);
Assert.assertTrue(model.testJavaScoring(train,score,1e-6));
} finally {
if (train != null) train.delete();
if (score != null) score.delete();
if (model != null) model.delete();
Scope.exit();
}
}
@Test public void testCovtype() throws InterruptedException, ExecutionException {
NaiveBayesModel model = null;
Frame train = null, score = null;
try {
Scope.enter();
train = parse_test_file(Key.make("covtype.hex"), "smalldata/covtype/covtype.20k.data");
Scope.track(train.replace(54, train.vecs()[54].toCategoricalVec())); // Change response to categorical
DKV.put(train);
NaiveBayesParameters parms = new NaiveBayesParameters();
parms._train = train._key;
parms._laplace = 0;
parms._response_column = train._names[54];
parms._compute_metrics = false;
model = new NaiveBayes(parms).trainModel().get();
// Done building model; produce a score column with class assignments
score = model.score(train);
Assert.assertTrue(model.testJavaScoring(train,score,1e-6));
} finally {
if (train != null) train.delete();
if (score != null) score.delete();
if (model != null) model.delete();
Scope.exit();
}
}
}