/* * This file is part of the X10 project (http://x10-lang.org). * * This file is licensed to You under the Eclipse Public License (EPL); * You may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.opensource.org/licenses/eclipse-1.0.php * * (C) Copyright IBM Corporation 2006-2010. */ package x10cuda.types; import java.util.ArrayList; import java.util.Iterator; import polyglot.ast.Block; import polyglot.ast.Expr; import polyglot.ast.Local; import polyglot.ast.LocalDecl; import polyglot.ast.Node; import polyglot.ast.Return; import polyglot.ast.Stmt; import polyglot.types.Context; import polyglot.types.Name; import polyglot.types.Type; import polyglot.visit.NodeVisitor; import polyglot.visit.Translator; import x10.ast.Closure; import polyglot.types.TypeSystem; import x10.optimizations.inlining.InliningRewriter; import x10.util.ClassifiedStream; import x10.util.StreamWrapper; /** * This class holds information about Shared memory definitions in a CUDA kernel (also used for constant memory). * * @author Dave Cunningham */ public class SharedMem implements Cloneable { ArrayList<Decl> decls = new ArrayList<Decl>(); public String toString() { return decls.toString(); } public void addDecls(Context c) { for (Decl decl : decls) { decl.addDecls(c); } } private abstract static class Decl { public final LocalDecl ast; public Decl (LocalDecl ast) { this.ast = ast; } abstract public void visitChildren(Node parent, NodeVisitor v); abstract public String generateDef(StreamWrapper out, String offset, Translator tr); abstract public String generateInit(StreamWrapper out, String offset, Translator tr); abstract public void generateSize(StreamWrapper inc, Translator tr); abstract public void generateCMemPop(StreamWrapper out, Translator tr); abstract public void addDecls(Context c); abstract public String toString(); } private static class Array extends Decl { public Expr numElements; public Expr init; public final String elementType; public String toString() { return "["+numElements+" of "+elementType+" init to "+init+"]"; } public Array (LocalDecl ast, Expr numElements, Expr init, String elementType) { super(ast); this.numElements = numElements; this.init = init; this.elementType = elementType; } public String generateDef(StreamWrapper out, String raw, Translator tr) { String name = ast.name().id().toString(); out.write("x10aux::cuda_array<"+elementType+"> "+name+" = { "); if (numElements!=null) { tr.print(null, numElements, out); } else { tr.print(null, init, out); out.write(".FMGL(size)"); } out.write(", ("+elementType+"*) "+raw+" };"); out.newline(); return "&"+ast.name().id()+".raw["+ast.name().id()+".FMGL(size)]"; } public String generateInit(StreamWrapper out, String offset, Translator tr) { out.write("{"); out.newline(4); out.begin(0); out.write("x10_int __len = "); if (numElements!=null) { tr.print(null, numElements, out); } else { tr.print(null, init, out); out.write(".FMGL(size)"); } out.write(";"); out.newline(); out.write("for (int __i=0 ; __i<__len ; ++__i) {"); out.newline(4); out.begin(0); // TODO: assumes Array initialised with another Array -- closure version also possible if (numElements == null) { out.write("const "+elementType+" &__v = "); tr.print(null, init, out); out.write(".raw[__i];"); } else { if (init instanceof Closure) { Closure lit = (Closure) init; // Use the InlininingRewriter to get rid of the early returns. // Then strip off the final return and assign __v instead. try { ((X10CUDAContext_c) tr.context()).shmIterationVar(lit.formals().get(0)); Closure init_c_norm = (Closure) lit.visit(new InliningRewriter(lit, tr.job(), tr.typeSystem(), tr.nodeFactory(), tr.context())); Block b = init_c_norm.body(); for (int i=0; i<b.statements().size()-1 ; ++i ) { tr.print(null, b.statements().get(i), out); } Stmt return_stmt = b.statements().get(b.statements().size()-1); Return return_ = (Return)return_stmt; out.write(elementType+" __v = "); tr.print(null, return_.expr(), out); out.write(";"); } finally { ((X10CUDAContext_c) tr.context()).shmIterationVar(null); } } else { out.write("const "+elementType+" &__v = "); tr.print(null, init, out); out.write(";"); } } out.newline(); out.write(ast.name().id()+".raw[__i] = __v;"); out.newline(); out.end(); out.newline(); out.write("}"); out.end(); out.newline(); out.write("}"); return "&"+ast.name().id()+".raw["+ast.name().id()+".FMGL(size)]"; } public void generateSize(StreamWrapper inc, Translator tr) { if (numElements!=null) { tr.print(null, numElements, inc); } else { tr.print(null, init, inc); inc.write("->FMGL(size)"); } // FIXME: x10_float is baked in here inc.write("*sizeof("+elementType+")"); } @Override public void generateCMemPop(StreamWrapper out, Translator tr) { // TODO Auto-generated method stub out.write("pop.populateArr<"+elementType+", x10::array::Array<"+elementType+">*>("); tr.print(null, init, out); out.write(");"); } @Override public void visitChildren(Node parent, NodeVisitor v) { numElements = (Expr) parent.visitChild(numElements, v); if (init!=null) init = (Expr) parent.visitChild(init, v); } public void addDecls(Context c) { ast.addDecls(c); } } private static class Var extends Decl { public String toString() { return "[???]"; } public Var (LocalDecl ast) { super(ast); } public String generateDef(StreamWrapper out, String offset, Translator tr) { // TODO: not implemented assert false: "not implemented"; return ""; } public String generateInit(StreamWrapper out, String offset, Translator tr) { // TODO: not implemented assert false: "not implemented"; return ""; } public void generateSize(StreamWrapper inc, Translator tr) { // TODO: not implemented assert false: "not implemented"; } @Override public void generateCMemPop(StreamWrapper out, Translator tr) { // TODO Auto-generated method stub } @Override public void visitChildren(Node parent, NodeVisitor v) { // TODO Auto-generated method stub } public void addDecls(Context c) { // TODO: not implemented assert false: "not implemented"; } } public void addArrayInitClosure(LocalDecl ast, Expr numElements, Expr init, String type) { decls.add(new Array(ast, numElements, init, type)); } public void addArrayInitArray(LocalDecl ast, Expr init, String type) { decls.add(new Array(ast, null, init, type)); } public void addVar(LocalDecl ast) { decls.add(new Var(ast)); } public boolean has(Name n) { for (Decl d : decls) { if (d.ast.name().id() == n) { return true; } } return false; } public void generateCodeSharedMem(StreamWrapper out, Translator tr) { out.write("// shm"); out.newline(); if (decls.size()==0) return; String raw = "__shm"; for (SharedMem.Decl d : decls) { d.generateDef(out, raw, tr); out.write("if (threadIdx.x == 0) {"); out.newline(4); out.begin(0); raw = d.generateInit(out, raw, tr); out.end(); out.newline(); out.write("}"); out.newline(); } } public void generateCodeConstantMemory(StreamWrapper out, Translator tr) { out.write("// cmem"); out.newline(); String raw = "&__cmem[0]"; for (SharedMem.Decl d : decls) { raw = d.generateDef(out, raw, tr); } } public void generateHostCodeConstantMemory(StreamWrapper out, Translator tr) { out.write("// cmem"); out.newline(); if (decls.size()==0) return; out.write("x10aux::CMemPopulator pop(__cmemv);"); out.newline(); for (SharedMem.Decl d : decls) { d.generateCMemPop(out, tr); out.newline(); } } public void generateSize(StreamWrapper inc, Translator tr) { // TODO Auto-generated method stub String prefix = ""; for (SharedMem.Decl d : decls) { inc.write(prefix); d.generateSize(inc, tr); prefix = " + "; } if (prefix.equals("")) inc.write("0"); } public void visitChildren(Node parent, NodeVisitor v) { for (SharedMem.Decl d : decls) { d.visitChildren(parent, v); } } public SharedMem clone () { try { SharedMem this_ = (SharedMem) super.clone(); ArrayList<Decl> decls = new ArrayList<Decl>(); for (Decl d : this.decls) { decls.add(d); } this_.decls = decls; return this_; } catch (CloneNotSupportedException e) { return null; } } }