/** * BSD-style license; for more info see http://pmd.sourceforge.net/license.html */ package net.sourceforge.pmd.lang.apex.rule.security; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.regex.Pattern; import net.sourceforge.pmd.lang.apex.ast.ASTAssignmentExpression; import net.sourceforge.pmd.lang.apex.ast.ASTBinaryExpression; import net.sourceforge.pmd.lang.apex.ast.ASTFieldDeclaration; import net.sourceforge.pmd.lang.apex.ast.ASTLiteralExpression; import net.sourceforge.pmd.lang.apex.ast.ASTMethod; import net.sourceforge.pmd.lang.apex.ast.ASTMethodCallExpression; import net.sourceforge.pmd.lang.apex.ast.ASTStandardCondition; import net.sourceforge.pmd.lang.apex.ast.ASTUserClass; import net.sourceforge.pmd.lang.apex.ast.ASTVariableDeclaration; import net.sourceforge.pmd.lang.apex.ast.ASTVariableExpression; import net.sourceforge.pmd.lang.apex.ast.AbstractApexNode; import net.sourceforge.pmd.lang.apex.rule.AbstractApexRule; import apex.jorje.semantic.ast.member.Parameter; import apex.jorje.semantic.ast.statement.VariableDeclaration; /** * Detects if variables in Database.query(variable) is escaped with * String.escapeSingleQuotes * * @author sergey.gorbaty * */ public class ApexSOQLInjectionRule extends AbstractApexRule { private static final String DOUBLE = "double"; private static final String LONG = "long"; private static final String DECIMAL = "decimal"; private static final String BOOLEAN = "boolean"; private static final String ID = "id"; private static final String INTEGER = "integer"; private static final String JOIN = "join"; private static final String ESCAPE_SINGLE_QUOTES = "escapeSingleQuotes"; private static final String STRING = "String"; private static final String DATABASE = "Database"; private static final String QUERY = "query"; private static final Pattern SELECT_PATTERN = Pattern.compile("^select[\\s]+?.*?$", Pattern.CASE_INSENSITIVE); private final HashSet<String> safeVariables = new HashSet<>(); private final HashMap<String, Boolean> selectContainingVariables = new HashMap<>(); public ApexSOQLInjectionRule() { setProperty(CODECLIMATE_CATEGORIES, new String[] { "Security" }); setProperty(CODECLIMATE_REMEDIATION_MULTIPLIER, 100); setProperty(CODECLIMATE_BLOCK_HIGHLIGHTING, false); } @Override public Object visit(ASTUserClass node, Object data) { if (Helper.isTestMethodOrClass(node) || Helper.isSystemLevelClass(node)) { return data; // stops all the rules } final List<ASTMethod> methodExpr = node.findDescendantsOfType(ASTMethod.class); for (ASTMethod m : methodExpr) { findSafeVariablesInSignature(m); } final List<ASTFieldDeclaration> fieldExpr = node.findDescendantsOfType(ASTFieldDeclaration.class); for (ASTFieldDeclaration a : fieldExpr) { findSanitizedVariables(a); findSelectContainingVariables(a); } // String foo = String.escapeSignleQuotes(...); final List<ASTVariableDeclaration> variableDecl = node.findDescendantsOfType(ASTVariableDeclaration.class); for (ASTVariableDeclaration a : variableDecl) { findSanitizedVariables(a); findSelectContainingVariables(a); } // baz = String.escapeSignleQuotes(...); final List<ASTAssignmentExpression> assignmentCalls = node.findDescendantsOfType(ASTAssignmentExpression.class); for (ASTAssignmentExpression a : assignmentCalls) { findSanitizedVariables(a); findSelectContainingVariables(a); } // Database.query(...) check final List<ASTMethodCallExpression> potentialDbQueryCalls = node .findDescendantsOfType(ASTMethodCallExpression.class); for (ASTMethodCallExpression m : potentialDbQueryCalls) { if (!Helper.isTestMethodOrClass(m) && Helper.isMethodName(m, DATABASE, QUERY)) { reportStrings(m, data); reportVariables(m, data); } } safeVariables.clear(); selectContainingVariables.clear(); return data; } private void findSafeVariablesInSignature(ASTMethod m) { List<Parameter> parameters = m.getNode().getMethodInfo().getParameters(); for (Parameter p : parameters) { switch (p.getType().getApexName().toLowerCase()) { case ID: case INTEGER: case BOOLEAN: case DECIMAL: case LONG: case DOUBLE: safeVariables.add(Helper.getFQVariableName(p)); break; default: break; } } } private void findSanitizedVariables(AbstractApexNode<?> node) { final ASTVariableExpression left = node.getFirstChildOfType(ASTVariableExpression.class); final ASTLiteralExpression literal = node.getFirstChildOfType(ASTLiteralExpression.class); final ASTMethodCallExpression right = node.getFirstChildOfType(ASTMethodCallExpression.class); // look for String a = 'b'; if (literal != null) { if (left != null) { Object o = literal.getNode().getLiteral(); if (o instanceof Integer || o instanceof Boolean || o instanceof Double) { safeVariables.add(Helper.getFQVariableName(left)); } if (o instanceof String) { if (SELECT_PATTERN.matcher((String) o).matches()) { selectContainingVariables.put(Helper.getFQVariableName(left), Boolean.TRUE); } else { safeVariables.add(Helper.getFQVariableName(left)); } } } } // look for String a = String.escapeSingleQuotes(foo); if (right != null) { if (Helper.isMethodName(right, STRING, ESCAPE_SINGLE_QUOTES)) { if (left != null) { safeVariables.add(Helper.getFQVariableName(left)); } } } if (node instanceof ASTVariableDeclaration) { VariableDeclaration o = (VariableDeclaration) node.getNode(); switch (o.getLocalInfo().getType().getApexName().toLowerCase()) { case INTEGER: case ID: case BOOLEAN: case DECIMAL: case LONG: case DOUBLE: safeVariables.add(Helper.getFQVariableName(left)); break; default: break; } } } private void findSelectContainingVariables(AbstractApexNode<?> node) { final ASTVariableExpression left = node.getFirstChildOfType(ASTVariableExpression.class); final ASTBinaryExpression right = node.getFirstChildOfType(ASTBinaryExpression.class); if (left != null && right != null) { recursivelyCheckForSelect(left, right); } } private void recursivelyCheckForSelect(final ASTVariableExpression var, final ASTBinaryExpression node) { final ASTBinaryExpression right = node.getFirstChildOfType(ASTBinaryExpression.class); if (right != null) { recursivelyCheckForSelect(var, right); } final ASTVariableExpression concatenatedVar = node.getFirstChildOfType(ASTVariableExpression.class); boolean isSafeVariable = false; if (concatenatedVar != null) { if (safeVariables.contains(Helper.getFQVariableName(concatenatedVar))) { isSafeVariable = true; } } final ASTMethodCallExpression methodCall = node.getFirstChildOfType(ASTMethodCallExpression.class); if (methodCall != null) { if (Helper.isMethodName(methodCall, STRING, ESCAPE_SINGLE_QUOTES)) { isSafeVariable = true; } } final ASTLiteralExpression literal = node.getFirstChildOfType(ASTLiteralExpression.class); if (literal != null) { Object o = literal.getNode().getLiteral(); if (o instanceof String) { if (SELECT_PATTERN.matcher((String) o).matches()) { if (!isSafeVariable) { // select literal + other unsafe vars selectContainingVariables.put(Helper.getFQVariableName(var), Boolean.FALSE); } else { safeVariables.add(Helper.getFQVariableName(var)); } } } } else { if (!isSafeVariable) { selectContainingVariables.put(Helper.getFQVariableName(var), Boolean.FALSE); } } } private void reportStrings(ASTMethodCallExpression m, Object data) { final HashSet<ASTVariableExpression> setOfSafeVars = new HashSet<>(); final List<ASTStandardCondition> conditions = m.findDescendantsOfType(ASTStandardCondition.class); for (ASTStandardCondition c : conditions) { List<ASTVariableExpression> vars = c.findDescendantsOfType(ASTVariableExpression.class); setOfSafeVars.addAll(vars); } final List<ASTBinaryExpression> binaryExpr = m.findChildrenOfType(ASTBinaryExpression.class); for (ASTBinaryExpression b : binaryExpr) { List<ASTVariableExpression> vars = b.findDescendantsOfType(ASTVariableExpression.class); for (ASTVariableExpression v : vars) { String fqName = Helper.getFQVariableName(v); if (selectContainingVariables.containsKey(fqName)) { boolean isLiteral = selectContainingVariables.get(fqName); if (isLiteral) { continue; } } if (setOfSafeVars.contains(v) || safeVariables.contains(fqName)) { continue; } final ASTMethodCallExpression parentCall = v.getFirstParentOfType(ASTMethodCallExpression.class); boolean isSafeMethod = Helper.isMethodName(parentCall, STRING, ESCAPE_SINGLE_QUOTES) || Helper.isMethodName(parentCall, STRING, JOIN); if (!isSafeMethod) { addViolation(data, v); } } } } private void reportVariables(final ASTMethodCallExpression m, Object data) { final ASTVariableExpression var = m.getFirstChildOfType(ASTVariableExpression.class); if (var != null) { String nameFQ = Helper.getFQVariableName(var); if (selectContainingVariables.containsKey(nameFQ)) { boolean isLiteral = selectContainingVariables.get(nameFQ); if (!isLiteral) { addViolation(data, var); } } } } }