/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.metamodel.jdbc.dialects; import java.sql.Timestamp; import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.ListIterator; import java.util.Set; import org.apache.metamodel.jdbc.JdbcDataContext; import org.apache.metamodel.query.AggregateFunction; import org.apache.metamodel.query.AverageAggregateFunction; import org.apache.metamodel.query.CountAggregateFunction; import org.apache.metamodel.query.FilterItem; import org.apache.metamodel.query.FromItem; import org.apache.metamodel.query.FunctionType; import org.apache.metamodel.query.MaxAggregateFunction; import org.apache.metamodel.query.MinAggregateFunction; import org.apache.metamodel.query.OperatorType; import org.apache.metamodel.query.Query; import org.apache.metamodel.query.ScalarFunction; import org.apache.metamodel.query.SelectItem; import org.apache.metamodel.query.SumAggregateFunction; import org.apache.metamodel.schema.ColumnType; import org.apache.metamodel.util.CollectionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Generic query rewriter that adds syntax enhancements that are only possible * to resolve just before execution time. */ public class DefaultQueryRewriter extends AbstractQueryRewriter { private static final Logger logger = LoggerFactory.getLogger(DefaultQueryRewriter.class); private static final String SPECIAL_ALIAS_CHARACTERS = "- ,.|*%()!#ยค/\\=?;:~"; private static final Set<Class<? extends FunctionType>> SUPPORTED_FUNCTION_CLASSES = new HashSet<>( Arrays.<Class<? extends FunctionType>> asList(CountAggregateFunction.class, SumAggregateFunction.class, MaxAggregateFunction.class, MinAggregateFunction.class, AverageAggregateFunction.class)); public DefaultQueryRewriter(JdbcDataContext dataContext) { super(dataContext); } @Override protected Query beforeRewrite(Query query) { query = query.clone(); JdbcDataContext dataContext = getDataContext(); if (dataContext != null) { String identifierQuoteString = dataContext.getIdentifierQuoteString(); if (identifierQuoteString != null) { List<SelectItem> selectItems = query.getSelectClause().getItems(); for (SelectItem item : selectItems) { String alias = item.getAlias(); if (needsQuoting(alias, identifierQuoteString)) { item.setAlias(identifierQuoteString + alias + identifierQuoteString); } } List<FromItem> fromItems = query.getFromClause().getItems(); for (FromItem item : fromItems) { String alias = item.getAlias(); if (needsQuoting(alias, identifierQuoteString)) { item.setAlias(identifierQuoteString + alias + identifierQuoteString); } } } } return query; } @Override public String rewriteColumnType(ColumnType columnType, Integer columnSize) { if (columnType == ColumnType.STRING) { // convert STRING to VARCHAR as the default SQL type for strings return rewriteColumnType(ColumnType.VARCHAR, columnSize); } if (columnType == ColumnType.NUMBER) { // convert NUMBER to FLOAT as the default SQL type for numbers return rewriteColumnType(ColumnType.FLOAT, columnSize); } return super.rewriteColumnType(columnType, columnSize); } private boolean needsQuoting(String alias, String identifierQuoteString) { boolean result = false; if (alias != null && identifierQuoteString != null) { if (alias.indexOf(identifierQuoteString) == -1) { for (int i = 0; i < SPECIAL_ALIAS_CHARACTERS.length(); i++) { char specialCharacter = SPECIAL_ALIAS_CHARACTERS.charAt(i); if (alias.indexOf(specialCharacter) != -1) { result = true; break; } } } } if (logger.isDebugEnabled()) { logger.debug("needsQuoting(" + alias + "," + identifierQuoteString + ") = " + result); } return result; } @Override public String rewriteFilterItem(FilterItem item) { Object operand = item.getOperand(); if (operand != null) { if (operand instanceof String) { String str = (String) operand; // escape single quotes if (str.indexOf('\'') != -1) { str = escapeQuotes(str); FilterItem replacementFilterItem = new FilterItem(item.getSelectItem(), item.getOperator(), str); return super.rewriteFilterItem(replacementFilterItem); } } else if (operand instanceof Timestamp) { final String timestampLiteral = rewriteTimestamp((Timestamp) operand); return rewriteFilterItemWithOperandLiteral(item, timestampLiteral); } else if (operand instanceof Iterable || operand.getClass().isArray()) { // operand is a set of values (typically in combination with an // IN or NOT IN operator). Each individual element must be escaped. assert OperatorType.IN.equals(item.getOperator()) || OperatorType.NOT_IN.equals(item.getOperator()); @SuppressWarnings("unchecked") final List<Object> elements = (List<Object>) CollectionUtils.toList(operand); for (ListIterator<Object> it = elements.listIterator(); it.hasNext();) { Object next = it.next(); if (next == null) { logger.warn( "element in IN list is NULL, which isn't supported by SQL. Stripping the element from the list: {}", item); it.remove(); } else if (next instanceof String) { String str = (String) next; if (str.indexOf('\'') != -1) { str = escapeQuotes(str); it.set(str); } } } FilterItem replacementFilterItem = new FilterItem(item.getSelectItem(), item.getOperator(), elements); return super.rewriteFilterItem(replacementFilterItem); } } return super.rewriteFilterItem(item); } /** * Rewrites a (non-compound) {@link FilterItem} when it's operand has * already been rewritten into a SQL literal. * * @param item * @param operandLiteral * @return */ protected String rewriteFilterItemWithOperandLiteral(FilterItem item, String operandLiteral) { final OperatorType operator = item.getOperator(); final SelectItem selectItem = item.getSelectItem(); final StringBuilder sb = new StringBuilder(); sb.append(selectItem.getSameQueryAlias(false)); FilterItem.appendOperator(sb, item.getOperand(), operator); sb.append(operandLiteral); return sb.toString(); } /** * Rewrites a {@link Timestamp} into it's literal representation as known by * this SQL dialect. * * This default implementation returns the JDBC spec's escape syntax for a * timestamp: {ts 'yyyy-mm-dd hh:mm:ss.f . . .'} * * @param ts * @return */ protected String rewriteTimestamp(Timestamp ts) { return "{ts '" + ts.toString() + "'}"; } @Override public boolean isScalarFunctionSupported(ScalarFunction function) { return SUPPORTED_FUNCTION_CLASSES.contains(function.getClass()); } @Override public boolean isAggregateFunctionSupported(AggregateFunction function) { return SUPPORTED_FUNCTION_CLASSES.contains(function.getClass()); } @Override public boolean isFirstRowSupported() { return false; } @Override public boolean isMaxRowsSupported() { return false; } @Override public String escapeQuotes(String item) { return item.replaceAll("\\'", "\\'\\'"); } }