All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.google.errorprone.bugpatterns; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.errorprone.BugPattern.Category.MOCKITO; import static com.google.errorprone.BugPattern.SeverityLevel.ERROR; import static com.google.errorprone.matchers.method.MethodMatchers.staticMethod; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.errorprone.BugPattern; import com.google.errorprone.VisitorState; import com.google.errorprone.bugpatterns.BugChecker.CompilationUnitTreeMatcher; import com.google.errorprone.fixes.SuggestedFix; import com.google.errorprone.fixes.SuggestedFixes; import com.google.errorprone.matchers.Description; import com.google.errorprone.matchers.Matcher; import com.google.errorprone.util.ASTHelpers; import com.sun.source.tree.AssignmentTree; import com.sun.source.tree.CompilationUnitTree; import com.sun.source.tree.ExpressionTree; import com.sun.source.tree.MethodInvocationTree; import com.sun.source.tree.Tree; import com.sun.source.tree.VariableTree; import com.sun.source.util.TreePathScanner; import com.sun.source.util.TreeScanner; import com.sun.tools.javac.code.Attribute; import com.sun.tools.javac.code.Attribute.Compound; import com.sun.tools.javac.code.Scope.LookupKind; import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Symbol.MethodSymbol; import com.sun.tools.javac.code.Symbol.VarSymbol; import com.sun.tools.javac.code.Type; import com.sun.tools.javac.code.Types; import com.sun.tools.javac.tree.JCTree.JCMethodInvocation; import java.util.LinkedHashSet; import java.util.Map.Entry; import java.util.Set; import javax.lang.model.element.ElementKind; /** @author Liam Miller-Cushon (cushon@google.com) */ @BugPattern( name = "MockitoCast", category = MOCKITO, summary = "A bug in Mockito will cause this test to fail at runtime with a ClassCastException", severity = ERROR ) public class MockitoCast extends BugChecker implements CompilationUnitTreeMatcher { private static final String MOCKITO_CLASS = "org.mockito.Mockito"; private static final String UI_FIELD_ANNOTATION = "com.google.gwt.uibinder.client.UiField"; private static final String MOCK_ANNOTATION = "org.mockito.Mock"; /** Answer strategies that always return an instance of the erased return type. */ private static final ImmutableSet<String> BAD_ANSWER_STRATEGIES = ImmutableSet.of("RETURNS_SMART_NULLS", "RETURNS_MOCKS", "RETURNS_DEEP_STUBS"); @Override public Description matchCompilationUnit(CompilationUnitTree tree, final VisitorState state) { Symbol mockitoSym = state.getSymbolFromString(MOCKITO_CLASS); if (mockitoSym == null) { // fast path if mockito isn't being used return Description.NO_MATCH; } // collect variable symbols for standard Answer constants that don't support generics final Set<Symbol> badAnswers = new LinkedHashSet<>(); for (Symbol member : mockitoSym.members().getSymbols(LookupKind.NON_RECURSIVE)) { if (member.getKind() != ElementKind.FIELD) { continue; } if (BAD_ANSWER_STRATEGIES.contains(member.getSimpleName().toString())) { badAnswers.add(member); } } // collect mocks that are initialized in this compilation unit using a bad answer strategy final Set<VarSymbol> mockVariables = MockInitializationScanner.scan(state, badAnswers); // check for when(...) calls on mocks using a bad answer strategy new WhenNeedsCastScanner(mockVariables, state).scan(state.getPath(), null); // errors are reported in WhenNeedsCastScanner return Description.NO_MATCH; } /** * Records declarations of and assignments to mock variables where the initializer references * an answer strategy that does not support generics. */ static class MockInitializationScanner extends TreeScanner<Void, Void> { static Set<VarSymbol> scan(VisitorState state, Set<Symbol> badAnswers) { MockInitializationScanner scanner = new MockInitializationScanner(badAnswers); state.getPath().getCompilationUnit().accept(scanner, null); return scanner.mockVariables; } private final Set<VarSymbol> mockVariables = new LinkedHashSet<>(); private final Set<Symbol> badAnswers; public MockInitializationScanner(Set<Symbol> badAnswers) { this.badAnswers = badAnswers; } @Override public Void visitVariable(VariableTree node, Void aVoid) { recordInitialization(node, node.getInitializer()); return super.visitVariable(node, aVoid); } @Override public Void visitAssignment(AssignmentTree node, Void aVoid) { recordInitialization(node.getVariable(), node.getExpression()); return super.visitAssignment(node, aVoid); } private void recordInitialization(Tree varTree, ExpressionTree initializer) { if (initializer == null) { return; } Symbol sym = ASTHelpers.getSymbol(varTree); if (!(sym instanceof VarSymbol)) { return; } Boolean initializedWithBadAnswer = initializer.accept( new TreeScanner<Boolean, Void>() { @Override public Boolean scan(Tree tree, Void unused) { if (badAnswers.contains(ASTHelpers.getSymbol(tree))) { return true; } return super.scan(tree, null); } @Override public Boolean reduce(Boolean r1, Boolean r2) { return firstNonNull(r1, false) || firstNonNull(r2, false); } }, null); if (firstNonNull(initializedWithBadAnswer, false)) { mockVariables.add((VarSymbol) sym); } } } private static final Matcher<ExpressionTree> WHEN_MATCHER = staticMethod().onClass(MOCKITO_CLASS).named("when"); /** Scans for when(...) calls that needs a cast added, and emits fixes. */ class WhenNeedsCastScanner extends TreePathScanner<Void, Void> { final Set<VarSymbol> badMocks; final VisitorState state; WhenNeedsCastScanner(Set<VarSymbol> badMocks, VisitorState state) { this.badMocks = badMocks; this.state = state; } @Override public Void visitMethodInvocation(MethodInvocationTree node, Void unused) { Description description = matchMethodInvocation(node, state.withPath(getCurrentPath())); if (description != Description.NO_MATCH) { state.reportMatch(description); } return super.visitMethodInvocation(node, null); } public Description matchMethodInvocation(MethodInvocationTree tree, final VisitorState state) { // look for a call to Mockito.when(arg) if (!WHEN_MATCHER.matches(tree, state)) { return Description.NO_MATCH; } // where the only arg is an invocation if (tree.getArguments().size() != 1) { return Description.NO_MATCH; } ExpressionTree arg = Iterables.getOnlyElement(tree.getArguments()); if (!(arg instanceof JCMethodInvocation)) { return Description.NO_MATCH; } // and the invocation's inferred erased and uninstantiated erased return types differ JCMethodInvocation call = (JCMethodInvocation) arg; Types types = state.getTypes(); if (call.meth.type == null) { return Description.NO_MATCH; } Type instantiatedReturnType = types.erasure(call.meth.type.getReturnType()); if (instantiatedReturnType == null) { return Description.NO_MATCH; } MethodSymbol methodSym = ASTHelpers.getSymbol(call); if (methodSym == null) { return Description.NO_MATCH; } if (methodSym.type == null) { return Description.NO_MATCH; } Type uninstantiatedReturnType = types.erasure(methodSym.type.getReturnType()); if (uninstantiatedReturnType == null) { return Description.NO_MATCH; } if (types.isSameType(instantiatedReturnType, uninstantiatedReturnType)) { return Description.NO_MATCH; } if (!MockAnswerStrategyScanner.scan(call.getMethodSelect(), state, badMocks)) { return Description.NO_MATCH; } final SuggestedFix.Builder fixBuilder = SuggestedFix.builder(); String qual = uninstantiatedReturnType.tsym.getTypeParameters().isEmpty() ? SuggestedFixes.qualifyType(state, fixBuilder, uninstantiatedReturnType.tsym) : "Object"; fixBuilder.prefixWith(arg, String.format("(%s) ", qual)); return describeMatch(tree, fixBuilder.build()); } } /** * Scans for the mock variable in a when(...), and checks if it has a bad answer strategy. * * <p>@Mock annotations are handled here instead of in {@link MockInitializationScanner} * because they're visible across compilation boundaries, so scanning the declarations in the * current compilation would result in false negatives. */ static class MockAnswerStrategyScanner extends TreeScanner<Boolean, Void> { static boolean scan(Tree tree, VisitorState state, Set<VarSymbol> badMocks) { return firstNonNull(tree.accept(new MockAnswerStrategyScanner(state, badMocks), null), false); } private final VisitorState state; private final Set<VarSymbol> badMocks; public MockAnswerStrategyScanner(VisitorState state, Set<VarSymbol> badMocks) { this.state = state; this.badMocks = badMocks; } @Override public Boolean scan(Tree tree, Void aVoid) { Symbol sym = ASTHelpers.getSymbol(tree); if (sym instanceof VarSymbol) { VarSymbol varSym = (VarSymbol) sym; if (badMocks.contains(ASTHelpers.getSymbol(tree))) { return true; } // custom answer strategies can be specified using @Mock(answer = ...) if (ASTHelpers.hasAnnotation(sym, MOCK_ANNOTATION, state) && !answerHandlesGenerics(varSym, state)) { return true; } // gwtmockito mocks @UiFields if (ASTHelpers.hasAnnotation(varSym, UI_FIELD_ANNOTATION, state)) { return true; } } return super.scan(tree, aVoid); } @Override public Boolean reduce(Boolean r1, Boolean r2) { return firstNonNull(r1, false) || firstNonNull(r2, false); } /** * Returns if the variable has a {@code @Mock} annotation that specifies an answer that does * not handle generics. */ static boolean answerHandlesGenerics(VarSymbol varSym, VisitorState state) { Compound attribute = varSym.attribute(state.getSymbolFromString(MOCK_ANNOTATION)); String answer = null; for (Entry<MethodSymbol, Attribute> e : attribute.getElementValues().entrySet()) { if (e.getKey().getSimpleName().contentEquals("answer")) { answer = e.getValue().getValue().toString(); break; } } return !BAD_ANSWER_STRATEGIES.contains(answer); } } }