package org.trifort.rootbeer.runtime; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import org.trifort.rootbeer.configuration.Configuration; import org.trifort.rootbeer.runtime.util.Stopwatch; import org.trifort.rootbeer.runtimegpu.GpuException; import org.trifort.rootbeer.util.ResourceReader; import com.lmax.disruptor.EventHandler; import com.lmax.disruptor.RingBuffer; import com.lmax.disruptor.dsl.Disruptor; public class CUDAContext implements Context { final private GpuDevice gpuDevice; final private boolean is32bit; private long nativeContext; private long memorySize; private byte[] cubinFile; private Memory objectMemory; private Memory handlesMemory; private Memory textureMemory; private Memory exceptionsMemory; private Memory classMemory; private boolean usingUncheckedMemory; private long requiredMemorySize; private CacheConfig cacheConfig; private ThreadConfig threadConfig; private Kernel kernelTemplate; private CompiledKernel compiledKernel; private boolean usingHandles; final private StatsRow stats; final private Stopwatch writeBlocksStopwatch; final private Stopwatch runStopwatch; final private Stopwatch runOnGpuStopwatch; final private Stopwatch readBlocksStopwatch; final private ExecutorService exec; final private Disruptor<GpuEvent> disruptor; final private EventHandler<GpuEvent> handler; final private RingBuffer<GpuEvent> ringBuffer; static { initializeDriver(); } public CUDAContext(GpuDevice device){ exec = Executors.newCachedThreadPool(new ThreadFactory() { public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setDaemon(true); return t; } }); disruptor = new Disruptor<GpuEvent>(GpuEvent.EVENT_FACTORY, 64, exec); handler = new GpuEventHandler(); disruptor.handleEventsWith(handler); ringBuffer = disruptor.start(); gpuDevice = device; memorySize = -1; String arch = System.getProperty("os.arch"); is32bit = arch.equals("x86") || arch.equals("i386"); usingUncheckedMemory = true; usingHandles = false; nativeContext = allocateNativeContext(); cacheConfig = CacheConfig.PREFER_NONE; stats = new StatsRow(); writeBlocksStopwatch = new Stopwatch(); runStopwatch = new Stopwatch(); runOnGpuStopwatch = new Stopwatch(); readBlocksStopwatch = new Stopwatch(); } @Override public GpuDevice getDevice() { return gpuDevice; } @Override public void close() { disruptor.shutdown(); exec.shutdown(); freeNativeContext(nativeContext); if(objectMemory != null){ objectMemory.close(); } if(handlesMemory != null){ handlesMemory.close(); } if(exceptionsMemory != null){ exceptionsMemory.close(); } if(classMemory != null){ classMemory.close(); } if(textureMemory != null){ textureMemory.close(); } } @Override public void setMemorySize(long memorySize) { this.memorySize = memorySize; } @Override public void setKernel(Kernel kernelTemplate) { this.kernelTemplate = kernelTemplate; this.compiledKernel = (CompiledKernel) kernelTemplate; } @Override public void setCacheConfig(CacheConfig cacheConfig) { this.cacheConfig = cacheConfig; } @Override public void setUsingHandles(boolean value){ usingHandles = value; } @Override public void useCheckedMemory(){ this.usingUncheckedMemory = false; } @Override public void setThreadConfig(ThreadConfig threadConfig) { this.threadConfig = threadConfig; } @Override public void setThreadConfig(int threadCountX, int blockCountX, int numThreads) { setThreadConfig(threadCountX, 1, 1, blockCountX, 1, numThreads); } @Override public void setThreadConfig(int threadCountX, int threadCountY, int blockCountX, int blockCountY, int numThreads) { setThreadConfig(threadCountX, threadCountY, 1, blockCountX, blockCountY, numThreads); } @Override public void setThreadConfig(int threadCountX, int threadCountY, int threadCountZ, int blockCountX, int blockCountY, int numThreads) { this.threadConfig = new ThreadConfig(threadCountX, threadCountY, threadCountZ, blockCountX, blockCountY, numThreads); } @Override public void buildState(){ String filename; int size = 0; boolean error = false; if(is32bit){ filename = compiledKernel.getCubin32(); size = compiledKernel.getCubin32Size(); error = compiledKernel.getCubin32Error(); } else { filename = compiledKernel.getCubin64(); size = compiledKernel.getCubin64Size(); error = compiledKernel.getCubin64Error(); } if(error){ throw new RuntimeException("CUDA code compiled with error"); } cubinFile = readCubinFile(filename, size); if(usingUncheckedMemory){ classMemory = new FixedMemory(1024); exceptionsMemory = new FixedMemory(getExceptionsMemSize(threadConfig)); textureMemory = new FixedMemory(8); if(usingHandles){ handlesMemory = new FixedMemory(4*threadConfig.getNumThreads()); } else { handlesMemory = new FixedMemory(4); } } else { exceptionsMemory = new CheckedFixedMemory(getExceptionsMemSize(threadConfig)); classMemory = new CheckedFixedMemory(1024); textureMemory = new CheckedFixedMemory(8); if(usingHandles){ handlesMemory = new CheckedFixedMemory(4*threadConfig.getNumThreads()); } else { handlesMemory = new CheckedFixedMemory(4); } } if(memorySize == -1){ findMemorySize(cubinFile.length); } if(usingUncheckedMemory){ objectMemory = new FixedMemory(memorySize); } else { objectMemory = new CheckedFixedMemory(memorySize); } long seq = ringBuffer.next(); GpuEvent gpuEvent = ringBuffer.get(seq); gpuEvent.setValue(GpuEventCommand.NATIVE_BUILD_STATE); gpuEvent.getFuture().reset(); ringBuffer.publish(seq); gpuEvent.getFuture().take(); } private long getExceptionsMemSize(ThreadConfig thread_config) { if(Configuration.runtimeInstance().getExceptions()){ return 4L*thread_config.getNumThreads(); } else { return 4; } } private byte[] readCubinFile(String filename, int length) { try { byte[] buffer = ResourceReader.getResourceArray(filename, length); return buffer; } catch(Exception ex){ ex.printStackTrace(); throw new RuntimeException(ex); } } private void findMemorySize(int cubinFileLength){ long freeMemSizeGPU = gpuDevice.getFreeGlobalMemoryBytes(); long freeMemSizeCPU = Runtime.getRuntime().freeMemory(); long freeMemSize = Math.min(freeMemSizeGPU, freeMemSizeCPU); freeMemSize -= cubinFileLength; freeMemSize -= exceptionsMemory.getSize(); freeMemSize -= classMemory.getSize(); freeMemSize -= 2048; if(freeMemSize <= 0){ StringBuilder error = new StringBuilder(); error.append("OutOfMemory while allocating Java CPU and GPU memory.\n"); error.append(" Try increasing the max Java Heap Size using -Xmx and the initial Java Heap Size using -Xms.\n"); error.append(" Try reducing the number of threads you are using.\n"); error.append(" Try using kernel templates.\n"); error.append(" Debugging Output:\n"); error.append(" GPU_SIZE: "+freeMemSizeGPU+"\n"); error.append(" CPU_SIZE: "+freeMemSizeCPU+"\n"); error.append(" EXCEPTIONS_SIZE: "+exceptionsMemory.getSize()+"\n"); error.append(" CLASS_MEMORY_SIZE: "+classMemory.getSize()); throw new RuntimeException(error.toString()); } memorySize = freeMemSize; } @Override public long getRequiredMemory() { return requiredMemorySize; } @Override public void run(){ GpuFuture future = runAsync(); future.take(); } @Override public GpuFuture runAsync() { long seq = ringBuffer.next(); GpuEvent gpuEvent = ringBuffer.get(seq); gpuEvent.setValue(GpuEventCommand.NATIVE_RUN); gpuEvent.getFuture().reset(); ringBuffer.publish(seq); return gpuEvent.getFuture(); } @Override public void run(List<Kernel> work) { GpuFuture future = runAsync(work); future.take(); } @Override public GpuFuture runAsync(List<Kernel> work) { long seq = ringBuffer.next(); GpuEvent gpuEvent = ringBuffer.get(seq); gpuEvent.setKernelList(work); gpuEvent.setValue(GpuEventCommand.NATIVE_RUN_LIST); gpuEvent.getFuture().reset(); ringBuffer.publish(seq); return gpuEvent.getFuture(); } @Override public StatsRow getStats() { return stats; } private class GpuEventHandler implements EventHandler<GpuEvent>{ @Override public void onEvent(final GpuEvent gpuEvent, final long sequence, final boolean endOfBatch){ try { switch(gpuEvent.getValue()){ case NATIVE_BUILD_STATE: boolean usingExceptions = Configuration.runtimeInstance().getExceptions(); nativeBuildState(nativeContext, gpuDevice.getDeviceId(), cubinFile, cubinFile.length, threadConfig.getThreadCountX(), threadConfig.getThreadCountY(), threadConfig.getThreadCountZ(), threadConfig.getBlockCountX(), threadConfig.getBlockCountY(), threadConfig.getNumThreads(), objectMemory, handlesMemory, exceptionsMemory, classMemory, b2i(usingExceptions), cacheConfig.ordinal()); gpuEvent.getFuture().signal(); break; case NATIVE_RUN: writeBlocksTemplate(); runGpu(); readBlocksTemplate(); gpuEvent.getFuture().signal(); break; case NATIVE_RUN_LIST: writeBlocksList(gpuEvent.getKernelList()); runGpu(); readBlocksList(gpuEvent.getKernelList()); gpuEvent.getFuture().signal(); break; } } catch(Exception ex){ gpuEvent.getFuture().setException(ex); gpuEvent.getFuture().signal(); } } } private void writeBlocksTemplate(){ writeBlocksStopwatch.start(); objectMemory.clearHeapEndPtr(); handlesMemory.setAddress(0); Serializer serializer = compiledKernel.getSerializer(objectMemory, textureMemory); serializer.writeStaticsToHeap(); long handle = serializer.writeToHeap(compiledKernel); handlesMemory.writeRef(handle); objectMemory.align16(); if(Configuration.getPrintMem()){ BufferPrinter printer = new BufferPrinter(); printer.print(objectMemory, 0, 256); } writeBlocksStopwatch.stop(); stats.setSerializationTime(writeBlocksStopwatch.elapsedTimeMillis()); } private void writeBlocksList(List<Kernel> work){ writeBlocksStopwatch.start(); objectMemory.clearHeapEndPtr(); Serializer serializer = compiledKernel.getSerializer(objectMemory, textureMemory); serializer.writeStaticsToHeap(); for(Kernel kernel : work){ long handle = serializer.writeToHeap(kernel); handlesMemory.writeRef(handle); } objectMemory.align16(); if(Configuration.getPrintMem()){ BufferPrinter printer = new BufferPrinter(); printer.print(objectMemory, 0, 256); } writeBlocksStopwatch.stop(); stats.setSerializationTime(writeBlocksStopwatch.elapsedTimeMillis()); } private void runGpu(){ runOnGpuStopwatch.start(); cudaRun(nativeContext, objectMemory, b2i(!usingHandles), stats); runOnGpuStopwatch.stop(); requiredMemorySize = objectMemory.getHeapEndPtr(); stats.setExecutionTime(runOnGpuStopwatch.elapsedTimeMillis()); } private void readBlocksSetup(Serializer serializer){ readBlocksStopwatch.start(); objectMemory.setAddress(0); exceptionsMemory.setAddress(0); if(Configuration.runtimeInstance().getExceptions()){ for(long i = 0; i < threadConfig.getNumThreads(); ++i){ long ref = exceptionsMemory.readRef(); if(ref != 0){ long ref_num = ref >> 4; if(ref_num == compiledKernel.getNullPointerNumber()){ throw new NullPointerException("NPE while running on GPU"); } else if(ref_num == compiledKernel.getOutOfMemoryNumber()){ throw new OutOfMemoryError("OOM error while running on GPU"); } objectMemory.setAddress(ref); Object except = serializer.readFromHeap(null, true, ref); if(except instanceof Error){ Error except_th = (Error) except; throw except_th; } else if(except instanceof GpuException){ GpuException gpu_except = (GpuException) except; throw new ArrayIndexOutOfBoundsException("array_index: "+gpu_except.m_arrayIndex+ " array_length: "+gpu_except.m_arrayLength+" array: "+gpu_except.m_array); } else { throw new RuntimeException((Throwable) except); } } } } serializer.readStaticsFromHeap(); } private void readBlocksTemplate(){ Serializer serializer = compiledKernel.getSerializer(objectMemory, textureMemory); readBlocksSetup(serializer); handlesMemory.setAddress(0); long handle = handlesMemory.readRef(); serializer.readFromHeap(compiledKernel, true, handle); if(Configuration.getPrintMem()){ BufferPrinter printer = new BufferPrinter(); printer.print(objectMemory, 0, 256); } readBlocksStopwatch.stop(); stats.setDeserializationTime(readBlocksStopwatch.elapsedTimeMillis()); } public void readBlocksList(List<Kernel> kernelList) { Serializer serializer = compiledKernel.getSerializer(objectMemory, textureMemory); readBlocksSetup(serializer); handlesMemory.setAddress(0); for(Kernel kernel : kernelList){ long ref = handlesMemory.readRef(); serializer.readFromHeap(kernel, true, ref); } if(Configuration.getPrintMem()){ BufferPrinter printer = new BufferPrinter(); printer.print(objectMemory, 0, 256); } readBlocksStopwatch.stop(); stats.setDeserializationTime(readBlocksStopwatch.elapsedTimeMillis()); } private int b2i(boolean value){ if(value){ return 1; } else { return 0; } } private static native void initializeDriver(); private native long allocateNativeContext(); private native void freeNativeContext(long nativeContext); private native void nativeBuildState(long nativeContext, int deviceIndex, byte[] cubinFile, int cubinLength, int threadCountX, int threadCountY, int threadCountZ, int blockCountX, int blockCountY, int numThreads, Memory objectMem, Memory handlesMem, Memory exceptionsMem, Memory classMem, int usingExceptions, int cacheConfig); private native void cudaRun(long nativeContext, Memory objectMem, int usingKernelTemplates, StatsRow stats); }