package jcuda.jcublas.ops; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.conf.Configuration; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; import static org.junit.Assert.assertEquals; /** * These tests are set up just to track changes done to various GPU transfer mechanics * * Each test will address specific memory allocation strategy * * @author raver119@gmail.com */ @Ignore public class CudaScalarsTests { @Before public void setUp() { CudaEnvironment.getInstance().getConfiguration() .setExecutionModel(Configuration.ExecutionModel.SEQUENTIAL) .setFirstMemory(AllocationStatus.DEVICE) .setMaximumBlockSize(64) .setMaximumGridSize(256) .enableDebug(true); System.out.println("Init called"); } // dim3 launchDims = getFlatLaunchParams((int) extraPointers[2], (int *) extraPointers[0], nullptr); // dim3 launchDims = getReduceLaunchParams((int) extraPointers[2], (int *) extraPointers[0], nullptr, nullptr, 1, sizeof(float), 2); // dim3 launchDims = getReduceLaunchParams((int) extraPointers[2], (int *) extraPointers[0], nullptr, resultShapeInfoPointer, dimensionLength, sizeof(float), 2); @Test public void testPinnedScalarDiv() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{2.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); array2.divi(0.5f); System.out.println("Divi result: " + array2.getFloat(0)); assertEquals(4.0f, array2.getFloat(0), 0.01f); Thread.sleep(100000000000L); } @Test public void testPinnedScalarAdd() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); array2.addi(0.5f); System.out.println("Addi result: " + array2.getFloat(0)); assertEquals(1.5f, array2.getFloat(0), 0.01f); } @Test public void testPinnedScalarSub() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{2.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); array2.subi(0.5f); System.out.println("Subi result: " + array2.getFloat(0)); assertEquals(1.5f, array2.getFloat(0), 0.01f); } @Test public void testPinnedScalarMul() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); array2.muli(2.0f); System.out.println("Mul result: " + array2.getFloat(0)); assertEquals(2.0f, array2.getFloat(0), 0.01f); } @Test public void testPinnedScalarRDiv() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{2.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); array2.rdivi(0.5f); System.out.println("RDiv result: " + array2.getFloat(0)); assertEquals(0.25f, array2.getFloat(0), 0.01f); } @Test public void testPinnedScalarRSub() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{2.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); array2.rsubi(0.5f); System.out.println("RSub result: " + array2.getFloat(0)); assertEquals(-1.5f, array2.getFloat(0), 0.01f); } @Test public void testPinnedScalarMax() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{2.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); INDArray max = Transforms.max(array2, 0.5f, true); System.out.println("Max result: " + max); assertEquals(2.0f, array2.getFloat(0), 0.01f); assertEquals(1.0f, array2.getFloat(1), 0.01f); } @Test public void testPinnedScalarLessOrEqual() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{2.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); INDArray max = array2.lti(1.1f); System.out.println("LTI result: " + max); assertEquals(0.0, array2.getFloat(0), 0.01f); } @Test public void testPinnedScalarGreaterOrEqual() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{2.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); INDArray max = array2.gti(1.1f); System.out.println("LTI result: " + max); assertEquals(1.0, array2.getFloat(0), 0.01f); assertEquals(0.0, array2.getFloat(1), 0.01f); } @Test public void testPinnedScalarEquals() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{2.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); INDArray max = array2.eqi(2.0f); System.out.println("EQI result: " + max); assertEquals(1.0, array2.getFloat(0), 0.01f); assertEquals(0.0, array2.getFloat(1), 0.01f); } @Test public void testPinnedScalarSet() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f}); INDArray array2 = Nd4j.create(new float[]{2.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}); INDArray max = array2.assign(3.0f); System.out.println("EQI result: " + max); System.out.println("Array2 result: " + array2); assertEquals(3.0, array2.getFloat(0), 0.01f); assertEquals(3.0, array2.getFloat(1), 0.01f); } @Test public void testScalMul2() throws Exception { INDArray array1 = Nd4j.linspace(1, 784000, 784000).reshape(784, 1000); INDArray array2 = Nd4j.linspace(1, 784000, 784000).reshape(784, 1000).dup('f'); long time1 = System.currentTimeMillis(); array1.muli(0.5f); long time2 = System.currentTimeMillis(); System.out.println("Execution time 1: " + (time2 - time1)); time1 = System.currentTimeMillis(); array2.muli(0.5f); time2 = System.currentTimeMillis(); System.out.println("Execution time 2: " + (time2 - time1)); //System.out.println("MUL result: " + array1); assertEquals(0.5f, array1.getDouble(0), 0.0001f); assertEquals(392000f, array1.getDouble(783999), 0.0001f); assertEquals(array1, array2); } @Test public void testScalMul3() throws Exception { INDArray array1 = Nd4j.linspace(1, 784, 784).reshape(1, 784); INDArray array2 = Nd4j.linspace(1, 784, 784).reshape(1, 784).dup('f'); array1.divi(0.5f); array2.divi(0.5f); //System.out.println("MUL result: " + array1); assertEquals(2.0f, array1.getDouble(0), 0.0001f); assertEquals(1568f, array1.getDouble(783), 0.0001f); assertEquals(array1, array2); } @Test public void testScalMul4() throws Exception { INDArray array1 = Nd4j.zeros(1, 784); INDArray array2 = Nd4j.zeros(1, 784).dup('f'); array1.divi(255f); array2.divi(255f); //System.out.println("MUL result: " + array1); assertEquals(0.0f, array1.getDouble(0), 0.0001f); assertEquals(0.0f, array1.getDouble(783), 0.0001f); assertEquals(array1, array2); } @Test public void testScalMul5() throws Exception { INDArray array1 = Nd4j.create(new double[] {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 191.0, 253.0, 253.0, 253.0, 60.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 190.0, 251.0, 251.0, 251.0, 230.0, 166.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.0, 127.0, 0.0, 190.0, 251.0, 251.0, 251.0, 253.0, 220.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 24.0, 84.0, 221.0, 229.0, 251.0, 139.0, 23.0, 31.0, 225.0, 251.0, 253.0, 248.0, 111.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 194.0, 251.0, 251.0, 251.0, 251.0, 218.0, 39.0, 0.0, 83.0, 193.0, 253.0, 251.0, 126.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 24.0, 194.0, 255.0, 253.0, 253.0, 253.0, 253.0, 255.0, 63.0, 0.0, 0.0, 100.0, 255.0, 253.0, 173.0, 12.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 186.0, 251.0, 253.0, 251.0, 251.0, 251.0, 251.0, 221.0, 54.0, 0.0, 0.0, 0.0, 253.0, 251.0, 251.0, 149.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 189.0, 251.0, 251.0, 253.0, 251.0, 251.0, 219.0, 126.0, 0.0, 0.0, 0.0, 0.0, 0.0, 253.0, 251.0, 251.0, 188.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 141.0, 251.0, 251.0, 253.0, 251.0, 251.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 114.0, 251.0, 251.0, 244.0, 83.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 127.0, 251.0, 251.0, 253.0, 231.0, 94.0, 12.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 96.0, 251.0, 251.0, 251.0, 94.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 92.0, 253.0, 253.0, 253.0, 255.0, 63.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 96.0, 253.0, 253.0, 253.0, 95.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 12.0, 197.0, 251.0, 251.0, 251.0, 161.0, 16.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 115.0, 251.0, 251.0, 251.0, 94.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 96.0, 251.0, 251.0, 251.0, 172.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 253.0, 251.0, 251.0, 219.0, 47.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 96.0, 251.0, 251.0, 251.0, 94.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 162.0, 253.0, 251.0, 251.0, 148.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 96.0, 251.0, 251.0, 251.0, 94.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 20.0, 158.0, 181.0, 251.0, 253.0, 251.0, 172.0, 12.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 96.0, 253.0, 253.0, 253.0, 153.0, 96.0, 96.0, 96.0, 96.0, 96.0, 155.0, 253.0, 253.0, 253.0, 253.0, 255.0, 253.0, 126.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 24.0, 185.0, 251.0, 251.0, 251.0, 253.0, 251.0, 251.0, 251.0, 251.0, 253.0, 251.0, 251.0, 251.0, 251.0, 253.0, 207.0, 31.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 188.0, 251.0, 251.0, 253.0, 251.0, 251.0, 251.0, 251.0, 253.0, 251.0, 251.0, 251.0, 172.0, 126.0, 31.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 95.0, 188.0, 228.0, 253.0, 251.0, 251.0, 251.0, 251.0, 253.0, 243.0, 109.0, 31.0, 12.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 59.0, 95.0, 94.0, 94.0, 94.0, 193.0, 95.0, 82.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}); array1.divi(255f); System.out.println("Array1: "+ array1); } }