package edu.berkeley.cs.nlp.ocular.model.em; import static org.jocl.CL.*; import org.jocl.*; import edu.berkeley.cs.nlp.ocular.model.CharacterTemplate; import tberg.murphy.gpu.CudaUtil; /** * @author Taylor Berg-Kirkpatrick (tberg@eecs.berkeley.edu) */ public class JOCLInnerLoop implements EmissionCacheInnerLoop { public static final int GPU_BLOCK_SIZE_X = 1; public static final int GPU_ROLL_X = 32; public static final int GPU_BLOCK_SIZE_Y = 64; public static final int CPU_BLOCK_SIZE_X = 1; public static final int CPU_ROLL_X = 8; public static final int CPU_BLOCK_SIZE_Y = 1; int blockSizeX; int rollX; int blockSizeY; int numThreads; int[] templateNumIndices; int[] templateIndicesOffsets; int maxTemplateWidth; int minTemplateWidth; cl_context context; cl_command_queue queue; cl_program program; cl_mem d_Ow; cl_mem d_Ob; cl_mem d_scores; cl_mem[] d_Tw; cl_mem[] d_Tb; cl_kernel[] kernels; private static String getString(cl_device_id device, int paramName) { // Obtain the length of the string that will be queried long size[] = new long[1]; clGetDeviceInfo(device, paramName, 0, null, size); // Create a buffer of the appropriate size and fill it with the info byte buffer[] = new byte[(int)size[0]]; clGetDeviceInfo(device, paramName, buffer.length, Pointer.to(buffer), null); // Create a string from the buffer (excluding the trailing \0 byte) return new String(buffer, 0, buffer.length-1); } public JOCLInnerLoop(int numThreads) { this.numThreads = numThreads; final int platformIndex = 0; CL.setExceptionsEnabled(true); int numPlatformsArray[] = new int[1]; clGetPlatformIDs(0, null, numPlatformsArray); int numPlatforms = numPlatformsArray[0]; cl_platform_id platforms[] = new cl_platform_id[numPlatforms]; clGetPlatformIDs(platforms.length, platforms, null); cl_platform_id platform = platforms[platformIndex]; cl_context_properties contextProperties = new cl_context_properties(); contextProperties.addProperty(CL_CONTEXT_PLATFORM, platform); cl_device_id device = null; boolean isGPU = false; { int numDevicesArray[] = new int[1]; final long deviceType = CL_DEVICE_TYPE_GPU; clGetDeviceIDs(platform, deviceType, 0, null, numDevicesArray); int numDevices = numDevicesArray[0]; cl_device_id devices[] = new cl_device_id[numDevices]; clGetDeviceIDs(platform, deviceType, numDevices, devices, null); for (int i=0; i<devices.length; ++i) { String deviceName = getString(devices[i], CL_DEVICE_NAME).toLowerCase(); if (deviceName.contains("radeon") || deviceName.contains("nvidia")) { device = devices[i]; isGPU = true; break; } } } if (!isGPU) { int numDevicesArray[] = new int[1]; final long deviceType = CL_DEVICE_TYPE_CPU; clGetDeviceIDs(platform, deviceType, 0, null, numDevicesArray); int numDevices = numDevicesArray[0]; cl_device_id devices[] = new cl_device_id[numDevices]; clGetDeviceIDs(platform, deviceType, numDevices, devices, null); device = devices[0]; isGPU = false; } if (isGPU) { this.blockSizeX = GPU_BLOCK_SIZE_X; this.rollX = GPU_ROLL_X; this.blockSizeY = GPU_BLOCK_SIZE_Y; } else { this.blockSizeX = CPU_BLOCK_SIZE_X; this.rollX = CPU_ROLL_X; this.blockSizeY = CPU_BLOCK_SIZE_Y; } System.out.printf("Device name: %s\n", getString(device, CL_DEVICE_NAME)); System.out.println("Block size x: "+blockSizeX); System.out.println("Roll x: "+rollX); System.out.println("Block size y: "+blockSizeY); // Create a context for the selected device context = clCreateContext(contextProperties, 1, new cl_device_id[]{device}, null, null, null); // Create a command-queue queue = clCreateCommandQueue(context, device, 0, null); // Create the program from the source code program = clCreateProgramWithSource(context, 1, new String[]{ kernelSrc() }, null, null); // Build the program clBuildProgram(program, 0, null, "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math -cl-no-signed-zeros", null, null); } public void startup(float[][] whiteTemplates, float[][] blackTemplates, int[] templateNumIndices, int[] templateIndicesOffsets, int minTemplateWidth, int maxTemplateWidth, int maxSequenceLength, int totalTemplateNumIndices) { this.templateNumIndices = templateNumIndices; this.templateIndicesOffsets = templateIndicesOffsets; this.maxTemplateWidth = maxTemplateWidth; this.minTemplateWidth = minTemplateWidth; int numTemplateWidths = (maxTemplateWidth-minTemplateWidth)+1; // Build kernels kernels = new cl_kernel[numTemplateWidths]; for (int tw=minTemplateWidth; tw<=maxTemplateWidth; ++tw) { if (templateNumIndices[tw-minTemplateWidth] > 0) { kernels[tw-minTemplateWidth] = clCreateKernel(program, "compute_emissions_"+tw, null); } } // Allocate the device input data int extendedMaxSeqLength = (blockSizeX*rollX) * (int) Math.ceil(((double) maxSequenceLength) / (blockSizeX*rollX)); this.d_Ow = clCreateBuffer(context, CL_MEM_READ_WRITE, Sizeof.cl_float * (extendedMaxSeqLength+maxTemplateWidth-1)*CharacterTemplate.LINE_HEIGHT, null, null); // this.d_Ow = context.createFloatBuffer(Usage.Input, (extendedMaxSeqLength+maxTemplateWidth-1)*CharacterTemplate.LINE_HEIGHT); this.d_Ob = clCreateBuffer(context, CL_MEM_READ_WRITE, Sizeof.cl_float * (extendedMaxSeqLength+maxTemplateWidth-1)*CharacterTemplate.LINE_HEIGHT, null, null); // this.d_Ob = context.createFloatBuffer(Usage.Input, (extendedMaxSeqLength+maxTemplateWidth-1)*CharacterTemplate.LINE_HEIGHT); this.d_scores = clCreateBuffer(context, CL_MEM_READ_WRITE, Sizeof.cl_float * maxSequenceLength*totalTemplateNumIndices, null, null); // this.d_scores = context.createFloatBuffer(Usage.Output, maxSequenceLength*totalTemplateNumIndices); this.d_Tw = new cl_mem[numTemplateWidths]; this.d_Tb = new cl_mem[numTemplateWidths]; for (int tw=minTemplateWidth; tw<=maxTemplateWidth; ++tw) { if (templateNumIndices[tw-minTemplateWidth] > 0) { d_Tw[tw-minTemplateWidth] = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Sizeof.cl_float * whiteTemplates[tw-minTemplateWidth].length, Pointer.to(whiteTemplates[tw-minTemplateWidth]), null); // d_Tw[tw-minTemplateWidth] = context.createFloatBuffer(Usage.Input, whiteTemplates[tw-minTemplateWidth].length); // d_Tw[tw-minTemplateWidth].write(queue, pc.capture(Pointer.pointerToFloats(whiteTemplates[tw-minTemplateWidth])), false); d_Tb[tw-minTemplateWidth] = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Sizeof.cl_float * blackTemplates[tw-minTemplateWidth].length, Pointer.to(blackTemplates[tw-minTemplateWidth]), null); // d_Tb[tw-minTemplateWidth] = context.createFloatBuffer(Usage.Input, whiteTemplates[tw-minTemplateWidth].length); // d_Tb[tw-minTemplateWidth].write(queue, pc.capture(Pointer.pointerToFloats(blackTemplates[tw-minTemplateWidth])), false); } } } public void shutdown() { clReleaseMemObject(d_Ow); clReleaseMemObject(d_Ob); clReleaseMemObject(d_scores); for (int tw=minTemplateWidth; tw<=maxTemplateWidth; ++tw) { if (templateNumIndices[tw-minTemplateWidth] > 0) { clReleaseMemObject(d_Tw[tw-minTemplateWidth]); clReleaseMemObject(d_Tb[tw-minTemplateWidth]); } } for (cl_kernel kernel : kernels) clReleaseKernel(kernel); } public void compute(final float[] scores, final float[] whiteObservations, final float[] blackObservations, final int sequenceLength) { int numTemplateWidths = (maxTemplateWidth-minTemplateWidth)+1; int gridSizeX = (int) Math.ceil(((double) sequenceLength) / (blockSizeX*rollX)); int extendedSeqLength = gridSizeX * (blockSizeX*rollX); cl_event[] writeEvents = new cl_event[] {new cl_event(), new cl_event()}; clEnqueueWriteBuffer(queue, d_Ow, CL_TRUE, 0, (extendedSeqLength+maxTemplateWidth-1)*CharacterTemplate.LINE_HEIGHT * Sizeof.cl_float, Pointer.to(CudaUtil.extendWithZeros(whiteObservations, (extendedSeqLength+maxTemplateWidth-1)*CharacterTemplate.LINE_HEIGHT)), 0, null, writeEvents[0]); // d_Ow.write(queue, pc.capture(Pointer.pointerToFloats(CudaUtil.extendWithZeros(whiteObservations, (extendedSeqLength+maxTemplateWidth-1)*CharacterTemplate.LINE_HEIGHT))), false); clEnqueueWriteBuffer(queue, d_Ob, CL_TRUE, 0, (extendedSeqLength+maxTemplateWidth-1)*CharacterTemplate.LINE_HEIGHT * Sizeof.cl_float, Pointer.to(CudaUtil.extendWithZeros(blackObservations, (extendedSeqLength+maxTemplateWidth-1)*CharacterTemplate.LINE_HEIGHT)), 0, null, writeEvents[1]); // d_Ob.write(queue, pc.capture(Pointer.pointerToFloats(CudaUtil.extendWithZeros(blackObservations, (extendedSeqLength+maxTemplateWidth-1)*CharacterTemplate.LINE_HEIGHT))), false); cl_event[] kernelEvents = new cl_event[numTemplateWidths]; for (int tw=minTemplateWidth; tw<=maxTemplateWidth; ++tw) { if (templateNumIndices[tw-minTemplateWidth] > 0) { int gridSizeY = (int) Math.ceil(((double) templateNumIndices[tw-minTemplateWidth]) / blockSizeY); cl_kernel kernel = kernels[tw-minTemplateWidth]; clSetKernelArg(kernel, 0, Sizeof.cl_int, Pointer.to(new int[] {templateIndicesOffsets[tw-minTemplateWidth]*sequenceLength})); clSetKernelArg(kernel, 1, Sizeof.cl_int, Pointer.to(new int[] {sequenceLength})); clSetKernelArg(kernel, 2, Sizeof.cl_int, Pointer.to(new int[] {templateNumIndices[tw-minTemplateWidth]})); clSetKernelArg(kernel, 3, Sizeof.cl_mem, Pointer.to(d_Tw[tw-minTemplateWidth])); clSetKernelArg(kernel, 4, Sizeof.cl_mem, Pointer.to(d_Tb[tw-minTemplateWidth])); clSetKernelArg(kernel, 5, Sizeof.cl_mem, Pointer.to(d_Ow)); clSetKernelArg(kernel, 6, Sizeof.cl_mem, Pointer.to(d_Ob)); clSetKernelArg(kernel, 7, Sizeof.cl_mem, Pointer.to(d_scores)); // computeKernel.setArgs(templateIndicesOffsets[tw-minTemplateWidth]*sequenceLength, sequenceLength, templateNumIndices[tw-minTemplateWidth], d_Tw[tw-minTemplateWidth], d_Tb[tw-minTemplateWidth], d_Ow, d_Ob, d_scores); kernelEvents[tw-minTemplateWidth] = new cl_event(); clEnqueueNDRangeKernel(queue, kernel, 2, null, new long[] {gridSizeX*blockSizeX, gridSizeY*blockSizeY}, new long[] {blockSizeX, blockSizeY}, 2, writeEvents, kernelEvents[tw-minTemplateWidth]); // computeKernel.enqueueNDRange(queue, new int[] {gridSizeX*blockSizeX, gridSizeY*blockSizeY}, new int[] {blockSizeX, blockSizeY}); } } cl_event readEvent = new cl_event(); clEnqueueReadBuffer(queue, d_scores, CL_TRUE, 0, scores.length * Sizeof.cl_float, Pointer.to(scores), kernelEvents.length, kernelEvents, readEvent); clWaitForEvents(1, new cl_event[] {readEvent}); // d_scores.read(queue).getFloats(scores); } public int numOuterThreads() { return 1; } public int numPopulateThreads() { return numThreads; } public String kernelSrc() { StringBuffer buf = new StringBuffer(); for (int tw=1; tw<=CharacterTemplate.LINE_HEIGHT; ++tw) { buf.append("__kernel void compute_emissions_"+tw+"(__const int scoresOffset, __const int Olength, __const int Tlength, __global float const* __restrict__ Tw, __global float const* __restrict__ Tb, __global float const* __restrict__ Ow, __global float const* __restrict__ Ob, __global float* scores) {\n"); buf.append("int Tindex = get_global_id(1);\n"); buf.append("if (Tindex < Tlength) {\n"); for (int r=0; r<rollX; ++r) { buf.append("float o"+r+" = 0;\n"); buf.append("float score"+r+" = 0;\n"); } buf.append("for (int i=0; i<"+CharacterTemplate.LINE_HEIGHT*tw+"; ++i) {\n"); buf.append("float tw = Tw[Tindex * "+CharacterTemplate.LINE_HEIGHT*tw+" + i];\n"); for (int r=0; r<rollX; ++r) { buf.append("o"+r+" = Ow[(get_group_id(0) * "+blockSizeX*rollX+" + get_local_id(0) * "+rollX+" + "+r+") * "+CharacterTemplate.LINE_HEIGHT+" + i];\n"); // buf.append("score"+r+" = fma(o"+r+", tw, score"+r+");\n"); // buf.append("score"+r+" = mad(o"+r+", tw, score"+r+");\n"); buf.append("score"+r+" += o"+r+" * tw;\n"); } buf.append("}\n"); buf.append("for (int i=0; i<"+CharacterTemplate.LINE_HEIGHT*tw+"; ++i) {\n"); buf.append("float tb = Tb[Tindex * "+CharacterTemplate.LINE_HEIGHT*tw+" + i];\n"); for (int r=0; r<rollX; ++r) { buf.append("o"+r+" = Ob[(get_group_id(0) * "+blockSizeX*rollX+" + get_local_id(0) * "+rollX+" + "+r+") * "+CharacterTemplate.LINE_HEIGHT+" + i];\n"); // buf.append("score"+r+" = fma(o"+r+", tb, score"+r+");\n"); // buf.append("score"+r+" = mad(o"+r+", tb, score"+r+");\n"); buf.append("score"+r+" += o"+r+" * tb;\n"); } buf.append("}\n"); buf.append("int Oindex;\n"); for (int r=0; r<rollX; ++r) { buf.append("Oindex = get_group_id(0) * "+blockSizeX*rollX+" + get_local_id(0) * "+rollX+" + "+r+";\n"); buf.append("if (Oindex < Olength) scores[scoresOffset + Oindex * Tlength + Tindex] = score"+r+";\n"); } buf.append("}\n"); buf.append("}\n"); } return buf.toString(); } }