package org.nd4j.linalg.profiling; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.math3.util.Pair; import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.profiler.OpProfiler; import java.util.Arrays; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * @author raver119@gmail.com */ @Slf4j public class OperationProfilerTests { @Before public void setUp() { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.OPERATIONS); OpProfiler.getInstance().reset(); } @After public void tearDown() { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); } @Test public void testCounter1() { INDArray array = Nd4j.createUninitialized(100); array.assign(10f); array.divi(2f); assertEquals(2, OpProfiler.getInstance().getInvocationsCount()); } @Test public void testStack1() throws Exception { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL); INDArray array = Nd4j.createUninitialized(100); array.assign(10f); array.assign(20f); array.assign(30f); assertEquals(3, OpProfiler.getInstance().getInvocationsCount()); OpProfiler.getInstance().printOutDashboard(); } @Test public void testBadCombos1() throws Exception { INDArray x = Nd4j.create(100); INDArray y = Nd4j.create(100); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processOperands(x, y); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.NONE)); } @Test public void testBadCombos2() throws Exception { INDArray x = Nd4j.create(100).reshape('f', 10, 10); INDArray y = Nd4j.create(100).reshape('c', 10, 10); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processOperands(x, y); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.MIXED_ORDER)); } @Test public void testBadCombos3() throws Exception { INDArray x = Nd4j.create(27).reshape('c', 3, 3, 3).tensorAlongDimension(0, 1, 2); INDArray y = Nd4j.create(100).reshape('f', 10, 10); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processOperands(x, y); log.info("Causes: {}", Arrays.toString(causes)); assertEquals(2, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.MIXED_ORDER)); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.NON_EWS_ACCESS)); } @Test public void testBadCombos4() throws Exception { INDArray x = Nd4j.create(27).reshape('c', 3, 3, 3).tensorAlongDimension(0, 1, 2); INDArray y = Nd4j.create(100).reshape('f', 10, 10); INDArray z = Nd4j.create(100).reshape('f', 10, 10); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processOperands(x, y, z); log.info("Causes: {}", Arrays.toString(causes)); assertEquals(2, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.MIXED_ORDER)); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.NON_EWS_ACCESS)); } @Test public void testBadCombos5() throws Exception { INDArray w = Nd4j.create(100).reshape('c', 10, 10); INDArray x = Nd4j.create(100).reshape('c', 10, 10); INDArray y = Nd4j.create(100).reshape('f', 10, 10); INDArray z = Nd4j.create(100).reshape('c', 10, 10); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processOperands(w, x, y, z); log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.MIXED_ORDER)); } @Test @Ignore public void testBadCombos6() throws Exception { INDArray x = Nd4j.create(27).reshape('f', 3, 3, 3).slice(1); INDArray y = Nd4j.create(100).reshape('f', 10, 10); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processOperands(x, y); log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.STRIDED_ACCESS)); } @Test public void testBadTad1() throws Exception { INDArray x = Nd4j.create(2, 4, 5, 6); Pair<DataBuffer, DataBuffer> pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, new int[] {0, 2}); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_NON_EWS_ACCESS)); } @Test public void testBadTad2() throws Exception { INDArray x = Nd4j.create(2, 4, 5, 6); Pair<DataBuffer, DataBuffer> pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, new int[] {2, 3}); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_NON_EWS_ACCESS)); } @Test public void testBadTad3() throws Exception { INDArray x = Nd4j.create(new int[] {2, 4, 5, 6, 7}, 'f'); Pair<DataBuffer, DataBuffer> pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, new int[] {0, 2, 4}); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_NON_EWS_ACCESS)); } @Test @Ignore public void testBadTad4() throws Exception { INDArray x = Nd4j.create(2, 4, 5, 6); Pair<DataBuffer, DataBuffer> pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, new int[] {3}); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); log.info("TAD: {}", Arrays.toString(pair.getFirst().asInt())); log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.NONE)); } @Test public void testBadTad5() throws Exception { INDArray x = Nd4j.create(new int[] {2, 4, 5, 6, 7}, 'f'); Pair<DataBuffer, DataBuffer> pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, new int[] {4}); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); log.info("TAD: {}", Arrays.toString(pair.getFirst().asInt())); log.info("Causes: {}", Arrays.toString(causes)); assertEquals(1, causes.length); assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_STRIDED_ACCESS)); } @Test public void testCxFxF1() throws Exception { INDArray a = Nd4j.create(10, 10).reshape('f', 10, 10); INDArray b = Nd4j.create(10, 10).reshape('c', 10, 10); INDArray c = Nd4j.create(10, 10).reshape('f', 10, 10); String ret = OpProfiler.getInstance().processOrders(a, b, c); assertEquals("F x C x F", ret); } @Test public void testCxFxF2() throws Exception { INDArray a = Nd4j.create(10, 10).reshape('c', 10, 10); INDArray b = Nd4j.create(10, 10).reshape('c', 10, 10); INDArray c = Nd4j.create(10, 10).reshape('f', 10, 10); String ret = OpProfiler.getInstance().processOrders(a, b, c); assertEquals("C x C x F", ret); } @Test public void testCxFxF3() throws Exception { INDArray a = Nd4j.create(10, 10).reshape('c', 10, 10); INDArray b = Nd4j.create(10, 10).reshape('c', 10, 10); INDArray c = Nd4j.create(10, 10).reshape('c', 10, 10); String ret = OpProfiler.getInstance().processOrders(a, b, c); assertEquals("C x C x C", ret); } @Test public void testBlasFF() throws Exception { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL); INDArray a = Nd4j.create(10, 10).reshape('f', 10, 10); INDArray b = Nd4j.create(10, 10).reshape('f', 10, 10); a.mmul(b); OpProfiler.getInstance().printOutDashboard(); } @Test(expected = ND4JIllegalStateException.class) public void testNaNPanic1() throws Exception { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); INDArray a = Nd4j.create(new float[] {1f, 2f, 3f, Float.NaN}); a.muli(3f); } @Test(expected = ND4JIllegalStateException.class) public void testNaNPanic2() throws Exception { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC); INDArray a = Nd4j.create(new float[] {1f, 2f, 3f, Float.POSITIVE_INFINITY}); a.muli(3f); } @Test(expected = ND4JIllegalStateException.class) public void testNaNPanic3() throws Exception { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); INDArray a = Nd4j.create(new float[] {1f, 2f, 3f, Float.NEGATIVE_INFINITY}); a.muli(3f); } }