package jcuda.jcublas.ops;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import java.util.Arrays;
/**
* @author raver119@gmail.com
*/
@Ignore
public class CudaBlasTests {
@Test
public void testMMuli1() throws Exception {
INDArray array1 = Nd4j.linspace(1, 250, 250).reshape(new int[]{5, 50});
System.out.println("array1: " + array1);
INDArray array2 = Nd4j.linspace(1, 250, 250).reshape(new int[]{50, 5});
System.out.println("array2: " + array2);
INDArray result = Nd4j.create(new int[]{5, 5});
System.out.println("Order1: " + array1.ordering());
System.out.println("Order2: " + array2.ordering());
System.out.println("Result order: " + result.ordering());
array1.mmul(array2, result);
System.out.println("Result: " + result);
Thread.sleep(100000000000L);
}
@Test
public void testDup1() throws Exception {
INDArray array1 = Nd4j.linspace(1, 250, 250).reshape(new int[]{5, 50}).dup('f');
INDArray array2 = array1.dup();
System.out.println("array1 ordering: " + array1.ordering()); // F
System.out.println("array2 ordering: " + array2.ordering()); // C
System.out.println("array1 eq array2: " + array1.equals(array2)); // true
//assertEquals(array1.getFloat(17), array2.getFloat(17), 0.001f );
}
@Test
public void testForAlex() throws Exception {
int[][] shape1s = new int[][]{{10240, 10240}};
int[][] shape2s = new int[][]{{10240, 10240}};
int[] nTestsArr = new int[]{5};
for(int test=0; test<shape1s.length; test++ ) {
int[] shape1 = shape1s[test];
int[] shape2 = shape2s[test];
int nTests = nTestsArr[test];
INDArray c1 = Nd4j.create(shape1, 'c');
INDArray c2 = Nd4j.create(shape2, 'c');
CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
AtomicAllocator.getInstance().getPointer(c1, context);
AtomicAllocator.getInstance().getPointer(c2, context);
//CC
long startCC = System.currentTimeMillis();
for (int i = 0; i < nTests; i++) {
c1.mmul(c2);
}
long endCC = System.currentTimeMillis();
System.out.println("cc");
System.out.println("mmul: " + Arrays.toString(shape1) + "x" + Arrays.toString(shape2) + ", " + nTests + " runs");
System.out.println("cc: " + (endCC - startCC));
}
}
}