/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.IfExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import static com.facebook.presto.SystemSessionProperties.isOptimizeDistinctAggregationEnabled;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE;
import static java.util.Objects.requireNonNull;
/*
* This optimizer convert query of form:
*
* SELECT a1, a2,..., an, F1(b1), F2(b2), F3(b3), ...., Fm(bm), F(distinct c) FROM Table GROUP BY a1, a2, ..., an
*
* INTO
*
* SELECT a1, a2,..., an, arbitrary(if(group = 0, f1)),...., arbitrary(if(group = 0, fm)), F(if(group = 1, c)) FROM
* SELECT a1, a2,..., an, F1(b1) as f1, F2(b2) as f2,...., Fm(bm) as fm, c, group FROM
* SELECT a1, a2,..., an, b1, b2, ... ,bn, c FROM Table GROUP BY GROUPING SETS ((a1, a2,..., an, b1, b2, ... ,bn), (a1, a2,..., an, c))
* GROUP BY a1, a2,..., an, c, group
* GROUP BY a1, a2,..., an
*/
public class OptimizeMixedDistinctAggregations
implements PlanOptimizer
{
private final Metadata metadata;
public OptimizeMixedDistinctAggregations(Metadata metadata)
{
this.metadata = metadata;
}
@Override
public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator)
{
if (isOptimizeDistinctAggregationEnabled(session)) {
return SimplePlanRewriter.rewriteWith(new Optimizer(idAllocator, symbolAllocator, metadata), plan, Optional.empty());
}
return plan;
}
private static class Optimizer
extends SimplePlanRewriter<Optional<AggregateInfo>>
{
private final PlanNodeIdAllocator idAllocator;
private final SymbolAllocator symbolAllocator;
private final Metadata metadata;
private Optimizer(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Metadata metadata)
{
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null");
this.metadata = requireNonNull(metadata, "metadata is null");
}
@Override
public PlanNode visitAggregation(AggregationNode node, RewriteContext<Optional<AggregateInfo>> context)
{
// optimize if and only if
// some aggregation functions have a distinct mask symbol
// and if not all aggregation functions on same distinct mask symbol (this case handled by SingleDistinctOptimizer)
Set<Symbol> masks = ImmutableSet.copyOf(node.getMasks().values());
if (masks.size() != 1 || node.getMasks().size() == node.getAggregations().size()) {
return context.defaultRewrite(node, Optional.empty());
}
if (node.getAggregations().values().stream().map(FunctionCall::getFilter).anyMatch(Optional::isPresent)) {
// Skip if any aggregation contains a filter
return context.defaultRewrite(node, Optional.empty());
}
AggregateInfo aggregateInfo = new AggregateInfo(
node.getGroupingKeys(),
Iterables.getOnlyElement(masks),
node.getAggregations(),
node.getFunctions());
if (!checkAllEquatableTypes(aggregateInfo)) {
// This optimization relies on being able to GROUP BY arguments
// of the original aggregation functions. If they their types are
// not comparable, we have to skip it.
return context.defaultRewrite(node, Optional.empty());
}
PlanNode source = context.rewrite(node.getSource(), Optional.of(aggregateInfo));
// make sure there's a markdistinct associated with this aggregation
if (!aggregateInfo.isFoundMarkDistinct()) {
return context.defaultRewrite(node, Optional.empty());
}
// Change aggregate node to do second aggregation, handles this part of optimized plan mentioned above:
// SELECT a1, a2,..., an, arbitrary(if(group = 0, f1)),...., arbitrary(if(group = 0, fm)), F(if(group = 1, c))
ImmutableMap.Builder<Symbol, FunctionCall> aggregations = ImmutableMap.builder();
ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder();
for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
FunctionCall functionCall = entry.getValue();
if (entry.getValue().isDistinct()) {
aggregations.put(
entry.getKey(), new FunctionCall(
functionCall.getName(),
functionCall.getWindow(),
false,
ImmutableList.of(aggregateInfo.getNewDistinctAggregateSymbol().toSymbolReference())));
functions.put(entry.getKey(), node.getFunctions().get(entry.getKey()));
}
else {
// Aggregations on non-distinct are already done by new node, just extract the non-null value
Symbol argument = aggregateInfo.getNewNonDistinctAggregateSymbols().get(entry.getKey());
QualifiedName functionName = QualifiedName.of("arbitrary");
aggregations.put(entry.getKey(), new FunctionCall(
functionName,
functionCall.getWindow(),
false,
ImmutableList.of(argument.toSymbolReference())));
functions.put(entry.getKey(), getFunctionSignature(functionName, argument));
}
}
return new AggregationNode(
idAllocator.getNextId(),
source,
aggregations.build(),
functions.build(),
Collections.emptyMap(),
node.getGroupingSets(),
node.getStep(),
Optional.empty(),
node.getGroupIdSymbol());
}
@Override
public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext<Optional<AggregateInfo>> context)
{
Optional<AggregateInfo> aggregateInfo = context.get();
// presence of aggregateInfo => mask also present
if (!aggregateInfo.isPresent() || !aggregateInfo.get().getMask().equals(node.getMarkerSymbol())) {
return context.defaultRewrite(node, Optional.empty());
}
aggregateInfo.get().foundMarkDistinct();
PlanNode source = context.rewrite(node.getSource(), Optional.empty());
Set<Symbol> allSymbols = new HashSet<>();
List<Symbol> groupBySymbols = aggregateInfo.get().getGroupBySymbols(); // a
List<Symbol> nonDistinctAggregateSymbols = aggregateInfo.get().getOriginalNonDistinctAggregateArgs(); //b
Symbol distinctSymbol = Iterables.getOnlyElement(aggregateInfo.get().getOriginalDistinctAggregateArgs()); // c
// If same symbol present in aggregations on distinct and non-distinct values, e.g. select sum(a), count(distinct a),
// then we need to create a duplicate stream for this symbol
Symbol duplicatedDistinctSymbol = distinctSymbol;
if (nonDistinctAggregateSymbols.contains(distinctSymbol)) {
Symbol newSymbol = symbolAllocator.newSymbol(distinctSymbol.getName(), symbolAllocator.getTypes().get(distinctSymbol));
nonDistinctAggregateSymbols.set(nonDistinctAggregateSymbols.indexOf(distinctSymbol), newSymbol);
duplicatedDistinctSymbol = newSymbol;
}
allSymbols.addAll(groupBySymbols);
allSymbols.addAll(nonDistinctAggregateSymbols);
allSymbols.add(distinctSymbol);
// 1. Add GroupIdNode
Symbol groupSymbol = symbolAllocator.newSymbol("group", BigintType.BIGINT); // g
GroupIdNode groupIdNode = createGroupIdNode(
groupBySymbols,
nonDistinctAggregateSymbols,
distinctSymbol,
duplicatedDistinctSymbol,
groupSymbol,
allSymbols,
source);
// 2. Add aggregation node
Set<Symbol> groupByKeys = new HashSet<>();
groupByKeys.addAll(groupBySymbols);
groupByKeys.add(distinctSymbol);
groupByKeys.add(groupSymbol);
ImmutableMap.Builder aggregationOutputSymbolsMapBuilder = ImmutableMap.builder();
AggregationNode aggregationNode = createNonDistinctAggregation(
aggregateInfo.get(),
distinctSymbol,
duplicatedDistinctSymbol,
groupByKeys,
groupIdNode,
node,
aggregationOutputSymbolsMapBuilder);
// This map has mapping only for aggregation on non-distinct symbols which the new AggregationNode handles
Map<Symbol, Symbol> aggregationOutputSymbolsMap = aggregationOutputSymbolsMapBuilder.build();
// 3. Add new project node that adds if expressions
ProjectNode projectNode = createProjectNode(
aggregationNode,
aggregateInfo.get(),
distinctSymbol,
groupSymbol,
groupBySymbols,
aggregationOutputSymbolsMap);
return projectNode;
}
// Returns false if either mask symbol or any of the symbols in aggregations is not comparable
private boolean checkAllEquatableTypes(AggregateInfo aggregateInfo)
{
for (Symbol symbol : aggregateInfo.getOriginalNonDistinctAggregateArgs()) {
Type type = symbolAllocator.getTypes().get(symbol);
if (!type.isComparable()) {
return false;
}
}
if (!symbolAllocator.getTypes().get(aggregateInfo.getMask()).isComparable()) {
return false;
}
return true;
}
/*
* This Project is useful for cases when we aggregate on distinct and non-distinct values of same symbol, eg:
* select a, sum(b), count(c), sum(distinct c) group by a
* Without this Project, we would count additional values for count(c)
*
* This method also populates maps of old to new symbols. For each key of outputNonDistinctAggregateSymbols,
* Higher level aggregation node's aggregation <key, AggregateExpression> will now have to run AggregateExpression on value of outputNonDistinctAggregateSymbols
* Same for outputDistinctAggregateSymbols map
*/
private ProjectNode createProjectNode(
AggregationNode source,
AggregateInfo aggregateInfo,
Symbol distinctSymbol,
Symbol groupSymbol,
List<Symbol> groupBySymbols,
Map<Symbol, Symbol> aggregationOutputSymbolsMap)
{
Assignments.Builder outputSymbols = Assignments.builder();
ImmutableMap.Builder<Symbol, Symbol> outputNonDistinctAggregateSymbols = ImmutableMap.builder();
for (Symbol symbol : source.getOutputSymbols()) {
if (distinctSymbol.equals(symbol)) {
Symbol newSymbol = symbolAllocator.newSymbol("expr", symbolAllocator.getTypes().get(symbol));
aggregateInfo.setNewDistinctAggregateSymbol(newSymbol);
Expression expression = createIfExpression(
groupSymbol.toSymbolReference(),
new Cast(new LongLiteral("1"), "bigint"), // TODO: this should use GROUPING() when that's available instead of relying on specific group numbering
ComparisonExpressionType.EQUAL,
symbol.toSymbolReference(),
symbolAllocator.getTypes().get(symbol));
outputSymbols.put(newSymbol, expression);
}
else if (aggregationOutputSymbolsMap.containsKey(symbol)) {
Symbol newSymbol = symbolAllocator.newSymbol("expr", symbolAllocator.getTypes().get(symbol));
// key of outputNonDistinctAggregateSymbols is key of an aggregation in AggrNode above, it will now aggregate on this Map's value
outputNonDistinctAggregateSymbols.put(aggregationOutputSymbolsMap.get(symbol), newSymbol);
Expression expression = createIfExpression(
groupSymbol.toSymbolReference(),
new Cast(new LongLiteral("0"), "bigint"), // TODO: this should use GROUPING() when that's available instead of relying on specific group numbering
ComparisonExpressionType.EQUAL,
symbol.toSymbolReference(),
symbolAllocator.getTypes().get(symbol));
outputSymbols.put(newSymbol, expression);
}
// A symbol can appear both in groupBy and distinct/non-distinct aggregation
if (groupBySymbols.contains(symbol)) {
Expression expression = symbol.toSymbolReference();
outputSymbols.put(symbol, expression);
}
}
// add null assignment for mask
// unused mask will be removed by PruneUnreferencedOutputs
outputSymbols.put(aggregateInfo.getMask(), new NullLiteral());
aggregateInfo.setNewNonDistinctAggregateSymbols(outputNonDistinctAggregateSymbols.build());
return new ProjectNode(idAllocator.getNextId(), source, outputSymbols.build());
}
private GroupIdNode createGroupIdNode(
List<Symbol> groupBySymbols,
List<Symbol> nonDistinctAggregateSymbols,
Symbol distinctSymbol,
Symbol duplicatedDistinctSymbol,
Symbol groupSymbol,
Set<Symbol> allSymbols,
PlanNode source)
{
List<List<Symbol>> groups = new ArrayList<>();
// g0 = {group-by symbols + allNonDistinctAggregateSymbols}
// g1 = {group-by symbols + Distinct Symbol}
// symbols present in Group_i will be set, rest will be Null
//g0
Set<Symbol> group0 = new HashSet<>();
group0.addAll(groupBySymbols);
group0.addAll(nonDistinctAggregateSymbols);
groups.add(ImmutableList.copyOf(group0));
// g1
Set<Symbol> group1 = new HashSet<>();
group1.addAll(groupBySymbols);
group1.add(distinctSymbol);
groups.add(ImmutableList.copyOf(group1));
return new GroupIdNode(
idAllocator.getNextId(),
source,
groups,
allSymbols.stream().collect(Collectors.toMap(
symbol -> symbol,
symbol -> (symbol.equals(duplicatedDistinctSymbol) ? distinctSymbol : symbol))),
ImmutableMap.of(),
groupSymbol);
}
/*
* This method returns a new Aggregation node which has aggregations on non-distinct symbols from original plan. Generates
* SELECT a1, a2,..., an, F1(b1) as f1, F2(b2) as f2,...., Fm(bm) as fm, c, group
* part in the optimized plan mentioned above
*
* It also populates the mappings of new function's output symbol to corresponding old function's output symbol, e.g.
* { f1 -> F1, f2 -> F2, ... }
* The new AggregateNode aggregates on the symbols that original AggregationNode aggregated on
* Original one will now aggregate on the output symbols of this new node
*/
private AggregationNode createNonDistinctAggregation(
AggregateInfo aggregateInfo,
Symbol distinctSymbol,
Symbol duplicatedDistinctSymbol,
Set<Symbol> groupByKeys,
GroupIdNode groupIdNode,
MarkDistinctNode originalNode,
ImmutableMap.Builder aggregationOutputSymbolsMapBuilder
)
{
ImmutableMap.Builder<Symbol, FunctionCall> aggregations = ImmutableMap.builder();
ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder();
for (Map.Entry<Symbol, FunctionCall> entry : aggregateInfo.getAggregations().entrySet()) {
FunctionCall functionCall = entry.getValue();
if (!functionCall.isDistinct()) {
Symbol newSymbol = symbolAllocator.newSymbol(entry.getKey().toSymbolReference(), symbolAllocator.getTypes().get(entry.getKey()));
aggregationOutputSymbolsMapBuilder.put(newSymbol, entry.getKey());
if (duplicatedDistinctSymbol.equals(distinctSymbol)) {
// Mask symbol was not present in aggregations without mask
aggregations.put(newSymbol, functionCall);
}
else {
// Handling for cases when mask symbol appears in non distinct aggregations too
// Now the aggregation should happen over the duplicate symbol added before
if (functionCall.getArguments().contains(distinctSymbol.toSymbolReference())) {
ImmutableList.Builder arguments = ImmutableList.builder();
for (Expression argument : functionCall.getArguments()) {
if (distinctSymbol.toSymbolReference().equals(argument)) {
arguments.add(duplicatedDistinctSymbol.toSymbolReference());
}
else {
arguments.add(argument);
}
}
aggregations.put(newSymbol, new FunctionCall(functionCall.getName(), functionCall.getWindow(), false, arguments.build()));
}
else {
aggregations.put(newSymbol, functionCall);
}
}
functions.put(newSymbol, aggregateInfo.getFunctions().get(entry.getKey()));
}
}
return new AggregationNode(
idAllocator.getNextId(),
groupIdNode,
aggregations.build(),
functions.build(),
Collections.emptyMap(),
ImmutableList.of(ImmutableList.copyOf(groupByKeys)),
SINGLE,
originalNode.getHashSymbol(),
Optional.empty());
}
private Signature getFunctionSignature(QualifiedName functionName, Symbol argument)
{
return metadata.getFunctionRegistry()
.resolveFunction(
functionName,
ImmutableList.of(new TypeSignatureProvider(symbolAllocator.getTypes().get(argument).getTypeSignature())));
}
// creates if clause specific to use case here, default value always null
private static IfExpression createIfExpression(Expression left, Expression right, ComparisonExpressionType type, Expression result, Type trueValueType)
{
return new IfExpression(
new ComparisonExpression(type, left, right),
result,
new Cast(new NullLiteral(), trueValueType.getTypeSignature().toString()));
}
}
private static class AggregateInfo
{
private final List<Symbol> groupBySymbols;
private final Symbol mask;
private final Map<Symbol, FunctionCall> aggregations;
private final Map<Symbol, Signature> functions;
// Filled on the way back, these are the symbols corresponding to their distinct or non-distinct original symbols
private Map<Symbol, Symbol> newNonDistinctAggregateSymbols;
private Symbol newDistinctAggregateSymbol;
private boolean foundMarkDistinct;
public AggregateInfo(List<Symbol> groupBySymbols, Symbol mask, Map<Symbol, FunctionCall> aggregations, Map<Symbol, Signature> functions)
{
this.groupBySymbols = ImmutableList.copyOf(groupBySymbols);
this.mask = mask;
this.aggregations = ImmutableMap.copyOf(aggregations);
this.functions = ImmutableMap.copyOf(functions);
}
public List<Symbol> getOriginalNonDistinctAggregateArgs()
{
return aggregations.values().stream()
.filter(function -> !function.isDistinct())
.flatMap(function -> function.getArguments().stream())
.distinct()
.map(Symbol::from)
.collect(Collectors.toList());
}
public List<Symbol> getOriginalDistinctAggregateArgs()
{
return aggregations.values().stream()
.filter(FunctionCall::isDistinct)
.flatMap(function -> function.getArguments().stream())
.distinct()
.map(Symbol::from)
.collect(Collectors.toList());
}
public Symbol getNewDistinctAggregateSymbol()
{
return newDistinctAggregateSymbol;
}
public void setNewDistinctAggregateSymbol(Symbol newDistinctAggregateSymbol)
{
this.newDistinctAggregateSymbol = newDistinctAggregateSymbol;
}
public Map<Symbol, Symbol> getNewNonDistinctAggregateSymbols()
{
return newNonDistinctAggregateSymbols;
}
public void setNewNonDistinctAggregateSymbols(Map<Symbol, Symbol> newNonDistinctAggregateSymbols)
{
this.newNonDistinctAggregateSymbols = newNonDistinctAggregateSymbols;
}
public Symbol getMask()
{
return mask;
}
public List<Symbol> getGroupBySymbols()
{
return groupBySymbols;
}
public Map<Symbol, FunctionCall> getAggregations()
{
return aggregations;
}
public Map<Symbol, Signature> getFunctions()
{
return functions;
}
public void foundMarkDistinct()
{
foundMarkDistinct = true;
}
public boolean isFoundMarkDistinct()
{
return foundMarkDistinct;
}
}
}