/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.DistinctLimitNode; import com.facebook.presto.sql.planner.plan.ExceptNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.IntersectNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.SetOperationNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.SortNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; import com.facebook.presto.sql.planner.plan.TopNNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnionNode; import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.ListMultimap; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import static com.google.common.base.Predicates.in; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Sets.intersection; import static java.util.Objects.requireNonNull; /** * Removes all computation that does is not referenced transitively from the root of the plan * <p> * E.g., * <p> * {@code Output[$0] -> Project[$0 := $1 + $2, $3 = $4 / $5] -> ...} * <p> * gets rewritten as * <p> * {@code Output[$0] -> Project[$0 := $1 + $2] -> ...} */ public class PruneUnreferencedOutputs implements PlanOptimizer { @Override public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { requireNonNull(plan, "plan is null"); requireNonNull(session, "session is null"); requireNonNull(types, "types is null"); requireNonNull(symbolAllocator, "symbolAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); return SimplePlanRewriter.rewriteWith(new Rewriter(), plan, ImmutableSet.of()); } private static class Rewriter extends SimplePlanRewriter<Set<Symbol>> { @Override public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext<Set<Symbol>> context) { return context.defaultRewrite(node, ImmutableSet.copyOf(node.getSource().getOutputSymbols())); } @Override public PlanNode visitExchange(ExchangeNode node, RewriteContext<Set<Symbol>> context) { Set<Symbol> expectedOutputSymbols = Sets.newHashSet(context.get()); node.getPartitioningScheme().getHashColumn().ifPresent(expectedOutputSymbols::add); node.getPartitioningScheme().getPartitioning().getColumns().stream() .forEach(expectedOutputSymbols::add); List<List<Symbol>> inputsBySource = new ArrayList<>(node.getInputs().size()); for (int i = 0; i < node.getInputs().size(); i++) { inputsBySource.add(new ArrayList<>()); } List<Symbol> newOutputSymbols = new ArrayList<>(node.getOutputSymbols().size()); for (int i = 0; i < node.getOutputSymbols().size(); i++) { Symbol outputSymbol = node.getOutputSymbols().get(i); if (expectedOutputSymbols.contains(outputSymbol)) { newOutputSymbols.add(outputSymbol); for (int source = 0; source < node.getInputs().size(); source++) { inputsBySource.get(source).add(node.getInputs().get(source).get(i)); } } } // newOutputSymbols contains all partition and hash symbols so simply swap the output layout PartitioningScheme partitioningScheme = new PartitioningScheme( node.getPartitioningScheme().getPartitioning(), newOutputSymbols, node.getPartitioningScheme().getHashColumn(), node.getPartitioningScheme().isReplicateNulls(), node.getPartitioningScheme().getBucketToPartition()); ImmutableList.Builder<PlanNode> rewrittenSources = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(inputsBySource.get(i)); rewrittenSources.add(context.rewrite( node.getSources().get(i), expectedInputs.build())); } return new ExchangeNode( node.getId(), node.getType(), node.getScope(), partitioningScheme, rewrittenSources.build(), inputsBySource); } @Override public PlanNode visitJoin(JoinNode node, RewriteContext<Set<Symbol>> context) { Set<Symbol> expectedFilterInputs = new HashSet<>(); if (node.getFilter().isPresent()) { expectedFilterInputs = ImmutableSet.<Symbol>builder() .addAll(DependencyExtractor.extractUnique(node.getFilter().get())) .addAll(context.get()) .build(); } ImmutableSet.Builder<Symbol> leftInputsBuilder = ImmutableSet.builder(); leftInputsBuilder.addAll(context.get()).addAll(Iterables.transform(node.getCriteria(), JoinNode.EquiJoinClause::getLeft)); if (node.getLeftHashSymbol().isPresent()) { leftInputsBuilder.add(node.getLeftHashSymbol().get()); } leftInputsBuilder.addAll(expectedFilterInputs); Set<Symbol> leftInputs = leftInputsBuilder.build(); ImmutableSet.Builder<Symbol> rightInputsBuilder = ImmutableSet.builder(); rightInputsBuilder.addAll(context.get()).addAll(Iterables.transform(node.getCriteria(), JoinNode.EquiJoinClause::getRight)); if (node.getRightHashSymbol().isPresent()) { rightInputsBuilder.add(node.getRightHashSymbol().get()); } rightInputsBuilder.addAll(expectedFilterInputs); Set<Symbol> rightInputs = rightInputsBuilder.build(); PlanNode left = context.rewrite(node.getLeft(), leftInputs); PlanNode right = context.rewrite(node.getRight(), rightInputs); List<Symbol> outputSymbols; if (node.isCrossJoin()) { // do not prune nested joins output since it is not supported // TODO: remove this "if" branch when output symbols selection is supported by nested loop join outputSymbols = ImmutableList.<Symbol>builder() .addAll(left.getOutputSymbols()) .addAll(right.getOutputSymbols()) .build(); } else { Set<Symbol> seenSymbol = new HashSet<>(); outputSymbols = node.getOutputSymbols().stream() .filter(context.get()::contains) .filter(seenSymbol::add) .collect(toImmutableList()); } return new JoinNode(node.getId(), node.getType(), left, right, node.getCriteria(), outputSymbols, node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); } @Override public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> sourceInputsBuilder = ImmutableSet.builder(); sourceInputsBuilder.addAll(context.get()).add(node.getSourceJoinSymbol()); if (node.getSourceHashSymbol().isPresent()) { sourceInputsBuilder.add(node.getSourceHashSymbol().get()); } Set<Symbol> sourceInputs = sourceInputsBuilder.build(); ImmutableSet.Builder<Symbol> filteringSourceInputBuilder = ImmutableSet.builder(); filteringSourceInputBuilder.add(node.getFilteringSourceJoinSymbol()); if (node.getFilteringSourceHashSymbol().isPresent()) { filteringSourceInputBuilder.add(node.getFilteringSourceHashSymbol().get()); } Set<Symbol> filteringSourceInputs = filteringSourceInputBuilder.build(); PlanNode source = context.rewrite(node.getSource(), sourceInputs); PlanNode filteringSource = context.rewrite(node.getFilteringSource(), filteringSourceInputs); return new SemiJoinNode(node.getId(), source, filteringSource, node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol(), node.getSemiJoinOutput(), node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol(), node.getDistributionType()); } @Override public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> probeInputsBuilder = ImmutableSet.builder(); probeInputsBuilder.addAll(context.get()) .addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getProbe)); if (node.getProbeHashSymbol().isPresent()) { probeInputsBuilder.add(node.getProbeHashSymbol().get()); } Set<Symbol> probeInputs = probeInputsBuilder.build(); ImmutableSet.Builder<Symbol> indexInputBuilder = ImmutableSet.builder(); indexInputBuilder.addAll(context.get()) .addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getIndex)); if (node.getIndexHashSymbol().isPresent()) { indexInputBuilder.add(node.getIndexHashSymbol().get()); } Set<Symbol> indexInputs = indexInputBuilder.build(); PlanNode probeSource = context.rewrite(node.getProbeSource(), probeInputs); PlanNode indexSource = context.rewrite(node.getIndexSource(), indexInputs); return new IndexJoinNode(node.getId(), node.getType(), probeSource, indexSource, node.getCriteria(), node.getProbeHashSymbol(), node.getIndexHashSymbol()); } @Override public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext<Set<Symbol>> context) { List<Symbol> newOutputSymbols = node.getOutputSymbols().stream() .filter(context.get()::contains) .collect(toImmutableList()); Set<Symbol> newLookupSymbols = node.getLookupSymbols().stream() .filter(context.get()::contains) .collect(toImmutableSet()); Set<Symbol> requiredAssignmentSymbols = context.get(); if (!node.getEffectiveTupleDomain().isNone()) { Set<Symbol> requiredSymbols = Maps.filterValues(node.getAssignments(), in(node.getEffectiveTupleDomain().getDomains().get().keySet())).keySet(); requiredAssignmentSymbols = Sets.union(context.get(), requiredSymbols); } Map<Symbol, ColumnHandle> newAssignments = Maps.filterKeys(node.getAssignments(), in(requiredAssignmentSymbols)); return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), node.getLayout(), newLookupSymbols, newOutputSymbols, newAssignments, node.getEffectiveTupleDomain()); } @Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(node.getGroupingKeys()); if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, FunctionCall> functionCalls = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); if (context.get().contains(symbol)) { FunctionCall call = entry.getValue(); expectedInputs.addAll(DependencyExtractor.extractUnique(call)); if (node.getMasks().containsKey(symbol)) { expectedInputs.add(node.getMasks().get(symbol)); masks.put(symbol, node.getMasks().get(symbol)); } functionCalls.put(symbol, call); functions.put(symbol, node.getFunctions().get(symbol)); } } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new AggregationNode(node.getId(), source, functionCalls.build(), functions.build(), masks.build(), node.getGroupingSets(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()); } @Override public PlanNode visitWindow(WindowNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(context.get()) .addAll(node.getPartitionBy()) .addAll(node.getOrderBy()); for (WindowNode.Frame frame : node.getFrames()) { if (frame.getStartValue().isPresent()) { expectedInputs.add(frame.getStartValue().get()); } if (frame.getEndValue().isPresent()) { expectedInputs.add(frame.getEndValue().get()); } } if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } ImmutableMap.Builder<Symbol, WindowNode.Function> functionsBuilder = ImmutableMap.builder(); for (Map.Entry<Symbol, WindowNode.Function> entry : node.getWindowFunctions().entrySet()) { Symbol symbol = entry.getKey(); WindowNode.Function function = entry.getValue(); if (context.get().contains(symbol)) { FunctionCall call = function.getFunctionCall(); expectedInputs.addAll(DependencyExtractor.extractUnique(call)); functionsBuilder.put(symbol, entry.getValue()); } } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); Map<Symbol, WindowNode.Function> functions = functionsBuilder.build(); if (functions.size() == 0) { return source; } return new WindowNode( node.getId(), source, node.getSpecification(), functions, node.getHashSymbol(), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix()); } @Override public PlanNode visitTableScan(TableScanNode node, RewriteContext<Set<Symbol>> context) { Set<Symbol> requiredTableScanOutputs = context.get().stream() .filter(node.getOutputSymbols()::contains) .collect(toImmutableSet()); List<Symbol> newOutputSymbols = node.getOutputSymbols().stream() .filter(requiredTableScanOutputs::contains) .collect(toImmutableList()); Map<Symbol, ColumnHandle> newAssignments = Maps.filterKeys(node.getAssignments(), in(requiredTableScanOutputs)); return new TableScanNode( node.getId(), node.getTable(), newOutputSymbols, newAssignments, node.getLayout(), node.getCurrentConstraint(), node.getOriginalConstraint()); } @Override public PlanNode visitFilter(FilterNode node, RewriteContext<Set<Symbol>> context) { Set<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(DependencyExtractor.extractUnique(node.getPredicate())) .addAll(context.get()) .build(); PlanNode source = context.rewrite(node.getSource(), expectedInputs); return new FilterNode(node.getId(), source, node.getPredicate()); } @Override public PlanNode visitGroupId(GroupIdNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.builder(); Map<Symbol, Symbol> newArgumentMappings = node.getArgumentMappings().entrySet().stream() .filter(entry -> context.get().contains(entry.getKey())) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); expectedInputs.addAll(newArgumentMappings.values()); ImmutableList.Builder<List<Symbol>> newGroupingSets = ImmutableList.builder(); Map<Symbol, Symbol> newGroupingMapping = new HashMap<>(); for (List<Symbol> groupingSet : node.getGroupingSets()) { ImmutableList.Builder<Symbol> newGroupingSet = ImmutableList.builder(); for (Symbol output : groupingSet) { if (context.get().contains(output)) { newGroupingSet.add(output); newGroupingMapping.putIfAbsent(output, node.getGroupingSetMappings().get(output)); expectedInputs.add(node.getGroupingSetMappings().get(output)); } } newGroupingSets.add(newGroupingSet.build()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new GroupIdNode(node.getId(), source, newGroupingSets.build(), newGroupingMapping, newArgumentMappings, node.getGroupIdSymbol()); } @Override public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext<Set<Symbol>> context) { if (!context.get().contains(node.getMarkerSymbol())) { return context.rewrite(node.getSource(), context.get()); } ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(node.getDistinctSymbols()) .addAll(context.get().stream() .filter(symbol -> !symbol.equals(node.getMarkerSymbol())) .collect(toImmutableList())); if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new MarkDistinctNode(node.getId(), source, node.getMarkerSymbol(), node.getDistinctSymbols(), node.getHashSymbol()); } @Override public PlanNode visitUnnest(UnnestNode node, RewriteContext<Set<Symbol>> context) { List<Symbol> replicateSymbols = node.getReplicateSymbols().stream() .filter(context.get()::contains) .collect(toImmutableList()); Optional<Symbol> ordinalitySymbol = node.getOrdinalitySymbol(); if (ordinalitySymbol.isPresent() && !context.get().contains(ordinalitySymbol.get())) { ordinalitySymbol = Optional.empty(); } Map<Symbol, List<Symbol>> unnestSymbols = node.getUnnestSymbols(); ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(replicateSymbols) .addAll(unnestSymbols.keySet()); PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new UnnestNode(node.getId(), source, replicateSymbols, unnestSymbols, ordinalitySymbol); } @Override public PlanNode visitProject(ProjectNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.builder(); Assignments.Builder builder = Assignments.builder(); for (int i = 0; i < node.getOutputSymbols().size(); i++) { Symbol output = node.getOutputSymbols().get(i); Expression expression = node.getAssignments().get(output); if (context.get().contains(output)) { expectedInputs.addAll(DependencyExtractor.extractUnique(expression)); builder.put(output, expression); } } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new ProjectNode(node.getId(), source, builder.build()); } @Override public PlanNode visitOutput(OutputNode node, RewriteContext<Set<Symbol>> context) { Set<Symbol> expectedInputs = ImmutableSet.copyOf(node.getOutputSymbols()); PlanNode source = context.rewrite(node.getSource(), expectedInputs); return new OutputNode(node.getId(), source, node.getColumnNames(), node.getOutputSymbols()); } @Override public PlanNode visitLimit(LimitNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(context.get()); PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new LimitNode(node.getId(), source, node.getCount(), node.isPartial()); } @Override public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext<Set<Symbol>> context) { Set<Symbol> expectedInputs; if (node.getHashSymbol().isPresent()) { expectedInputs = ImmutableSet.copyOf(concat(node.getOutputSymbols(), ImmutableList.of(node.getHashSymbol().get()))); } else { expectedInputs = ImmutableSet.copyOf(node.getOutputSymbols()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs); return new DistinctLimitNode(node.getId(), source, node.getLimit(), node.isPartial(), node.getHashSymbol()); } @Override public PlanNode visitTopN(TopNNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(context.get()) .addAll(node.getOrderBy()); PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new TopNNode(node.getId(), source, node.getCount(), node.getOrderBy(), node.getOrderings(), node.isPartial()); } @Override public PlanNode visitRowNumber(RowNumberNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> inputsBuilder = ImmutableSet.builder(); ImmutableSet.Builder<Symbol> expectedInputs = inputsBuilder .addAll(context.get()) .addAll(node.getPartitionBy()); if (node.getHashSymbol().isPresent()) { inputsBuilder.add(node.getHashSymbol().get()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new RowNumberNode(node.getId(), source, node.getPartitionBy(), node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), node.getHashSymbol()); } @Override public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(context.get()) .addAll(node.getPartitionBy()) .addAll(node.getOrderBy()); if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new TopNRowNumberNode(node.getId(), source, node.getSpecification(), node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), node.isPartial(), node.getHashSymbol()); } @Override public PlanNode visitSort(SortNode node, RewriteContext<Set<Symbol>> context) { Set<Symbol> expectedInputs = ImmutableSet.copyOf(concat(context.get(), node.getOrderBy())); PlanNode source = context.rewrite(node.getSource(), expectedInputs); return new SortNode(node.getId(), source, node.getOrderBy(), node.getOrderings()); } @Override public PlanNode visitTableWriter(TableWriterNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(node.getColumns()); if (node.getPartitioningScheme().isPresent()) { PartitioningScheme partitioningScheme = node.getPartitioningScheme().get(); partitioningScheme.getPartitioning().getColumns().stream() .forEach(expectedInputs::add); partitioningScheme.getHashColumn().ifPresent(expectedInputs::add); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new TableWriterNode( node.getId(), source, node.getTarget(), node.getColumns(), node.getColumnNames(), node.getOutputSymbols(), node.getPartitioningScheme()); } @Override public PlanNode visitTableFinish(TableFinishNode node, RewriteContext<Set<Symbol>> context) { // Maintain the existing inputs needed for TableCommitNode PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputSymbols())); return new TableFinishNode(node.getId(), source, node.getTarget(), node.getOutputSymbols()); } @Override public PlanNode visitDelete(DeleteNode node, RewriteContext<Set<Symbol>> context) { PlanNode source = context.rewrite(node.getSource(), ImmutableSet.of(node.getRowId())); return new DeleteNode(node.getId(), source, node.getTarget(), node.getRowId(), node.getOutputSymbols()); } @Override public PlanNode visitUnion(UnionNode node, RewriteContext<Set<Symbol>> context) { ListMultimap<Symbol, Symbol> rewrittenSymbolMapping = rewriteSetOperationSymbolMapping(node, context); ImmutableList<PlanNode> rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenSymbolMapping); return new UnionNode(node.getId(), rewrittenSubPlans, rewrittenSymbolMapping, ImmutableList.copyOf(rewrittenSymbolMapping.keySet())); } @Override public PlanNode visitIntersect(IntersectNode node, RewriteContext<Set<Symbol>> context) { ListMultimap<Symbol, Symbol> rewrittenSymbolMapping = rewriteSetOperationSymbolMapping(node, context); ImmutableList<PlanNode> rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenSymbolMapping); return new IntersectNode(node.getId(), rewrittenSubPlans, rewrittenSymbolMapping, ImmutableList.copyOf(rewrittenSymbolMapping.keySet())); } @Override public PlanNode visitExcept(ExceptNode node, RewriteContext<Set<Symbol>> context) { ListMultimap<Symbol, Symbol> rewrittenSymbolMapping = rewriteSetOperationSymbolMapping(node, context); ImmutableList<PlanNode> rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenSymbolMapping); return new ExceptNode(node.getId(), rewrittenSubPlans, rewrittenSymbolMapping, ImmutableList.copyOf(rewrittenSymbolMapping.keySet())); } private ListMultimap<Symbol, Symbol> rewriteSetOperationSymbolMapping(SetOperationNode node, RewriteContext<Set<Symbol>> context) { // Find out which output symbols we need to keep ImmutableListMultimap.Builder<Symbol, Symbol> rewrittenSymbolMappingBuilder = ImmutableListMultimap.builder(); for (Symbol symbol : node.getOutputSymbols()) { if (context.get().contains(symbol)) { rewrittenSymbolMappingBuilder.putAll(symbol, node.getSymbolMapping().get(symbol)); } } return rewrittenSymbolMappingBuilder.build(); } private ImmutableList<PlanNode> rewriteSetOperationSubPlans(SetOperationNode node, RewriteContext<Set<Symbol>> context, ListMultimap<Symbol, Symbol> rewrittenSymbolMapping) { // Find the corresponding input symbol to the remaining output symbols and prune the subplans ImmutableList.Builder<PlanNode> rewrittenSubPlans = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { ImmutableSet.Builder<Symbol> expectedInputSymbols = ImmutableSet.builder(); for (Collection<Symbol> symbols : rewrittenSymbolMapping.asMap().values()) { expectedInputSymbols.add(Iterables.get(symbols, i)); } rewrittenSubPlans.add(context.rewrite(node.getSources().get(i), expectedInputSymbols.build())); } return rewrittenSubPlans.build(); } @Override public PlanNode visitValues(ValuesNode node, RewriteContext<Set<Symbol>> context) { ImmutableList.Builder<Symbol> rewrittenOutputSymbolsBuilder = ImmutableList.builder(); ImmutableList.Builder<ImmutableList.Builder<Expression>> rowBuildersBuilder = ImmutableList.builder(); // Initialize builder for each row for (int i = 0; i < node.getRows().size(); i++) { rowBuildersBuilder.add(ImmutableList.builder()); } ImmutableList<ImmutableList.Builder<Expression>> rowBuilders = rowBuildersBuilder.build(); for (int i = 0; i < node.getOutputSymbols().size(); i++) { Symbol outputSymbol = node.getOutputSymbols().get(i); // If output symbol is used if (context.get().contains(outputSymbol)) { rewrittenOutputSymbolsBuilder.add(outputSymbol); // Add the value of the output symbol for each row for (int j = 0; j < node.getRows().size(); j++) { rowBuilders.get(j).add(node.getRows().get(j).get(i)); } } } List<List<Expression>> rewrittenRows = rowBuilders.stream() .map((rowBuilder) -> rowBuilder.build()) .collect(toImmutableList()); return new ValuesNode(node.getId(), rewrittenOutputSymbolsBuilder.build(), rewrittenRows); } @Override public PlanNode visitApply(ApplyNode node, RewriteContext<Set<Symbol>> context) { // remove unused apply nodes if (intersection(node.getSubqueryAssignments().getSymbols(), context.get()).isEmpty()) { return context.rewrite(node.getInput(), context.get()); } // extract symbols required subquery plan ImmutableSet.Builder<Symbol> subqueryAssignmentsSymbolsBuilder = ImmutableSet.builder(); Assignments.Builder subqueryAssignments = Assignments.builder(); for (Map.Entry<Symbol, Expression> entry : node.getSubqueryAssignments().getMap().entrySet()) { Symbol output = entry.getKey(); Expression expression = entry.getValue(); if (context.get().contains(output)) { subqueryAssignmentsSymbolsBuilder.addAll(DependencyExtractor.extractUnique(expression)); subqueryAssignments.put(output, expression); } } Set<Symbol> subqueryAssignmentsSymbols = subqueryAssignmentsSymbolsBuilder.build(); PlanNode subquery = context.rewrite(node.getSubquery(), subqueryAssignmentsSymbols); // prune not used correlation symbols Set<Symbol> subquerySymbols = DependencyExtractor.extractUnique(subquery); List<Symbol> newCorrelation = node.getCorrelation().stream() .filter(subquerySymbols::contains) .collect(toImmutableList()); Set<Symbol> inputContext = ImmutableSet.<Symbol>builder() .addAll(context.get()) .addAll(newCorrelation) .addAll(subqueryAssignmentsSymbols) // need to include those: e.g: "expr" from "expr IN (SELECT 1)" .build(); PlanNode input = context.rewrite(node.getInput(), inputContext); return new ApplyNode(node.getId(), input, subquery, subqueryAssignments.build(), newCorrelation); } @Override public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext<Set<Symbol>> context) { if (!context.get().contains(node.getIdColumn())) { return context.rewrite(node.getSource(), context.get()); } return context.defaultRewrite(node, context.get()); } } }