/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.codehaus.groovy.transform; import groovy.lang.Newify; import org.codehaus.groovy.GroovyBugError; import org.codehaus.groovy.ast.ASTNode; import org.codehaus.groovy.ast.AnnotatedNode; import org.codehaus.groovy.ast.AnnotationNode; import org.codehaus.groovy.ast.ClassCodeExpressionTransformer; import org.codehaus.groovy.ast.ClassNode; import org.codehaus.groovy.ast.FieldNode; import org.codehaus.groovy.ast.MethodNode; import org.codehaus.groovy.ast.expr.ClassExpression; import org.codehaus.groovy.ast.expr.ClosureExpression; import org.codehaus.groovy.ast.expr.ConstantExpression; import org.codehaus.groovy.ast.expr.ConstructorCallExpression; import org.codehaus.groovy.ast.expr.DeclarationExpression; import org.codehaus.groovy.ast.expr.Expression; import org.codehaus.groovy.ast.expr.ListExpression; import org.codehaus.groovy.ast.expr.MethodCallExpression; import org.codehaus.groovy.ast.expr.VariableExpression; import org.codehaus.groovy.control.CompilePhase; import org.codehaus.groovy.control.SourceUnit; import java.util.HashSet; import java.util.List; import java.util.Arrays; import java.util.Set; import static org.codehaus.groovy.ast.ClassHelper.make; import static org.codehaus.groovy.ast.tools.GeneralUtils.callX; import static org.codehaus.groovy.ast.tools.GeneralUtils.classX; /** * Handles generation of code for the @Newify annotation. * * @author Paul King */ @GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION) public class NewifyASTTransformation extends ClassCodeExpressionTransformer implements ASTTransformation { private static final ClassNode MY_TYPE = make(Newify.class); private static final String MY_NAME = MY_TYPE.getNameWithoutPackage(); private static final String BASE_BAD_PARAM_ERROR = "Error during @" + MY_NAME + " processing. Annotation parameter must be a class or list of classes but found "; private SourceUnit source; private ListExpression classesToNewify; private DeclarationExpression candidate; private boolean auto; public void visit(ASTNode[] nodes, SourceUnit source) { this.source = source; if (nodes.length != 2 || !(nodes[0] instanceof AnnotationNode) || !(nodes[1] instanceof AnnotatedNode)) { internalError("Expecting [AnnotationNode, AnnotatedClass] but got: " + Arrays.asList(nodes)); } AnnotatedNode parent = (AnnotatedNode) nodes[1]; AnnotationNode node = (AnnotationNode) nodes[0]; if (!MY_TYPE.equals(node.getClassNode())) { internalError("Transformation called from wrong annotation: " + node.getClassNode().getName()); } boolean autoFlag = determineAutoFlag(node.getMember("auto")); Expression value = node.getMember("value"); if (parent instanceof ClassNode) { newifyClass((ClassNode) parent, autoFlag, determineClasses(value, false)); } else if (parent instanceof MethodNode || parent instanceof FieldNode) { newifyMethodOrField(parent, autoFlag, determineClasses(value, false)); } else if (parent instanceof DeclarationExpression) { newifyDeclaration((DeclarationExpression) parent, autoFlag, determineClasses(value, true)); } } private void newifyDeclaration(DeclarationExpression de, boolean autoFlag, ListExpression list) { ClassNode cNode = de.getDeclaringClass(); candidate = de; final ListExpression oldClassesToNewify = classesToNewify; final boolean oldAuto = auto; classesToNewify = list; auto = autoFlag; super.visitClass(cNode); classesToNewify = oldClassesToNewify; auto = oldAuto; } private static boolean determineAutoFlag(Expression autoExpr) { return !(autoExpr instanceof ConstantExpression && ((ConstantExpression) autoExpr).getValue().equals(false)); } /** allow non-strict mode in scripts because parsing not complete at that point */ private ListExpression determineClasses(Expression expr, boolean searchSourceUnit) { ListExpression list = new ListExpression(); if (expr instanceof ClassExpression) { list.addExpression(expr); } else if (expr instanceof VariableExpression && searchSourceUnit) { VariableExpression ve = (VariableExpression) expr; ClassNode fromSourceUnit = getSourceUnitClass(ve); if (fromSourceUnit != null) { ClassExpression found = classX(fromSourceUnit); found.setSourcePosition(ve); list.addExpression(found); } else { addError(BASE_BAD_PARAM_ERROR + "an unresolvable reference to '" + ve.getName() + "'.", expr); } } else if (expr instanceof ListExpression) { list = (ListExpression) expr; final List<Expression> expressions = list.getExpressions(); for (int i = 0; i < expressions.size(); i++) { Expression next = expressions.get(i); if (next instanceof VariableExpression && searchSourceUnit) { VariableExpression ve = (VariableExpression) next; ClassNode fromSourceUnit = getSourceUnitClass(ve); if (fromSourceUnit != null) { ClassExpression found = classX(fromSourceUnit); found.setSourcePosition(ve); expressions.set(i, found); } else { addError(BASE_BAD_PARAM_ERROR + "a list containing an unresolvable reference to '" + ve.getName() + "'.", next); } } else if (!(next instanceof ClassExpression)) { addError(BASE_BAD_PARAM_ERROR + "a list containing type: " + next.getType().getName() + ".", next); } } checkDuplicateNameClashes(list); } else if (expr != null) { addError(BASE_BAD_PARAM_ERROR + "a type: " + expr.getType().getName() + ".", expr); } return list; } private ClassNode getSourceUnitClass(VariableExpression ve) { List<ClassNode> classes = source.getAST().getClasses(); for (ClassNode classNode : classes) { if (classNode.getNameWithoutPackage().equals(ve.getName())) return classNode; } return null; } public Expression transform(Expression expr) { if (expr == null) return null; if (expr instanceof MethodCallExpression && candidate == null) { MethodCallExpression mce = (MethodCallExpression) expr; Expression args = transform(mce.getArguments()); if (isNewifyCandidate(mce)) { Expression transformed = transformMethodCall(mce, args); transformed.setSourcePosition(mce); return transformed; } Expression method = transform(mce.getMethod()); Expression object = transform(mce.getObjectExpression()); MethodCallExpression transformed = callX(object, method, args); transformed.setImplicitThis(mce.isImplicitThis()); transformed.setSourcePosition(mce); return transformed; } else if (expr instanceof ClosureExpression) { ClosureExpression ce = (ClosureExpression) expr; ce.getCode().visit(this); } else if (expr instanceof ConstructorCallExpression) { ConstructorCallExpression cce = (ConstructorCallExpression) expr; if (cce.isUsingAnonymousInnerClass()) { cce.getType().visitContents(this); } } else if (expr instanceof DeclarationExpression) { DeclarationExpression de = (DeclarationExpression) expr; if (de == candidate || auto) { candidate = null; Expression left = de.getLeftExpression(); Expression right = transform(de.getRightExpression()); DeclarationExpression newDecl = new DeclarationExpression(left, de.getOperation(), right); newDecl.addAnnotations(de.getAnnotations()); return newDecl; } return de; } return expr.transformExpression(this); } private void newifyClass(ClassNode cNode, boolean autoFlag, ListExpression list) { String cName = cNode.getName(); if (cNode.isInterface()) { addError("Error processing interface '" + cName + "'. @" + MY_NAME + " not allowed for interfaces.", cNode); } final ListExpression oldClassesToNewify = classesToNewify; final boolean oldAuto = auto; classesToNewify = list; auto = autoFlag; super.visitClass(cNode); classesToNewify = oldClassesToNewify; auto = oldAuto; } private void newifyMethodOrField(AnnotatedNode parent, boolean autoFlag, ListExpression list) { final ListExpression oldClassesToNewify = classesToNewify; final boolean oldAuto = auto; checkClassLevelClashes(list); checkAutoClash(autoFlag, parent); classesToNewify = list; auto = autoFlag; if (parent instanceof FieldNode) { super.visitField((FieldNode) parent); } else { super.visitMethod((MethodNode) parent); } classesToNewify = oldClassesToNewify; auto = oldAuto; } private void checkDuplicateNameClashes(ListExpression list) { final Set<String> seen = new HashSet<String>(); @SuppressWarnings("unchecked") final List<ClassExpression> classes = (List)list.getExpressions(); for (ClassExpression ce : classes) { final String name = ce.getType().getNameWithoutPackage(); if (seen.contains(name)) { addError("Duplicate name '" + name + "' found during @" + MY_NAME + " processing.", ce); } seen.add(name); } } private void checkAutoClash(boolean autoFlag, AnnotatedNode parent) { if (auto && !autoFlag) { addError("Error during @" + MY_NAME + " processing. The 'auto' flag can't be false at " + "method/constructor/field level if it is true at the class level.", parent); } } private void checkClassLevelClashes(ListExpression list) { @SuppressWarnings("unchecked") final List<ClassExpression> classes = (List)list.getExpressions(); for (ClassExpression ce : classes) { final String name = ce.getType().getNameWithoutPackage(); if (findClassWithMatchingBasename(name)) { addError("Error during @" + MY_NAME + " processing. Class '" + name + "' can't appear at " + "method/constructor/field level if it already appears at the class level.", ce); } } } private boolean findClassWithMatchingBasename(String nameWithoutPackage) { if (classesToNewify == null) return false; @SuppressWarnings("unchecked") final List<ClassExpression> classes = (List)classesToNewify.getExpressions(); for (ClassExpression ce : classes) { if (ce.getType().getNameWithoutPackage().equals(nameWithoutPackage)) { return true; } } return false; } private boolean isNewifyCandidate(MethodCallExpression mce) { return mce.getObjectExpression() == VariableExpression.THIS_EXPRESSION || (auto && isNewMethodStyle(mce)); } private static boolean isNewMethodStyle(MethodCallExpression mce) { final Expression obj = mce.getObjectExpression(); final Expression meth = mce.getMethod(); return (obj instanceof ClassExpression && meth instanceof ConstantExpression && ((ConstantExpression) meth).getValue().equals("new")); } private Expression transformMethodCall(MethodCallExpression mce, Expression args) { ClassNode classType; if (isNewMethodStyle(mce)) { classType = mce.getObjectExpression().getType(); } else { classType = findMatchingCandidateClass(mce); } if (classType != null) { return new ConstructorCallExpression(classType, args); } // set the args as they might have gotten Newify transformed GROOVY-3491 mce.setArguments(args); return mce; } private ClassNode findMatchingCandidateClass(MethodCallExpression mce) { if (classesToNewify == null) return null; @SuppressWarnings("unchecked") List<ClassExpression> classes = (List)classesToNewify.getExpressions(); for (ClassExpression ce : classes) { final ClassNode type = ce.getType(); if (type.getNameWithoutPackage().equals(mce.getMethodAsString())) { return type; } } return null; } private static void internalError(String message) { throw new GroovyBugError("Internal error: " + message); } protected SourceUnit getSourceUnit() { return source; } }