package jcuda.jcublas.ops; import lombok.extern.slf4j.Slf4j; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp; import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd; import org.nd4j.linalg.api.ops.impl.transforms.IsMax; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.jcublas.ops.executioner.CudaGridExecutioner; import org.nd4j.linalg.util.DeviceLocalNDArray; import java.io.File; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.*; import static org.nd4j.linalg.api.shape.Shape.newShapeNoCopy; /** * @author raver119@gmail.com */ @Slf4j public class SporadicTests { @Before public void setUp() throws Exception { //CudaEnvironment.getInstance().getConfiguration().enableDebug(true).setVerbose(false); } @Test public void testIsMax1() throws Exception { int[] shape = new int[]{2,2}; int length = 4; int alongDimension = 0; INDArray arrC = Nd4j.linspace(1,length, length).reshape('c',shape); Nd4j.getExecutioner().execAndReturn(new IsMax(arrC, alongDimension)); //System.out.print(arrC); assertEquals(0.0, arrC.getDouble(0), 0.1); assertEquals(0.0, arrC.getDouble(1), 0.1); assertEquals(1.0, arrC.getDouble(2), 0.1); assertEquals(1.0, arrC.getDouble(3), 0.1); } @Test public void randomStrangeTest() { DataBuffer.Type type = Nd4j.dataType(); DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); int a=9; int b=2; int[] shapes = new int[a]; for (int i = 0; i < a; i++) { shapes[i] = b; } INDArray c = Nd4j.linspace(1, (int) (100 * 1 + 1 + 2), (int) Math.pow(b, a)).reshape(shapes); c=c.sum(0); double[] d = c.data().asDouble(); System.out.println("d: " + Arrays.toString(d)); DataTypeUtil.setDTypeForContext(type); } @Test public void testBroadcastWithPermute(){ Nd4j.getRandom().setSeed(12345); int length = 4*4*5*2; INDArray arr = Nd4j.linspace(1,length,length).reshape('c',4,4,5,2).permute(2,3,1,0); // INDArray arr = Nd4j.linspace(1,length,length).reshape('f',4,4,5,2).permute(2,3,1,0); Nd4j.getExecutioner().commit(); INDArray arrDup = arr.dup('c'); Nd4j.getExecutioner().commit(); INDArray row = Nd4j.rand(1,2); assertEquals(row.length(), arr.size(1)); assertEquals(row.length(), arrDup.size(1)); assertEquals(arr,arrDup); INDArray first = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(arr, row, Nd4j.createUninitialized(arr.shape(), 'c'), 1)); INDArray second = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(arrDup, row, Nd4j.createUninitialized(arr.shape(), 'c'), 1)); System.out.println("A1: " + Arrays.toString(arr.shapeInfoDataBuffer().asInt())); System.out.println("A2: " + Arrays.toString(first.shapeInfoDataBuffer().asInt())); System.out.println("B1: " + Arrays.toString(arrDup.shapeInfoDataBuffer().asInt())); System.out.println("B2: " + Arrays.toString(second.shapeInfoDataBuffer().asInt())); INDArray resultSameStrides = Nd4j.zeros(new int[]{4,4,5,2},'c').permute(2,3,1,0); assertArrayEquals(arr.stride(), resultSameStrides.stride()); INDArray third = Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(arr, row, resultSameStrides, 1)); assertEquals(second, third); //Original and result w/ same strides: passes assertEquals(first,second); //Original and result w/ different strides: fails } @Test public void testBroadcastEquality1() { INDArray array = Nd4j.zeros(new int[]{4, 5}, 'f'); INDArray array2 = Nd4j.zeros(new int[]{4, 5}, 'f'); INDArray row = Nd4j.create(new float[]{1, 2, 3, 4, 5}); array.addiRowVector(row); System.out.println(array); System.out.println("-------"); ScalarAdd add = new ScalarAdd(array2, row, array2, array2.length(), 0.0f); add.setDimension(0); Nd4j.getExecutioner().exec(add); System.out.println(array2); assertEquals(array, array2); } @Test public void testBroadcastEquality2() { INDArray array = Nd4j.zeros(new int[]{4, 5}, 'c'); INDArray array2 = Nd4j.zeros(new int[]{4, 5}, 'c'); INDArray column = Nd4j.create(new float[]{1, 2, 3, 4}).reshape(4,1); array.addiColumnVector(column); System.out.println(array); System.out.println("-------"); ScalarAdd add = new ScalarAdd(array2, column, array2, array2.length(), 0.0f); add.setDimension(1); Nd4j.getExecutioner().exec(add); System.out.println(array2); assertEquals(array, array2); } @Test public void testIAMax1() throws Exception { INDArray arrayX = Nd4j.rand('c', 128000, 4); Nd4j.getExecutioner().exec(new IAMax(arrayX), 1); long time1 = System.nanoTime(); for (int i = 0; i < 10000; i++) { Nd4j.getExecutioner().exec(new IAMax(arrayX), 1); } long time2 = System.nanoTime(); System.out.println("Time: " + ((time2 - time1) / 10000)); } @Test public void testLocality() { INDArray array = Nd4j.create(new float[]{1,2,3,4,5,6,7,8,9}); AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(array); assertEquals(true, point.isActualOnDeviceSide()); INDArray arrayR = array.reshape('f', 3, 3); AllocationPoint pointR = AtomicAllocator.getInstance().getAllocationPoint(arrayR); assertEquals(true, pointR.isActualOnDeviceSide()); INDArray arrayS = Shape.newShapeNoCopy(array,new int[]{3,3}, true); AllocationPoint pointS = AtomicAllocator.getInstance().getAllocationPoint(arrayS); assertEquals(true, pointS.isActualOnDeviceSide()); INDArray arrayL = Nd4j.create(new int[]{3,4,4,4},'c'); AllocationPoint pointL = AtomicAllocator.getInstance().getAllocationPoint(arrayL); assertEquals(true, pointL.isActualOnDeviceSide()); } @Test public void testEnvironment() throws Exception { INDArray array = Nd4j.zeros(new int[]{4, 5}, 'f'); Properties properties = Nd4j.getExecutioner().getEnvironmentInformation(); System.out.println("Props: " + properties.toString()); } /** * This is special test that checks for memory alignment * @throws Exception */ @Test @Ignore public void testDTypeSpam() throws Exception { Random rnd = new Random(); for(int i = 0; i < 100; i++) { DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT); float rand[] = new float[rnd.nextInt(10) + 1]; for (int x = 0; x < rand.length; x++) { rand[x] = rnd.nextFloat(); } Nd4j.getConstantHandler().getConstantBuffer(rand); int shape[] = new int[rnd.nextInt(3)+2]; for (int x = 0; x < shape.length; x++) { shape[x] = rnd.nextInt(100) + 2; } DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); INDArray array = Nd4j.rand(shape); BooleanIndexing.applyWhere(array, Conditions.lessThan(rnd.nextDouble()), rnd.nextDouble()); } } @Test public void testIsView() { INDArray array = Nd4j.zeros(100, 100); assertFalse(array.isView()); } @Test public void testReplicate1() throws Exception { INDArray array = Nd4j.create(new float[]{1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f}); INDArray exp = Nd4j.create(new float[]{2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f, 2f}); log.error("Array length: {}", array.length()); int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); final DeviceLocalNDArray locals = new DeviceLocalNDArray(array); Thread[] threads = new Thread[numDevices]; for (int t = 0; t < numDevices; t++) { threads[t] = new Thread(new Runnable() { @Override public void run() { locals.get().addi(1f); locals.get().addi(0f); } }); threads[t].start(); } for (int t = 0; t < numDevices; t++) { threads[t].join(); } for (int t = 0; t < numDevices; t++) { exp.addi(0.0f); assertEquals(exp, locals.get(t)); } } @Test public void testReplicate2() throws Exception { DataBuffer buffer = Nd4j.createBuffer(new float[] {1f, 1f, 1f, 1f, 1f}); DataBuffer buffer2 = Nd4j.getAffinityManager().replicateToDevice(1, buffer); assertEquals(1f, buffer2.getFloat(0), 0.001f); } @Test public void testReplicate3() throws Exception { INDArray array = Nd4j.ones(10, 10); INDArray exp = Nd4j.create(10).assign(10f); log.error("Array length: {}", array.length()); int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); final DeviceLocalNDArray locals = new DeviceLocalNDArray(array); Thread[] threads = new Thread[numDevices]; for (int t = 0; t < numDevices; t++) { threads[t] = new Thread(new Runnable() { @Override public void run() { AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(locals.get()); log.error("Point deviceId: {}; current deviceId: {}", point.getDeviceId(), Nd4j.getAffinityManager().getDeviceForCurrentThread()); INDArray sum = locals.get().sum(1); INDArray localExp = Nd4j.create(10).assign(10f); assertEquals(localExp, sum); } }); threads[t].start(); } for (int t = 0; t < numDevices; t++) { threads[t].join(); } for (int t = 0; t < numDevices; t++) { AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(locals.get(t)); log.error("Point deviceId: {}; current deviceId: {}", point.getDeviceId(), Nd4j.getAffinityManager().getDeviceForCurrentThread()); exp.addi(0.0f); assertEquals(exp, locals.get(t).sum(0)); log.error("Point after: {}", point.getDeviceId()); } } @Test public void testReplicate4() throws Exception { INDArray array = Nd4j.create(3,3); array.getRow(1).putScalar(0, 1f); array.getRow(1).putScalar(1, 1f); array.getRow(1).putScalar(2, 1f); final DeviceLocalNDArray locals = new DeviceLocalNDArray(array); int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); for (int t = 0; t < numDevices; t++) { assertEquals(3, locals.get(t).sumNumber().floatValue(), 0.001f); } } @Test public void testReplicate5() throws Exception { INDArray array = Nd4j.create(3, 3); log.error("Original: Host pt: {}; Dev pt: {}", AtomicAllocator.getInstance().getAllocationPoint(array).getPointers().getHostPointer().address(), AtomicAllocator.getInstance().getAllocationPoint(array).getPointers().getDevicePointer().address()); final DeviceLocalNDArray locals = new DeviceLocalNDArray(array); int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); for (int t = 0; t < numDevices; t++) { log.error("deviceId: {}; Host pt: {}; Dev pt: {}", t, AtomicAllocator.getInstance().getAllocationPoint(locals.get(t)).getPointers().getHostPointer().address(), AtomicAllocator.getInstance().getAllocationPoint(locals.get(t)).getPointers().getDevicePointer().address()); } Thread[] threads = new Thread[numDevices]; for (int t = 0; t < numDevices; t++) { threads[t] = new Thread(new Runnable() { @Override public void run() { AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(locals.get()); log.error("deviceId: {}; Host pt: {}; Dev pt: {}", Nd4j.getAffinityManager().getDeviceForCurrentThread(), point.getPointers().getHostPointer().address(), point.getPointers().getDevicePointer().address()); } }); threads[t].start(); } for (int t = 0; t < numDevices; t++) { threads[t].join(); } } @Test public void testEnvInfo() throws Exception { Properties props = Nd4j.getExecutioner().getEnvironmentInformation(); List<Map<String, Object>> list = (List<Map<String,Object>>) props.get("cuda.devicesInformation"); for (Map<String, Object> map: list) { log.error("devName: {}", map.get("cuda.deviceName")); log.error("totalMem: {}", map.get("cuda.totalMemory")); log.error("freeMem: {}", map.get("cuda.freeMemory")); System.out.println(); } } @Test public void testStd() { INDArray values = Nd4j.linspace(1, 4, 4).transpose(); double corrected = values.std(true, 0).getDouble(0); double notCorrected = values.std(false, 0).getDouble(0); System.out.println(String.format("Corrected: %f, non corrected: %f", corrected, notCorrected)); } @Ignore @Test public void testHalf19() { DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF); INDArray first = Nd4j.rand(20, 10); INDArray second = Nd4j.rand(3, 20); DataSet data = new DataSet(first, second); data.normalize(); } @Test public void testDebugEdgeCase(){ INDArray l1 = Nd4j.create(new double[]{-0.2585039112684677,-0.005179485353710878,0.4348343401770497,0.020356532375728764,-0.1970793298488186}); INDArray l2 = Nd4j.create(3,l1.size(1)); INDArray p1 = Nd4j.create(new double[]{1.3979850406519119,0.6169451410155852,1.128993957530918,0.21000426084450596,0.3171215178932696}); INDArray p2 = Nd4j.create(3, p1.size(1)); for( int i=0; i<3; i++ ){ l2.putRow(i, l1); p2.putRow(i, p1); } INDArray s1 = scoreArray(l1, p1); INDArray s2 = scoreArray(l2, p2); //Outputs here should be identical: System.out.println(Arrays.toString(s1.data().asDouble())); System.out.println(Arrays.toString(s2.getRow(0).dup().data().asDouble())); } public static INDArray scoreArray(INDArray labels, INDArray preOutput) { INDArray yhatmag = preOutput.norm2(1); INDArray scoreArr = preOutput.mul(labels); scoreArr.diviColumnVector(yhatmag); return scoreArr; } @Test public void testDebugEdgeCase2(){ DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); INDArray l1 = Nd4j.create(new double[]{-0.2585039112684677,-0.005179485353710878,0.4348343401770497,0.020356532375728764,-0.1970793298488186}); INDArray l2 = Nd4j.create(2,l1.size(1)); INDArray p1 = Nd4j.create(new double[]{1.3979850406519119,0.6169451410155852,1.128993957530918,0.21000426084450596,0.3171215178932696}); INDArray p2 = Nd4j.create(2, p1.size(1)); for( int i=0; i<2; i++ ){ l2.putRow(i, l1); p2.putRow(i, p1); } INDArray norm2_1 = l1.norm2(1); System.out.println("Queue: " + ((CudaGridExecutioner) Nd4j.getExecutioner()).getQueueLength()); INDArray temp1 = p1.mul(l1); System.out.println("Queue: " + ((CudaGridExecutioner) Nd4j.getExecutioner()).getQueueLength()); // if (Nd4j.getExecutioner() instanceof CudaGridExecutioner) // ((CudaGridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking(); INDArray out1 = temp1.diviColumnVector(norm2_1); System.out.println("------"); Nd4j.getExecutioner().commit(); INDArray norm2_2 = l2.norm2(1); System.out.println("norm2_1: " + Arrays.toString(norm2_1.data().asDouble())); System.out.println("norm2_2: " + Arrays.toString(norm2_2.data().asDouble())); INDArray temp2 = p2.mul(l2); System.out.println("temp1: " + Arrays.toString(temp1.data().asDouble())); System.out.println("temp2: " + Arrays.toString(temp2.data().asDouble())); INDArray out2 = temp2.diviColumnVector(norm2_2); //Outputs here should be identical: System.out.println(Arrays.toString(out1.data().asDouble())); System.out.println(Arrays.toString(out2.getRow(0).dup().data().asDouble())); } @Test public void testSum(){ DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); System.out.println("dType: "+ Nd4j.dataType()); double[] d = new double[]{0.0018527202080054288, 4.860460399681257E-5, 1.0866740014537973E-4, -0.0067027589698796745, 3.137875745180366E-4, -0.004068275565124544, -0.01787441478759584, 0.008485829165871582, -2.763756128155635E-4, 2.871786662038523E-4, -0.010483973019250817, 0.007756981321203987, -0.0017533316296846166, -0.013154552997235138, 0.002664089383318023, 0.003219604745689706, -0.017140063751978196, -0.0402780728371146, -0.024062552380901596, 0.0034055167376910362, -0.0014209322438402164, -0.0019807697611663373, -2.2838830354674253E-4, 0.00802614947703168, 6.809417793628905E-5, -3.5682443741929696E-4, 2.3642116161687557E-4, 0.0020248602841291333, 0.008922488929012673, -1.8730312287067906E-6, 2.6969347916614046E-5, -1.3560804909521155E-5, -0.0019803959188298536, -0.011388435316124648, 0.009815024112605878, -6.217212819848868E-5, 7.198174495385047E-6, 7.859570666088778E-4, 3.2352438373925256E-4, 0.009310926419061586, 0.001285919484459703, -0.004614530932162568, -5.693929364499898E-4, 4.436935763832914E-5, 0.010423203318809186, 0.006593852752045009, 0.0063445124848706584, 9.737683314182195E-5, -7.002675823907349E-4, 0.0010906784650723032, -9.972152373258224E-6, 0.00871521334612937, 0.0015878927877041975, 3.5864863535235535E-4, -4.398790476749721E-5, -7.77853455185052E-5, 1.8862217750434992E-4, 0.0224868440061588, 0.0073318858545188175, -8.220926236861101E-4, -2.4336360596325374E-4, -0.0018348955616861627, 0.011423225743787646, -0.0016207645948113378, 2.289915435117371E-4, -7.122130486259979E-4, -4.94058936059287E-4, 0.004245767850547438, 2.1598406094246788E-4, 0.0014093429117757112, 3.1948093888499473E-4, -1.327894312872927E-4, 5.756401064075624E-4, -0.013501868757425933, 0.08022280647460137, -0.025763510735921924, 0.2147635435756625, -3.570893204705811E-4, 0.23699343725699218, 0.02005726530793397, 0.2233494849035487, 0.0015628679046820334, 0.03686828571588657, -0.034884254322163376, -0.04580585504492872, 0.022492109246861913, 0.6122906576027609, 0.0013512843074173794, 0.009833469844281123, -0.12754922577196826, -0.05866094108326281, 0.00786015783509335, 0.012943402024682067, -0.04138337949224019, 0.16422596234194609, -0.003224047448184361, -0.013553967826667544, 0.024567523776697443, 0.003119569763001505, 0.06676632404231841, 0.0019518418481909879, -3.546570152995131E-4, 8.843184961729061E-4, 0.02791605441470666, -0.013688049930718094, 0.03237370158354087, 7.749693316768275E-4, -0.006175397798237791, -0.001425650837729542, -7.358356122518933E-4, -5.924546696292049E-4, 2.572174974203492E-4, 0.008635399542952251, -0.00785894020636433, 0.00611004654858908, 3.849937280461072E-4, -0.0011280073492511923, -0.014039863611056342, 0.005258910284221449, -0.0012716079353840685, -0.005880609969998075, 0.03884904612026859, -0.007808162270479559, 0.13764734350512128, -0.0955607452917015, 0.01739042887598923, 0.003176716700283583, -8.845189001196553E-5, 0.059890266991132, -0.011719738031782573, -0.009720651901132008, -0.020271048497565218, 3.5861474460486776E-5, 0.003234054136867597, -0.016855942686723118, -0.04109181803561225, 0.03929335910336556, -0.002045944958743484, 4.986319734224706E-6, -2.0719501403766647E-5, 0.022377318509545937, -0.007592601387358396, 4.490315393644052E-4, 9.033852118576955E-5, 4.621091068084668E-6, -5.247702006473915E-4, -7.902654461500924E-4, -0.0011914084606579713, 0.0030085580689989877, -4.246971856810759E-4, -1.2340215512440867E-4, 0.0019671074593817285, -6.010387216740781E-4, 0.013650305790487045, -0.0011454153127967719, -0.007189180788631945, 2.870289907492301E-6, 4.3693088414999864E-4, 0.01200434591332941, -0.014509596674846678, -0.0029357117029629866, -2.1150207332328822E-4, -0.00315536512642124, 1.0374814880225154E-4, -0.0034757406691398496, 0.011599985159323294, 1.2969970680596453E-4, 9.964327556021442E-4, -0.001849649601501932, 0.002689358375591656, 0.012896200751328621, 0.007476029001401352, -0.0033194177760658377, -4.3432827454975976E-4, 3.411369387610943E-4, -4.103832908635317E-4, 0.007055642948203781, -0.0015501810107658967, -0.005752034090813254, -2.844831713420882E-4, -9.563438979460705E-5, -0.02284555356663203, -0.009025504086580169, -0.1559083024105329, 0.12294355422935457, -8.708345100849238E-4, 0.02784682111718311, 0.09887344727692746, 0.1984110780215329, -0.0019539047730033083, 0.436534119185953, -0.0022943880212978763, 0.0033303334626212217, -0.47305986663738375, 0.2870128297740214, -0.4852364244335913, -0.1966639932906117, -0.011543131716351632, -0.037961570290375855, 0.7991053621370379, -0.0965493466734368, 0.14022527291688097, -0.15353621266599798, 0.032127740955076554, -0.03391229079838272, 0.04220928870735664, -0.10022115665949234, -0.0060843857983522015, 0.05969884290137722, -0.001513774894231756, 0.003573617155056928, -0.030126515163639428, 0.006604847374388239, -0.01685524264155275, -0.015135550991685925, -0.002122525156000015}; int[] shape = {2, 108}; int[] stride = {108, 1}; char order = 'c'; INDArray arr = Nd4j.create(d, shape, stride, 0, order ); double[] exp = new double[2]; for( int i=0; i<shape[1]; i++ ){ exp[0] += arr.getDouble(0,i); exp[1] += arr.getDouble(1,i); } System.out.println("Expected: " + Arrays.toString(exp)); System.out.println("Actual: " + Arrays.toString(arr.sum(1).data().asDouble())); } @Test public void testDataSetSaveLost() throws Exception { INDArray features = Nd4j.linspace(1, 16 * 784, 16 * 784).reshape(16, 784); INDArray labels = Nd4j.linspace(1, 160, 160).reshape(16, 10); for (int i = 0; i < 100; i++) { DataSet ds = new DataSet(features, labels); File tempFile = File.createTempFile("dataset", "temp"); tempFile.deleteOnExit(); ds.save(tempFile); DataSet restore = new DataSet(); restore.load(tempFile); assertEquals(features, restore.getFeatureMatrix()); assertEquals(labels, restore.getLabels()); } } @Test public void testEps() throws Exception { DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF); INDArray arr = Nd4j.create(new double[]{0,0,0,1,1,1,2,2,2}); System.out.println(arr.eps(0.0)); System.out.println(arr.eps(1.0)); System.out.println(arr.eps(2.0)); } @Test public void testNeg() { INDArray rnd = Nd4j.rand(2, 2); System.out.println(rnd.equals(rnd.neq(1))); } @Test public void testEq() { INDArray z = Nd4j.ones(2, 2) .eq(2); Nd4j.getExecutioner().commit(); System.out.println("Z: " + z); } @Test public void testCrash() throws Exception { System.out.println("Executor: " + Nd4j.getExecutioner().getClass().getSimpleName()); int shape[] = new int[]{1, 3, 150, 150}; INDArray img = Nd4j.create(shape); INDArray lbl = Nd4j.create(205); AtomicInteger cnt = new AtomicInteger(0); while (cnt.get() < 16) { System.out.println("Iteration: " + cnt.getAndIncrement()); getBatch(img, lbl, 128); } } @Test public void testAffinityManager() { Nd4j.getMemoryManager().setAutoGcWindow(127); assertEquals(127, CudaEnvironment.getInstance().getConfiguration().getNoGcWindowMs()); } @Test public void testPrintOut() throws Exception { Nd4j.create(100); Nd4j.getExecutioner().printEnvironmentInformation(); log.info("-------------------------------------"); Nd4j.create(500); Nd4j.getExecutioner().printEnvironmentInformation(); } @Test public void testReduceX() throws Exception { CudaEnvironment.getInstance().getConfiguration().setMaximumGridSize(11); INDArray x = Nd4j.create(500, 500); INDArray exp_0 = Nd4j.linspace(1, 500, 500); INDArray exp_1 = Nd4j.create(500).assign(250.5); x.addiRowVector(Nd4j.linspace(1, 500, 500)); assertEquals(exp_0, x.mean(0)); assertEquals(exp_1, x.mean(1)); assertEquals(250.5, x.meanNumber().doubleValue(), 1e-5); } @Test public void testIndexReduceX() throws Exception { CudaEnvironment.getInstance().getConfiguration().setMaximumGridSize(11); INDArray x = Nd4j.create(500, 500); INDArray exp_0 = Nd4j.create(500).assign(0); INDArray exp_1 = Nd4j.create(500).assign(499); x.addiRowVector(Nd4j.linspace(1, 500, 500)); assertEquals(exp_0, Nd4j.argMax(x, 0)); assertEquals(exp_1, Nd4j.argMax(x, 1)); } @Test public void testInf() { INDArray x = Nd4j.create(10).assign(0.0); x.muli(0.0); log.error("X: {}", x); } @Test public void testTreo1() { INDArray points = Nd4j.rand(100000, 300); INDArray q = Nd4j.rand(10000, 300); System.out.println("----------------"); ArrayList<Float> floats1 = new ArrayList<>(); List<Long> results = new ArrayList<>(); for (int i = 0; i < 10000; i++) { long time1 = System.currentTimeMillis(); INDArray gemm = points.mmul(q.getRow(i).transpose()); float[] floats = gemm.data().asFloat(); long time2 = System.currentTimeMillis(); /*for (int k = 0; k < floats.length; k++) { floats1.add(floats[k]); } floats1.clear();*/ results.add(time2 - time1); } log.error("p50: {}", results.get(results.size() / 2)); } public DataSet getBatch(INDArray input, INDArray label, int batchSize) { List<INDArray> inp = new ArrayList<>(); List<INDArray> lab = new ArrayList<>(); for (int i = 0; i < batchSize; i++) { inp.add(input); lab.add(label); } DataSet ds = getTransformation(inp, inp); return ds; } public DataSet getTransformation(List<INDArray> inp , List<INDArray> lab){ DataSet ret = new DataSet(Nd4j.vstack(inp.toArray(new INDArray[0])), Nd4j.vstack(lab.toArray(new INDArray[0]))); return ret; } }