package com.amd.aparapi.device; import com.amd.aparapi.ProfileInfo; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.lang.annotation.Annotation; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import com.amd.aparapi.Range; import com.amd.aparapi.internal.opencl.OpenCLArgDescriptor; import com.amd.aparapi.internal.opencl.OpenCLKernel; import com.amd.aparapi.internal.opencl.OpenCLPlatform; import com.amd.aparapi.internal.opencl.OpenCLProgram; import com.amd.aparapi.opencl.OpenCL; import com.amd.aparapi.opencl.OpenCL.Arg; import com.amd.aparapi.opencl.OpenCL.Constant; import com.amd.aparapi.opencl.OpenCL.GlobalReadOnly; import com.amd.aparapi.opencl.OpenCL.GlobalReadWrite; import com.amd.aparapi.opencl.OpenCL.GlobalWriteOnly; import com.amd.aparapi.opencl.OpenCL.Kernel; import com.amd.aparapi.opencl.OpenCL.Local; import com.amd.aparapi.opencl.OpenCL.Resource; import com.amd.aparapi.opencl.OpenCL.Source; public class OpenCLDevice extends Device{ private final OpenCLPlatform platform; private final long deviceId; private int maxComputeUnits; private long localMemSize; private long globalMemSize; private long maxMemAllocSize; /** * Minimal constructor * * @param _platform * @param _deviceId * @param _type */ public OpenCLDevice(OpenCLPlatform _platform, long _deviceId, TYPE _type) { platform = _platform; deviceId = _deviceId; type = _type; } public OpenCLPlatform getOpenCLPlatform() { return platform; } public int getMaxComputeUnits() { return maxComputeUnits; } public void setMaxComputeUnits(int _maxComputeUnits) { maxComputeUnits = _maxComputeUnits; } public long getLocalMemSize() { return localMemSize; } public void setLocalMemSize(long _localMemSize) { localMemSize = _localMemSize; } public long getMaxMemAllocSize() { return maxMemAllocSize; } public void setMaxMemAllocSize(long _maxMemAllocSize) { maxMemAllocSize = _maxMemAllocSize; } public long getGlobalMemSize() { return globalMemSize; } public void setGlobalMemSize(long _globalMemSize) { globalMemSize = _globalMemSize; } void setMaxWorkItemSize(int _dim, int _value) { maxWorkItemSize[_dim] = _value; } public long getDeviceId() { return (deviceId); } public static class OpenCLInvocationHandler<T extends OpenCL<T>> implements InvocationHandler{ private final Map<String, OpenCLKernel> map; private final OpenCLProgram program; private boolean disposed = false; public OpenCLInvocationHandler(OpenCLProgram _program, Map<String, OpenCLKernel> _map) { program = _program; map = _map; disposed = false; } @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { if (disposed){ throw new IllegalStateException("bound interface already disposed"); } if (!isReservedInterfaceMethod(method)) { final OpenCLKernel kernel = map.get(method.getName()); if (kernel != null) { kernel.invoke(args); } } else { if (method.getName().equals("put")) { System.out.println("put not implemented"); /* for (Object arg : args) { Class<?> argClass = arg.getClass(); if (argClass.isArray()) { if (argClass.getComponentType().isPrimitive()) { OpenCLMem mem = program.getMem(arg, 0L); if (mem == null) { throw new IllegalStateException("can't put an array that has never been passed to a kernel " + argClass); } mem.bits |= OpenCLMem.MEM_DIRTY_BIT; } else { throw new IllegalStateException("Only array args (of primitives) expected for put/get, cant deal with " + argClass); } } else { throw new IllegalStateException("Only array args expected for put/get, cant deal with " + argClass); } } */ } else if (method.getName().equals("get")) { System.out.println("get not implemented"); /* for (Object arg : args) { Class<?> argClass = arg.getClass(); if (argClass.isArray()) { if (argClass.getComponentType().isPrimitive()) { OpenCLMem mem = program.getMem(arg, 0L); if (mem == null) { throw new IllegalStateException("can't get an array that has never been passed to a kernel " + argClass); } OpenCLJNI.getJNI().getMem(program, mem); } else { throw new IllegalStateException("Only array args (of primitives) expected for put/get, cant deal with " + argClass); } } else { throw new IllegalStateException("Only array args expected for put/get, cant deal with " + argClass); } } */ } else if (method.getName().equals("begin")) { System.out.println("begin not implemented"); } else if (method.getName().equals("dispose")) { // System.out.println("dispose"); for (OpenCLKernel k:map.values()){ k.dispose(); } program.dispose(); map.clear(); disposed=true; } else if (method.getName().equals("end")) { System.out.println("end not implemented"); } else if (method.getName().equals("getProfileInfo")){ proxy = (Object)program.getProfileInfo(); } } return proxy; } } public List<OpenCLArgDescriptor> getArgs(Method m) { final List<OpenCLArgDescriptor> args = new ArrayList<OpenCLArgDescriptor>(); final Annotation[][] parameterAnnotations = m.getParameterAnnotations(); final Class<?>[] parameterTypes = m.getParameterTypes(); for (int arg = 0; arg < parameterTypes.length; arg++) { if (parameterTypes[arg].isAssignableFrom(Range.class)) { } else { long bits = 0L; String name = null; for (final Annotation pa : parameterAnnotations[arg]) { if (pa instanceof GlobalReadOnly) { name = ((GlobalReadOnly) pa).value(); bits |= OpenCLArgDescriptor.ARG_GLOBAL_BIT | OpenCLArgDescriptor.ARG_READONLY_BIT; } else if (pa instanceof GlobalWriteOnly) { name = ((GlobalWriteOnly) pa).value(); bits |= OpenCLArgDescriptor.ARG_GLOBAL_BIT | OpenCLArgDescriptor.ARG_WRITEONLY_BIT; } else if (pa instanceof GlobalReadWrite) { name = ((GlobalReadWrite) pa).value(); bits |= OpenCLArgDescriptor.ARG_GLOBAL_BIT | OpenCLArgDescriptor.ARG_READWRITE_BIT; } else if (pa instanceof Local) { name = ((Local) pa).value(); bits |= OpenCLArgDescriptor.ARG_LOCAL_BIT; } else if (pa instanceof Constant) { name = ((Constant) pa).value(); bits |= OpenCLArgDescriptor.ARG_CONST_BIT | OpenCLArgDescriptor.ARG_READONLY_BIT; } else if (pa instanceof Arg) { name = ((Arg) pa).value(); bits |= OpenCLArgDescriptor.ARG_ISARG_BIT; } } if (parameterTypes[arg].isArray()) { if (parameterTypes[arg].isAssignableFrom(float[].class)) { bits |= OpenCLArgDescriptor.ARG_FLOAT_BIT | OpenCLArgDescriptor.ARG_ARRAY_BIT; } else if (parameterTypes[arg].isAssignableFrom(int[].class)) { bits |= OpenCLArgDescriptor.ARG_INT_BIT | OpenCLArgDescriptor.ARG_ARRAY_BIT; } else if (parameterTypes[arg].isAssignableFrom(double[].class)) { bits |= OpenCLArgDescriptor.ARG_DOUBLE_BIT | OpenCLArgDescriptor.ARG_ARRAY_BIT; } else if (parameterTypes[arg].isAssignableFrom(byte[].class)) { bits |= OpenCLArgDescriptor.ARG_BYTE_BIT | OpenCLArgDescriptor.ARG_ARRAY_BIT; } else if (parameterTypes[arg].isAssignableFrom(short[].class)) { bits |= OpenCLArgDescriptor.ARG_SHORT_BIT | OpenCLArgDescriptor.ARG_ARRAY_BIT; } else if (parameterTypes[arg].isAssignableFrom(long[].class)) { bits |= OpenCLArgDescriptor.ARG_LONG_BIT | OpenCLArgDescriptor.ARG_ARRAY_BIT; } } else if (parameterTypes[arg].isPrimitive()) { if (parameterTypes[arg].isAssignableFrom(float.class)) { bits |= OpenCLArgDescriptor.ARG_FLOAT_BIT | OpenCLArgDescriptor.ARG_PRIMITIVE_BIT; } else if (parameterTypes[arg].isAssignableFrom(int.class)) { bits |= OpenCLArgDescriptor.ARG_INT_BIT | OpenCLArgDescriptor.ARG_PRIMITIVE_BIT; } else if (parameterTypes[arg].isAssignableFrom(double.class)) { bits |= OpenCLArgDescriptor.ARG_DOUBLE_BIT | OpenCLArgDescriptor.ARG_PRIMITIVE_BIT; } else if (parameterTypes[arg].isAssignableFrom(byte.class)) { bits |= OpenCLArgDescriptor.ARG_BYTE_BIT | OpenCLArgDescriptor.ARG_PRIMITIVE_BIT; } else if (parameterTypes[arg].isAssignableFrom(short.class)) { bits |= OpenCLArgDescriptor.ARG_SHORT_BIT | OpenCLArgDescriptor.ARG_PRIMITIVE_BIT; } else if (parameterTypes[arg].isAssignableFrom(long.class)) { bits |= OpenCLArgDescriptor.ARG_LONG_BIT | OpenCLArgDescriptor.ARG_PRIMITIVE_BIT; } } else { System.out.println("OUch!"); } if (name == null) { throw new IllegalStateException("no name!"); } final OpenCLArgDescriptor kernelArg = new OpenCLArgDescriptor(name, bits); args.add(kernelArg); } } return (args); } private static boolean isReservedInterfaceMethod(Method _methods) { return ( _methods.getName().equals("put") || _methods.getName().equals("get") || _methods.getName().equals("dispose") || _methods.getName().equals("begin") || _methods.getName().equals("end") || _methods.getName().equals("getProfileInfo")); } private String streamToString(InputStream _inputStream) { final StringBuilder sourceBuilder = new StringBuilder(); if (_inputStream != null) { final BufferedReader reader = new BufferedReader(new InputStreamReader(_inputStream)); try { for (String line = reader.readLine(); line != null; line = reader.readLine()) { sourceBuilder.append(line).append("\n"); } } catch (final IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } try { _inputStream.close(); } catch (final IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } return (sourceBuilder.toString()); } public <T extends OpenCL<T>> T bind(Class<T> _interface, InputStream _inputStream) { return (bind(_interface, streamToString(_inputStream))); } public <T extends OpenCL<T>> T bind(Class<T> _interface) { return (bind(_interface, (String) null)); } public <T extends OpenCL<T>> T bind(Class<T> _interface, String _source) { final Map<String, List<OpenCLArgDescriptor>> kernelNameToArgsMap = new HashMap<String, List<OpenCLArgDescriptor>>(); if (_source == null) { final StringBuilder sourceBuilder = new StringBuilder(); boolean interfaceIsAnnotated = false; for (final Annotation a : _interface.getAnnotations()) { if (a instanceof Source) { final Source source = (Source) a; sourceBuilder.append(source.value()).append("\n"); interfaceIsAnnotated = true; } else if (a instanceof Resource) { final Resource sourceResource = (Resource) a; final InputStream stream = _interface.getClassLoader().getResourceAsStream(sourceResource.value()); sourceBuilder.append(streamToString(stream)); interfaceIsAnnotated = true; } } if (interfaceIsAnnotated) { // just crawl the methods (non put or get) and create kernels for (final Method m : _interface.getDeclaredMethods()) { if (!isReservedInterfaceMethod(m)) { final List<OpenCLArgDescriptor> args = getArgs(m); kernelNameToArgsMap.put(m.getName(), args); } } } else { for (final Method m : _interface.getDeclaredMethods()) { if (!isReservedInterfaceMethod(m)) { for (final Annotation a : m.getAnnotations()) { // System.out.println(" annotation "+a); // System.out.println(" annotation type " + a.annotationType()); if (a instanceof Kernel) { sourceBuilder.append("__kernel void " + m.getName() + "("); final List<OpenCLArgDescriptor> args = getArgs(m); boolean first = true; for (final OpenCLArgDescriptor arg : args) { if (first) { first = false; } else { sourceBuilder.append(","); } sourceBuilder.append("\n " + arg); } sourceBuilder.append(")"); final Kernel kernel = (Kernel) a; sourceBuilder.append(kernel.value()); kernelNameToArgsMap.put(m.getName(), args); } } } } } _source = sourceBuilder.toString(); } else { for (final Method m : _interface.getDeclaredMethods()) { if (!isReservedInterfaceMethod(m)) { final List<OpenCLArgDescriptor> args = getArgs(m); kernelNameToArgsMap.put(m.getName(), args); } } } // System.out.println("opencl{\n" + _source + "\n}opencl"); final OpenCLProgram program = new OpenCLProgram(this, _source).createProgram(this); final Map<String, OpenCLKernel> map = new HashMap<String, OpenCLKernel>(); for (final String name : kernelNameToArgsMap.keySet()) { final OpenCLKernel kernel = OpenCLKernel.createKernel(program, name, kernelNameToArgsMap.get(name)); //final OpenCLKernel kernel = new OpenCLKernel(program, name, kernelNameToArgsMap.get(name)); if (kernel == null) { throw new IllegalStateException("kernel is null"); } map.put(name, kernel); } final OpenCLInvocationHandler<T> invocationHandler = new OpenCLInvocationHandler<T>(program, map); final T instance = (T) Proxy.newProxyInstance(OpenCLDevice.class.getClassLoader(), new Class[] { _interface, OpenCL.class }, invocationHandler); return instance; } public interface DeviceSelector{ OpenCLDevice select(OpenCLDevice _device); } public interface DeviceComparitor{ OpenCLDevice select(OpenCLDevice _deviceLhs, OpenCLDevice _deviceRhs); } public static OpenCLDevice select(DeviceSelector _deviceSelector) { OpenCLDevice device = null; final OpenCLPlatform platform = new OpenCLPlatform(0, null, null, null); for (final OpenCLPlatform p : platform.getOpenCLPlatforms()) { for (final OpenCLDevice d : p.getOpenCLDevices()) { device = _deviceSelector.select(d); if (device != null) { break; } } if (device != null) { break; } } return (device); } public static OpenCLDevice select(DeviceComparitor _deviceComparitor) { OpenCLDevice device = null; final OpenCLPlatform platform = new OpenCLPlatform(0, null, null, null); for (final OpenCLPlatform p : platform.getOpenCLPlatforms()) { for (final OpenCLDevice d : p.getOpenCLDevices()) { if (device == null) { device = d; } else { device = _deviceComparitor.select(device, d); } } } return (device); } @Override public String toString() { final StringBuilder s = new StringBuilder("{"); boolean first = true; for (final int workItemSize : maxWorkItemSize) { if (first) { first = false; } else { s.append(", "); } s.append(workItemSize); } s.append("}"); return ("Device " + deviceId + "\n type:" + type + "\n maxComputeUnits=" + maxComputeUnits + "\n maxWorkItemDimensions=" + maxWorkItemDimensions + "\n maxWorkItemSizes=" + s + "\n maxWorkWorkGroupSize=" + maxWorkGroupSize + "\n globalMemSize=" + globalMemSize + "\n localMemSize=" + localMemSize); } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; OpenCLDevice that = (OpenCLDevice) o; if (deviceId != that.deviceId) return false; return true; } @Override public int hashCode() { return (int) (deviceId ^ (deviceId >>> 32)); } }