package hex.deepwater; import deepwater.backends.BackendModel; import deepwater.backends.BackendParams; import deepwater.backends.RuntimeOptions; import deepwater.datasets.ImageDataSet; import org.junit.*; import water.parser.BufferedString; import water.util.FileUtils; import water.util.StringUtils; import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; import java.io.*; import java.nio.file.Paths; public class DeepWaterMXNetIntegrationTest extends DeepWaterAbstractIntegrationTest { static long copy(InputStream var0, OutputStream var1) throws IOException { byte[] var2 = new byte[4096]; long var3 = 0L; while(true) { int var5 = var0.read(var2); if(var5 == -1) { return var3; } var1.write(var2, 0, var5); var3 += (long)var5; } } @Override DeepWaterParameters.Backend getBackend() { return DeepWaterParameters.Backend.mxnet; } @BeforeClass public static void checkBackend() { Assume.assumeTrue(DeepWater.haveBackend(DeepWaterParameters.Backend.mxnet)); } public static String extractFile(String path, String file) throws IOException { InputStream in = DeepWaterMXNetIntegrationTest.class.getClassLoader().getResourceAsStream(Paths.get(path, file).toString()); String target = Paths.get(System.getProperty("java.io.tmpdir"), file).toString(); OutputStream out = new FileOutputStream(target); copy(in, out); return target; } // This test has nothing to do with H2O - Pure integration test of deepwater/backends/mxnet @Test public void inceptionPredictionMX() throws IOException { for (boolean gpu : new boolean[]{true, false}) { // Set model parameters int w = 224, h = 224, channels = 3, nclasses=1000; ImageDataSet id = new ImageDataSet(w,h,channels,nclasses); RuntimeOptions opts = new RuntimeOptions(); opts.setSeed(1234); opts.setUseGPU(gpu); BackendParams bparm = new BackendParams(); bparm.set("mini_batch_size", 1); // Load the model String path = "deepwater/backends/mxnet/models/Inception/"; BackendModel _model = backend.buildNet(id, opts, bparm, nclasses, StringUtils.expandPath(extractFile(path, "Inception_BN-symbol.json"))); backend.loadParam(_model, StringUtils.expandPath(extractFile(path, "Inception_BN-0039.params"))); water.fvec.Frame labels = parse_test_file(extractFile(path, "synset.txt")); float[] mean = backend.loadMeanImage(_model, extractFile(path, "mean_224.nd")); // Turn the image into a vector of the correct size File imgFile = FileUtils.getFile("smalldata/deepwater/imagenet/test2.jpg"); BufferedImage img = ImageIO.read(imgFile); BufferedImage scaledImg = new BufferedImage(w, h, img.getType()); Graphics2D g2d = scaledImg.createGraphics(); g2d.drawImage(img, 0, 0, w, h, null); g2d.dispose(); float[] pixels = new float[w * h * channels]; int r_idx = 0; int g_idx = r_idx + w * h; int b_idx = g_idx + w * h; for (int i = 0; i < h; i++) { for (int j = 0; j < w; j++) { Color mycolor = new Color(scaledImg.getRGB(j, i)); int red = mycolor.getRed(); int green = mycolor.getGreen(); int blue = mycolor.getBlue(); pixels[r_idx] = red - mean[r_idx]; r_idx++; pixels[g_idx] = green - mean[g_idx]; g_idx++; pixels[b_idx] = blue - mean[b_idx]; b_idx++; } } float[] preds = backend.predict(_model, pixels); int K = 5; int[] topK = new int[K]; for ( int i = 0; i < preds.length; i++ ) { for ( int j = 0; j < K; j++ ) { if ( preds[i] > preds[topK[j]] ) { topK[j] = i; break; } } } // Display the top 5 predictions StringBuilder sb = new StringBuilder(); sb.append("\nTop " + K + " predictions:\n"); BufferedString str = new BufferedString(); for ( int j = 0; j < K; j++ ) { String label = labels.anyVec().atStr(str, topK[j]).toString(); sb.append(" Score: " + String.format("%.4f", preds[topK[j]]) + "\t" + label + "\n"); } System.out.println("\n\n" + sb.toString() +"\n\n"); Assert.assertTrue("Illegal predictions!", sb.toString().substring(40,60).contains("Pembroke")); labels.remove(); } } @Ignore @Test public void PreTrainedMOJO() { water.fvec.Frame tr = null; water.fvec.Frame preds = null; DeepWaterModel m = null; try { DeepWaterParameters p = new DeepWaterParameters(); //p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cc.csv"))._key; p._train = (tr=parse_test_file("bigdata/laptop/deepwater/imagenet/cat_dog_mouse.csv"))._key; p._response_column = "C2"; // p._problem_type = DeepWaterParameters.ProblemType.image_classification; // p._train.get().remove("C3"); // for (String col : new String[]{p._train.get().name(0)}) { // Vec v = tr.remove(col); // tr.add(col, v.toStringVec()); // v.remove(); // } // for (String col : new String[]{p._response_column}) { // Vec v = tr.remove(col); // tr.add(col, v.toCategoricalVec()); // v.remove(); // } String path = "../deepwater/mxnet/src/main/resources/deepwater/backends/mxnet/models/Inception/"; // p._network = DeepWaterParameters.Network.user; p._image_shape = new int[]{224, 224}; p._channels = 3; p._network_definition_file = path + "Inception_BN-symbol.json"; //TODO: allow loading this 1000-class graph for this 3-class problem p._network_parameters_file = path + "Inception_BN-0039.params"; //TODO: allow loading this parameter file for the 3-class modified graph p._mean_image_file = path + "mean_224.nd"; p._epochs = 0.1; //just make a model, no training needed p._learning_rate = 0; //just make a model, no training needed DeepWater j = new DeepWater(p); m = j.trainModel().get(); preds = m.score(p._train.get()); Assert.assertTrue(m.testJavaScoring(p._train.get(),preds,1e-3)); } finally { if (tr!=null) tr.remove(); if (preds!=null) preds.remove(); if (m!=null) m.remove(); } } }