/*
* JBoss, Home of Professional Open Source.
* See the COPYRIGHT.txt file distributed with this work for information
* regarding copyright ownership. Some portions may be licensed
* to Red Hat, Inc. under one or more contributor license agreements.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
* 02110-1301 USA.
*/
package org.teiid.query.sql.visitor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.teiid.query.sql.LanguageObject;
import org.teiid.query.sql.LanguageVisitor;
import org.teiid.query.sql.lang.*;
import org.teiid.query.sql.lang.ObjectTable.ObjectColumn;
import org.teiid.query.sql.lang.XMLTable.XMLColumn;
import org.teiid.query.sql.navigator.PreOrPostOrderNavigator;
import org.teiid.query.sql.proc.AssignmentStatement;
import org.teiid.query.sql.proc.ExceptionExpression;
import org.teiid.query.sql.proc.ReturnStatement;
import org.teiid.query.sql.symbol.*;
/**
* It is important to use a Post Navigator with this class,
* otherwise a replacement containing itself will not work
*/
public class ExpressionMappingVisitor extends LanguageVisitor {
private Map symbolMap;
private boolean clone = true;
private boolean elementSymbolsOnly;
/**
* Constructor for ExpressionMappingVisitor.
* @param symbolMap Map of ElementSymbol to Expression
*/
public ExpressionMappingVisitor(Map symbolMap) {
this.symbolMap = symbolMap;
}
public ExpressionMappingVisitor(Map symbolMap, boolean clone) {
this.symbolMap = symbolMap;
this.clone = clone;
}
protected boolean createAliases() {
return true;
}
public void visit(Select obj) {
List<Expression> symbols = obj.getSymbols();
for (int i = 0; i < symbols.size(); i++) {
Expression symbol = symbols.get(i);
if (symbol instanceof MultipleElementSymbol) {
continue;
}
Expression replacmentSymbol = replaceSymbol(symbol, true);
symbols.set(i, replacmentSymbol);
}
}
public boolean isClone() {
return clone;
}
public void setClone(boolean clone) {
this.clone = clone;
}
@Override
public void visit(DerivedColumn obj) {
Expression original = obj.getExpression();
obj.setExpression(replaceExpression(original));
if (obj.isPropagateName() && obj.getAlias() == null && original instanceof ElementSymbol) {
obj.setAlias(((ElementSymbol)original).getShortName());
}
}
@Override
public void visit(XMLTable obj) {
for (XMLColumn col : obj.getColumns()) {
Expression exp = col.getDefaultExpression();
if (exp != null) {
col.setDefaultExpression(replaceExpression(exp));
}
}
}
@Override
public void visit(ObjectTable obj) {
for (ObjectColumn col : obj.getColumns()) {
Expression exp = col.getDefaultExpression();
if (exp != null) {
col.setDefaultExpression(replaceExpression(exp));
}
}
}
@Override
public void visit(XMLSerialize obj) {
obj.setExpression(replaceExpression(obj.getExpression()));
}
@Override
public void visit(XMLParse obj) {
obj.setExpression(replaceExpression(obj.getExpression()));
}
private Expression replaceSymbol(Expression ses,
boolean alias) {
Expression expr = ses;
String name = Symbol.getShortName(ses);
if (ses instanceof ExpressionSymbol) {
expr = ((ExpressionSymbol)ses).getExpression();
}
Expression replacmentSymbol = replaceExpression(expr);
if (!(replacmentSymbol instanceof Symbol)) {
replacmentSymbol = new ExpressionSymbol(name, replacmentSymbol);
} else if (alias && createAliases() && !Symbol.getShortName(replacmentSymbol).equals(name)) {
replacmentSymbol = new AliasSymbol(name, replacmentSymbol);
}
return replacmentSymbol;
}
/**
* @see org.teiid.query.sql.LanguageVisitor#visit(org.teiid.query.sql.symbol.AliasSymbol)
*/
public void visit(AliasSymbol obj) {
Expression replacement = replaceExpression(obj.getSymbol());
obj.setSymbol(replacement);
}
public void visit(ExpressionSymbol expr) {
expr.setExpression(replaceExpression(expr.getExpression()));
}
/**
* @see org.teiid.query.sql.LanguageVisitor#visit(BetweenCriteria)
*/
public void visit(BetweenCriteria obj) {
obj.setExpression( replaceExpression(obj.getExpression()) );
obj.setLowerExpression( replaceExpression(obj.getLowerExpression()) );
obj.setUpperExpression( replaceExpression(obj.getUpperExpression()) );
}
public void visit(CaseExpression obj) {
obj.setExpression(replaceExpression(obj.getExpression()));
final int whenCount = obj.getWhenCount();
ArrayList whens = new ArrayList(whenCount);
ArrayList thens = new ArrayList(whenCount);
for (int i = 0; i < whenCount; i++) {
whens.add(replaceExpression(obj.getWhenExpression(i)));
thens.add(replaceExpression(obj.getThenExpression(i)));
}
obj.setWhen(whens, thens);
if (obj.getElseExpression() != null) {
obj.setElseExpression(replaceExpression(obj.getElseExpression()));
}
}
/**
* @see org.teiid.query.sql.LanguageVisitor#visit(CompareCriteria)
*/
public void visit(CompareCriteria obj) {
obj.setLeftExpression( replaceExpression(obj.getLeftExpression()) );
obj.setRightExpression( replaceExpression(obj.getRightExpression()) );
}
/**
* @see org.teiid.query.sql.LanguageVisitor#visit(Function)
*/
public void visit(Function obj) {
Expression[] args = obj.getArgs();
if(args != null && args.length > 0) {
for(int i=0; i<args.length; i++) {
args[i] = replaceExpression(args[i]);
}
}
}
/**
* @see org.teiid.query.sql.LanguageVisitor#visit(IsNullCriteria)
*/
public void visit(IsNullCriteria obj) {
obj.setExpression( replaceExpression(obj.getExpression()) );
}
/**
* @see org.teiid.query.sql.LanguageVisitor#visit(MatchCriteria)
*/
public void visit(MatchCriteria obj) {
obj.setLeftExpression( replaceExpression(obj.getLeftExpression()) );
obj.setRightExpression( replaceExpression(obj.getRightExpression()) );
}
public void visit(SearchedCaseExpression obj) {
int whenCount = obj.getWhenCount();
ArrayList<Expression> thens = new ArrayList<Expression>(whenCount);
ArrayList<Criteria> whens = new ArrayList<Criteria>(whenCount);
for (int i = 0; i < whenCount; i++) {
thens.add(replaceExpression(obj.getThenExpression(i)));
Expression ex = replaceExpression(obj.getWhenCriteria(i));
if (!(ex instanceof Criteria)) {
whens.add(new ExpressionCriteria(ex));
} else {
whens.add((Criteria)ex);
}
}
obj.setWhen(whens, thens);
if (obj.getElseExpression() != null) {
obj.setElseExpression(replaceExpression(obj.getElseExpression()));
}
}
/**
* @see org.teiid.query.sql.LanguageVisitor#visit(SetCriteria)
*/
public void visit(SetCriteria obj) {
obj.setExpression( replaceExpression(obj.getExpression()) );
if (obj.isAllConstants()) {
return;
}
Collection newValues = new ArrayList(obj.getValues().size());
Iterator valueIter = obj.getValues().iterator();
while(valueIter.hasNext()) {
newValues.add( replaceExpression( (Expression) valueIter.next() ) );
}
obj.setValues(newValues);
}
public void visit(DependentSetCriteria obj) {
obj.setExpression( replaceExpression(obj.getExpression()) );
}
/**
* @see org.teiid.query.sql.LanguageVisitor#visit(org.teiid.query.sql.lang.SubqueryCompareCriteria)
*/
public void visit(SubqueryCompareCriteria obj) {
obj.setLeftExpression( replaceExpression(obj.getLeftExpression()) );
if (obj.getArrayExpression() != null) {
obj.setArrayExpression(replaceExpression(obj.getArrayExpression()));
}
}
/**
* @see org.teiid.query.sql.LanguageVisitor#visit(org.teiid.query.sql.lang.SubquerySetCriteria)
*/
public void visit(SubquerySetCriteria obj) {
obj.setExpression( replaceExpression(obj.getExpression()) );
}
public Expression replaceExpression(Expression element) {
if (elementSymbolsOnly && !(element instanceof ElementSymbol)) {
return element;
}
Expression mapped = (Expression) this.symbolMap.get(element);
if(mapped != null) {
if (clone) {
return (Expression)mapped.clone();
}
return mapped;
}
return element;
}
public void visit(StoredProcedure obj) {
for (Iterator<SPParameter> paramIter = obj.getInputParameters().iterator(); paramIter.hasNext();) {
SPParameter param = paramIter.next();
Expression expr = param.getExpression();
param.setExpression(replaceExpression(expr));
}
}
public void visit(AggregateSymbol obj) {
visit((Function)obj);
if (obj.getCondition() != null) {
obj.setCondition(replaceExpression(obj.getCondition()));
}
}
/**
* Swap each ElementSymbol in GroupBy (other symbols are ignored).
* @param obj Object to remap
*/
public void visit(GroupBy obj) {
List<Expression> symbols = obj.getSymbols();
for (int i = 0; i < symbols.size(); i++) {
Expression symbol = symbols.get(i);
symbols.set(i, replaceExpression(symbol));
}
}
@Override
public void visit(OrderByItem obj) {
obj.setSymbol(replaceSymbol(obj.getSymbol(), obj.getExpressionPosition() != -1));
}
public void visit(Limit obj) {
if (obj.getOffset() != null) {
obj.setOffset(replaceExpression(obj.getOffset()));
}
obj.setRowLimit(replaceExpression(obj.getRowLimit()));
}
public void visit(DynamicCommand obj) {
obj.setSql(replaceExpression(obj.getSql()));
if (obj.getUsing() != null) {
for (SetClause clause : obj.getUsing().getClauses()) {
visit(clause);
}
}
}
public void visit(SetClause obj) {
obj.setValue(replaceExpression(obj.getValue()));
}
@Override
public void visit(QueryString obj) {
obj.setPath(replaceExpression(obj.getPath()));
}
@Override
public void visit(ExpressionCriteria obj) {
obj.setExpression(replaceExpression(obj.getExpression()));
}
/**
* The object is modified in place, so is not returned.
* @param obj Language object
* @param exprMap Expression map, Expression to Expression
*/
public static void mapExpressions(LanguageObject obj, Map<? extends Expression, ? extends Expression> exprMap) {
mapExpressions(obj, exprMap, false);
}
/**
* The object is modified in place, so is not returned.
* @param obj Language object
* @param exprMap Expression map, Expression to Expression
*/
public static void mapExpressions(LanguageObject obj, Map<? extends Expression, ? extends Expression> exprMap, boolean deep) {
if(obj == null || exprMap == null || exprMap.isEmpty()) {
return;
}
final ExpressionMappingVisitor visitor = new ExpressionMappingVisitor(exprMap);
visitor.elementSymbolsOnly = true;
boolean preOrder = true;
boolean useReverseMapping = true;
for (Map.Entry<? extends Expression, ? extends Expression> entry : exprMap.entrySet()) {
if (!(entry.getKey() instanceof ElementSymbol)) {
visitor.elementSymbolsOnly = false;
break;
}
}
if (!visitor.elementSymbolsOnly) {
for (Map.Entry<? extends Expression, ? extends Expression> entry : exprMap.entrySet()) {
if (!(entry.getValue() instanceof ElementSymbol)) {
useReverseMapping = !Collections.disjoint(GroupsUsedByElementsVisitor.getGroups(exprMap.keySet()),
GroupsUsedByElementsVisitor.getGroups(exprMap.values()));
break;
}
}
} else {
preOrder = false;
useReverseMapping = false;
}
if (useReverseMapping) {
final Set<Expression> reverseSet = new HashSet<Expression>(exprMap.values());
PreOrPostOrderNavigator pon = new PreOrPostOrderNavigator(visitor, PreOrPostOrderNavigator.PRE_ORDER, deep) {
@Override
protected void visitNode(LanguageObject obj) {
if (!(obj instanceof Expression) || !reverseSet.contains(obj)) {
super.visitNode(obj);
}
}
};
obj.acceptVisitor(pon);
} else {
PreOrPostOrderNavigator.doVisit(obj, visitor, preOrder, deep);
}
}
protected void setVariableValues(Map variableValues) {
this.symbolMap = variableValues;
}
protected Map getVariableValues() {
return symbolMap;
}
/**
* @see org.teiid.query.sql.LanguageVisitor#visit(org.teiid.query.sql.proc.AssignmentStatement)
* @since 5.0
*/
public void visit(AssignmentStatement obj) {
obj.setExpression(replaceExpression(obj.getExpression()));
}
/**
* @see org.teiid.query.sql.LanguageVisitor#visit(org.teiid.query.sql.lang.Insert)
* @since 5.0
*/
public void visit(Insert obj) {
for (int i = 0; i < obj.getValues().size(); i++) {
obj.getValues().set(i, replaceExpression((Expression)obj.getValues().get(i)));
}
}
@Override
public void visit(XMLElement obj) {
for (int i = 0; i < obj.getContent().size(); i++) {
obj.getContent().set(i, replaceExpression(obj.getContent().get(i)));
}
}
@Override
public void visit(WindowSpecification windowSpecification) {
if (windowSpecification.getPartition() == null) {
return;
}
List<Expression> partition = windowSpecification.getPartition();
for (int i = 0; i < partition.size(); i++) {
partition.set(i, replaceExpression(partition.get(i)));
}
}
@Override
public void visit(Array array) {
List<Expression> exprs = array.getExpressions();
for (int i = 0; i < exprs.size(); i++) {
exprs.set(i, replaceExpression(exprs.get(i)));
}
}
@Override
public void visit(ExceptionExpression exceptionExpression) {
if (exceptionExpression.getMessage() != null) {
exceptionExpression.setMessage(replaceExpression(exceptionExpression.getMessage()));
}
if (exceptionExpression.getSqlState() != null) {
exceptionExpression.setSqlState(replaceExpression(exceptionExpression.getSqlState()));
}
if (exceptionExpression.getErrorCode() != null) {
exceptionExpression.setErrorCode(replaceExpression(exceptionExpression.getErrorCode()));
}
if (exceptionExpression.getParent() != null) {
exceptionExpression.setParent(replaceExpression(exceptionExpression.getParent()));
}
}
@Override
public void visit(ReturnStatement obj) {
if (obj.getExpression() != null) {
obj.setExpression(replaceExpression(obj.getExpression()));
}
}
}