/**
* BSD-style license; for more info see http://pmd.sourceforge.net/license.html
*/
package net.sourceforge.pmd.lang.java.rule.design;
import net.sourceforge.pmd.lang.ast.Node;
import net.sourceforge.pmd.lang.java.ast.ASTBlock;
import net.sourceforge.pmd.lang.java.ast.ASTBooleanLiteral;
import net.sourceforge.pmd.lang.java.ast.ASTIfStatement;
import net.sourceforge.pmd.lang.java.ast.ASTMethodDeclaration;
import net.sourceforge.pmd.lang.java.ast.ASTPrimitiveType;
import net.sourceforge.pmd.lang.java.ast.ASTResultType;
import net.sourceforge.pmd.lang.java.ast.ASTReturnStatement;
import net.sourceforge.pmd.lang.java.ast.ASTUnaryExpressionNotPlusMinus;
import net.sourceforge.pmd.lang.java.rule.AbstractJavaRule;
public class SimplifyBooleanReturnsRule extends AbstractJavaRule {
public Object visit(ASTMethodDeclaration node, Object data) {
// only boolean methods should be inspected
ASTResultType r = node.getResultType();
if (!r.isVoid()) {
Node t = r.jjtGetChild(0);
if (t.jjtGetNumChildren() == 1) {
t = t.jjtGetChild(0);
if (t instanceof ASTPrimitiveType && ((ASTPrimitiveType) t).isBoolean()) {
return super.visit(node, data);
}
}
}
// skip method
return data;
}
public Object visit(ASTIfStatement node, Object data) {
// that's the case: if..then..return; return;
if (!node.hasElse() && isIfJustReturnsBoolean(node) && isJustReturnsBooleanAfter(node)) {
addViolation(data, node);
return super.visit(node, data);
}
// only deal with if..then..else stmts
if (node.jjtGetNumChildren() != 3) {
return super.visit(node, data);
}
// don't bother if either the if or the else block is empty
if (node.jjtGetChild(1).jjtGetNumChildren() == 0 || node.jjtGetChild(2).jjtGetNumChildren() == 0) {
return super.visit(node, data);
}
Node returnStatement1 = node.jjtGetChild(1).jjtGetChild(0);
Node returnStatement2 = node.jjtGetChild(2).jjtGetChild(0);
if (returnStatement1 instanceof ASTReturnStatement && returnStatement2 instanceof ASTReturnStatement) {
Node expression1 = returnStatement1.jjtGetChild(0).jjtGetChild(0);
Node expression2 = returnStatement2.jjtGetChild(0).jjtGetChild(0);
if (terminatesInBooleanLiteral(returnStatement1) && terminatesInBooleanLiteral(returnStatement2)) {
addViolation(data, node);
} else if (expression1 instanceof ASTUnaryExpressionNotPlusMinus
^ expression2 instanceof ASTUnaryExpressionNotPlusMinus) {
// We get the nodes under the '!' operator
// If they are the same => error
if (isNodesEqualWithUnaryExpression(expression1, expression2)) {
// second case:
// If
// Expr
// Statement
// ReturnStatement
// UnaryExpressionNotPlusMinus '!'
// Expression E
// Statement
// ReturnStatement
// Expression E
// i.e.,
// if (foo)
// return !a;
// else
// return a;
addViolation(data, node);
}
}
} else if (hasOneBlockStmt(node.jjtGetChild(1)) && hasOneBlockStmt(node.jjtGetChild(2))) {
// We have blocks so we must go down three levels (BlockStatement,
// Statement, ReturnStatement)
returnStatement1 = returnStatement1.jjtGetChild(0).jjtGetChild(0).jjtGetChild(0);
returnStatement2 = returnStatement2.jjtGetChild(0).jjtGetChild(0).jjtGetChild(0);
// if we have 2 return;
if (isSimpleReturn(returnStatement1) && isSimpleReturn(returnStatement2)) {
// third case
// If
// Expr
// Statement
// Block
// BlockStatement
// Statement
// ReturnStatement
// Statement
// Block
// BlockStatement
// Statement
// ReturnStatement
// i.e.,
// if (foo) {
// return true;
// } else {
// return false;
// }
addViolation(data, node);
} else {
Node expression1 = getDescendant(returnStatement1, 4);
Node expression2 = getDescendant(returnStatement2, 4);
if (terminatesInBooleanLiteral(node.jjtGetChild(1).jjtGetChild(0))
&& terminatesInBooleanLiteral(node.jjtGetChild(2).jjtGetChild(0))) {
addViolation(data, node);
} else if (expression1 instanceof ASTUnaryExpressionNotPlusMinus
^ expression2 instanceof ASTUnaryExpressionNotPlusMinus) {
// We get the nodes under the '!' operator
// If they are the same => error
if (isNodesEqualWithUnaryExpression(expression1, expression2)) {
// forth case
// If
// Expr
// Statement
// Block
// BlockStatement
// Statement
// ReturnStatement
// UnaryExpressionNotPlusMinus '!'
// Expression E
// Statement
// Block
// BlockStatement
// Statement
// ReturnStatement
// Expression E
// i.e.,
// if (foo) {
// return !a;
// } else {
// return a;
// }
addViolation(data, node);
}
}
}
}
return super.visit(node, data);
}
/**
* Checks, whether there is a statement after the given if statement, and if
* so, whether this is just a return boolean statement.
*
* @param node
* the if statement
* @return
*/
private boolean isJustReturnsBooleanAfter(ASTIfStatement ifNode) {
Node blockStatement = ifNode.jjtGetParent().jjtGetParent();
Node block = blockStatement.jjtGetParent();
if (block.jjtGetNumChildren() != blockStatement.jjtGetChildIndex() + 1 + 1) {
return false;
}
Node nextBlockStatement = block.jjtGetChild(blockStatement.jjtGetChildIndex() + 1);
return terminatesInBooleanLiteral(nextBlockStatement);
}
/**
* Checks whether the given ifstatement just returns a boolean in the if
* clause.
*
* @param node
* the if statement
* @return
*/
private boolean isIfJustReturnsBoolean(ASTIfStatement ifNode) {
Node node = ifNode.jjtGetChild(1);
return node.jjtGetNumChildren() == 1
&& (hasOneBlockStmt(node) || terminatesInBooleanLiteral(node.jjtGetChild(0)));
}
private boolean hasOneBlockStmt(Node node) {
return node.jjtGetChild(0) instanceof ASTBlock && node.jjtGetChild(0).jjtGetNumChildren() == 1
&& terminatesInBooleanLiteral(node.jjtGetChild(0).jjtGetChild(0));
}
/**
* Returns the first child node going down 'level' levels or null if level
* is invalid
*/
private Node getDescendant(Node node, int level) {
Node n = node;
for (int i = 0; i < level; i++) {
if (n.jjtGetNumChildren() == 0) {
return null;
}
n = n.jjtGetChild(0);
}
return n;
}
private boolean terminatesInBooleanLiteral(Node node) {
return eachNodeHasOneChild(node) && getLastChild(node) instanceof ASTBooleanLiteral;
}
private boolean eachNodeHasOneChild(Node node) {
if (node.jjtGetNumChildren() > 1) {
return false;
}
if (node.jjtGetNumChildren() == 0) {
return true;
}
return eachNodeHasOneChild(node.jjtGetChild(0));
}
private Node getLastChild(Node node) {
if (node.jjtGetNumChildren() == 0) {
return node;
}
return getLastChild(node.jjtGetChild(0));
}
private boolean isNodesEqualWithUnaryExpression(Node n1, Node n2) {
Node node1;
Node node2;
if (n1 instanceof ASTUnaryExpressionNotPlusMinus) {
node1 = n1.jjtGetChild(0);
} else {
node1 = n1;
}
if (n2 instanceof ASTUnaryExpressionNotPlusMinus) {
node2 = n2.jjtGetChild(0);
} else {
node2 = n2;
}
return isNodesEquals(node1, node2);
}
private boolean isNodesEquals(Node n1, Node n2) {
int numberChild1 = n1.jjtGetNumChildren();
int numberChild2 = n2.jjtGetNumChildren();
if (numberChild1 != numberChild2) {
return false;
}
if (!n1.getClass().equals(n2.getClass())) {
return false;
}
if (!n1.toString().equals(n2.toString())) {
return false;
}
for (int i = 0; i < numberChild1; i++) {
if (!isNodesEquals(n1.jjtGetChild(i), n2.jjtGetChild(i))) {
return false;
}
}
return true;
}
private boolean isSimpleReturn(Node node) {
return node instanceof ASTReturnStatement && node.jjtGetNumChildren() == 0;
}
}