/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.Partitioning;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static java.util.Objects.requireNonNull;
public class PartialAggregationPushDown
implements PlanOptimizer
{
private final FunctionRegistry functionRegistry;
public PartialAggregationPushDown(FunctionRegistry registry)
{
requireNonNull(registry, "registry is null");
this.functionRegistry = registry;
}
@Override
public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator)
{
return SimplePlanRewriter.rewriteWith(new Rewriter(symbolAllocator, idAllocator), plan, null);
}
private class Rewriter
extends SimplePlanRewriter<Void>
{
private final SymbolAllocator allocator;
private final PlanNodeIdAllocator idAllocator;
public Rewriter(SymbolAllocator allocator, PlanNodeIdAllocator idAllocator)
{
this.allocator = requireNonNull(allocator, "allocator is null");
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
}
@Override
public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context)
{
PlanNode child = node.getSource();
if (!(child instanceof ExchangeNode)) {
return context.defaultRewrite(node);
}
boolean decomposable = node.isDecomposable(functionRegistry);
if (node.getStep().equals(SINGLE) &&
node.hasEmptyGroupingSet() &&
node.hasNonEmptyGroupingSet()) {
checkState(
decomposable,
"Distributed aggregation with empty grouping set requires partial but functions are not decomposable");
return context.rewrite(split(node));
}
if (!decomposable) {
return context.defaultRewrite(node);
}
// partial aggregation can only be pushed through exchange that doesn't change
// the cardinality of the stream (i.e., gather or repartition)
ExchangeNode exchange = (ExchangeNode) child;
if ((exchange.getType() != GATHER && exchange.getType() != REPARTITION) ||
exchange.getPartitioningScheme().isReplicateNulls()) {
return context.defaultRewrite(node);
}
if (exchange.getType() == REPARTITION) {
// if partitioning columns are not a subset of grouping keys,
// we can't push this through
List<Symbol> partitioningColumns = exchange.getPartitioningScheme()
.getPartitioning()
.getArguments()
.stream()
.filter(Partitioning.ArgumentBinding::isVariable)
.map(Partitioning.ArgumentBinding::getColumn)
.collect(Collectors.toList());
if (!node.getGroupingKeys().containsAll(partitioningColumns)) {
return context.defaultRewrite(node);
}
}
// currently, we only support plans that don't use pre-computed hash functions
if (node.getHashSymbol().isPresent() || exchange.getPartitioningScheme().getHashColumn().isPresent()) {
return context.defaultRewrite(node);
}
switch (node.getStep()) {
case SINGLE:
// Split it into a FINAL on top of a PARTIAL and
// reprocess the resulting plan to push the partial
// below the exchange (see case below).
return context.rewrite(split(node));
case PARTIAL:
// Push it underneath each branch of the exchange
// and reprocess in case it can be pushed further down
// (e.g., if there are local/remote exchanges stacked)
return context.rewrite(pushPartial(node, exchange));
default:
return context.defaultRewrite(node);
}
}
private PlanNode pushPartial(AggregationNode partial, ExchangeNode exchange)
{
List<PlanNode> partials = new ArrayList<>();
for (int i = 0; i < exchange.getSources().size(); i++) {
PlanNode source = exchange.getSources().get(i);
SymbolMapper.Builder mappingsBuilder = SymbolMapper.builder();
for (int outputIndex = 0; outputIndex < exchange.getOutputSymbols().size(); outputIndex++) {
Symbol output = exchange.getOutputSymbols().get(outputIndex);
Symbol input = exchange.getInputs().get(i).get(outputIndex);
if (!output.equals(input)) {
mappingsBuilder.put(output, input);
}
}
SymbolMapper symbolMapper = mappingsBuilder.build();
AggregationNode mappedPartial = symbolMapper.map(partial, source, idAllocator);
Assignments.Builder assignments = Assignments.builder();
for (Symbol output : partial.getOutputSymbols()) {
Symbol input = symbolMapper.map(output);
assignments.put(output, input.toSymbolReference());
}
partials.add(new ProjectNode(idAllocator.getNextId(), mappedPartial, assignments.build()));
}
for (PlanNode node : partials) {
verify(partial.getOutputSymbols().equals(node.getOutputSymbols()));
}
// Since this exchange source is now guaranteed to have the same symbols as the inputs to the the partial
// aggregation, we don't need to rewrite symbols in the partitioning function
PartitioningScheme partitioning = new PartitioningScheme(
exchange.getPartitioningScheme().getPartitioning(),
partial.getOutputSymbols(),
exchange.getPartitioningScheme().getHashColumn(),
exchange.getPartitioningScheme().isReplicateNulls(),
exchange.getPartitioningScheme().getBucketToPartition());
return new ExchangeNode(
idAllocator.getNextId(),
exchange.getType(),
exchange.getScope(),
partitioning,
partials,
ImmutableList.copyOf(Collections.nCopies(partials.size(), partial.getOutputSymbols())));
}
private PlanNode split(AggregationNode node)
{
// otherwise, add a partial and final with an exchange in between
Map<Symbol, Symbol> masks = node.getMasks();
Map<Symbol, FunctionCall> finalCalls = new HashMap<>();
Map<Symbol, FunctionCall> intermediateCalls = new HashMap<>();
Map<Symbol, Signature> intermediateFunctions = new HashMap<>();
Map<Symbol, Symbol> intermediateMask = new HashMap<>();
for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
Signature signature = node.getFunctions().get(entry.getKey());
InternalAggregationFunction function = functionRegistry.getAggregateFunctionImplementation(signature);
Symbol intermediateSymbol = allocator.newSymbol(signature.getName(), function.getIntermediateType());
intermediateCalls.put(intermediateSymbol, entry.getValue());
intermediateFunctions.put(intermediateSymbol, signature);
if (masks.containsKey(entry.getKey())) {
intermediateMask.put(intermediateSymbol, masks.get(entry.getKey()));
}
// rewrite final aggregation in terms of intermediate function
finalCalls.put(entry.getKey(), new FunctionCall(QualifiedName.of(signature.getName()), ImmutableList.of(intermediateSymbol.toSymbolReference())));
}
PlanNode partial = new AggregationNode(
idAllocator.getNextId(),
node.getSource(),
intermediateCalls,
intermediateFunctions,
intermediateMask,
node.getGroupingSets(),
PARTIAL,
node.getHashSymbol(),
node.getGroupIdSymbol());
return new AggregationNode(
node.getId(),
partial,
finalCalls,
node.getFunctions(),
ImmutableMap.of(),
node.getGroupingSets(),
FINAL,
node.getHashSymbol(),
node.getGroupIdSymbol());
}
}
}