package jcuda.jcublas.ops; import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertEquals; /** * @author raver119@gmail.com */ @Ignore public class CublasTests { @Test public void testGemm1() throws Exception { INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array3 = array1.mmul(array2); assertEquals(338350f, array3.getFloat(0), 0.001f); } @Test public void testGemm2() throws Exception { INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); INDArray array3 = array1.mmul(array2); assertEquals(338350f, array3.getFloat(0), 0.001f); } @Test public void testGemm3() throws Exception { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); INDArray array3 = array1.mmul(array2); //System.out.println("Array3: " + Arrays.toString(array3.data().asFloat())); assertEquals(3338050.0f, array3.data().getFloat(0),0.001f); assertEquals(8298050.0f, array3.data().getFloat(1),0.001f); assertEquals(3343100.0f, array3.data().getFloat(10),0.001f); assertEquals(8313100.0f, array3.data().getFloat(11),0.001f); assertEquals(3348150.0f, array3.data().getFloat(20),0.001f); assertEquals(8328150.0f, array3.data().getFloat(21),0.001f); } @Test public void testGemm4() throws Exception { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10); INDArray array3 = array1.mmul(array2); //System.out.println("Array3: " + Arrays.toString(array3.data().asFloat())); assertEquals(338350f, array3.data().getFloat(0),0.001f); assertEquals(843350f, array3.data().getFloat(1),0.001f); assertEquals(843350f, array3.data().getFloat(10),0.001f); assertEquals(2348350f, array3.data().getFloat(11),0.001f); assertEquals(1348350f, array3.data().getFloat(20),0.001f); assertEquals(3853350f, array3.data().getFloat(21),0.001f); } @Test public void testGemm5() throws Exception { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); INDArray array3 = array1.mmul(array2); //System.out.println("Array3: " + Arrays.toString(array3.data().asFloat())); assertEquals(3.293408E7f, array3.data().getFloat(0),10f); assertEquals(3.29837E7f, array3.data().getFloat(1),10f); assertEquals(3.3835E7f, array3.data().getFloat(99),10f); } @Test public void testGemm6() throws Exception { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10); INDArray array3 = array1.mmul(array2); //System.out.println("Array3: " + Arrays.toString(array3.data().asFloat())); assertEquals(3338050.0f, array3.data().getFloat(0),0.001f); assertEquals(3343100f, array3.data().getFloat(1),0.001f); assertEquals(8298050f, array3.data().getFloat(10),0.001f); assertEquals(8313100.0f, array3.data().getFloat(11),0.001f); assertEquals(1.325805E7f, array3.data().getFloat(20),5f); assertEquals(1.32831E7f, array3.data().getFloat(21),5f); } @Test public void testGemm7() throws Exception { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); INDArray array3 = Nd4j.create(10, 10); array1.mmul(array2, array3); // System.out.println("Array3: " + Arrays.toString(array3.data().asFloat())); assertEquals(3338050.0f, array3.data().getFloat(0),0.001f); assertEquals(8298050.0f, array3.data().getFloat(1),0.001f); assertEquals(3343100.0f, array3.data().getFloat(10),0.001f); assertEquals(8313100.0f, array3.data().getFloat(11),0.001f); assertEquals(3348150.0f, array3.data().getFloat(20),0.001f); assertEquals(8328150.0f, array3.data().getFloat(21),0.001f); } @Test public void testGemm8() throws Exception { INDArray array1 = Nd4j.ones(10, 10); INDArray array2 = Nd4j.ones(10, 10); INDArray array3 = Nd4j.create(10, 10); array1.mmul(array2, array3); assertEquals(10.0f, array3.data().getFloat(0),0.001f); assertEquals(10.0f, array3.data().getFloat(1),0.001f); assertEquals(10.0f, array3.data().getFloat(10),0.001f); assertEquals(10.0f, array3.data().getFloat(11),0.001f); assertEquals(10.0f, array3.data().getFloat(20),0.001f); assertEquals(10.0f, array3.data().getFloat(21),0.001f); } @Test public void testGemv1() throws Exception { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1,100, 100).reshape(100,1); INDArray array3 = array1.mmul(array2); assertEquals(10, array3.length()); assertEquals(338350f, array3.getFloat(0), 0.001f); assertEquals(843350f, array3.getFloat(1), 0.001f); assertEquals(1348350f, array3.getFloat(2), 0.001f); assertEquals(1853350f, array3.getFloat(3), 0.001f); } @Test public void testGemv2() throws Exception { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1,100, 100).reshape('f', 100,1); INDArray array3 = array1.mmul(array2); assertEquals(10, array3.length()); assertEquals(338350f, array3.getFloat(0), 0.001f); assertEquals(843350f, array3.getFloat(1), 0.001f); assertEquals(1348350f, array3.getFloat(2), 0.001f); assertEquals(1853350f, array3.getFloat(3), 0.001f); } @Test public void testGemv3() throws Exception { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array2 = Nd4j.linspace(1,100, 100).reshape('f', 100,1); INDArray array3 = array1.mmul(array2); assertEquals(10, array3.length()); assertEquals(3338050f, array3.getFloat(0), 0.001f); assertEquals(3343100f, array3.getFloat(1), 0.001f); assertEquals(3348150f, array3.getFloat(2), 0.001f); assertEquals(3353200f, array3.getFloat(3), 0.001f); } @Test public void testGemv4() throws Exception { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array2 = Nd4j.linspace(1,100, 100).reshape(100,1); INDArray array3 = array1.mmul(array2); assertEquals(10, array3.length()); assertEquals(3338050f, array3.getFloat(0), 0.001f); assertEquals(3343100f, array3.getFloat(1), 0.001f); assertEquals(3348150f, array3.getFloat(2), 0.001f); assertEquals(3353200f, array3.getFloat(3), 0.001f); } }