package org.deeplearning4j.parallelism; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.datavec.api.util.ClassPathResource; import org.deeplearning4j.berkeley.Pair; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.parallelism.inference.InferenceMode; import org.deeplearning4j.parallelism.inference.InferenceObservable; import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObserver; import org.deeplearning4j.parallelism.inference.observers.BatchedInferenceObservable; import org.deeplearning4j.util.ModelSerializer; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import java.io.File; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Queue; import java.util.concurrent.LinkedBlockingQueue; import static org.junit.Assert.*; /** * @author raver119@gmail.com */ @Slf4j public class ParallelInferenceTest { private static MultiLayerNetwork model; private static DataSetIterator iterator; @Before public void setUp() throws Exception { if (model == null) { File file = new ClassPathResource("models/LenetMnistMLN.zip").getFile(); model = ModelSerializer.restoreMultiLayerNetwork(file, true); iterator = new MnistDataSetIterator(1, false, 12345); } } @After public void tearDown() throws Exception { iterator.reset(); } @Test public void testInferenceSequential1() throws Exception { ParallelInference inf = new ParallelInference.Builder(model).inferenceMode(InferenceMode.SEQUENTIAL).workers(2).build(); log.info("Features shape: {}", Arrays.toString(iterator.next().getFeatureMatrix().shapeInfoDataBuffer().asInt())); INDArray array1 = inf.output(iterator.next().getFeatureMatrix()); INDArray array2 = inf.output(iterator.next().getFeatureMatrix()); assertFalse(array1.isAttached()); assertFalse(array2.isAttached()); INDArray array3 = inf.output(iterator.next().getFeatureMatrix()); assertFalse(array3.isAttached()); iterator.reset(); evalClassifcationSingleThread(inf, iterator); // both workers threads should have non-zero assertTrue(inf.getWorkerCounter(0) > 100L); assertTrue(inf.getWorkerCounter(1) > 100L); } @Test public void testInferenceSequential2() throws Exception { ParallelInference inf = new ParallelInference.Builder(model).inferenceMode(InferenceMode.SEQUENTIAL).workers(2).build(); log.info("Features shape: {}", Arrays.toString(iterator.next().getFeatureMatrix().shapeInfoDataBuffer().asInt())); INDArray array1 = inf.output(iterator.next().getFeatureMatrix()); INDArray array2 = inf.output(iterator.next().getFeatureMatrix()); assertFalse(array1.isAttached()); assertFalse(array2.isAttached()); INDArray array3 = inf.output(iterator.next().getFeatureMatrix()); assertFalse(array3.isAttached()); iterator.reset(); evalClassifcationMultipleThreads(inf, iterator, 10); // both workers threads should have non-zero assertTrue(inf.getWorkerCounter(0) > 100L); assertTrue(inf.getWorkerCounter(1) > 100L); } @Test public void testInferenceBatched1() throws Exception { ParallelInference inf = new ParallelInference.Builder(model).inferenceMode(InferenceMode.BATCHED).batchLimit(8) .workers(2).build(); log.info("Features shape: {}", Arrays.toString(iterator.next().getFeatureMatrix().shapeInfoDataBuffer().asInt())); INDArray array1 = inf.output(iterator.next().getFeatureMatrix()); INDArray array2 = inf.output(iterator.next().getFeatureMatrix()); assertFalse(array1.isAttached()); assertFalse(array2.isAttached()); INDArray array3 = inf.output(iterator.next().getFeatureMatrix()); assertFalse(array3.isAttached()); iterator.reset(); evalClassifcationMultipleThreads(inf, iterator, 20); // both workers threads should have non-zero assertTrue(inf.getWorkerCounter(0) > 10L); assertTrue(inf.getWorkerCounter(1) > 10L); } @Test public void testProvider1() throws Exception { LinkedBlockingQueue queue = new LinkedBlockingQueue(); BasicInferenceObserver observer = new BasicInferenceObserver(); ParallelInference.ObservablesProvider provider = new ParallelInference.ObservablesProvider(10000000L, 100, queue); InferenceObservable observable1 = provider.setInput(observer, Nd4j.create(100)); InferenceObservable observable2 = provider.setInput(observer, Nd4j.create(100)); assertNotEquals(null, observable1); assertTrue(observable1 == observable2); } @Test public void testProvider2() throws Exception { LinkedBlockingQueue queue = new LinkedBlockingQueue(); BasicInferenceObserver observer = new BasicInferenceObserver(); ParallelInference.ObservablesProvider provider = new ParallelInference.ObservablesProvider(10000000L, 100, queue); InferenceObservable observable1 = provider.setInput(observer, Nd4j.create(100).assign(1.0)); InferenceObservable observable2 = provider.setInput(observer, Nd4j.create(100).assign(2.0)); assertNotEquals(null, observable1); assertTrue(observable1 == observable2); INDArray[] input = observable1.getInput(); assertEquals(1, input.length); assertArrayEquals(new int[] {2, 100}, input[0].shape()); assertEquals(1.0f, input[0].tensorAlongDimension(0, 1).meanNumber().floatValue(), 0.001); assertEquals(2.0f, input[0].tensorAlongDimension(1, 1).meanNumber().floatValue(), 0.001); } @Test public void testProvider3() throws Exception { LinkedBlockingQueue queue = new LinkedBlockingQueue(); BasicInferenceObserver observer = new BasicInferenceObserver(); ParallelInference.ObservablesProvider provider = new ParallelInference.ObservablesProvider(10000000L, 2, queue); InferenceObservable observable1 = provider.setInput(observer, Nd4j.create(100).assign(1.0)); InferenceObservable observable2 = provider.setInput(observer, Nd4j.create(100).assign(2.0)); InferenceObservable observable3 = provider.setInput(observer, Nd4j.create(100).assign(3.0)); assertNotEquals(null, observable1); assertNotEquals(null, observable3); assertTrue(observable1 == observable2); assertTrue(observable1 != observable3); INDArray[] input = observable1.getInput(); assertEquals(1.0f, input[0].tensorAlongDimension(0, 1).meanNumber().floatValue(), 0.001); assertEquals(2.0f, input[0].tensorAlongDimension(1, 1).meanNumber().floatValue(), 0.001); input = observable3.getInput(); assertEquals(3.0f, input[0].tensorAlongDimension(0, 1).meanNumber().floatValue(), 0.001); } @Test public void testProvider4() throws Exception { LinkedBlockingQueue queue = new LinkedBlockingQueue(); BasicInferenceObserver observer = new BasicInferenceObserver(); ParallelInference.ObservablesProvider provider = new ParallelInference.ObservablesProvider(10000000L, 4, queue); BatchedInferenceObservable observable1 = (BatchedInferenceObservable) provider.setInput(observer, Nd4j.create(100).assign(1.0)); BatchedInferenceObservable observable2 = (BatchedInferenceObservable) provider.setInput(observer, Nd4j.create(100).assign(2.0)); BatchedInferenceObservable observable3 = (BatchedInferenceObservable) provider.setInput(observer, Nd4j.create(100).assign(3.0)); INDArray bigOutput = Nd4j.create(3, 10); for (int i = 0; i < bigOutput.rows(); i++) bigOutput.getRow(i).assign((float) i); observable3.setOutput(bigOutput); INDArray out = null; observable3.setPosition(0); out = observable3.getOutput()[0]; assertArrayEquals(new int[] {1, 10}, out.shape()); assertEquals(0.0f, out.meanNumber().floatValue(), 0.01f); observable3.setPosition(1); out = observable3.getOutput()[0]; assertArrayEquals(new int[] {1, 10}, out.shape()); assertEquals(1.0f, out.meanNumber().floatValue(), 0.01f); observable3.setPosition(2); out = observable3.getOutput()[0]; assertArrayEquals(new int[] {1, 10}, out.shape()); assertEquals(2.0f, out.meanNumber().floatValue(), 0.01f); } protected void evalClassifcationSingleThread(@NonNull ParallelInference inf, @NonNull DataSetIterator iterator) { DataSet ds = iterator.next(); log.info("NumColumns: {}", ds.getLabels().columns()); iterator.reset(); Evaluation eval = new Evaluation(ds.getLabels().columns()); while (iterator.hasNext()) { ds = iterator.next(); INDArray output = inf.output(ds.getFeatureMatrix()); eval.eval(ds.getLabels(), output); } log.info(eval.stats()); } protected void evalClassifcationMultipleThreads(@NonNull ParallelInference inf, @NonNull DataSetIterator iterator, int numThreads) throws Exception { DataSet ds = iterator.next(); log.info("NumColumns: {}", ds.getLabels().columns()); iterator.reset(); Evaluation eval = new Evaluation(ds.getLabels().columns()); final Queue<DataSet> dataSets = new LinkedBlockingQueue<>(); final Queue<Pair<INDArray, INDArray>> outputs = new LinkedBlockingQueue<>(); int cnt = 0; // first of all we'll build datasets while (iterator.hasNext() && cnt < 256) { ds = iterator.next(); dataSets.add(ds); cnt++; } // now we'll build outputs in parallel Thread[] threads = new Thread[numThreads]; for (int i = 0; i < numThreads; i++) { threads[i] = new Thread(new Runnable() { @Override public void run() { DataSet ds; while ((ds = dataSets.poll()) != null) { INDArray output = inf.output(ds); outputs.add(Pair.makePair(ds.getLabels(), output)); } } }); } for (int i = 0; i < numThreads; i++) { threads[i].start(); } for (int i = 0; i < numThreads; i++) { threads[i].join(); } // and now we'll evaluate in single thread once again Pair<INDArray, INDArray> output; while ((output = outputs.poll()) != null) { eval.eval(output.getFirst(), output.getSecond()); } log.info(eval.stats()); } }