package org.n3r.eql.dbfieldcryptor.parser;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.*;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.oracle.ast.stmt.*;
import com.alibaba.druid.sql.dialect.oracle.parser.OracleStatementParser;
import com.alibaba.druid.sql.dialect.oracle.visitor.OracleASTVisitorAdapter;
import com.alibaba.druid.sql.parser.ParserException;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.google.common.base.MoreObjects;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Slf4j
public class OracleSensitiveFieldsParser implements SensitiveFieldsParser {
private final Map<String, Object> aliasTablesMap = Maps.newHashMap();
private final Set<Integer> secureBindIndices = Sets.newHashSet();
private final Set<Integer> secureResultIndices = Sets.newHashSet();
private final Set<String> secureResultLabels = Sets.newHashSet();
private final List<BindVariant> subQueryBindAndVariantOfFrom = Lists.<BindVariant>newArrayList();
private final Set<String> secureFields;
private @Getter int variantIndex = 0;
private final String sql;
private OracleASTVisitorAdapter adapter = new OracleASTVisitorAdapter() {
@Override
public boolean visit(SQLVariantRefExpr x) {
++variantIndex;
return true;
}
@Override
public boolean visit(SQLBinaryOpExpr x) {
if (hasSecureField(x.getLeft())) {
checkOnlyOneAsk(x.getRight());
} else if (hasSecureField(x.getRight())) {
checkOnlyOneAsk(x.getLeft());
}
return true;
}
private boolean hasSecureField(SQLExpr field) {
return field instanceof SQLIdentifierExpr && isSecureField((SQLIdentifierExpr) field)
|| field instanceof SQLPropertyExpr && isSecureField((SQLPropertyExpr) field);
}
};
// TIPS PART FORMAT: /*** bind(1,2,3) result(1) ***/
private static Pattern encryptHint = Pattern.compile("\\s*/\\*{3}\\s*(.*?)\\s*\\*{3}/");
private static OracleSensitiveFieldsParser tryParseHint(String sql, Set<String> secureFields) {
OracleSensitiveFieldsParser fieldsParser = null;
Matcher matcher = encryptHint.matcher(sql);
if (matcher.find() && matcher.start() == 0) {
String convertedSql = sql.substring(matcher.end());
String hint = matcher.group(1);
fieldsParser = new OracleSensitiveFieldsParser(secureFields, convertedSql);
fieldsParser.parseHint(hint);
}
return fieldsParser;
}
public static OracleSensitiveFieldsParser parseSql(String sql, Set<String> secureFields) {
OracleSensitiveFieldsParser fieldsParser = tryParseHint(sql, secureFields);
if (fieldsParser == null) {
SQLStatement sqlStatement = parseSql(sql);
fieldsParser = new OracleSensitiveFieldsParser(secureFields, sql);
fieldsParser = parseStatement(fieldsParser, sqlStatement);
}
if (fieldsParser == null) return null;
if (fieldsParser.haveNonSecureFields()) return null;
return fieldsParser;
}
private static OracleSensitiveFieldsParser parseStatement(
OracleSensitiveFieldsParser parser,
SQLStatement sqlStatement) {
if (sqlStatement instanceof SQLSelectStatement) {
parser.parseSelectQuery(((SQLSelectStatement) sqlStatement).getSelect().getQuery());
} else if (sqlStatement instanceof OracleDeleteStatement) {
parser.parseDelete((OracleDeleteStatement) sqlStatement);
} else if (sqlStatement instanceof OracleInsertStatement) {
parser.parseInsert((OracleInsertStatement) sqlStatement);
} else if (sqlStatement instanceof OracleUpdateStatement) {
parser.parseUpdate((OracleUpdateStatement) sqlStatement);
} else if (sqlStatement instanceof OracleMergeStatement) {
parser.parseMerge((OracleMergeStatement) sqlStatement);
} else if (sqlStatement instanceof SQLCallStatement) {
parser.parseCall((SQLCallStatement) sqlStatement);
} else if (sqlStatement instanceof OracleMultiInsertStatement) {
parser.parseMultiInsert((OracleMultiInsertStatement) sqlStatement);
}
return parser;
}
private void parseSelectQuery(SQLSelectQuery query) {
if (query instanceof SQLSelectQueryBlock) {
parseQuery((SQLSelectQueryBlock) query);
} else if (query instanceof SQLUnionQuery) {
parseUnionQuery((SQLUnionQuery) query);
}
}
private static SQLStatement parseSql(String sql) {
SQLStatementParser parser = new OracleStatementParser(sql);
List<SQLStatement> stmtList;
try {
stmtList = parser.parseStatementList();
} catch (ParserException exception) {
exception.printStackTrace();
throw new RuntimeException(sql + " is invalid, detail " + exception.getMessage());
}
return stmtList.get(0);
}
private static Pattern bindPattern = Pattern.compile("bind\\s*\\((.*?)\\)");
private static Pattern resultPattern = Pattern.compile("result\\s*\\((.*?)\\)");
private static Splitter indexSplitter = Splitter.on(',').omitEmptyStrings().trimResults();
private OracleSensitiveFieldsParser(Set<String> secureFields, String sql) {
this.secureFields = secureFields;
this.sql = sql;
}
private void parseHint(String hint) {
Matcher matcher = bindPattern.matcher(hint);
if (matcher.find()) {
Iterable<String> bindIndices = indexSplitter.split(matcher.group(1));
for (String bindIndex : bindIndices)
secureBindIndices.add(Integer.parseInt(bindIndex));
}
matcher = resultPattern.matcher(hint);
if (matcher.find()) {
Iterable<String> resultIndices = indexSplitter.split(matcher.group(1));
for (String resultIndex : resultIndices)
secureResultIndices.add(Integer.parseInt(resultIndex));
}
}
private void parseUnionQuery(SQLUnionQuery sqlUnionQuery) {
SQLSelectQuery left = sqlUnionQuery.getLeft();
parseQuery((SQLSelectQueryBlock) left);
SQLSelectQuery right = sqlUnionQuery.getRight();
if (right instanceof SQLUnionQuery) {
parseUnionQuery((SQLUnionQuery) right);
} else {
parseQuery((SQLSelectQueryBlock) right);
}
}
private void parseQuery(SQLSelectQueryBlock queryBlock) {
parseTable(queryBlock.getFrom());
parseSelectItems(queryBlock.getSelectList());
adjustSubQueryBindIndicesOfFrom();
if (queryBlock.getWhere() != null)
queryBlock.getWhere().accept(adapter);
}
private void adjustSubQueryBindIndicesOfFrom() {
for (BindVariant bindVariant : this.subQueryBindAndVariantOfFrom) {
for (Integer index : bindVariant.getBindIndices())
this.secureBindIndices.add(this.variantIndex + index);
this.variantIndex += bindVariant.getVariantIndex();
}
}
private void parseDelete(OracleDeleteStatement deleteStatement) {
val tableSource = (SQLExprTableSource) deleteStatement.getTableSource();
if (tableSource.getExpr() instanceof SQLIdentifierExpr)
addTableAlias(tableSource, (SQLIdentifierExpr) tableSource.getExpr());
if (deleteStatement.getWhere() != null)
deleteStatement.getWhere().accept(adapter);
}
private void parseCall(SQLCallStatement callStatement) {
addTableAlias("", callStatement.getProcedureName().toString());
boolean isOraFunc = callStatement.getOutParameter() != null;
if (isOraFunc && isSecureField(1)) secureBindIndices.add(1);
List<SQLExpr> parameters = callStatement.getParameters();
for (int i = 0, ii = parameters.size(); i < ii; ++i) {
SQLExpr parameter = parameters.get(i);
parameter.accept(adapter);
int paramIndex = i + 1 + (isOraFunc ? 1 : 0);
if (!isSecureField(paramIndex)) continue;
if (parameter instanceof SQLVariantRefExpr) {
secureBindIndices.add(variantIndex + (isOraFunc ? 1 : 0));
} else {
log.warn("secure field is not passed as a single value in sql [" + sql + "]");
}
}
}
private void parseMerge(OracleMergeStatement mergeStatement) {
if (mergeStatement.getInto() instanceof SQLIdentifierExpr) {
SQLIdentifierExpr expr = (SQLIdentifierExpr) mergeStatement.getInto();
addTableAlias(mergeStatement.getAlias(), expr);
}
mergeStatement.getOn().accept(adapter);
val updateClause = mergeStatement.getUpdateClause();
if (updateClause != null) {
List<SQLUpdateSetItem> items = updateClause.getItems();
walkUpdateItems(items);
}
val insertClause = mergeStatement.getInsertClause();
if (insertClause != null) {
val secureFieldsIndices = walkInsertColumns(insertClause.getColumns());
walkInsertValues(secureFieldsIndices, insertClause.getValues());
}
}
private void parseUpdate(OracleUpdateStatement updateStatement) {
val tableSource = (OracleSelectTableReference) updateStatement.getTableSource();
if (tableSource.getExpr() instanceof SQLIdentifierExpr)
addTableAlias(tableSource, (SQLIdentifierExpr) tableSource.getExpr());
val items = updateStatement.getItems();
val item0 = items.get(0);
if (items.size() == 1 && item0.getColumn() instanceof SQLListExpr
&& item0.getValue() instanceof SQLQueryExpr) {
// update xxx set (a,b) = (select ... from) where
walkUpdateSelect(item0);
} else {
walkUpdateItems(items);
}
if (updateStatement.getWhere() != null)
updateStatement.getWhere().accept(adapter);
}
private void walkUpdateSelect(SQLUpdateSetItem item) {
val sqlListExpr = (SQLListExpr) item.getColumn();
List<SQLExpr> items = sqlListExpr.getItems();
Set<Integer> secureFieldIndices = Sets.newHashSet();
for (int i = 0, ii = items.size(); i < ii; ++i) {
SQLExpr expr = items.get(i);
if (expr instanceof SQLPropertyExpr && isSecureField((SQLPropertyExpr) expr)) {
secureFieldIndices.add(i);
}
}
val value = (SQLQueryExpr) item.getValue();
val queryBlock = (SQLSelectQueryBlock) value.getSubQuery().getQuery();
parseTable(queryBlock.getFrom());
parseSelectItemsInUpdate(secureFieldIndices, queryBlock.getSelectList());
if (queryBlock.getWhere() != null)
queryBlock.getWhere().accept(adapter);
}
private void parseSelectItemsInUpdate(
Set<Integer> secureFieldIndices, List<SQLSelectItem> selectList) {
for (int i = 0, ii = selectList.size(); i < ii; ++i) {
SQLSelectItem item = selectList.get(i);
item.accept(adapter);
if (secureFieldIndices.contains(i) && item.getExpr() instanceof SQLVariantRefExpr) {
secureBindIndices.add(variantIndex);
}
}
}
private void walkUpdateItems(List<SQLUpdateSetItem> items) {
for (int i = 0, ii = items.size(); i < ii; ++i) {
SQLUpdateSetItem item = items.get(i);
item.accept(adapter);
boolean isSecureField = false;
if (item.getColumn() instanceof SQLPropertyExpr) {
SQLPropertyExpr expr = (SQLPropertyExpr) item.getColumn();
isSecureField = isSecureField(expr);
} else if (item.getColumn() instanceof SQLIdentifierExpr) {
isSecureField = isSecureField((SQLIdentifierExpr) item.getColumn());
}
if (!isSecureField) continue;
if (item.getValue() instanceof SQLVariantRefExpr) {
secureBindIndices.add(variantIndex);
} else {
log.warn("secure field is not updated as a single value in sql [" + sql + "]");
}
}
}
// only check one situation of right ? like: A.PCARD_CODE = upper(?)
private void checkOnlyOneAsk(SQLExpr right) {
val rightVariantIndex = new AtomicInteger(0);
right.accept(new OracleASTVisitorAdapter() {
@Override
public boolean visit(SQLVariantRefExpr x) {
rightVariantIndex.incrementAndGet();
return true;
}
});
if (rightVariantIndex.get() == 1)
secureBindIndices.add(variantIndex + 1);
}
private boolean isSecureField(SQLAllColumnExpr field) {
Object oneTableName = getOneTableName();
return oneTableName != null && containsInSecureFields(oneTableName, "*");
}
private boolean isSecureField(SQLIdentifierExpr field) {
Object oneTableName = getOneTableName();
return oneTableName != null && containsInSecureFields(oneTableName, field.getName());
}
private boolean isSecureField(SQLPropertyExpr expr) {
Object tableName = aliasTablesMap.get(expr.getOwner().toString());
String fieldName = expr.getName();
return containsInSecureFields(tableName, fieldName);
}
private boolean containsInSecureFields(Object tableName, String fieldName) {
if (tableName instanceof String)
return containsInSecureFields((String) tableName, fieldName);
else if (tableName instanceof OracleSensitiveFieldsParser)
return containsInSecureFields((OracleSensitiveFieldsParser) tableName, fieldName);
return false;
}
private boolean containsInSecureFields(
OracleSensitiveFieldsParser parser, String fieldName) {
return "*".equals(fieldName)
? !parser.getSecureResultIndices().isEmpty()
: parser.inResultLabels(fieldName);
}
private boolean containsInSecureFields(String tableName, String fieldName) {
String secretField = tableName + "." + fieldName;
return secureFields.contains(secretField.toUpperCase());
}
private boolean isSecureField(int procedureParameterIndex) {
Object oneTableName = getOneTableName();
return oneTableName != null
&& containsInSecureFields(oneTableName, "" + procedureParameterIndex);
}
private Object getOneTableName() {
if (aliasTablesMap.size() == 1)
for (Map.Entry<String, Object> entry : aliasTablesMap.entrySet())
return entry.getValue();
return null;
}
private void parseTable(SQLTableSource from) {
if (from instanceof OracleSelectTableReference) {
val source = (SQLExprTableSource) from;
if (source.getExpr() instanceof SQLIdentifierExpr)
addTableAlias(from, (SQLIdentifierExpr) source.getExpr());
} else if (from instanceof SQLJoinTableSource) {
val joinTableSource = (SQLJoinTableSource) from;
parseTable(joinTableSource.getLeft());
parseTable(joinTableSource.getRight());
// maybe there are binding variants in connection
val conditionOn = joinTableSource.getCondition();
if (conditionOn != null) conditionOn.accept(adapter);
} else if (from instanceof OracleSelectSubqueryTableSource) {
val tableSource = (OracleSelectSubqueryTableSource) from;
val query = tableSource.getSelect().getQuery();
val subParser = createSubQueryParser(query, QueryBelongs.FROM);
addTableAlias(from, subParser);
}
}
private void addTableAlias(SQLTableSource from,
OracleSensitiveFieldsParser subParser) {
addTableAlias(from.getAlias(), subParser);
}
private void addTableAlias(SQLTableSource from, SQLIdentifierExpr expr) {
addTableAlias(from.getAlias(), expr);
}
private void addTableAlias(String alias, SQLIdentifierExpr expr) {
addTableAlias(alias, expr.getName());
}
private void addTableAlias(String alias, String tableName) {
aliasTablesMap.put(MoreObjects.firstNonNull(alias, tableName), tableName);
}
private void addTableAlias(String alias,
OracleSensitiveFieldsParser subParser) {
aliasTablesMap.put(alias, subParser);
}
private String cleanQuotesAndToUpper(String str) {
String cleanString = str;
if (str.charAt(0) == '"' && str.charAt(str.length() - 1) == '"'
|| str.charAt(0) == '\'' && str.charAt(str.length() - 1) == '\'')
cleanString = str.substring(1, str.length() - 1);
return cleanString.toUpperCase();
}
private void parseSelectItems(List<SQLSelectItem> sqlSelectItems) {
for (int itemIndex = 0, ii = sqlSelectItems.size(); itemIndex < ii; ++itemIndex) {
SQLSelectItem item = sqlSelectItems.get(itemIndex);
String alias = item.getAlias();
if (item.getExpr() instanceof SQLIdentifierExpr) {
val expr = (SQLIdentifierExpr) item.getExpr();
if (isSecureField(expr)) {
secureResultIndices.add(itemIndex + 1);
secureResultLabels.add(
cleanQuotesAndToUpper(alias == null ? expr.getName() : alias));
}
} else if (item.getExpr() instanceof SQLPropertyExpr) {
val expr = (SQLPropertyExpr) item.getExpr();
if (isSecureField(expr)) {
if ("*".equals(expr.getName())) {
Object tableName = aliasTablesMap.get(expr.getOwner().toString());
copyResultIndicesAndLabels(itemIndex, tableName);
} else {
secureResultIndices.add(itemIndex + 1);
secureResultLabels.add(cleanQuotesAndToUpper(alias == null ? expr.getName() : alias));
}
}
} else if (item.getExpr() instanceof SQLAllColumnExpr) {
if (isSecureField((SQLAllColumnExpr) item.getExpr())) {
Object tableName = getOneTableName();
copyResultIndicesAndLabels(itemIndex, tableName);
}
} else if (item.getExpr() instanceof SQLQueryExpr) {
val expr = (SQLQueryExpr) item.getExpr();
val subQuery = expr.getSubQuery().getQuery();
val subParser = createSubQueryParser(subQuery, QueryBelongs.SELECT);
if (subParser.inResultIndices(1)) {
secureResultIndices.add(itemIndex + 1);
Set<String> labels = subParser.getSecureResultLabels();
secureResultLabels.add(
cleanQuotesAndToUpper(alias == null ? labels.iterator().next() : alias));
}
}
}
}
private void copyResultIndicesAndLabels(int itemIndex, Object tableName) {
if (tableName instanceof OracleSensitiveFieldsParser) {
val parser = (OracleSensitiveFieldsParser) tableName;
for (Integer resultIndex : parser.getSecureResultIndices()) {
secureResultIndices.add(resultIndex + itemIndex);
}
secureResultLabels.addAll(parser.getSecureResultLabels());
}
}
private OracleSensitiveFieldsParser createSubQueryParser(
SQLSelectQuery subQuery, QueryBelongs mode) {
val subParser = new OracleSensitiveFieldsParser(secureFields, sql);
subParser.parseSelectQuery(subQuery);
switch (mode) {
case FROM:
val bindAndVariant = new BindVariant(subParser.getVariantIndex(),
subParser.getSecureBindIndices());
subQueryBindAndVariantOfFrom.add(bindAndVariant);
break;
case SELECT:
for (Integer index : subParser.getSecureBindIndices())
this.secureBindIndices.add(variantIndex + index);
variantIndex += subParser.getVariantIndex();
break;
}
return subParser;
}
private void parseMultiInsert(OracleMultiInsertStatement multiInsertStatement) {
val entries = multiInsertStatement.getEntries();
for (OracleMultiInsertStatement.Entry entry : entries) {
parseInsert((OracleMultiInsertStatement.InsertIntoClause) entry);
}
}
private void parseInsert(SQLInsertInto x) {
val tableSource = x.getTableSource();
if (tableSource.getExpr() instanceof SQLIdentifierExpr)
addTableAlias(tableSource, (SQLIdentifierExpr) tableSource.getExpr());
List<SQLExpr> columns = x.getColumns();
List<Integer> secureFieldsIndices = walkInsertColumns(columns);
val valuesClause = x.getValues();
// may be insert ... select ...
if (valuesClause != null) {
List<SQLExpr> values = valuesClause.getValues();
walkInsertValues(secureFieldsIndices, values);
} else if (x.getQuery() != null) {
val query = x.getQuery();
parseQuery4Insert(secureFieldsIndices, (SQLSelectQueryBlock) query.getQuery());
}
}
private void parseQuery4Insert(
List<Integer> secureFieldsIndices, SQLSelectQueryBlock queryBlock) {
val selectList = queryBlock.getSelectList();
for (int itemIndex = 0, ii = selectList.size(); itemIndex < ii; ++itemIndex) {
val item = selectList.get(itemIndex);
item.accept(adapter);
if (secureFieldsIndices.contains(itemIndex)
&& item.getExpr() instanceof SQLVariantRefExpr) {
secureBindIndices.add(variantIndex);
}
}
queryBlock.getFrom().accept(adapter);
parseTable(queryBlock.getFrom());
if (queryBlock.getWhere() != null)
queryBlock.getWhere().accept(adapter);
}
private void walkInsertValues(
List<Integer> secureFieldsIndices, List<SQLExpr> values) {
for (int i = 0, ii = values.size(); i < ii; ++i) {
SQLExpr expr = values.get(i);
expr.accept(adapter);
if (secureFieldsIndices.contains(i)) {
if (expr instanceof SQLVariantRefExpr)
secureBindIndices.add(variantIndex);
else
log.warn("secure field is not inserted as a single value in sql [{}]", sql);
}
}
}
private List<Integer> walkInsertColumns(List<SQLExpr> columns) {
List<Integer> secureFieldsIndices = Lists.<Integer>newArrayList();
for (int i = 0, ii = columns.size(); i < ii; ++i) {
SQLExpr column = columns.get(i);
if (column instanceof SQLIdentifierExpr) {
val expr = (SQLIdentifierExpr) column;
if (isSecureField(expr)) secureFieldsIndices.add(i);
}
}
return secureFieldsIndices;
}
@Override
public Set<Integer> getSecureBindIndices() {
return secureBindIndices;
}
@Override
public Set<Integer> getSecureResultIndices() {
return secureResultIndices;
}
@Override
public Set<String> getSecureResultLabels() {
return secureResultLabels;
}
@Override
public boolean inBindIndices(int index) {
return getSecureBindIndices().contains(index);
}
@Override
public boolean inResultIndices(int index) {
return getSecureResultIndices().contains(index);
}
@Override
public boolean inResultLabels(String label) {
return getSecureResultLabels().contains(label);
}
@Override
public boolean inResultIndicesOrLabel(Object indexOrLabel) {
if (indexOrLabel instanceof Number) {
return getSecureResultIndices().contains(indexOrLabel);
} else if (indexOrLabel instanceof String) {
String upper = ((String) indexOrLabel).toUpperCase();
return getSecureResultLabels().contains(upper);
}
return false;
}
@Override
public boolean haveNonSecureFields() {
return secureResultLabels.isEmpty()
&& secureResultIndices.isEmpty()
&& secureBindIndices.isEmpty();
}
@Override
public String getSql() {
return sql;
}
enum QueryBelongs {
FROM, SELECT
}
@AllArgsConstructor @Getter
static class BindVariant {
private Integer variantIndex;
private Set<Integer> bindIndices;
}
}