/* * This file is part of the X10 project (http://x10-lang.org). * * This file is licensed to You under the Eclipse Public License (EPL); * You may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.opensource.org/licenses/eclipse-1.0.php * * (C) Copyright IBM Corporation 2006-2010. */ package x10.optimizations; import java.util.ArrayList; import java.util.Collections; import java.util.List; import polyglot.ast.Assign; import polyglot.ast.Binary; import polyglot.ast.Block; import polyglot.ast.Branch; import polyglot.ast.Call; import polyglot.ast.Expr; import polyglot.ast.For; import polyglot.ast.Formal; import polyglot.ast.Id; import polyglot.ast.Labeled; import polyglot.ast.Local; import polyglot.ast.LocalDecl; import polyglot.ast.Loop; import polyglot.ast.Node; import polyglot.ast.NodeFactory; import polyglot.ast.Receiver; import polyglot.ast.Stmt; import polyglot.ast.Switch; import polyglot.frontend.Job; import polyglot.types.Context; import polyglot.types.Flags; import polyglot.types.Name; import polyglot.types.SemanticException; import polyglot.types.Type; import polyglot.types.TypeSystem; import polyglot.types.Types; import polyglot.util.InternalCompilerError; import polyglot.util.Position; import polyglot.visit.ContextVisitor; import polyglot.visit.NodeVisitor; import x10.ast.ClosureCall; import x10.ast.ForLoop; import x10.ast.X10Binary_c; import x10.ast.X10Formal; import x10.ast.SettableAssign; import x10.constraint.XFailure; import x10.constraint.XTerm; import x10.types.ConstrainedType; import x10.types.X10FieldInstance; import x10.types.constants.ConstantValue; import x10.types.constants.IntegralValue; import x10.types.matcher.X10TypeMatcher; import x10.util.AltSynthesizer; import x10.visit.ConstantPropagator; /** * Optimize loops of the form: for (formal in domain) S. * If domain is, or can be lowered to, a Region, that is rectangular with know rank, * tranform into a nest of traditional for-loops with iterates of type int. * * @author vj * @author Bowen Alpern */ public class ForLoopOptimizer extends ContextVisitor { private static final Name ITERATOR = Name.make("iterator"); private static final Name HASNEXT = Name.make("hasNext"); private static final Name REGION = Name.make("region"); private static final Name DIST = Name.make("dist"); private static final Name NEXT = Name.make("next"); private static final Name MAKE = Name.make("make"); private static final Name RANK = Name.make("rank"); private static final Name MIN = Name.make("min"); private static final Name MAX = Name.make("max"); private static final Name SET = SettableAssign.SET; private final TypeSystem xts; private AltSynthesizer syn; public ForLoopOptimizer(Job job, TypeSystem ts, NodeFactory nf) { super(job, ts, nf); xts = ts; syn = new AltSynthesizer(ts, nf); } private Name label = null; protected Name label() { return label; } protected ForLoopOptimizer label(Name label) { ForLoopOptimizer flo = (ForLoopOptimizer) shallowCopy(); flo.label = label; return flo; } @Override protected NodeVisitor enterCall(Node parent, Node n) { // Set the label when seeing a Labeled; clear it for anything but a ForLoop ForLoopOptimizer res = this; if (n instanceof Labeled) { res = res.label(((Labeled) n).labelNode().id()); } else if (!(n instanceof ForLoop)) { res = res.label(null); } return res; } @Override public Node leaveCall(Node old, Node n, NodeVisitor v) { if (n instanceof ForLoop) return visitForLoop((ForLoop) n); if (n instanceof Labeled) { Labeled l = (Labeled) n; ForLoopOptimizer flo = (ForLoopOptimizer) v; assert (l.labelNode().id().equals(flo.label())); if (old instanceof Labeled && ((Labeled) old).statement() instanceof ForLoop && !(l.statement() instanceof ForLoop)) { return l.statement(); // The label will have been propagated onto the loop } } return n; } private static final boolean VERBOSE = false; /** * Transform a ForLoop with a Point iterate over a rectangular region to a nest of simple For loops. * <pre> * for (p:Point(k+1) in r) S -> * { // k=r.rank-1 * point = Rail.make[Int](k+1); // if p is named * mink=r.min(k); maxk=r.max(k); * ... * min0=r.min(0); max0=r.max(0); * for (i0 = min0; i0<=max0; i0+=1) { * point(0) = i0; // if p is named * ... * for (ik = mink; ik<=maxk; ik+=1) { * point(k) = ik; // if p is named * p = Point.make(point); // if p is named * S * } * } * } * </pre> * <tt>r</tt> can be an Array, a DistArray, a Dist, or a Region. * * Also, desugars untransformed ForLoops, TODO: move this to the desugarer * <pre> * for (x in y) S -> for (i=y.iterator(); i.hasNext(); ) { x=i.next(); S } * </pre> * * @param loop the ForLoop to be transformed * @return the transformed loop */ public Node visitForLoop(ForLoop loop) { Position pos = loop.position(); X10Formal formal = (X10Formal) loop.formal(); Expr domain = loop.domain(); Stmt body = loop.body(); if (VERBOSE) { System.out.println("\nOptimizing ForLoop at " +pos); loop.prettyPrint(System.out); System.out.println(); } // if domain <: DistArray, transform to Distribution if (xts.isX10DistArray(domain.type())) { if (VERBOSE) System.out.println(" domain is DistArray, tranforming to Dist"); domain = syn.createFieldRef(pos, domain, DIST); assert (null != domain); } // if domain <: Distribution, transform to Region if (xts.isDistribution(domain.type())) { if (VERBOSE) System.out.println(" domain is Dist, transforming to Region"); domain = syn.createFieldRef(pos, domain, REGION); assert (null != domain); } // if domain <: Array, transform to Region if (xts.isX10Array(domain.type())) { if (VERBOSE) System.out.println(" domain is Array, tranforming to Region"); domain = syn.createFieldRef(pos, domain, REGION); assert (null != domain); } Id label = syn.createLabel(pos); Context context = (Context) context(); List<Formal> formalVars = formal.vars(); boolean named = !formal.isUnnamed(); ConstrainedType domainType = Types.toConstrainedType(domain.type()); boolean isRect = domainType.isRect(context); Integer domainRank = (Integer) getPropertyConstantValue(domain, RANK); int rank = (null != domainRank) ? (int) domainRank : (null != formalVars) ? formalVars.size() : -1; assert null == formalVars || formalVars.isEmpty() || formalVars.size() == rank; // Transform loops over IntRange and LongRange into counted for loops if (xts.isIntRange(domainType) || xts.isLongRange(domainType)) { Name varName = null == formalVars || formalVars.isEmpty() ? Name.makeFresh("i") : Name.makeFresh(formalVars.get(0).name().id()); Name minName = Name.makeFresh(varName+ "min"); Name maxName = Name.makeFresh(varName+ "max"); boolean isLong = xts.isLongRange(domainType); Expr low; Expr high; LocalDecl domLDecl; if (domain instanceof Call && ((Call)domain).name().id().equals(X10Binary_c.binaryMethodName(Binary.DOT_DOT))) { // SPECIAL CASE: if The IntRange is being created in the for loop header itself with the .. operator, avoid creating it entirely. List<Expr> args = ((Call) loop.domain()).arguments(); assert (args.size() == 2); low = args.get(0); high = args.get(1); domLDecl = null; } else { domLDecl = syn.createLocalDecl(domain.position(), Flags.FINAL, Name.makeFresh(varName+"domain"), domain); low = syn.createFieldRef(pos, syn.createLocal(pos, domLDecl), MIN); high = syn.createFieldRef(pos, syn.createLocal(pos, domLDecl), MAX); } LocalDecl minLDecl = syn.createLocalDecl(pos, Flags.FINAL, minName, low); LocalDecl maxLDecl = syn.createLocalDecl(pos, Flags.FINAL, maxName, high); LocalDecl varLDecl = syn.createLocalDecl(pos, Flags.NONE, varName, isLong ? xts.Long() : xts.Int(), syn.createLocal(pos, minLDecl)); Expr cond = syn.createBinary( domain.position(), syn.createLocal(pos, varLDecl), Binary.LE, syn.createLocal(pos, maxLDecl), this ); Expr update = syn.createAssign( domain.position(), syn.createLocal(pos, varLDecl), Assign.ADD_ASSIGN, isLong ? syn.createLongLit(1) : syn.createIntLit(1), this); List<Stmt> bodyStmts = new ArrayList<Stmt>(); if (named) { // declare the formal variable as a local and initialize it LocalDecl formalLDecl = syn.createLocalDecl(formal, syn.createLocal(pos, varLDecl)); bodyStmts.add(formalLDecl); } bodyStmts.add(body); body = syn.createBlock(loop.body().position(), bodyStmts); For forLoop = syn.createStandardFor(pos, varLDecl, cond, update, body); Stmt newLoop = forLoop; if (label() != null) { newLoop = syn.createLabeledStmt(pos, label(), forLoop); } if (domLDecl == null ) { return syn.createBlock(pos, minLDecl, maxLDecl, newLoop); } else { return syn.createBlock(pos, domLDecl, minLDecl, maxLDecl, newLoop); } } // transform loops over rectangular regions of known rank if (xts.isRegion(domainType) && isRect && rank > 0) { assert xts.isPoint(formal.declType()); if (VERBOSE) System.out.println(" rectangular region, rank=" +rank+ " point=" +formal); if (1 < rank) { body = labelFreeBreaks(body, label); } List<Stmt> stmts = new ArrayList<Stmt>(); Name prefix = named ? formal.name().id() : Name.make("p"); // cache the value of domain in a local temporary variable LocalDecl domLDecl = syn.createLocalDecl(domain.position(), Flags.FINAL, Name.makeFresh(prefix), domain); stmts.add(domLDecl); // Prepare to redeclare the formal iterate as local Point variable (if the formal is not anonymous) Type indexType = null; // type of the formal var initializer (if any) LocalDecl indexLDecl = null; // redeclaration of the formal var (if it has a name) if (named) { // create an array to contain the value of the formal at each iteration Name indexName = Name.makeFresh(prefix); indexType = Types.makeArrayRailOf(xts.Int(), rank, pos); Expr indexInit = syn.createTuple(pos, rank, syn.createIntLit(0)); indexLDecl = syn.createLocalDecl(pos, Flags.FINAL, indexName, indexType, indexInit); // add the declaration of the index rail to the list of statements to be executed before the loop nest stmts.add(indexLDecl); } LocalDecl varLDecls[] = new LocalDecl[rank]; // syn.create the loop nest (from the inside out) for (int r=rank-1; 0<=r; r--) { // syn.create new names for the r-th iterate and limits Name varName = null == formalVars || formalVars.isEmpty() ? Name.makeFresh(prefix) : Name.makeFresh(formalVars.get(r).name().id()); Name minName = Name.makeFresh(varName+ "min"); Name maxName = Name.makeFresh(varName+ "max"); // create an AST node for the calls to domain.min() and domain.max() Expr minVal = syn.createInstanceCall(pos, syn.createLocal(domain.position(), domLDecl), MIN, context, syn.createIntLit(r)); Expr maxVal = syn.createInstanceCall(pos, syn.createLocal(domain.position(), domLDecl), MAX, context, syn.createIntLit(r)); // create an AST node for the declaration of the temporary locations for the r-th var, min, and max LocalDecl minLDecl = syn.createLocalDecl(pos, Flags.FINAL, minName, minVal); LocalDecl maxLDecl = syn.createLocalDecl(pos, Flags.FINAL, maxName, maxVal); LocalDecl varLDecl = syn.createLocalDecl(pos, Flags.NONE, varName, xts.Int(), syn.createLocal(pos, minLDecl)); varLDecls[r] = varLDecl; // add the declarations for the r-th min and max to the list of statements to be executed before the loop nest stmts.add(minLDecl); stmts.add(maxLDecl); // create expressions for the second and third positions in the r-th for clause Expr cond = syn.createBinary( domain.position(), syn.createLocal(pos, varLDecl), Binary.LE, syn.createLocal(pos, maxLDecl), this ); Expr update = syn.createAssign( domain.position(), syn.createLocal(pos, varLDecl), Assign.ADD_ASSIGN, syn.createIntLit(1), this ); List<Stmt> bodyStmts = new ArrayList<Stmt>(); if (null != formalVars && !formalVars.isEmpty()) { bodyStmts.add(syn.createLocalDecl((X10Formal) formalVars.get(r), syn.createLocal(pos, varLDecl))); } // concoct declaration for formal, in case it might be referenced in the body if (named) { // set the r-th slot int the index rail to the current value of the r-th iterate Expr setExpr = syn.createInstanceCall( pos, syn.createLocal(pos, indexLDecl), SET, context, syn.createIntLit(r), syn.createLocal(pos, varLDecl) ); bodyStmts.addAll(syn.convertToStmtList(setExpr)); if (r+1 == rank) { // the innermost loop // declare the formal variable as a local and initialize it to the index rail Expr formExpr = syn.createStaticCall(pos, formal.declType(), MAKE, syn.createLocal(pos, indexLDecl)); bodyStmts.add(syn.createLocalDecl(formal, formExpr)); } } bodyStmts.add(body); body = syn.createBlock(pos, bodyStmts); // syn.create the AST node for the r-th concocted for-statement body = syn.createStandardFor(pos, varLDecl, cond, update, body); } if (1 < rank) { Position position = body.position(); body = syn.createLabeledStmt(position, label, body); } if (false) { // 3/11/2011. Dave G. Disabled because transformation doesn't properly update closure environment. body = explodePoint(formal, indexLDecl, varLDecls, body); } if (label() != null) { body = syn.createLabeledStmt(pos, label(), body); } stmts.add(body); Block result = syn.createBlock(pos, stmts); if (VERBOSE) result.dump(System.out); return result; } assert Types.getIterableIndex(domainType, context).size()>=1; // When Iterable was covariant: (xts.isSubtype(domainType, xts.Iterable(xts.Any()), context)); Name iterName = named ? Name.makeFresh(formal.name().id()) : Name.makeFresh(); Expr iterInit = syn.createInstanceCall(pos, domain, ITERATOR, context); LocalDecl iterLDecl = syn.createLocalDecl(pos, Flags.FINAL, iterName, iterInit); Expr hasExpr = syn.createInstanceCall(pos, syn.createLocal(pos, iterLDecl), HASNEXT, context); Expr nextExpr = syn.createInstanceCall(pos, syn.createLocal(pos, iterLDecl), NEXT, context); if (!xts.typeEquals(nextExpr.type(), formal.declType(), context)) { Expr newNextExpr = syn.createCoercion(pos, nextExpr, formal.declType(), this); if (null == newNextExpr) throw new InternalCompilerError("Unable to cast "+nextExpr+" to the iterated type "+formal.declType(), pos); nextExpr = newNextExpr; } List<Stmt> bodyStmts = new ArrayList<Stmt>(); LocalDecl varLDecl = syn.createLocalDecl(formal, nextExpr); bodyStmts.add(varLDecl); if (null != formalVars && !formalVars.isEmpty()) try { bodyStmts.addAll(formal.explode(this)); } catch (SemanticException e) { throw new InternalCompilerError("We cannot explode the formal. Huh?", formal.position(), e); } if (body instanceof Block) { bodyStmts.addAll(((Block) body).statements()); } else { bodyStmts.add(body); } Stmt result = syn.createStandardFor(pos, iterLDecl, hasExpr, syn.createBlock(pos, bodyStmts)); if (label() != null) { result = syn.createLabeledStmt(pos, label(), result); } if (VERBOSE) result.dump(System.out); return result; } /** * Replace calls to the apply method on point with corresponding calls to the corresponding method on rail throughout the body. * * 3/10/2011. Dave G. This transformation is not complete, therefore disabled. * The issue is that if the call is within a closure, then the captured environment * information for the closure (and all lexcially enclosing closures up to the for loop) * must be updated to reflect the additional variables being captured. * * @param point a Point formal variable * @param rail the underlying Rail defining point * @param body the AST containing the usses of point to be replaced * @return a copy of body with every call to point.apply() replaced by a call to rail.apply() */ private Stmt explodePoint(final X10Formal point, final LocalDecl rail, final LocalDecl[] indices, final Stmt body) { assert false : "This transformation is not enabled because it is incomplete"; ContextVisitor pointExploder = new ContextVisitor(job, xts, nodeFactory()) { /* (non-Javadoc) * @see polyglot.visit.ErrorHandlingVisitor#leaveCall(polyglot.ast.Node) */ @Override protected Node leaveCall(Node n) { if (n instanceof Call) { X10Formal p = point; LocalDecl r = rail; LocalDecl[] is = indices; Call call = (Call) n; Receiver target = call.target(); if (target instanceof Local && call.methodInstance().name().equals(ClosureCall.APPLY)) { if (((Local) target).localInstance().def() == point.localDef()) { List<Expr> args = call.arguments(); assert (1 == args.size()); Expr arg = args.get(0); if (arg.isConstant()) { int i = ((IntegralValue) arg.constantValue()).intValue(); return syn.createLocal(n.position(), indices[i]); } call = call.target(syn.createLocal(target.position(), rail)); call = call.methodInstance(syn.createMethodInstance( rail.type().type(), ClosureCall.APPLY, context, Collections.<Type>emptyList(), call.methodInstance().formalTypes())); return call; } } } return n; } }; return (Stmt) body.visit(pointExploder.begin()); } /** * Change free unlabeled breaks in the body to refer to a given label. * * @param body the body of a ForLoop * @param label a label to be attached to the outermost synthesized For * @return aa copy of the body with its free breaks suitably captured */ private Stmt labelFreeBreaks(Stmt body, final Id label) { return (Stmt) body.visit(new NodeVisitor(){ @Override public Node override(Node node) { // these constructs capture free breaks if (node instanceof Loop) return node; if (node instanceof Switch) return node; return null; } @Override public Node leave(Node old, Node n, NodeVisitor v) { if (n instanceof Branch) { Branch b = (Branch) n; if (b.kind().equals(Branch.BREAK) && null == b.labelNode()) { return b.labelNode(label); } } return n; } }); } // General helper methods /** * Obtain the constant value of a property of an expression, if that value is known at compile time. * * @param expr the Expr whose property is to be extracted * @param name the Name of the property to extract * @return the value of the named property of expr if it is a compile-time constant, or null if none * TODO: move into ASTQuery */ public Object getPropertyConstantValue(Expr expr, Name name) { X10FieldInstance propertyFI = Types.getProperty(expr.type(), name); if (null == propertyFI) return null; Expr propertyExpr = syn.createFieldRef(expr.position(), expr, propertyFI); if (null == propertyExpr) return null; return ConstantValue.toJavaObject(ConstantPropagator.constantValue(propertyExpr)); } /** * Add a constraint to the type that binds a given property to a given value. * * @param type the Type to be constrained * @param name the Name of a property of type * @param value the value of the named property for this type * @return the type with the additional constraint {name==value}, or null if no such property * TODO: move into Synthesizer */ /* public static Type addPropertyConstraint(Type type, Name name, Object value) { return addPropertyConstraint(type, name, XTerms.makeLit(value)); }*/ /** * Add a self constraint to the type that binds self to a given value. * * @param type the Type to be constrained * @param value the value of self for this type * @return the type with the additional constraint {self==value}, or null if the proposed * binding is inconsistent * TODO: move into Synthesizer */ public static Type addSelfConstraint(Type type, XTerm value) { return Types.addSelfBinding(type, value); } }