package hex.genmodel; import hex.genmodel.easy.EasyPredictModelWrapper; import hex.genmodel.easy.RowData; import hex.genmodel.easy.prediction.RegressionModelPrediction; import org.junit.Test; import java.io.Closeable; import java.io.File; import java.net.URL; import static org.junit.Assert.*; import static hex.genmodel.MojoReaderBackendFactory.CachingStrategy; public class MojoReaderBackendFactoryTest { @Test public void testCreateReaderBackend_URL_Memory() throws Exception { URL dumjo = MojoReaderBackendFactoryTest.class.getResource("dumjo.zip"); assertNotNull(dumjo); MojoReaderBackend r = MojoReaderBackendFactory.createReaderBackend(dumjo, CachingStrategy.MEMORY); assertTrue(r instanceof InMemoryMojoReaderBackend); try { assertTrue(r.exists("binary-file")); assertTrue(r.exists("text-file")); assertEquals("line1", r.getTextFile("text-file").readLine()); } finally { ((Closeable) r).close(); } } @Test public void testCreateReaderBackend_URL_Disk() throws Exception { URL dumjo = MojoReaderBackendFactoryTest.class.getResource("dumjo.zip"); assertNotNull(dumjo); MojoReaderBackend r = MojoReaderBackendFactory.createReaderBackend(dumjo, CachingStrategy.DISK); assertTrue(r instanceof TmpMojoReaderBackend); File tempFile = ((TmpMojoReaderBackend) r)._tempZipFile; assertTrue(tempFile.exists()); try { assertTrue(r.exists("binary-file")); assertTrue(r.exists("text-file")); assertEquals("line1", r.getTextFile("text-file").readLine()); } finally { ((Closeable) r).close(); } assertFalse(tempFile.exists()); } @Test public void testMojoE2E_Memory() throws Exception { testMojoE2E(CachingStrategy.MEMORY); } @Test public void testMojoE2E_Disk() throws Exception { testMojoE2E(CachingStrategy.DISK); } private void testMojoE2E(CachingStrategy cachingStrategy) throws Exception { URL mojoSource = MojoReaderBackendFactoryTest.class.getResource("mojo.zip"); assertNotNull(mojoSource); MojoReaderBackend reader = MojoReaderBackendFactory.createReaderBackend(mojoSource, cachingStrategy); MojoModel model = ModelMojoReader.readFrom(reader); EasyPredictModelWrapper modelWrapper = new EasyPredictModelWrapper(model); RowData testRow = makeTestRow(); RegressionModelPrediction prediction = (RegressionModelPrediction) modelWrapper.predict(testRow); assertEquals(71.085d, prediction.value, 0.001d); } private static RowData makeTestRow() { RowData testRow = new RowData(); String[] row = ("75,0,190,80,91,193,371,174,121,-16,13,64,-2,0,63,0,52,44,0,0,32,0,0,0,0,0,0,0,44,20,36,0,28,0,0,0,0," + "0,0,52,40,0,0,0,60,0,0,0,0,0,0,52,0,0,0,0,0,0,0,0,0,0,0,0,56,36,0,0,32,0,0,0,0,0,0,48,32,0,0,0,56,0,0,0,0,0," + "0,80,0,0,0,0,0,0,0,0,0,0,0,0,40,52,0,0,28,0,0,0,0,0,0,0,48,48,0,0,32,0,0,0,0,0,0,0,52,52,0,0,36,0,0,0,0,0,0," + "0,52,48,0,0,32,0,0,0,0,0,0,0,56,44,0,0,32,0,0,0,0,0,0,-0.2,0.0,6.1,-1.0,0.0,0.0,0.6,2.1,13.6,30.8,0.0,0.0,1.7," + "-1.0,0.6,0.0,1.3,1.5,3.7,14.5,0.1,-5.2,1.4,0.0,0.0,0.0,0.8,-0.6,-10.7,-15.6,0.4,-3.9,0.0,0.0,0.0,0.0,-0.8,-1.7," + "-10.1,-22.0,0.0,0.0,5.7,-1.0,0.0,0.0,-0.1,1.2,14.1,22.5,0.0,-2.5,0.8,0.0,0.0,0.0,1.0,0.4,-4.8,-2.7,0.1,-6.0,0.0" + ",0.0,0.0,0.0,-0.8,-0.6,-24.0,-29.7,0.0,0.0,2.0,-6.4,0.0,0.0,0.2,2.9,-12.6,15.2,-0.1,0.0,8.4,-10.0,0.0,0.0,0.6,5.9," + "-3.9,52.7,-0.3,0.0,15.2,-8.4,0.0,0.0,0.9,5.1,17.7,70.7,-0.4,0.0,13.5,-4.0,0.0,0.0,0.9,3.9,25.5,62.9,-0.3,0.0,9.0," + "-0.9,0.0,0.0,0.9,2.9,23.3,49.4,8") .split(","); for (int i = 0; i < row.length; i++) testRow.put("C" + (i+1), Double.valueOf(row[i])); return testRow; } }