package hex.deepwater; import hex.genmodel.algos.deepwater.caffe.DeepwaterCaffeModel; import org.junit.Assume; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; import java.io.DataInputStream; import java.io.File; import java.io.FileInputStream; import java.util.Random; import java.util.zip.GZIPInputStream; public class DeepWaterCaffeIntegrationTest extends DeepWaterAbstractIntegrationTest { @Override DeepWaterParameters.Backend getBackend() { return DeepWaterParameters.Backend.caffe; } @BeforeClass public static void checkBackend() { Assume.assumeTrue(DeepWater.haveBackend(DeepWaterParameters.Backend.caffe)); }; @Ignore @Test public void run() throws Exception { /* MNIST demo. Get the data first in your home folder: cd wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz */ final int PIXELS = 28 * 28; String home = System.getProperty("user.home"); DataInputStream pixels = new DataInputStream(new GZIPInputStream(new FileInputStream(new File( home + "/train-images-idx3-ubyte.gz")))); DataInputStream labels = new DataInputStream(new GZIPInputStream(new FileInputStream(new File( home + "/train-labels-idx1-ubyte.gz")))); pixels.readInt(); // Magic int count = pixels.readInt(); pixels.readInt(); // Rows pixels.readInt(); // Cols labels.readInt(); // Magic labels.readInt(); // Count System.out.println("Read " + count + " samples"); byte[][] rawI = new byte[count][PIXELS]; byte[] rawL = new byte[count]; for (int i = 0; i < count; i++) { pixels.readFully(rawI[i]); rawL[i] = labels.readByte(); } System.out.println("Randomize"); Random rand = new Random(); for (int i = count - 1; i >= 0; i--) { int shuffle = rand.nextInt(i + 1); byte[] image = rawI[shuffle]; rawI[shuffle] = rawI[i]; rawI[i] = image; byte label = rawL[shuffle]; rawL[shuffle] = rawL[i]; rawL[i] = label; } System.out.println("Create model"); final int batch = 256; DeepwaterCaffeModel model = new DeepwaterCaffeModel( batch, new int[] {PIXELS, 4024, 4024, 4048, 10}, new String[] {"data", "relu", "relu", "relu", "loss"}, new double[] {.9, .5, .5, .5, 0.}, 1234, true // GPU ); System.out.println("Train"); float[] ps = new float[batch * PIXELS]; float[] ls = new float[batch]; for (int iter = 0; iter < 10; iter++) { for (int b = 0; b < batch; b++) { for (int i = 0; i < PIXELS; i++) ps[b * PIXELS + i] = (rawI[b][i] & 0xff) * 0.00390625f; ls[b] = rawL[b]; } model.train(ps, ls); model.predict(ps); } model.saveModel("/tmp/graph"); model.saveParam("/tmp/params"); model.loadParam("/tmp/params"); } }