package hex.genmodel.algos.glm; import hex.genmodel.ModelMojoReader; import hex.genmodel.MojoReaderBackend; import org.junit.Test; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import static org.junit.Assert.*; public class GlmMojoModelTest { @Test public void testScore0() throws Exception { double[][] data = new double[][]{ new double[]{2,73,2,1,7.9,18,6}, new double[]{1,51,3,1,8.9,0,6}, new double[]{2,57,3,1,3.4,30.8,6}, new double[]{1,65,4,1,6.3,0,6}, new double[]{1,61,3,1,1.5,0,5}, new double[]{1,56,2,2,58,0,6}, new double[]{1,72,2,1,1.4,24.2,6}, new double[]{1,54,2,1,18,43,9}, new double[]{1,62,2,1,7.3,0,7}, new double[]{2,63,3,1,14.3,16,7}, new double[]{1,68,1,1,5.4,34,5}, new double[]{1,Double.NaN,1,1,5.4,34,5} // value should be imputed }; double[][] expPreds = new double[][]{ new double[]{0.0, 0.883740206424754, 0.11625979357524593}, new double[]{1.0, 0.5591006829867439, 0.44089931701325613}, new double[]{0.0, 0.8200793110208472, 0.1799206889791528}, new double[]{1.0, 0.4855023555733662, 0.5144976444266338}, new double[]{0.0, 0.8260781970262484, 0.17392180297375157}, new double[]{1.0, 0.2685796973779421, 0.7314203026220579}, new double[]{0.0, 0.8265057623033865, 0.1734942376966135}, new double[]{1.0, 0.1332488800455477, 0.8667511199544523}, new double[]{1.0, 0.5038183003787983, 0.49618169962120173}, new double[]{1.0, 0.5384202639029669, 0.46157973609703307}, new double[]{0.0, 0.9543248143434919, 0.04567518565650803}, new double[]{0.0, 0.9531416700165544, 0.046858329983445586} }; GlmMojoModel mojo = (GlmMojoModel) ModelMojoReader.readFrom(new ClasspathReaderBackend()); for (int i = 0; i < data.length; i++) { double preds[] = mojo.score0(data[i], new double[3]); assertArrayEquals("Predictions for row #" + i, expPreds[i], preds, 0.0000001); } } private static class ClasspathReaderBackend implements MojoReaderBackend { @Override public BufferedReader getTextFile(String filename) throws IOException { InputStream is = GlmMojoModelTest.class.getResourceAsStream("prostate/" + filename); return new BufferedReader(new InputStreamReader(is)); } @Override public byte[] getBinaryFile(String filename) throws IOException { throw new UnsupportedOperationException("Unexpected call to getBinaryFile()"); } @Override public boolean exists(String name) { throw new UnsupportedOperationException("Unexpected call to exists()"); } } }