package com.nativelibs4java.opencl; import com.nativelibs4java.opencl.CLMem.Usage; import java.lang.reflect.Array; import static org.junit.Assert.*; import org.bridj.*; import static org.bridj.Pointer.*; import org.junit.*; public class KernelTest { CLContext context; CLQueue queue; @Before public void setup() { context = JavaCL.createBestContext(CLPlatform.DeviceFeature.CPU); queue = context.createDefaultQueue(); } public <T> Pointer<T> testArg(String type, Object value, Class<T> targetType) { long size = BridJ.sizeOf(targetType); CLBuffer<Byte> out = context.createByteBuffer(Usage.Output, size) ; CLKernel k = context.createProgram( // "#if __OPENCL_VERSION__ <= CL_VERSION_1_1\n" + " #pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" + // "#endif\n" + "kernel void f(" + type + " arg, global " + type + "* out, long size) {\n" + "char* in = (char*) &arg;\n" + "for (long i = 0; i < size; i++) {\n" + "out[i] = in[i];\n" + "}\n" + "}" ).createKernel("f", value, out, size); CLEvent e = k.enqueueTask(queue); return out.as(targetType).read(queue, e); } public <T> Object testArrayArg(String type, Object array, Class<T> targetType) { long size = BridJ.sizeOf(targetType); long length = Array.getLength(array); CLBuffer<Byte> out = context.createByteBuffer(Usage.Output, size * length); StringBuilder b = new StringBuilder( "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" + "kernel void f(" + type + length + " arg, global " + type + "* out, long length) {\n"); for (long i = 0; i < length; i++) { b.append("out[" + i + "] = arg.s" + (i < 10 ? i + "" : ((char)((int)'a' + (i - 10))) + "") + ";\n"); } b.append("}\n"); System.out.println(b); CLKernel k = context.createProgram(b.toString()).createKernel("f", array, out, length); CLEvent e = k.enqueueTask(queue); return out.as(targetType).read(queue, e).getArray(); } @Test public void nullArg() { CLBuffer<Byte> out = context.createByteBuffer(Usage.InputOutput, 2) ; CLKernel isInputNull = context.createProgram( "kernel void isInputNull(global int* in, global bool* out) {\n" + "*out = !in;\n" + "}" ).createKernel("isInputNull"); isInputNull.setArgs(CLKernel.NULL_POINTER_KERNEL_ARGUMENT, out); assertTrue(out.read(queue, isInputNull.enqueueTask(queue)).as(Boolean.class).get()); isInputNull.setArgs(out, out); assertFalse(out.read(queue, isInputNull.enqueueTask(queue)).as(Boolean.class).get()); } byte[] byteTup(int n) { byte[] a = new byte[n]; for (int i = 0; i < n; i++) a[i] = (byte)(i + 1); return a; } @Test public void byteArg() { assertArrayEquals(new byte[] { 2 }, testArg("char", (byte) 2, byte.class).getBytes()); for (byte[] tup : new byte[][] { byteTup(2), byteTup(3), byteTup(4), byteTup(8), byteTup(16) }) { assertArrayEquals(tup, (byte[]) testArrayArg("char", tup, byte.class)); } } short[] shortTup(int n) { short[] a = new short[n]; for (int i = 0; i < n; i++) a[i] = (short)(i + 1); return a; } @Test public void shortArg() { assertArrayEquals(new short[] { 2 }, testArg("short", (short) 2, short.class).getShorts()); for (short[] tup : new short[][] { shortTup(2), shortTup(3), shortTup(4), shortTup(8), shortTup(16) }) { assertArrayEquals(tup, (short[]) testArrayArg("short", tup, short.class)); } } int[] intTup(int n) { int[] a = new int[n]; for (int i = 0; i < n; i++) a[i] = i + 1; return a; } @Test public void intArg() { assertArrayEquals(new int[] { 2 }, testArg("int", (int) 2, int.class).getInts()); for (int[] tup : new int[][] { intTup(2), intTup(3), intTup(4), intTup(8), intTup(16) }) { assertArrayEquals(tup, (int[]) testArrayArg("int", tup, int.class)); } } long[] longTup(int n) { long[] a = new long[n]; for (int i = 0; i < n; i++) a[i] = i + 1; return a; } @Test public void longArg() { assertArrayEquals(new long[] { 2 }, testArg("long", (long) 2, long.class).getLongs()); for (long[] tup : new long[][] { longTup(2), longTup(3), longTup(4), longTup(8), longTup(16) }) { assertArrayEquals(tup, (long[]) testArrayArg("long", tup, long.class)); } } float[] floatTup(int n) { float[] a = new float[n]; for (int i = 0; i < n; i++) a[i] = i + 1; return a; } @Ignore @Test public void floatArg() { assertArrayEquals(new float[] { 2f }, testArg("float", (float) 4, float.class).getFloats(), 0); for (float[] tup : new float[][] { floatTup(2), floatTup(3), floatTup(4), floatTup(8), floatTup(16) }) { assertArrayEquals(tup, (float[]) testArrayArg("float", tup, float.class), 0); } } double[] doubleTup(int n) { double[] a = new double[n]; for (int i = 0; i < n; i++) a[i] = i + 1; return a; } @Ignore @Test public void doubleArg() { assertArrayEquals(new double[] { 2d }, testArg("double", (double) 8, double.class).getDoubles(), 0); for (double[] tup : new double[][] { doubleTup(2), doubleTup(3), doubleTup(4), doubleTup(8), doubleTup(16) }) { assertArrayEquals(tup, (double[]) testArrayArg("double", tup, double.class), 0); } } // void assertArrayEquals(Object exp, Object act) { // assertEquals(Arrays) // } }