package jcuda.jcublas.ops; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.conf.Configuration; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.IndexAccumulation; import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertEquals; /** * @author raver119@gmail.com */ @Ignore public class CudaIndexReduceTests { @Before public void setUp() { CudaEnvironment.getInstance().getConfiguration() .setExecutionModel(Configuration.ExecutionModel.SEQUENTIAL) .setFirstMemory(AllocationStatus.DEVICE) .setMaximumBlockSize(64) .setMaximumGridSize(64) .enableDebug(true); System.out.println("Init called"); } @Test public void testPinnedIMax() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.0f, 0.1f, 2.0f, 3.0f, 4.0f, 5.0f}); int idx = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(array1))).getFinalResult(); System.out.println("Array1: " + array1); assertEquals(5, idx); } @Test public void testPinnedIMax4() throws Exception { // simple way to stop test if we're not on CUDA backend here INDArray array1 = Nd4j.create(new float[]{0.0f, 0.0f, 0.0f, 2.0f, 2.0f, 0.0f}); int idx = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(array1))).getFinalResult(); System.out.println("Array1: " + array1); assertEquals(3, idx); } @Test public void testPinnedIMaxLarge() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.linspace(1,1024,1024); int idx = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(array1))).getFinalResult(); System.out.println("Array1: " + array1); assertEquals(1023, idx); } @Test public void testIMaxLargeLarge() throws Exception { INDArray array1 = Nd4j.linspace(1, 1000, 12800); int idx = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(array1))).getFinalResult(); assertEquals(12799, idx); } @Test public void testIamaxC() { INDArray linspace = Nd4j.linspace(1, 4, 4).dup('c'); assertEquals(3,Nd4j.getBlasWrapper().iamax(linspace)); } @Test public void testIamaxF() { INDArray linspace = Nd4j.linspace(1, 4, 4).dup('f'); assertEquals(3,Nd4j.getBlasWrapper().iamax(linspace)); } @Test public void testIMax2() { INDArray array1 = Nd4j.linspace(1, 1000, 128000).reshape(128, 1000); long time1 = System.currentTimeMillis(); INDArray argMax = Nd4j.argMax(array1, 1); long time2 = System.currentTimeMillis(); System.out.println("Execution time: " + (time2 - time1)); for (int i = 0; i < 128; i++) { assertEquals(999f, argMax.getFloat(i), 0.0001f); } } @Test public void testIMaxAlongDimension() throws Exception { INDArray array = Nd4j.linspace(1, 491520, 491520).reshape(10, 3, 4, 64, 64); INDArray result = Nd4j.argMax(array, 4); System.out.println("Result shapeInfo: " + result.shapeInfoDataBuffer()); System.out.println("Result length: " + result.length()); //System.out.println("Result: " + result); } @Test public void testIMax3() { INDArray array1 = Nd4j.linspace(1, 1000, 128000).reshape(128, 1000); INDArray argMax = Nd4j.argMax(array1, 0); System.out.println("ARgmax length: " + argMax.length()); for (int i = 0; i < 1000; i++) { assertEquals("Failed iteration: ["+ i +"]", 127, argMax.getFloat(i), 0.0001f); } } @Test public void testIMax4() { INDArray array1 = Nd4j.linspace(1, 1000, 128000).reshape(128, 1000); long time1 = System.currentTimeMillis(); INDArray argMax = Nd4j.argMax(array1, 0,1); long time2 = System.currentTimeMillis(); System.out.println("Execution time: " + (time2 - time1)); assertEquals(127999f, argMax.getFloat(0), 0.001f); } @Test public void testIMaxDimensional() throws Exception { INDArray toArgMax = Nd4j.linspace(1,24,24).reshape(4, 3, 2); INDArray valueArray = Nd4j.valueArrayOf(new int[]{4, 2}, 2.0); INDArray valueArrayTwo = Nd4j.valueArrayOf(new int[]{3,2},3.0); INDArray valueArrayThree = Nd4j.valueArrayOf(new int[]{4,3},1.0); INDArray argMax = Nd4j.argMax(toArgMax, 1); assertEquals(valueArray, argMax); INDArray argMaxZero = Nd4j.argMax(toArgMax,0); assertEquals(valueArrayTwo, argMaxZero); INDArray argMaxTwo = Nd4j.argMax(toArgMax,2); assertEquals(valueArrayThree,argMaxTwo); } @Test public void testPinnedIMax2() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{6.0f, 0.1f, 2.0f, 3.0f, 7.0f, 5.0f}); int idx = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(array1))).getFinalResult(); System.out.println("Array1: " + array1); assertEquals(4, idx); } @Test public void testPinnedIMax3() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{6.0f, 0.1f, 2.0f, 3.0f, 7.0f, 9.0f}); int idx = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(array1))).getFinalResult(); System.out.println("Array1: " + array1); assertEquals(5, idx); } @Test public void testPinnedIMin() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{1.0f, 0.1f, 2.0f, 3.0f, 4.0f, 5.0f}); int idx = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMin(array1))).getFinalResult(); System.out.println("Array1: " + array1); assertEquals(1, idx); } @Test public void testPinnedIMin2() throws Exception { // simple way to stop test if we're not on CUDA backend here assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName()); INDArray array1 = Nd4j.create(new float[]{0.1f, 1.1f, 2.0f, 3.0f, 4.0f, 5.0f}); int idx = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMin(array1))).getFinalResult(); System.out.println("Array1: " + array1); assertEquals(0, idx); } @Test public void testIMaxF1() throws Exception { Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.rand('f',10,2); for( int i=0; i<10; i++ ){ INDArray row = arr.getRow(i); int maxIdx; if(row.getDouble(0) > row.getDouble(1)) maxIdx = 0; else maxIdx = 1; INDArray argmax = Nd4j.argMax(row,1); double argmaxd = argmax.getDouble(0); assertEquals(maxIdx, (int)argmaxd); System.out.println(row); System.out.println(argmax); System.out.println("exp: " + maxIdx + ", act: " + argmaxd); } } }