package jcuda.jcublas.ops; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.conf.Configuration; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.accum.Max; import org.nd4j.linalg.api.ops.impl.accum.Mean; import org.nd4j.linalg.api.ops.impl.accum.Min; import org.nd4j.linalg.api.ops.impl.accum.Sum; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.ArrayUtil; import java.util.Arrays; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; /** * @author raver119@gmail.com */ @Ignore public class CudaAccumTests { @Before public void setUp() { CudaEnvironment.getInstance().getConfiguration() .setExecutionModel(Configuration.ExecutionModel.ASYNCHRONOUS) .setFirstMemory(AllocationStatus.DEVICE) .setMaximumBlockSize(128) .setMaximumGridSize(256) .enableDebug(false) .setVerbose(false); System.out.println("Init called"); } @Test public void testBiggerSum() throws Exception { INDArray array = Nd4j.ones(128000, 512); array.sum(0); } /** * Sum call * @throws Exception */ @Test public void testPinnedSum() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{2.01f, 2.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); Sum sum = new Sum(array1); Nd4j.getExecutioner().exec(sum, 1); Number resu = sum.getFinalResult(); System.out.println("Result: " + resu); assertEquals(17.15f, resu.floatValue(), 0.01f); } @Test public void testPinnedSum2() throws Exception { // simple way to stop test if we're not on CUDA backend here INDArray array1 = Nd4j.linspace(1, 10000, 100000).reshape(100,1000); Sum sum = new Sum(array1); INDArray result;/* = Nd4j.getExecutioner().exec(sum, 0); assertEquals(495055.44f, result.getFloat(0), 0.01f); */ result = Nd4j.getExecutioner().exec(sum, 1); result = Nd4j.getExecutioner().exec(sum, 1); assertEquals(50945.52f, result.getFloat(0), 0.01f); } @Test public void testPinnedSum3() throws Exception { // simple way to stop test if we're not on CUDA backend here INDArray array1 = Nd4j.linspace(1, 100000, 100000).reshape(100,1000); for (int x = 0; x < 100000; x++ ){ assertEquals("Failed on iteration [" + x + "]", x+1, array1.getFloat(x), 0.01f); } } @Test public void testPinnedSumNumber() throws Exception { // simple way to stop test if we're not on CUDA backend here INDArray array1 = Nd4j.linspace(1, 10000, 10000); float sum = array1.sumNumber().floatValue(); assertEquals(5.0005E7, sum, 1f); } @Test public void testPinnedSumNumber2() throws Exception { // simple way to stop test if we're not on CUDA backend here INDArray array1 = Nd4j.ones(128000); long time1 = System.currentTimeMillis(); float sum = array1.sumNumber().floatValue(); long time2 = System.currentTimeMillis(); System.out.println("Execution time: " + (time2 - time1)); assertEquals(128000f, sum, 0.01f); } @Test public void testPinnedSumNumber3() throws Exception { // simple way to stop test if we're not on CUDA backend here INDArray array1 = Nd4j.ones(12800000); float sum = array1.sumNumber().floatValue(); assertEquals(12800000f, sum, 0.01f); } @Test public void testStdev0(){ double[][] ind = {{5.1, 3.5, 1.4}, {4.9, 3.0, 1.4}, {4.7, 3.2, 1.3}}; INDArray in = Nd4j.create(ind); INDArray stdev = in.std(0); INDArray exp = Nd4j.create(new double[]{0.2, 0.25166114784, 0.05773502692}); System.out.println("Exp dtype: " + exp.data().dataType()); System.out.println("Exp dtype: " + exp.data().dataType()); System.out.println("Array: " + Arrays.toString(exp.data().asFloat())); assertEquals(exp,stdev); } @Test public void testStdev1(){ double[][] ind = {{5.1, 3.5, 1.4}, {4.9, 3.0, 1.4}, {4.7, 3.2, 1.3}}; INDArray in = Nd4j.create(ind); INDArray stdev = in.std(1); INDArray exp = Nd4j.create(new double[]{1.8556220880, 1.7521415468, 1.7039170559}); assertEquals(exp,stdev); } @Test public void testStdevNum(){ INDArray in = Nd4j.linspace(1, 1000, 10000); float stdev = in.stdNumber().floatValue(); assertEquals(288.42972f, stdev, 0.001f); } /** * Mean call * @throws Exception */ @Test public void testPinnedMean() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{2.01f, 2.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); Mean mean = new Mean(array1); Nd4j.getExecutioner().exec(mean, 1); Number resu = mean.getFinalResult(); // INDArray result = Nd4j.getExecutioner().exec(new Mean(array1), 1); System.out.println("Array1: " + array1); System.out.println("Result: " + resu); assertEquals(1.14f, resu.floatValue(), 0.01f); } @Test public void testSum2() { INDArray n = Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[]{2, 2, 2}); System.out.println("N result: " + n); INDArray test = Nd4j.create(new float[]{3, 7, 11, 15}, new int[]{2, 2}); System.out.println("Test result: " + test); INDArray sum = n.sum(-1); System.out.println("Sum result: " + sum); assertEquals(test, sum); } @Test public void testMax() { INDArray n = Nd4j.linspace(1, 15, 15); float max = n.maxNumber().floatValue(); assertEquals(15f, max, 0.001f); } @Test public void testSum3() { INDArray n = Nd4j.linspace(1, 1000, 128000).reshape(128, 1000); long time1 = System.currentTimeMillis(); INDArray sum = n.sum(new int[]{0}); long time2 = System.currentTimeMillis(); System.out.println("Time elapsed: "+ (time2 - time1) ); System.out.println("Sum: " + sum); System.out.println("Sum.Length: " + sum.length()); System.out.println("elementWiseStride: " + n.elementWiseStride()); System.out.println("elementStride: " + n.elementStride()); assertEquals(63565.02f, sum.getFloat(0), 0.01f); assertEquals(63566.02f, sum.getFloat(1), 0.01f); } @Test public void testSum3_1() throws Exception { INDArray n = Nd4j.linspace(1, 128000, 128000).reshape(128, 1000); long time1 = System.currentTimeMillis(); INDArray sum = n.sum(new int[]{0}); long time2 = System.currentTimeMillis(); System.out.println("Time elapsed: "+ (time2 - time1) ); System.out.println("Sum: " + sum); System.out.println("Sum.Length: " + sum.length()); System.out.println("elementWiseStride: " + n.elementWiseStride()); System.out.println("elementStride: " + n.elementStride()); assertEquals(8128128.0f, sum.getFloat(0), 0.01f); assertEquals(8128256.0f, sum.getFloat(1), 0.01f); assertEquals(8128512.0f, sum.getFloat(3), 0.01f); assertEquals(8128640.0f, sum.getFloat(4), 0.01f); } @Test public void testSum4() { INDArray n = Nd4j.linspace(1, 1000, 128000).reshape(128, 1000); long time1 = System.currentTimeMillis(); INDArray sum = n.sum(new int[]{1}); long time2 = System.currentTimeMillis(); System.out.println("Execution time: " + (time2 - time1)); System.out.println("elementWiseStride: " + n.elementWiseStride()); System.out.println("elementStride: " + n.elementStride()); assertEquals(4898.4707f, sum.getFloat(0), 0.01f); assertEquals(12703.209f, sum.getFloat(1), 0.01f); } @Test public void testSum5() { INDArray n = Nd4j.linspace(1, 1000, 128000).reshape(128, 1000); INDArray sum = n.sum(new int[]{1}); INDArray sum2 = n.sum(new int[]{-1}); INDArray sum3 = n.sum(new int[]{0}); System.out.println("elementWiseStride: " + n.elementWiseStride()); System.out.println("elementStride: " + n.elementStride()); assertEquals(4898.4707f, sum.getFloat(0), 0.01f); assertEquals(12703.209f, sum.getFloat(1), 0.01f); assertEquals(sum, sum2); assertNotEquals(sum, sum3); assertEquals(63565.023f, sum3.getFloat(0), 0.01f); assertEquals(63570.008f, sum3.getFloat(5), 0.01f); } @Test public void testSum6() { INDArray n = Nd4j.linspace(1, 1000, 128000).reshape(128, 10, 10, 10); INDArray sum0 = n.sum(new int[]{0}); INDArray sum1 = n.sum(new int[]{1}); INDArray sum3 = n.sum(new int[]{3}); INDArray sumN = n.sum(new int[]{-1}); INDArray sum2 = n.sum(new int[]{2}); System.out.println("elementWiseStride: " + n.elementWiseStride()); System.out.println("elementStride: " + n.elementStride()); assertEquals(63565.023f, sum0.getFloat(0), 0.01f); assertEquals(63570.008f, sum0.getFloat(5), 0.01f); assertEquals(45.12137f, sum1.getFloat(0), 0.01f); assertEquals(45.511604f, sum1.getFloat(5), 0.01f); assertEquals(10.351214f, sum3.getFloat(0), 0.01f); assertEquals(14.25359f, sum3.getFloat(5), 0.01f); assertEquals(14.25359f, sumN.getFloat(5), 0.01f); assertEquals(13.74628f, sum2.getFloat(3), 0.01f); } @Test public void testSum3Of4_2222() { int[] shape = {2, 2, 2, 2}; int length = ArrayUtil.prod(shape); INDArray arrC = Nd4j.linspace(1, length, length).reshape(shape); INDArray arrF = Nd4j.create(arrC.shape()).reshape('f', arrC.shape()).assign(arrC); System.out.println("Arrf: " + arrF); System.out.println("Arrf: " + Arrays.toString(arrF.data().asFloat())); System.out.println("ArrF shapeInfo: " + arrF.shapeInfoDataBuffer()); System.out.println("----------------------------"); int[][] dimsToSum = new int[][]{{0, 1, 2}, {0, 1, 3}, {0, 2, 3}, {1, 2, 3}}; double[][] expD = new double[][]{{64, 72}, {60, 76}, {52, 84}, {36, 100}}; for (int i = 0; i < dimsToSum.length; i++) { int[] d = dimsToSum[i]; INDArray outC = arrC.sum(d); INDArray outF = arrF.sum(d); INDArray exp = Nd4j.create(expD[i],outC.shape()); assertEquals(exp, outC); assertEquals(exp, outF); System.out.println("PASSED:" + Arrays.toString(d) + "\t" + outC + "\t" + outF); } } @Test public void testDimensionMax() { INDArray linspace = Nd4j.linspace(1, 6, 6).reshape('f', 2, 3); int axis = 0; INDArray row = linspace.slice(axis); System.out.println("Linspace: " + linspace); System.out.println("Row: " + row); System.out.println("Row shapeInfo: " + row.shapeInfoDataBuffer()); Max max = new Max(row); double max2 = Nd4j.getExecutioner().execAndReturn(max).getFinalResult().doubleValue(); assertEquals(5.0, max2, 1e-1); Min min = new Min(row); double min2 = Nd4j.getExecutioner().execAndReturn(min).getFinalResult().doubleValue(); assertEquals(1.0, min2, 1e-1); } @Test public void testNorm2() throws Exception { INDArray array1 = Nd4j.create(new float[]{2.01f, 2.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray result = array1.norm2(1); System.out.println(result); assertEquals(4.62f, result.getDouble(0), 0.001); } @Test public void testSumF() throws Exception { INDArray arrc = Nd4j.linspace(1,6,6).reshape('c',3,2); INDArray arrf = Nd4j.create(new double[6],new int[]{3,2},'f').assign(arrc); System.out.println("ArrC: " + arrc); System.out.println("ArrC buffer: " + Arrays.toString(arrc.data().asFloat())); System.out.println("ArrF: " + arrf); System.out.println("ArrF buffer: " + Arrays.toString(arrf.data().asFloat())); System.out.println("ArrF shape: " + arrf.shapeInfoDataBuffer()); INDArray cSum = arrc.sum(0); INDArray fSum = arrf.sum(0); assertEquals(Nd4j.create(new float[]{9f,12f}),fSum); } @Test public void testMax1() throws Exception { INDArray array1 = Nd4j.linspace(1, 76800,76800).reshape(256, 300); long time1 = System.currentTimeMillis(); INDArray array = array1.max(1); long time2 = System.currentTimeMillis(); System.out.println("Time elapsed: "+ (time2 - time1) ); assertEquals(256, array.length()); for (int x = 0; x < 256; x++) { assertEquals((x + 1) * 300, array.getFloat(x), 0.01f); } } @Test public void testMax0() throws Exception { INDArray array1 = Nd4j.linspace(1, 76800,76800).reshape(256, 300); long time1 = System.currentTimeMillis(); INDArray array = array1.max(0); long time2 = System.currentTimeMillis(); System.out.println("Array1 shapeInfo: " + array1.shapeInfoDataBuffer()); System.out.println("Result shapeInfo: " + array.shapeInfoDataBuffer()); System.out.println("Time elapsed: "+ (time2 - time1) ); assertEquals(300, array.length()); for (int x = 0; x < 300; x++) { assertEquals("Failed on x: " + x, 76800 - (array1.columns() - x) + 1 , array.getFloat(x), 0.01f); } } @Test public void testMax1_2() throws Exception { INDArray array1 = Nd4j.linspace(1, 7680000,7680000).reshape(2560, 3000); /* for (int x = 0; x < 7680000; x++) { assertEquals(x+1, array1.getFloat(x), 0.001f); } */ long time1 = System.currentTimeMillis(); INDArray array = array1.max(1); long time2 = System.currentTimeMillis(); System.out.println("Time elapsed: "+ (time2 - time1) ); assertEquals(2560, array.length()); //System.out.println("Array: " + array); for (int x = 0; x < 2560; x++) { assertEquals("Failed on x:" + x,(x + 1) * 3000, array.getFloat(x), 0.01f); } } }