package org.nd4j.bytebuddy.shape; import net.bytebuddy.ByteBuddy; import net.bytebuddy.dynamic.DynamicType; import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; import net.bytebuddy.implementation.Implementation; import net.bytebuddy.matcher.ElementMatchers; import org.junit.Test; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** * @author Adam Gibson */ public class ShapeMapperTest { @Test public void testShapeMapper() throws Exception { Implementation cImpl = ShapeMapper.getInd2Sub('c', 2); Implementation fImpl = ShapeMapper.getInd2Sub('f', 2); DynamicType.Unloaded<IndexMapper> c = new ByteBuddy().subclass(IndexMapper.class) .method(ElementMatchers.isDeclaredBy(IndexMapper.class)).intercept(cImpl).make(); DynamicType.Unloaded<IndexMapper> f = new ByteBuddy().subclass(IndexMapper.class) .method(ElementMatchers.isDeclaredBy(IndexMapper.class)).intercept(fImpl).make(); Class<?> dynamicType = c.load(IndexMapper.class.getClassLoader(), ClassLoadingStrategy.Default.WRAPPER).getLoaded(); Class<?> dynamicTypeF = f.load(IndexMapper.class.getClassLoader(), ClassLoadingStrategy.Default.WRAPPER).getLoaded(); IndexMapper testC = (IndexMapper) dynamicType.newInstance(); IndexMapper testF = (IndexMapper) dynamicTypeF.newInstance(); int n = 1000; long byteBuddyTotal = 0; for (int i = 0; i < n; i++) { long start = System.nanoTime(); int[] cTest = testC.ind2sub(new int[] {2, 2}, 1, 4, 'c'); long end = System.nanoTime(); byteBuddyTotal += Math.abs((end - start)); } byteBuddyTotal /= n; System.out.println("Took " + byteBuddyTotal); int[] cTest = testC.ind2sub(new int[] {2, 2}, 1, 4, 'c'); int[] fTest = testF.ind2sub(new int[] {2, 2}, 1, 4, 'f'); assertArrayEquals(new int[] {1, 0}, fTest); assertArrayEquals(new int[] {0, 1}, cTest); } @Test public void testOffsetMapper() throws Exception { OffsetMapper mapper = ShapeMapper.getOffsetMapperInstance(2); assertEquals(verifyImpl(0, new int[] {3, 5}, new int[] {4, 1}, new int[] {1, 1}), mapper.getOffset(0, new int[] {3, 5}, new int[] {4, 1}, new int[] {1, 1})); long oldImplTotal = 0; long newImplTotal = 0; int[] timingShape = {1, 5, 1, 1}; int[] timingStride = {4, 1, 1, 1}; int[] timingIndex = {1, 1, 1, 1}; for (int i = 0; i < 1000; i++) { long old = System.nanoTime(); verifyImpl(0, timingShape, timingStride, timingIndex); long newTime = System.nanoTime(); long delta = Math.abs(newTime - old); long oldDelta = delta; oldImplTotal += delta; old = System.nanoTime(); mapper.getOffset(0, timingShape, timingStride, timingIndex); newTime = System.nanoTime(); delta = Math.abs(newTime - old); newImplTotal += delta; System.out.println("Time for old was " + oldDelta + " while new was " + delta + " in nanoseconds at " + i); } oldImplTotal /= 1000; newImplTotal /= 1000; System.out.println("Time for old was " + oldImplTotal + " while new was " + newImplTotal + " in nanoseconds"); } private int verifyImpl(int baseOffset, int[] shape, int[] stride, int[] indices) { int offset = 0; for (int i = 0; i < indices.length; i++) { /** * See: * http://docs.scipy.org/doc/numpy/reference/arrays.ndarray.html * Basically if the size(i) is 1, the stride shouldn't be counted. */ if (shape[i] == 1) continue; offset += indices[i] * stride[i]; } return offset + baseOffset; } }