package hex; import hex.CoxPH.*; import water.*; import water.api.CoxPHModelView; import water.deploy.Node; import water.deploy.NodeVM; import water.fvec.Frame; import water.fvec.NFSFileVec; import water.fvec.ParseDataset2; import java.io.File; import java.util.concurrent.ExecutionException; import org.junit.Test; import static org.junit.Assert.assertEquals; public class CoxPHTest extends TestUtil { public static void testHTML(CoxPHModel m) { StringBuilder sb = new StringBuilder(); CoxPHModelView modelView = new CoxPHModelView(); modelView.coxph_model = m; modelView.toHTML(sb); assert(sb.length() > 0); } private static Frame getFrameForFile(Key outputKey, String path) { File f = TestUtil.find_test_file(path); Key k = NFSFileVec.make(f); return ParseDataset2.parse(outputKey, new Key[]{k}); } @Test public void testCoxPHEfron1Var() throws InterruptedException, ExecutionException { Key parsed = Key.make("coxph_efron_test_data_parsed"); Key modelKey = Key.make("coxph_efron_test"); CoxPHModel model = null; Frame fr = null; try { fr = getFrameForFile(parsed, "smalldata/heart.csv"); CoxPH job = new CoxPH(); job.destination_key = modelKey; job.source = fr; job.start_column = fr.vec("start"); job.stop_column = fr.vec("stop"); job.event_column = fr.vec("event"); job.x_columns = new int[] {fr.find("age")}; job.ties = CoxPHTies.efron; job.fork(); job.get(); model = DKV.get(modelKey).get(); testHTML(model); assertEquals(model.coef[0], 0.0307077486571334, 1e-8); assertEquals(model.var_coef[0][0], 0.000203471477951459, 1e-8); assertEquals(model.null_loglik, -298.121355672984, 1e-8); assertEquals(model.loglik, -295.536762216228, 1e-8); assertEquals(model.score_test, 4.64097294749287, 1e-8); assert(model.iter >= 1); assertEquals(model.x_mean_num[0], -2.48402655078554, 1e-8); assertEquals(model.n, 172); assertEquals(model.total_event, 75); assertEquals(model.wald_test, 4.6343882547245, 1e-8); } finally { if (fr != null) fr.delete(); if (model != null) model.delete(); } } @Test public void testCoxPHBreslow1Var() throws InterruptedException, ExecutionException { Key parsed = Key.make("coxph_efron_test_data_parsed"); Key modelKey = Key.make("coxph_efron_test"); CoxPHModel model = null; Frame fr = null; try { fr = getFrameForFile(parsed, "smalldata/heart.csv"); CoxPH job = new CoxPH(); job.destination_key = modelKey; job.source = fr; job.start_column = fr.vec("start"); job.stop_column = fr.vec("stop"); job.event_column = fr.vec("event"); job.x_columns = new int[] {fr.find("age")}; job.ties = CoxPHTies.breslow; job.fork(); job.get(); model = DKV.get(modelKey).get(); testHTML(model); assertEquals(model.coef[0], 0.0306910411003801, 1e-8); assertEquals(model.var_coef[0][0], 0.000203592486905101, 1e-8); assertEquals(model.null_loglik, -298.325606736463, 1e-8); assertEquals(model.loglik, -295.745227177782, 1e-8); assertEquals(model.score_test, 4.63317821557301, 1e-8); assert(model.iter >= 1); assertEquals(model.x_mean_num[0], -2.48402655078554, 1e-8); assertEquals(model.n, 172); assertEquals(model.total_event, 75); assertEquals(model.wald_test, 4.62659510743282, 1e-8); } finally { if (fr != null) fr.delete(); if (model != null) model.delete(); } } @Test public void testCoxPHEfron1VarNoStart() throws InterruptedException, ExecutionException { Key parsed = Key.make("coxph_efron_test_data_parsed"); Key modelKey = Key.make("coxph_efron_test"); CoxPHModel model = null; Frame fr = null; try { fr = getFrameForFile(parsed, "smalldata/heart.csv"); CoxPH job = new CoxPH(); job.destination_key = modelKey; job.source = fr; job.start_column = null; job.stop_column = fr.vec("stop"); job.event_column = fr.vec("event"); job.x_columns = new int[] {fr.find("age")}; job.ties = CoxPHTies.efron; job.fork(); job.get(); model = DKV.get(modelKey).get(); testHTML(model); assertEquals(model.coef[0], 0.0289468187293998, 1e-8); assertEquals(model.var_coef[0][0], 0.000210975113029285, 1e-8); assertEquals(model.null_loglik, -314.148170059513, 1e-8); assertEquals(model.loglik, -311.946958322919, 1e-8); assertEquals(model.score_test, 3.97716015008595, 1e-8); assert(model.iter >= 1); assertEquals(model.x_mean_num[0], -2.48402655078554, 1e-8); assertEquals(model.n, 172); assertEquals(model.total_event, 75); assertEquals(model.wald_test, 3.97164529276219, 1e-8); } finally { if (fr != null) fr.delete(); if (model != null) model.delete(); } } @Test public void testCoxPHBreslow1VarNoStart() throws InterruptedException, ExecutionException { Key parsed = Key.make("coxph_efron_test_data_parsed"); Key modelKey = Key.make("coxph_efron_test"); CoxPHModel model = null; Frame fr = null; try { fr = getFrameForFile(parsed, "smalldata/heart.csv"); CoxPH job = new CoxPH(); job.destination_key = modelKey; job.source = fr; job.start_column = null; job.stop_column = fr.vec("stop"); job.event_column = fr.vec("event"); job.x_columns = new int[] {fr.find("age")}; job.ties = CoxPHTies.breslow; job.fork(); job.get(); model = DKV.get(modelKey).get(); testHTML(model); assertEquals(model.coef[0], 0.0289484855901731, 1e-8); assertEquals(model.var_coef[0][0], 0.000211028794751156, 1e-8); assertEquals(model.null_loglik, -314.296493366900, 1e-8); assertEquals(model.loglik, -312.095342077591, 1e-8); assertEquals(model.score_test, 3.97665282498882, 1e-8); assert(model.iter >= 1); assertEquals(model.x_mean_num[0], -2.48402655078554, 1e-8); assertEquals(model.n, 172); assertEquals(model.total_event, 75); assertEquals(model.wald_test, 3.97109228128153, 1e-8); } finally { if (fr != null) fr.delete(); if (model != null) model.delete(); } } public static void main(String [] args) throws Exception{ System.out.println("Running ParserTest2"); final int nnodes = 1; for (int i = 0; i < nnodes; i++) { Node n = new NodeVM(args); n.inheritIO(); n.start(); } H2O.waitForCloudSize(nnodes); System.out.println("Running..."); new CoxPHTest().testCoxPHEfron1Var(); new CoxPHTest().testCoxPHBreslow1Var(); new CoxPHTest().testCoxPHEfron1VarNoStart(); new CoxPHTest().testCoxPHBreslow1VarNoStart(); System.out.println("DONE!"); } }