package org.trifort.rootbeer.sort; import org.trifort.rootbeer.runtime.Rootbeer; import org.trifort.rootbeer.runtime.GpuDevice; import org.trifort.rootbeer.runtime.Context; import org.trifort.rootbeer.runtime.ThreadConfig; import org.trifort.rootbeer.runtime.StatsRow; import org.trifort.rootbeer.runtime.CacheConfig; import java.util.List; import java.util.Arrays; import java.util.Random; public class GPUSort { private int[] newArray(int size){ int[] ret = new int[size]; for(int i = 0; i < size; ++i){ ret[i] = i; } return ret; } public void checkSorted(int[] array, int outerIndex){ for(int index = 0; index < array.length; ++index){ if(array[index] != index){ for(int index2 = 0; index2 < array.length; ++index2){ System.out.println("array["+index2+"]: "+array[index2]); } throw new RuntimeException("not sorted: "+outerIndex); } } } public void fisherYates(int[] array) { Random random = new Random(); for (int i = array.length - 1; i > 0; i--){ int index = random.nextInt(i + 1); int a = array[index]; array[index] = array[i]; array[i] = a; } } public void sort(){ //should have 192 threads per SM int size = 2048; int sizeBy2 = size / 2; //int numMultiProcessors = 14; //int blocksPerMultiProcessor = 512; int numMultiProcessors = 2; int blocksPerMultiProcessor = 256; int outerCount = numMultiProcessors*blocksPerMultiProcessor; int[][] array = new int[outerCount][]; for(int i = 0; i < outerCount; ++i){ array[i] = newArray(size); } Rootbeer rootbeer = new Rootbeer(); List<GpuDevice> devices = rootbeer.getDevices(); GpuDevice device0 = devices.get(0); Context context0 = device0.createContext(4212880); context0.setCacheConfig(CacheConfig.PREFER_SHARED); context0.setThreadConfig(sizeBy2, outerCount, outerCount * sizeBy2); context0.setKernel(new GPUSortKernel(array)); context0.buildState(); while(true){ for(int i = 0; i < outerCount; ++i){ fisherYates(array[i]); } long gpuStart = System.currentTimeMillis(); context0.run(); long gpuStop = System.currentTimeMillis(); long gpuTime = gpuStop - gpuStart; StatsRow row0 = context0.getStats(); System.out.println("serialization_time: "+row0.getSerializationTime()); System.out.println("driver_memcopy_to_device_time: "+row0.getDriverMemcopyToDeviceTime()); System.out.println("driver_execution_time: "+row0.getDriverExecTime()); System.out.println("driver_memcopy_from_device_time: "+row0.getDriverMemcopyFromDeviceTime()); System.out.println("total_driver_execution_time: "+row0.getTotalDriverExecutionTime()); System.out.println("deserialization_time: "+row0.getDeserializationTime()); System.out.println("gpu_required_memory: "+context0.getRequiredMemory()); System.out.println("gpu_time: "+gpuTime); for(int i = 0; i < outerCount; ++i){ checkSorted(array[i], i); fisherYates(array[i]); } long cpuStart = System.currentTimeMillis(); for(int i = 0; i < outerCount; ++i){ Arrays.sort(array[i]); } long cpuStop = System.currentTimeMillis(); long cpuTime = cpuStop - cpuStart; System.out.println("cpu_time: "+cpuTime); double ratio = (double) cpuTime / (double) gpuTime; System.out.println("ratio: "+ratio); } //context0.close(); } public static void main(String[] args){ GPUSort sorter = new GPUSort(); while(true){ sorter.sort(); } } }