package x10cuda.visit; import java.util.List; import polyglot.ast.Binary; import polyglot.ast.Block; import polyglot.ast.Call; import polyglot.ast.CanonicalTypeNode; import polyglot.ast.Eval; import polyglot.ast.Expr; import polyglot.ast.For; import polyglot.ast.Formal; import polyglot.ast.IntLit; import polyglot.ast.Local; import polyglot.ast.LocalDecl; import polyglot.ast.Node; import polyglot.ast.NodeFactory; import polyglot.ast.Receiver; import polyglot.ast.Stmt; import polyglot.ast.Term; import polyglot.ast.Try; import polyglot.ast.TypeNode; import polyglot.frontend.Job; import polyglot.types.ClassType; import polyglot.types.Name; import polyglot.types.QName; import polyglot.types.SemanticException; import polyglot.types.Type; import polyglot.types.TypeSystem; import polyglot.util.ErrorInfo; import polyglot.visit.ContextVisitor; import polyglot.visit.NodeVisitor; import x10.ast.Async; import x10.ast.AtStmt; import x10.ast.Closure; import x10.ast.Finish; import x10.ast.X10Call; import x10.ast.X10Formal; import x10.ast.X10Loop; import x10.ast.X10New; import x10.extension.X10Ext; import x10.types.MethodInstance; import x10.types.TypeParamSubst; import x10.types.X10ClassType; import x10.visit.NodeTransformingVisitor; import x10.visit.Reinstantiator; import x10cpp.visit.Emitter; import x10cuda.ast.CUDAKernel; import x10cuda.types.CUDAData; import x10cuda.types.SharedMem; public class CUDAPatternMatcher extends ContextVisitor { private static final String ANN_KERNEL = "x10.compiler.CUDA"; private static final String ANN_DIRECT_PARAMS = "x10.compiler.CUDADirectParams"; public CUDAPatternMatcher(Job job, TypeSystem ts, NodeFactory nf) { super(job, ts, nf); } private TypeSystem xts() { return ts; } // Type from name private Type getType(String name) throws SemanticException { return xts().systemResolver().findOne(QName.make(name)); } // does the block have the given annotation private boolean nodeHasAnnotation(Node n, String ann) { X10Ext ext = (X10Ext) n.ext(); try { return !ext.annotationMatching(getType(ann)).isEmpty(); } catch (SemanticException e) { assert false : e; return false; // in case asserts are off } } // does the block have the annotation that denotes that it should be // split-compiled to cuda? private boolean blockIsKernel(Node n) { return nodeHasAnnotation(n, ANN_KERNEL); } // does the block have the annotation that denotes that it should be // compiled to use conventional cuda kernel params private boolean kernelWantsDirectParams(Node n) { return nodeHasAnnotation(n, ANN_DIRECT_PARAMS); } private static class Complaint extends RuntimeException { } private void complainIfNot(boolean cond, String exp, Node n, boolean except) { complainIfNot2(cond, "@CUDA Expected: " + exp, n, except); } private void complainIfNot2(boolean cond, String exp, Node n, boolean except) { if (!cond) { job.compiler().errorQueue().enqueue(ErrorInfo.SEMANTIC_ERROR, exp, n.position()); if (except) throw new Complaint(); } } private void complainIfNot(boolean cond, String exp, Node n) { complainIfNot(cond, exp, n, true); } private void complainIfNot2(boolean cond, String exp, Node n) { complainIfNot2(cond, exp, n, true); } private Type arrayCargo(Type typ) { if (xts().isArray(typ)) { typ = typ.toClass(); X10ClassType ctyp = (X10ClassType) typ; assert ctyp.typeArguments() != null && ctyp.typeArguments().size() == 1; // Array[T] return ctyp.typeArguments().get(0); } if (xts().isRemoteArray(typ)) { typ = typ.toClass(); X10ClassType ctyp = (X10ClassType) typ; assert ctyp.typeArguments() != null && ctyp.typeArguments().size() == 1; // RemoteRef[Array[T]] Type type2 = ctyp.typeArguments().get(0); X10ClassType ctyp2 = (X10ClassType) typ; assert ctyp2.typeArguments() != null && ctyp2.typeArguments().size() == 1; // Array[T] return ctyp2.typeArguments().get(0); } return null; } private boolean isFloatArray(Type typ) { Type cargo = arrayCargo(typ); return cargo != null && cargo.isFloat(); } private boolean isIntArray(Type typ) { Type cargo = arrayCargo(typ); return cargo != null && cargo.isInt(); } // Java cannot return multiple values from a function class MultipleValues { public Expr max; public Formal var; public Block body; } protected MultipleValues processLoop(Block b) { Node loop_ = b.statements().get(0); complainIfNot(loop_ instanceof X10Loop, "A 1-dimensional iteration of the form 0..", loop_); X10Loop loop = (X10Loop) loop_; MultipleValues r = new MultipleValues(); Formal loop_formal = loop.formal(); complainIfNot(loop_formal instanceof X10Formal, "named loop formal", loop); X10Formal loop_x10_formal = (X10Formal) loop_formal; r.var = loop_x10_formal; Expr domain = loop.domain(); complainIfNot(domain instanceof Binary, "An iteration over a int range literal of the form 0..", domain); Binary region = (Binary) domain; complainIfNot(region.operator() == Binary.DOT_DOT, "An iteration over an int range literal of the form 0..", domain); MethodInstance mi = region.methodInstance(); ClassType target_type = mi.container().toClass(); complainIfNot(target_type.isInt(), "An iteration over an int range literal of the form 0..", domain); Expr from_ = region.left(); Expr to_ = region.right(); complainIfNot(from_ instanceof IntLit, "An iteration over an int range literal of the form 0..", from_); IntLit from = (IntLit) from_; complainIfNot(from.value() == 0, "An iteration over an int range literal of the form 0..", from_); r.max = to_; r.body = (Block) loop.body(); return r; } public static boolean checkStaticCall(Expr e, String class_name, String method_name, int args) { try { Call call = (Call) e; if (!call.name().toString().equals(method_name)) return false; CanonicalTypeNode async_target_type_node = (CanonicalTypeNode) (call .target()); if (!async_target_type_node.nameString().equals(class_name)) return false; if (call.arguments().size() != args) return false; } catch (ClassCastException exc) { return false; } return true; } public static boolean checkStaticCall(Stmt s, String class_name, String method_name, int args) { try { Expr e = ((Eval) s).expr(); return checkStaticCall(e, class_name, method_name, args); } catch (ClassCastException e) { return false; } } public static Stmt checkFinish(Stmt s, boolean clocked) { try { Finish s1 = (Finish) s; if (s1.clocked() != clocked) return null; return s1.body(); } catch (ClassCastException e) { return null; } } public static Block checkAsync(Stmt s, boolean clocked) { try { Async s1 = (Async) s; if (s1.clocked() != clocked) return null; return (Block) s1.body(); } catch (ClassCastException e) { return null; } } public Node leaveCall(Node parent, Node old, Node child, NodeVisitor visitor) { if (child instanceof Block) { Block b = (Block) child; if (blockIsKernel(b)) { try { //System.out.println("Got kernel: "); //parent.prettyPrint(System.out); //System.out.println(); Reinstantiator reinstantiator = new Reinstantiator(TypeParamSubst.IDENTITY); Block kernel_block = (Block) b.visit(new NodeTransformingVisitor(job, ts, nf, reinstantiator).context(context())); boolean direct = kernelWantsDirectParams(b); complainIfNot2(parent instanceof AtStmt, "@CUDA annotation must be on an at body", kernel_block); // if there are no autoblocks/threads statemnets, this will be 1 complainIfNot(kernel_block.statements().size() >= 1, "A block containing at least one statement.", kernel_block); LocalDecl autoBlocks = null, autoThreads = null; // handle autoblocks/autothreads and constant memory // declarations SharedMem cmem = new SharedMem(); for (int i = 0; i < kernel_block.statements().size() - 1; ++i) { Stmt ld_ = kernel_block.statements().get(i); complainIfNot( ld_ instanceof LocalDecl, "val <something> = <autoBlocks/Threads or constant cache definition", ld_); LocalDecl ld = (LocalDecl) ld_; Expr init_expr = ld.init(); if (init_expr instanceof X10Call) { X10Call init_call = (X10Call) init_expr; Receiver init_call_target = init_call.target(); if (init_call_target instanceof CanonicalTypeNode) { CanonicalTypeNode init_call_target_node = (CanonicalTypeNode) init_call_target; String classname = init_call_target_node.nameString(); int targs = init_call.typeArguments().size(); int args = init_call.arguments().size(); String methodname = init_call.name().toString(); if (classname.equals("CUDAUtilities") && targs == 0 && args == 0 && methodname.equals("autoBlocks")) { complainIfNot2(autoBlocks == null, "@CUDA: Already have autoBlocks", init_call); autoBlocks = ld.init(null); } else if (classname.equals("CUDAUtilities") && targs == 0 && args == 0 && methodname.equals("autoThreads")) { complainIfNot2(autoThreads == null, "@CUDA: Already have autoThreads", init_call); autoThreads = ld.init(null); } else { complainIfNot(false, "A call to CUDAUtilities.autoBlocks/autoThreads", init_call); } } else if (init_call_target instanceof Expr) { Expr arr_ = (Expr) init_call_target; complainIfNot(arr_ instanceof Local, "val <something> = some_array.sequence()", arr_); Local arr = (Local) arr_; complainIfNot(init_call.name().id().toString().equals("sequence"), "constant cache definition to call 'sequence'", init_expr); Type cargo = arrayCargo(arr.type()); cmem.addArrayInitArray((LocalDecl) setReachable(ld), arr, Emitter.translateType(cargo, true)); } else { complainIfNot( false, "val <something> = CUDAUtilities.autoBlocks/Threads() or constant cache definition", init_call_target); } /* Not doing this anymore because we're using sequences instead of array (constant cache is immutable) } else { complainIfNot( init_expr instanceof X10New_c, "val <something> = new Array(...)", init_expr); X10New_c init_new = (X10New_c) init_expr; Type instantiatedType = init_new.objectType().type(); complainIfNot(xts().isArray(instantiatedType), "Initialisation expression to have Array[T] type.", init_new); TypeNode rail_type_arg_node = init_new.typeArguments().get( 0); Type rail_type_arg = rail_type_arg_node.type(); String rail_type_arg_ = Emitter.translateType(rail_type_arg, true); // TODO: support other types if (init_new.arguments().size() == 2) { Expr num_elements = init_new.arguments().get(0); Expr rail_init_closure = init_new.arguments().get(1); cmem.addArrayInitClosure(ld, num_elements, rail_init_closure, rail_type_arg_); } else { complainIfNot(init_new.arguments().size() == 1, "val <var> = new Array[T](other_array)", init_new); Expr src_array = init_new.arguments().get(0); complainIfNot( xts().isArray(src_array.type()) || xts().isRemoteArray(src_array.type()), "Constant memory to be initialised from array or remote array type", src_array); cmem.addArrayInitArray(ld, src_array, rail_type_arg_); } */ } } Stmt finish = kernel_block.statements().get(kernel_block.statements().size() - 1); Stmt finish_body = checkFinish(finish, false); complainIfNot(finish_body != null, "A finish statement", finish); complainIfNot(finish_body instanceof Block, "A single loop at the root of the kernel", finish_body); Block finish_body_block = (Block) finish_body; complainIfNot(finish_body_block.statements().size()==1, "A single loop at the root of the CUDA kernel", finish_body_block); MultipleValues outer = processLoop(finish_body_block); Block outer_b = (Block) outer.body; outer_b = (Block) checkAsync(outer_b.statements().get(0), false); complainIfNot(outer_b != null, "An async for the block", outer.body); Stmt last = outer_b.statements().get(outer_b.statements().size() - 1); SharedMem shm = new SharedMem(); // look at all but the last statement to find shm decls for (Stmt st : outer_b.statements()) { if (st == last) continue; complainIfNot(st instanceof LocalDecl, "Shared memory definition", st); LocalDecl ld = (LocalDecl) st; Expr init_expr = ld.init(); // TODO: primitive vals and shared vars complainIfNot(init_expr instanceof X10New, "val <var> = new Array[T](...)", init_expr); X10New init_new = (X10New) init_expr; Type instantiatedType = init_new.objectType().type(); complainIfNot(xts().isArray(instantiatedType), "Initialisation expression to have Array[T] type.", init_new); TypeNode rail_type_arg_node = init_new.typeArguments().get( 0); Type rail_type_arg = rail_type_arg_node.type(); String rail_type_arg_ = Emitter.translateType(rail_type_arg, true); // TODO: support other types if (init_new.arguments().size() == 2) { Expr num_elements = init_new.arguments().get(0); Expr rail_init_closure = init_new.arguments().get(1); shm.addArrayInitClosure((LocalDecl) setReachable(ld), (Expr) setReachable(num_elements), (Expr) setReachable(rail_init_closure), rail_type_arg_); } else { complainIfNot(init_new.arguments().size() == 1, "val <var> = new Array[T](other_array)", init_new); Expr src_array = init_new.arguments().get(0); complainIfNot( xts().isArray(src_array.type()) || xts().isRemoteArray(src_array.type()), "SHM to be initialised from array or remote array type", src_array); shm.addArrayInitArray((LocalDecl) setReachable(ld), (Expr) setReachable(src_array), rail_type_arg_); } } Stmt for_block2_ = checkFinish(last, true); complainIfNot(for_block2_ != null, "A clocked finish statement", last); complainIfNot(for_block2_ instanceof Block, "A loop over CUDA threads", for_block2_); Block for_block2 = (Block) for_block2_; MultipleValues inner = processLoop(for_block2); Block inner_b = inner.body; complainIfNot(inner_b.statements().size() == 1, "A block with a single statement", inner_b); Stmt async = inner_b.statements().get(0); Block async_body = (Block) checkAsync(async, true); CUDAKernel ck = nf.CUDAKernel(b.position(), b.statements(), (Block) setReachable(async_body)); ck.autoBlocks = (LocalDecl) setReachable(autoBlocks); ck.autoThreads = (LocalDecl) setReachable(autoThreads); ck.blocks = (Expr) setReachable(outer.max); ck.blocksVar = (Formal) setReachable(outer.var); ck.threads = (Expr) setReachable(inner.max); ck.threadsVar = (Formal) setReachable(inner.var); ck.shm = shm; ck.cmem = cmem; ck.directParams = direct; return ck; } catch (Complaint e) { e.printStackTrace(); } } } return child; } private static Node setReachable (Term x) { if (x==null) return null; return x.visit(new NodeVisitor() { public Node leave(Node parent, Node old, Node child, NodeVisitor v) { if (child instanceof Term) { Term child_term = (Term) child; return child_term.reachable(true); } return child; } }); } }