/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.DependencyExtractor;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.Multimap;
import java.util.Collection;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.google.common.base.Preconditions.checkState;
/**
* Merge together the functions in WindowNodes that have identical WindowNode.Specifications.
* For example:
* <p>
* OutputNode
* `--...
* `--WindowNode(Specification: A, Functions: [sum(something)])
* `--WindowNode(Specification: B, Functions: [sum(something)])
* `--WindowNode(Specification: A, Functions: [avg(something)])
* `--...
*
* Will be transformed into
* <p>
* OutputNode
* `--...
* `--WindowNode(Specification: B, Functions: [sum(something)])
* `--WindowNode(Specification: A, Functions: [avg(something), sum(something)])
* `--...
*
* This will NOT merge the functions in WindowNodes that have identical WindowNode.Specifications,
* but have a node between them that is not a WindowNode.
* In the following example, the functions in the WindowNodes with specification `A' will not be
* merged into a single WindowNode.
* <p>
* OutputNode
* `--...
* `--WindowNode(Specification: A, Functions: [sum(something)])
* `--WindowNode(Specification: B, Functions: [sum(something)])
* `-- ProjectNode(...)
* `--WindowNode(Specification: A, Functions: [avg(something)])
* `--...
*/
public class MergeWindows
implements PlanOptimizer
{
@Override
public PlanNode optimize(PlanNode plan,
Session session,
Map<Symbol, Type> types,
SymbolAllocator symbolAllocator,
PlanNodeIdAllocator idAllocator)
{
// ImmutableListMultimap preserves order of window nodes
return SimplePlanRewriter.rewriteWith(new Rewriter(), plan, ImmutableListMultimap.of());
}
private static class Rewriter
extends SimplePlanRewriter<Multimap<WindowNode.Specification, WindowNode>>
{
@Override
protected PlanNode visitPlan(
PlanNode node,
RewriteContext<Multimap<WindowNode.Specification, WindowNode>> context)
{
PlanNode newNode = context.defaultRewrite(node, ImmutableListMultimap.of());
return collapseWindowsWithinSpecification(context.get(), newNode);
}
@Override
public PlanNode visitWindow(
WindowNode windowNode,
RewriteContext<Multimap<WindowNode.Specification, WindowNode>> context)
{
checkState(!windowNode.getHashSymbol().isPresent(), "MergeWindows should be run before HashGenerationOptimizer");
checkState(windowNode.getPrePartitionedInputs().isEmpty() && windowNode.getPreSortedOrderPrefix() == 0, "MergeWindows should be run before AddExchanges");
checkState(windowNode.getWindowFunctions().values().stream().distinct().count() == 1, "Frames expected to be identical");
for (WindowNode.Specification specification : context.get().keySet()) {
Collection<WindowNode> nodes = context.get().get(specification);
if (nodes.stream().anyMatch(node -> dependsOn(node, windowNode))) {
return collapseWindowsWithinSpecification(context.get(),
context.rewrite(
windowNode.getSource(),
ImmutableListMultimap.of(windowNode.getSpecification(), windowNode)));
}
}
return context.rewrite(
windowNode.getSource(),
ImmutableListMultimap.<WindowNode.Specification, WindowNode>builder()
.put(windowNode.getSpecification(), windowNode) // Add the current window first so that it gets precedence in iteration order
.putAll(context.get())
.build());
}
private static PlanNode collapseWindowsWithinSpecification(Multimap<WindowNode.Specification, WindowNode> windowsMap, PlanNode sourceNode)
{
for (WindowNode.Specification specification : windowsMap.keySet()) {
Collection<WindowNode> windows = windowsMap.get(specification);
sourceNode = collapseWindows(sourceNode, specification, windows);
}
return sourceNode;
}
private static WindowNode collapseWindows(PlanNode source, WindowNode.Specification specification, Collection<WindowNode> windows)
{
WindowNode canonical = windows.iterator().next();
return new WindowNode(
canonical.getId(),
source,
specification,
windows.stream()
.map(WindowNode::getWindowFunctions)
.flatMap(map -> map.entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)),
canonical.getHashSymbol(),
canonical.getPrePartitionedInputs(),
canonical.getPreSortedOrderPrefix());
}
private static boolean dependsOn(WindowNode parent, WindowNode child)
{
Set<Symbol> childOutputs = child.getCreatedSymbols();
Stream<Symbol> arguments = parent.getWindowFunctions().values().stream()
.map(WindowNode.Function::getFunctionCall)
.flatMap(functionCall -> functionCall.getArguments().stream())
.map(DependencyExtractor::extractUnique)
.flatMap(Collection::stream);
return parent.getPartitionBy().stream().anyMatch(childOutputs::contains)
|| parent.getOrderBy().stream().anyMatch(childOutputs::contains)
|| arguments.anyMatch(childOutputs::contains);
}
}
}