/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.tree.BindExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class LambdaCaptureDesugaringRewriter { private final Map<Symbol, Type> symbolTypes; private final SymbolAllocator symbolAllocator; public LambdaCaptureDesugaringRewriter(Map<Symbol, Type> symbolTypes, SymbolAllocator symbolAllocator) { this.symbolTypes = requireNonNull(symbolTypes, "symbolTypes is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); } public Expression rewrite(Expression expression) { return ExpressionTreeRewriter.rewriteWith(new Visitor(), expression, new Context()); } private static Expression replaceSymbols(Expression expression, ImmutableMap<Symbol, Symbol> symbolMapping) { return ExpressionTreeRewriter.rewriteWith( new ExpressionRewriter<Void>() { @Override public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { Symbol mapTo = symbolMapping.get(new Symbol(node.getName())); if (mapTo == null) { return node; } return mapTo.toSymbolReference(); } }, expression); } public class Visitor extends ExpressionRewriter<Context> { @Override public Expression rewriteLambdaExpression(LambdaExpression node, Context context, ExpressionTreeRewriter<Context> treeRewriter) { // Use linked hash set to guarantee deterministic iteration order LinkedHashSet<Symbol> referencedSymbols = new LinkedHashSet<>(); Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), context.withReferencedSymbols(referencedSymbols)); List<Symbol> lambdaArguments = node.getArguments().stream() .map(LambdaArgumentDeclaration::getName) .map(Symbol::new) .collect(toImmutableList()); // referenced symbols - lambda arguments = capture symbols // referencedSymbols no longer contains what its name suggests after this line referencedSymbols.removeAll(lambdaArguments); Set<Symbol> captureSymbols = referencedSymbols; // x -> f(x, captureSymbol) will be rewritten into // "$internal$bind"(captureSymbol, (extraSymbol, x) -> f(x, extraSymbol)) ImmutableMap.Builder<Symbol, Symbol> captureSymbolToExtraSymbol = ImmutableMap.builder(); ImmutableList.Builder<LambdaArgumentDeclaration> newLambdaArguments = ImmutableList.builder(); for (Symbol captureSymbol : captureSymbols) { Symbol extraSymbol = symbolAllocator.newSymbol(captureSymbol.getName(), symbolTypes.get(captureSymbol)); captureSymbolToExtraSymbol.put(captureSymbol, extraSymbol); newLambdaArguments.add(new LambdaArgumentDeclaration(extraSymbol.getName())); } newLambdaArguments.addAll(node.getArguments()); Expression rewrittenExpression = new LambdaExpression(newLambdaArguments.build(), replaceSymbols(rewrittenBody, captureSymbolToExtraSymbol.build())); for (Symbol captureSymbol : captureSymbols) { rewrittenExpression = new BindExpression(new SymbolReference(captureSymbol.getName()), rewrittenExpression); } context.getReferencedSymbols().addAll(captureSymbols); return rewrittenExpression; } @Override public Expression rewriteSymbolReference(SymbolReference node, Context context, ExpressionTreeRewriter<Context> treeRewriter) { context.getReferencedSymbols().add(new Symbol(node.getName())); return null; } } private static class Context { // Use linked hash set to guarantee deterministic iteration order LinkedHashSet<Symbol> referencedSymbols; public Context() { this(new LinkedHashSet<>()); } private Context(LinkedHashSet<Symbol> referencedSymbols) { this.referencedSymbols = referencedSymbols; } public LinkedHashSet<Symbol> getReferencedSymbols() { return referencedSymbols; } public Context withReferencedSymbols(LinkedHashSet<Symbol> symbols) { return new Context(symbols); } } }