/* * Copyright 2012 Phil Pratt-Szeliga and other contributors * http://chirrup.org/ * * See the file LICENSE for copying permission. */ package org.trifort.rootbeer.generate.opencl; import soot.jimple.NewExpr; import soot.rbclassload.MethodSignatureUtil; import java.io.BufferedReader; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.io.PrintWriter; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; import org.trifort.rootbeer.configuration.Configuration; import org.trifort.rootbeer.configuration.RootbeerPaths; import org.trifort.rootbeer.entry.ForcedFields; import org.trifort.rootbeer.entry.CompilerSetup; import org.trifort.rootbeer.generate.bytecode.MethodCodeSegment; import org.trifort.rootbeer.generate.bytecode.ReadOnlyTypes; import org.trifort.rootbeer.generate.opencl.fields.CompositeField; import org.trifort.rootbeer.generate.opencl.fields.CompositeFieldFactory; import org.trifort.rootbeer.generate.opencl.fields.FieldCodeGeneration; import org.trifort.rootbeer.generate.opencl.fields.FieldTypeSwitch; import org.trifort.rootbeer.generate.opencl.fields.OffsetCalculator; import org.trifort.rootbeer.generate.opencl.fields.OpenCLField; import org.trifort.rootbeer.generate.opencl.tweaks.CompileResult; import org.trifort.rootbeer.generate.opencl.tweaks.CudaTweaks; import org.trifort.rootbeer.generate.opencl.tweaks.Tweaks; import org.trifort.rootbeer.util.ReadFile; import org.trifort.rootbeer.util.ResourceReader; import soot.*; import soot.rbclassload.FieldSignatureUtil; import soot.rbclassload.NumberedType; import soot.rbclassload.RootbeerClassLoader; public class OpenCLScene { private static OpenCLScene m_instance; private static int m_curentIdent; private Map<String, OpenCLClass> m_classes; private Set<OpenCLArrayType> m_arrayTypes; private MethodHierarchies m_methodHierarchies; private boolean m_usesGarbageCollector; private SootClass m_rootSootClass; private int m_endOfStatics; private ReadOnlyTypes m_readOnlyTypes; private Set<OpenCLInstanceof> m_instanceOfs; private List<CompositeField> m_compositeFields; private List<SootMethod> m_methods; private ClassConstantNumbers m_constantNumbers; private FieldCodeGeneration m_fieldCodeGeneration; static { m_curentIdent = 0; } public OpenCLScene(){ } public void init(){ m_classes = new LinkedHashMap<String, OpenCLClass>(); m_arrayTypes = new LinkedHashSet<OpenCLArrayType>(); m_methodHierarchies = new MethodHierarchies(); m_instanceOfs = new HashSet<OpenCLInstanceof>(); m_methods = new ArrayList<SootMethod>(); m_constantNumbers = new ClassConstantNumbers(); m_fieldCodeGeneration = new FieldCodeGeneration(); loadTypes(); } public static OpenCLScene v(){ return m_instance; } public static void setInstance(OpenCLScene scene){ m_instance = scene; } public static void releaseV(){ m_instance = null; m_curentIdent++; } public String getIdent(){ return "" + m_curentIdent; } public String getUuid(){ return "ab850b60f96d11de8a390800200c9a66"; } public int getEndOfStatics(){ return m_endOfStatics; } public int getClassType(SootClass soot_class){ return RootbeerClassLoader.v().getClassNumber(soot_class); } public void addMethod(SootMethod soot_method){ SootClass soot_class = soot_method.getDeclaringClass(); OpenCLClass ocl_class = getOpenCLClass(soot_class); ocl_class.addMethod(new OpenCLMethod(soot_method, soot_class)); //add the method m_methodHierarchies.addMethod(soot_method); m_methods.add(soot_method); } public List<SootMethod> getMethods(){ return m_methods; } public void addArrayType(OpenCLArrayType array_type){ if(m_arrayTypes.contains(array_type)) return; m_arrayTypes.add(array_type); } public void addInstanceof(Type type){ OpenCLInstanceof to_add = new OpenCLInstanceof(type); if(m_instanceOfs.contains(to_add) == false){ m_instanceOfs.add(to_add); } } public OpenCLClass getOpenCLClass(SootClass soot_class){ String class_name = soot_class.getName(); if(m_classes.containsKey(class_name)){ return m_classes.get(class_name); } else { OpenCLClass ocl_class = new OpenCLClass(soot_class); m_classes.put(class_name, ocl_class); return ocl_class; } } public void addField(SootField soot_field){ SootClass soot_class = soot_field.getDeclaringClass(); OpenCLClass ocl_class = getOpenCLClass(soot_class); ocl_class.addField(new OpenCLField(soot_field, soot_class)); } private String getRuntimeBasicBlockClassName(){ SootClass soot_class = m_rootSootClass; OpenCLClass ocl_class = getOpenCLClass(soot_class); return ocl_class.getName(); } private String readCudaCodeFromFile(){ try { BufferedReader reader = new BufferedReader(new FileReader("generated.cu")); String ret = ""; while(true){ String temp = reader.readLine(); if(temp == null) return ret; ret += temp+"\n"; } } catch(Exception ex){ throw new RuntimeException(); } } public void setUsingGarbageCollector(){ m_usesGarbageCollector = true; } public boolean getUsingGarbageCollector(){ return m_usesGarbageCollector; } private void writeTypesToFile(List<NumberedType> types){ try { PrintWriter writer = new PrintWriter(RootbeerPaths.v().getTypeFile()); for(NumberedType type : types){ writer.println(type.getNumber()+" "+type.getType().toString()); } writer.flush(); writer.close(); } catch(Exception ex){ ex.printStackTrace(); } } public int getOutOfMemoryNumber(){ SootClass soot_class = Scene.v().getSootClass("java.lang.OutOfMemoryError"); int ret = RootbeerClassLoader.v().getClassNumber(soot_class); return ret; } private void loadTypes(){ Set<String> methods = RootbeerClassLoader.v().getDfsInfo().getMethods(); MethodSignatureUtil util = new MethodSignatureUtil(); for(String method_sig : methods){ util.parse(method_sig); SootMethod method = util.getSootMethod(); addMethod(method); } CompilerSetup compiler_setup = new CompilerSetup(); for(String extra_method : compiler_setup.getExtraMethods()){ util.parse(extra_method); addMethod(util.getSootMethod()); } Set<SootField> fields = RootbeerClassLoader.v().getDfsInfo().getFields(); for(SootField field : fields){ addField(field); } FieldSignatureUtil field_util = new FieldSignatureUtil(); ForcedFields forced_fields = new ForcedFields(); for(String field_sig : forced_fields.get()){ field_util.parse(field_sig); addField(field_util.getSootField()); } Set<ArrayType> array_types = RootbeerClassLoader.v().getDfsInfo().getArrayTypes(); for(ArrayType array_type : array_types){ OpenCLArrayType ocl_array_type = new OpenCLArrayType(array_type); addArrayType(ocl_array_type); } for(ArrayType array_type : compiler_setup.getExtraArrayTypes()){ OpenCLArrayType ocl_array_type = new OpenCLArrayType(array_type); addArrayType(ocl_array_type); } Set<Type> instanceofs = RootbeerClassLoader.v().getDfsInfo().getInstanceOfs(); for(Type type : instanceofs){ addInstanceof(type); } buildCompositeFields(); } private String[] makeSourceCode() throws Exception { if(Configuration.compilerInstance().isManualCuda()){ String filename = Configuration.compilerInstance().getManualCudaFilename(); String cuda_code = readCode(filename); String[] ret = new String[2]; ret[0] = cuda_code; ret[1] = cuda_code; return ret; } m_usesGarbageCollector = false; List<NumberedType> types = RootbeerClassLoader.v().getDfsInfo().getNumberedTypes(); writeTypesToFile(types); StringBuilder unix_code = new StringBuilder(); StringBuilder windows_code = new StringBuilder(); String method_protos = methodPrototypesString(); String gc_string = garbageCollectorString(); String bodies_string = methodBodiesString(); unix_code.append(headerString(true)); unix_code.append(method_protos); unix_code.append(gc_string); unix_code.append(bodies_string); unix_code.append(kernelString(true)); windows_code.append(headerString(false)); windows_code.append(method_protos); windows_code.append(gc_string); windows_code.append(bodies_string); windows_code.append(kernelString(false)); String cuda_unix = setupEntryPoint(unix_code); String cuda_windows = setupEntryPoint(windows_code); //print out code for debugging PrintWriter writer = new PrintWriter(new FileWriter(RootbeerPaths.v().getRootbeerHome()+"generated_unix.cu")); writer.println(cuda_unix); writer.flush(); writer.close(); //print out code for debugging writer = new PrintWriter(new FileWriter(RootbeerPaths.v().getRootbeerHome()+"generated_windows.cu")); writer.println(cuda_windows); writer.flush(); writer.close(); NameMangling.v().writeTypesToFile(); String[] ret = new String[2]; ret[0] = cuda_unix; ret[1] = cuda_windows; return ret; } private String readCode(String filename){ ReadFile reader = new ReadFile(filename); try { return reader.read(); } catch(Exception ex){ ex.printStackTrace(System.out); throw new RuntimeException(ex); } } private String setupEntryPoint(StringBuilder builder){ String cuda_code = builder.toString(); String mangle = NameMangling.v().mangle(VoidType.v()); String replacement = getRuntimeBasicBlockClassName()+"_gpuMethod"+mangle; //class names can have $ in them, make them regex safe replacement = replacement.replace("$", "\\$"); cuda_code = cuda_code.replaceAll("%%invoke_run%%", replacement); int string_builder_number = RootbeerClassLoader.v().getClassNumber("java.lang.StringBuilder"); String sbn_str = "" + string_builder_number; cuda_code = cuda_code.replaceAll("%%java_lang_StringBuilder_TypeNumber%%", sbn_str); int null_pointer_number = RootbeerClassLoader.v().getClassNumber("java.lang.NullPointerException"); String np_str = "" + null_pointer_number; cuda_code = cuda_code.replaceAll("%%java_lang_NullPointerException_TypeNumber%%", np_str); int out_of_memory_number = RootbeerClassLoader.v().getClassNumber("java.lang.OutOfMemoryError"); String out_of_memory_str = "" + out_of_memory_number; cuda_code = cuda_code.replaceAll("%%java_lang_OutOfMemoryError_TypeNumber%%", out_of_memory_str); int size = Configuration.compilerInstance().getSharedMemSize(); String size_str = ""+size; cuda_code = cuda_code.replaceAll("%%shared_mem_size%%", size_str); boolean exceptions = Configuration.compilerInstance().getExceptions(); String exceptions_str; if(exceptions){ exceptions_str = ""+1; } else { exceptions_str = ""+0; } cuda_code = cuda_code.replaceAll("%%using_exceptions%%", exceptions_str); int string_number = RootbeerClassLoader.v().getClassNumber("java.lang.String"); String string_str = "" + string_number; cuda_code = cuda_code.replaceAll("%%java_lang_String_TypeNumber%%", string_str); int integer_number = RootbeerClassLoader.v().getClassNumber("java.lang.Integer"); String integer_str = "" + integer_number; cuda_code = cuda_code.replaceAll("%%java_lang_Integer_TypeNumber%%", integer_str); int long_number = RootbeerClassLoader.v().getClassNumber("java.lang.Long"); String long_str = "" + long_number; cuda_code = cuda_code.replaceAll("%%java_lang_Long_TypeNumber%%", long_str); int float_number = RootbeerClassLoader.v().getClassNumber("java.lang.Float"); String float_str = "" + float_number; cuda_code = cuda_code.replaceAll("%%java_lang_Float_TypeNumber%%", float_str); int double_number = RootbeerClassLoader.v().getClassNumber("java.lang.Double"); String double_str = "" + double_number; cuda_code = cuda_code.replaceAll("%%java_lang_Double_TypeNumber%%", double_str); int boolean_number = RootbeerClassLoader.v().getClassNumber("java.lang.Boolean"); String boolean_str = "" + boolean_number; cuda_code = cuda_code.replaceAll("%%java_lang_Boolean_TypeNumber%%", boolean_str); return cuda_code; } public String[] getOpenCLCode() throws Exception { String[] source_code = makeSourceCode(); return source_code; } public CompileResult[] getCudaCode() throws Exception { String[] source_code = makeSourceCode(); return new CudaTweaks().compileProgram(source_code[0], Configuration.compilerInstance().getCompileArchitecture()); } private String headerString(boolean unix) throws IOException { String defines = ""; if(Configuration.compilerInstance().getArrayChecks()){ defines += "#define ARRAY_CHECKS\n"; } String specific_path; if(unix){ specific_path = Tweaks.v().getUnixHeaderPath(); } else { specific_path = Tweaks.v().getWindowsHeaderPath(); } if(specific_path == null) return ""; String both_path = Tweaks.v().getBothHeaderPath(); String both_header = ""; if(both_path != null){ both_header = ResourceReader.getResource(both_path); } String specific_header = ResourceReader.getResource(specific_path); String barrier_path = Tweaks.v().getBarrierPath(); String barrier_code = ""; if(barrier_path != null){ barrier_code = ResourceReader.getResource(barrier_path); } return defines + "\n" + specific_header + "\n" + both_header + "\n" + barrier_code; } private String kernelString(boolean unix) throws IOException { String kernel_path; if(unix){ kernel_path = Tweaks.v().getUnixKernelPath(); } else { kernel_path = Tweaks.v().getWindowsKernelPath(); } String specific_kernel_code = ResourceReader.getResource(kernel_path); String both_kernel_code = ""; String both_kernel_path = Tweaks.v().getBothKernelPath(); if(both_kernel_path != null){ both_kernel_code = ResourceReader.getResource(both_kernel_path); } return both_kernel_code + "\n" + specific_kernel_code; } private String garbageCollectorString() throws IOException { String path = Tweaks.v().getGarbageCollectorPath(); String ret = ResourceReader.getResource(path); ret = ret.replace("$$__device__$$", Tweaks.v().getDeviceFunctionQualifier()); ret = ret.replace("$$__global$$", Tweaks.v().getGlobalAddressSpaceQualifier()); return ret; } private String methodPrototypesString(){ //using a set so duplicates get filtered out. Set<String> protos = new HashSet<String>(); StringBuilder ret = new StringBuilder(); ArrayCopyGenerate arr_generate = new ArrayCopyGenerate(); protos.add(arr_generate.getProto()); List<OpenCLMethod> methods = m_methodHierarchies.getMethods(); for(OpenCLMethod method : methods){ protos.add(method.getMethodPrototype()); } List<OpenCLPolymorphicMethod> poly_methods = m_methodHierarchies.getPolyMorphicMethods(); for(OpenCLPolymorphicMethod poly_method : poly_methods){ protos.add(poly_method.getMethodPrototypes()); } protos.add(m_fieldCodeGeneration.prototypes(m_classes)); for(OpenCLArrayType array_type : m_arrayTypes){ protos.add(array_type.getPrototypes()); } for(OpenCLInstanceof type : m_instanceOfs){ protos.add(type.getPrototype()); } Iterator<String> iter = protos.iterator(); while(iter.hasNext()){ ret.append(iter.next()); } return ret.toString(); } private String methodBodiesString() throws IOException{ StringBuilder ret = new StringBuilder(); if(m_usesGarbageCollector){ ret.append("#define USING_GARBAGE_COLLECTOR\n"); } //a set is used so duplicates get filtered out Set<String> bodies = new HashSet<String>(); ArrayCopyTypeReduction reduction = new ArrayCopyTypeReduction(); Set<OpenCLArrayType> new_types = reduction.run(m_arrayTypes, m_methodHierarchies); ArrayCopyGenerate arr_generate = new ArrayCopyGenerate(); bodies.add(arr_generate.get(new_types)); List<OpenCLMethod> methods = m_methodHierarchies.getMethods(); for(OpenCLMethod method : methods){ bodies.add(method.getMethodBody()); } List<OpenCLPolymorphicMethod> poly_methods = m_methodHierarchies.getPolyMorphicMethods(); for(OpenCLPolymorphicMethod poly_method : poly_methods){ bodies.addAll(poly_method.getMethodBodies()); } FieldTypeSwitch type_switch = new FieldTypeSwitch(); String field_bodies = m_fieldCodeGeneration.bodies(m_classes, type_switch); bodies.add(field_bodies); for(OpenCLArrayType array_type : m_arrayTypes){ bodies.add(array_type.getBodies()); } for(OpenCLInstanceof type : m_instanceOfs){ bodies.add(type.getBody()); } Iterator<String> iter = bodies.iterator(); ret.append(type_switch.getFunctions()); while(iter.hasNext()){ ret.append(iter.next()); } return ret.toString(); } public OffsetCalculator getOffsetCalculator(SootClass soot_class){ List<CompositeField> composites = getCompositeFields(); for(CompositeField composite : composites){ List<SootClass> classes = composite.getClasses(); if(classes.contains(soot_class)) return new OffsetCalculator(composite); } throw new RuntimeException("Cannot find composite field for soot_class"); } public void addCodeSegment(MethodCodeSegment codeSegment){ m_rootSootClass = codeSegment.getRootSootClass(); m_readOnlyTypes = new ReadOnlyTypes(codeSegment.getRootMethod()); getOpenCLClass(m_rootSootClass); } public boolean isArrayLocalWrittenTo(Local local){ return true; } public ReadOnlyTypes getReadOnlyTypes(){ return m_readOnlyTypes; } public boolean isRootClass(SootClass soot_class) { return soot_class.getName().equals(m_rootSootClass.getName()); } public Map<String, OpenCLClass> getClassMap(){ return m_classes; } public List<CompositeField> getCompositeFields() { return m_compositeFields; } private void buildCompositeFields() { CompositeFieldFactory factory = new CompositeFieldFactory(); factory.setup(m_classes); m_compositeFields = factory.getCompositeFields(); } public ClassConstantNumbers getClassConstantNumbers(){ return m_constantNumbers; } }