/*
* Copyright 2017 Google Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.errorprone.bugpatterns;
import static com.google.common.collect.Iterables.getLast;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.errorprone.matchers.Description.NO_MATCH;
import static com.google.errorprone.matchers.Matchers.expressionStatement;
import static com.google.errorprone.matchers.Matchers.staticMethod;
import static com.google.errorprone.matchers.method.MethodMatchers.instanceMethod;
import static com.google.errorprone.util.ASTHelpers.getReceiver;
import static com.google.errorprone.util.ASTHelpers.isSubtype;
import com.google.common.base.Joiner;
import com.google.common.collect.Iterators;
import com.google.common.collect.PeekingIterator;
import com.google.errorprone.VisitorState;
import com.google.errorprone.bugpatterns.BugChecker.MethodTreeMatcher;
import com.google.errorprone.fixes.Fix;
import com.google.errorprone.fixes.SuggestedFix;
import com.google.errorprone.fixes.SuggestedFix.Builder;
import com.google.errorprone.fixes.SuggestedFixes;
import com.google.errorprone.matchers.Description;
import com.google.errorprone.matchers.Matcher;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.BlockTree;
import com.sun.source.tree.ExpressionStatementTree;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.MethodTree;
import com.sun.source.tree.StatementTree;
import com.sun.source.tree.Tree;
import com.sun.source.tree.Tree.Kind;
import com.sun.source.util.TreeScanner;
import com.sun.tools.javac.code.Symbol.MethodSymbol;
import com.sun.tools.javac.code.Symtab;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.tree.JCTree;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;
/** @author cushon@google.com (Liam Miller-Cushon) */
public abstract class AbstractExpectedExceptionChecker extends BugChecker
implements MethodTreeMatcher {
static final Matcher<StatementTree> MATCHER =
expressionStatement(
instanceMethod()
.onExactClass("org.junit.rules.ExpectedException")
.withNameMatching(Pattern.compile("expect.*")));
static final Matcher<ExpressionTree> IS_A =
staticMethod()
.onClassAny("org.hamcrest.Matchers", "org.hamcrest.CoreMatchers")
.withSignature("<T>isA(java.lang.Class<T>)");
@Override
public Description matchMethod(MethodTree tree, VisitorState state) {
if (tree.getBody() == null) {
return NO_MATCH;
}
tree.getBody()
.accept(
new TreeScanner<Void, Void>() {
@Override
public Void visitBlock(BlockTree block, Void unused) {
Description description = scanBlock(tree, block, state);
if (description != NO_MATCH) {
state.reportMatch(description);
}
return super.visitBlock(block, unused);
}
},
null);
return NO_MATCH;
}
Description scanBlock(MethodTree tree, BlockTree block, VisitorState state) {
PeekingIterator<? extends StatementTree> it =
Iterators.peekingIterator(block.getStatements().iterator());
while (it.hasNext() && !MATCHER.matches(it.peek(), state)) {
it.next();
}
List<Tree> expectations = new ArrayList<>();
while (it.hasNext() && MATCHER.matches(it.peek(), state)) {
expectations.add(it.next());
}
if (expectations.isEmpty()) {
return NO_MATCH;
}
List<StatementTree> suffix = new ArrayList<>();
Iterators.addAll(suffix, it);
return handleMatch(tree, state, expectations, suffix);
}
/**
* Handle a method that contains a use of {@code ExpectedException}.
*
* @param tree the method
* @param state the visitor state
* @param expectations the statements for the call to {@code thrown.except(...)}, and any
* additional assertions
* @param suffix the statements after the assertions, which are expected to throw
*/
protected abstract Description handleMatch(
MethodTree tree, VisitorState state, List<Tree> expectations, List<StatementTree> suffix);
protected BaseFix buildBaseFix(VisitorState state, List<Tree> expectations) {
String exceptionClass = "Throwable";
// additional assertions to perform on the captured exception (if any)
List<String> newAsserts = new ArrayList<>();
Builder fix = SuggestedFix.builder();
for (Tree expectation : expectations) {
MethodInvocationTree invocation =
(MethodInvocationTree) ((ExpressionStatementTree) expectation).getExpression();
MethodSymbol symbol = ASTHelpers.getSymbol(invocation);
Symtab symtab = state.getSymtab();
List<? extends ExpressionTree> args = invocation.getArguments();
switch (symbol.getSimpleName().toString()) {
case "expect":
Type type = ASTHelpers.getType(getOnlyElement(invocation.getArguments()));
if (isSubtype(type, symtab.classType, state)) {
// expect(Class<?>)
exceptionClass = state.getSourceForNode(getReceiver(getOnlyElement(args)));
} else if (isSubtype(type, state.getTypeFromString("org.hamcrest.Matcher"), state)) {
Type matcherType =
state.getTypes().asSuper(type, state.getSymbolFromString("org.hamcrest.Matcher"));
if (!matcherType.getTypeArguments().isEmpty()) {
Type matchType = getOnlyElement(matcherType.getTypeArguments());
if (isSubtype(matchType, symtab.throwableType, state)) {
exceptionClass = SuggestedFixes.qualifyType(state, fix, matchType);
}
}
// expect(Matcher)
fix.addStaticImport("org.hamcrest.MatcherAssert.assertThat");
newAsserts.add(
String.format(
"assertThat(thrown, %s);", state.getSourceForNode(getOnlyElement(args))));
}
break;
case "expectCause":
ExpressionTree matcher = getOnlyElement(invocation.getArguments());
if (IS_A.matches(matcher, state)) {
fix.addStaticImport("com.google.common.truth.Truth.assertThat");
newAsserts.add(
String.format(
"assertThat(thrown).hasCauseThat().isInstanceOf(%s);",
state.getSourceForNode(
getOnlyElement(((MethodInvocationTree) matcher).getArguments()))));
} else {
fix.addStaticImport("org.hamcrest.MatcherAssert.assertThat");
newAsserts.add(
String.format(
"assertThat(thrown.getCause(), %s);",
state.getSourceForNode(getOnlyElement(args))));
}
break;
case "expectMessage":
if (isSubtype(
getOnlyElement(symbol.getParameters()).asType(), symtab.stringType, state)) {
// expectedMessage(String)
fix.addStaticImport("com.google.common.truth.Truth.assertThat");
newAsserts.add(
String.format(
"assertThat(thrown).hasMessageThat().contains(%s);",
state.getSourceForNode(getOnlyElement(args))));
} else {
// expectedMessage(Matcher)
fix.addStaticImport("org.hamcrest.MatcherAssert.assertThat");
newAsserts.add(
String.format(
"assertThat(thrown.getMessage(), %s);",
state.getSourceForNode(getOnlyElement(args))));
}
break;
default:
throw new AssertionError("unknown expect method: " + symbol.getSimpleName());
}
}
// remove all interactions with the ExpectedException rule
fix.replace(
((JCTree) expectations.get(0)).getStartPosition(),
state.getEndPosition(getLast(expectations)),
"");
return new BaseFix(fix.build(), exceptionClass, newAsserts);
}
/** A partially assembled fix. */
protected static class BaseFix {
final SuggestedFix baseFix;
final String exceptionClass;
final List<String> newAsserts;
BaseFix(SuggestedFix baseFix, String exceptionClass, List<String> newAsserts) {
this.baseFix = baseFix;
this.exceptionClass = exceptionClass;
this.newAsserts = newAsserts;
}
public Fix build(List<? extends StatementTree> throwingStatements) {
if (throwingStatements.isEmpty()) {
return baseFix;
}
SuggestedFix.Builder fix = SuggestedFix.builder().merge(baseFix);
StringBuilder fixPrefix = new StringBuilder();
if (newAsserts.isEmpty()) {
fix.addStaticImport("org.junit.Assert.assertThrows");
fixPrefix.append("assertThrows");
} else {
fix.addStaticImport("org.junit.Assert.expectThrows");
fixPrefix.append(String.format("%s thrown = expectThrows", exceptionClass));
}
fixPrefix.append(String.format("(%s.class, () -> ", exceptionClass));
boolean useExpressionLambda =
throwingStatements.size() == 1
&& getOnlyElement(throwingStatements).getKind() == Kind.EXPRESSION_STATEMENT;
if (!useExpressionLambda) {
fixPrefix.append("{");
}
fix.prefixWith(throwingStatements.get(0), fixPrefix.toString());
if (useExpressionLambda) {
fix.postfixWith(((ExpressionStatementTree) throwingStatements.get(0)).getExpression(), ")");
fix.postfixWith(getLast(throwingStatements), Joiner.on('\n').join(newAsserts));
} else {
fix.postfixWith(getLast(throwingStatements), "});\n" + Joiner.on('\n').join(newAsserts));
}
return fix.build();
}
}
}