/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.ResolvedField; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.ImmutableList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; /** * Keeps track of fields and expressions and their mapping to symbols in the current plan */ class TranslationMap { // all expressions are rewritten in terms of fields declared by this relation plan private final RelationPlan rewriteBase; private final Analysis analysis; private final IdentityLinkedHashMap<LambdaArgumentDeclaration, Symbol> lambdaDeclarationToSymbolMap; // current mappings of underlying field -> symbol for translating direct field references private final Symbol[] fieldSymbols; // current mappings of sub-expressions -> symbol private final Map<Expression, Symbol> expressionToSymbols = new HashMap<>(); private final Map<Expression, Expression> expressionToExpressions = new HashMap<>(); public TranslationMap(RelationPlan rewriteBase, Analysis analysis, IdentityLinkedHashMap<LambdaArgumentDeclaration, Symbol> lambdaDeclarationToSymbolMap) { this.rewriteBase = requireNonNull(rewriteBase, "rewriteBase is null"); this.analysis = requireNonNull(analysis, "analysis is null"); this.lambdaDeclarationToSymbolMap = requireNonNull(lambdaDeclarationToSymbolMap, "lambdaDeclarationToSymbolMap is null"); fieldSymbols = new Symbol[rewriteBase.getFieldMappings().size()]; } public RelationPlan getRelationPlan() { return rewriteBase; } public Analysis getAnalysis() { return analysis; } public IdentityLinkedHashMap<LambdaArgumentDeclaration, Symbol> getLambdaDeclarationToSymbolMap() { return lambdaDeclarationToSymbolMap; } public void setFieldMappings(List<Symbol> symbols) { checkArgument(symbols.size() == fieldSymbols.length, "size of symbols list (%s) doesn't match number of expected fields (%s)", symbols.size(), fieldSymbols.length); for (int i = 0; i < symbols.size(); i++) { this.fieldSymbols[i] = symbols.get(i); } } public void copyMappingsFrom(TranslationMap other) { checkArgument(other.fieldSymbols.length == fieldSymbols.length, "number of fields in other (%s) doesn't match number of expected fields (%s)", other.fieldSymbols.length, fieldSymbols.length); expressionToSymbols.putAll(other.expressionToSymbols); expressionToExpressions.putAll(other.expressionToExpressions); System.arraycopy(other.fieldSymbols, 0, fieldSymbols, 0, other.fieldSymbols.length); } public void putExpressionMappingsFrom(TranslationMap other) { expressionToSymbols.putAll(other.expressionToSymbols); expressionToExpressions.putAll(other.expressionToExpressions); } public Expression rewrite(Expression expression) { // first, translate names from sql-land references to plan symbols Expression mapped = translateNamesToSymbols(expression); // then rewrite subexpressions in terms of the current mappings return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() { @Override public Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { if (expressionToSymbols.containsKey(node)) { return expressionToSymbols.get(node).toSymbolReference(); } else if (expressionToExpressions.containsKey(node)) { Expression mapping = getMapping(node); mapping = translateNamesToSymbols(mapping); return treeRewriter.defaultRewrite(mapping, context); } else { return treeRewriter.defaultRewrite(node, context); } } }, mapped); } private Expression getMapping(Expression expression) { if (!expressionToExpressions.containsKey(expression)) { return expression; } Expression mapped = expressionToExpressions.get(expression); Expression translated = translateNamesToSymbols(mapped); if (!translated.equals(expression) && expressionToExpressions.containsKey(translated)) { mapped = getMapping(translated); } return mapped; } public void put(Expression expression, Symbol symbol) { if (expression instanceof FieldReference) { int fieldIndex = ((FieldReference) expression).getFieldIndex(); fieldSymbols[fieldIndex] = symbol; expressionToSymbols.put(rewriteBase.getSymbol(fieldIndex).toSymbolReference(), symbol); return; } Expression translated = translateNamesToSymbols(expression); expressionToSymbols.put(translated, symbol); // also update the field mappings if this expression is a field reference rewriteBase.getScope().tryResolveField(expression) .filter(ResolvedField::isLocal) .ifPresent(field -> fieldSymbols[field.getHierarchyFieldIndex()] = symbol); } public boolean containsSymbol(Expression expression) { if (expression instanceof FieldReference) { int field = ((FieldReference) expression).getFieldIndex(); return fieldSymbols[field] != null; } Expression translated = translateNamesToSymbols(expression); return expressionToSymbols.containsKey(translated); } public Symbol get(Expression expression) { if (expression instanceof FieldReference) { int field = ((FieldReference) expression).getFieldIndex(); checkArgument(fieldSymbols[field] != null, "No mapping for field: %s", field); return fieldSymbols[field]; } Expression translated = translateNamesToSymbols(expression); if (!expressionToSymbols.containsKey(translated)) { checkArgument(expressionToExpressions.containsKey(translated), "No mapping for expression: %s", expression); return get(expressionToExpressions.get(translated)); } return expressionToSymbols.get(translated); } public void put(Expression expression, Expression rewritten) { expressionToExpressions.put(translateNamesToSymbols(expression), rewritten); } public void addIntermediateMapping(Expression expression, Expression rewritten) { if (rewritten.equals(expression)) { return; } Expression translated = translateNamesToSymbols(expression); if (expressionToExpressions.containsKey(translated)) { Expression previousMapping = expressionToExpressions.get(translated); if (!previousMapping.equals(rewritten)) { put(expression, rewritten); addIntermediateMapping(rewritten, previousMapping); } } else { put(expression, rewritten); } } private Expression translateNamesToSymbols(Expression expression) { return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() { @Override public Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { Expression rewrittenExpression = treeRewriter.defaultRewrite(node, context); return coerceIfNecessary(node, rewrittenExpression); } @Override public Expression rewriteFieldReference(FieldReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { Symbol symbol = rewriteBase.getSymbol(node.getFieldIndex()); checkState(symbol != null, "No symbol mapping for node '%s' (%s)", node, node.getFieldIndex()); return symbol.toSymbolReference(); } @Override public Expression rewriteIdentifier(Identifier node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { LambdaArgumentDeclaration referencedLambdaArgumentDeclaration = analysis.getLambdaArgumentReference(node); if (referencedLambdaArgumentDeclaration != null) { Symbol symbol = lambdaDeclarationToSymbolMap.get(referencedLambdaArgumentDeclaration); return coerceIfNecessary(node, symbol.toSymbolReference()); } else { return rewriteExpressionWithResolvedName(node); } } private Expression rewriteExpressionWithResolvedName(Expression node) { return getSymbol(rewriteBase, node) .map(symbol -> coerceIfNecessary(node, symbol.toSymbolReference())) .orElse(coerceIfNecessary(node, node)); } @Override public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { Optional<ResolvedField> resolvedField = rewriteBase.getScope().tryResolveField(node); if (resolvedField.isPresent()) { if (resolvedField.get().isLocal()) { return getSymbol(rewriteBase, node) .map(symbol -> coerceIfNecessary(node, symbol.toSymbolReference())) .orElseThrow(() -> new IllegalStateException("No symbol mapping for node " + node)); } // do not rewrite outer references, it will be handled in outer scope planner return node; } return rewriteExpression(node, context, treeRewriter); } @Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { checkState(analysis.getCoercion(node) == null, "cannot coerce a lambda expression"); ImmutableList.Builder<LambdaArgumentDeclaration> newArguments = ImmutableList.builder(); for (LambdaArgumentDeclaration argument : node.getArguments()) { newArguments.add(new LambdaArgumentDeclaration(lambdaDeclarationToSymbolMap.get(argument).getName())); } Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), null); return new LambdaExpression(newArguments.build(), rewrittenBody); } private Expression coerceIfNecessary(Expression original, Expression rewritten) { Type coercion = analysis.getCoercion(original); if (coercion != null) { rewritten = new Cast( rewritten, coercion.getTypeSignature().toString(), false, analysis.isTypeOnlyCoercion(original)); } return rewritten; } }, expression, null); } Optional<Symbol> getSymbol(RelationPlan plan, Expression expression) { return plan.getScope() .tryResolveField(expression) .filter(ResolvedField::isLocal) .map(field -> plan.getFieldMappings().get(field.getHierarchyFieldIndex())); } }