package org.deeplearning4j.parallelism.inference.observers; import lombok.extern.slf4j.Slf4j; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.List; import static org.junit.Assert.*; /** * @author raver119@gmail.com */ @Slf4j public class BatchedInferenceObservableTest { @Before public void setUp() throws Exception {} @After public void tearDown() throws Exception {} @Test public void testVerticalBatch1() throws Exception { BatchedInferenceObservable observable = new BatchedInferenceObservable(); for (int i = 0; i < 32; i++) { observable.setInput(Nd4j.create(100).assign(i)); } assertEquals(1, observable.getInput().length); INDArray array = observable.getInput()[0]; assertEquals(2, array.rank()); log.info("Array shape: {}", Arrays.toString(array.shapeInfoDataBuffer().asInt())); for (int i = 0; i < 32; i++) { assertEquals((float) i, array.tensorAlongDimension(i, 1).meanNumber().floatValue(), 0.001f); } } @Test public void testVerticalBatch2() throws Exception { BatchedInferenceObservable observable = new BatchedInferenceObservable(); for (int i = 0; i < 32; i++) { observable.setInput(Nd4j.create(3, 72, 72).assign(i)); } assertEquals(1, observable.getInput().length); INDArray array = observable.getInput()[0]; assertEquals(4, array.rank()); assertEquals(32, array.shape()[0]); log.info("Array shape: {}", Arrays.toString(array.shapeInfoDataBuffer().asInt())); for (int i = 0; i < 32; i++) { assertEquals((float) i, array.tensorAlongDimension(i, 1, 2, 3).meanNumber().floatValue(), 0.001f); } } @Test public void testHorizontalBatch1() throws Exception { BatchedInferenceObservable observable = new BatchedInferenceObservable(); for (int i = 0; i < 32; i++) { observable.setInput(Nd4j.create(3, 72, 72).assign(i), Nd4j.create(100, 100).assign(100 + i)); } assertEquals(2, observable.getInput().length); INDArray[] inputs = observable.getInput(); INDArray features0 = inputs[0]; INDArray features1 = inputs[1]; assertArrayEquals(new int[] {32, 3, 72, 72}, features0.shape()); assertArrayEquals(new int[] {32, 100, 100}, features1.shape()); for (int i = 0; i < 32; i++) { assertEquals((float) i, features0.tensorAlongDimension(i, 1, 2, 3).meanNumber().floatValue(), 0.001f); assertEquals((float) 100 + i, features1.tensorAlongDimension(i, 1, 2).meanNumber().floatValue(), 0.001f); } } @Test public void testTearsBatch1() throws Exception { BatchedInferenceObservable observable = new BatchedInferenceObservable(); INDArray output0 = Nd4j.create(32, 10); INDArray output1 = Nd4j.create(32, 15); for (int i = 0; i < 32; i++) { output0.tensorAlongDimension(i, 1).assign(i); output1.tensorAlongDimension(i, 1).assign(i); } observable.setCounter(32); observable.setOutput(output0, output1); List<INDArray[]> outputs = observable.getOutputs(); for (int i = 0; i < 32; i++) { assertEquals(2, outputs.get(i).length); assertEquals((float) i, outputs.get(i)[0].meanNumber().floatValue(), 0.001f); assertEquals((float) i, outputs.get(i)[1].meanNumber().floatValue(), 0.001f); } } }