/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.plan;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.FunctionCall;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import javax.annotation.concurrent.Immutable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;
@Immutable
public class AggregationNode
extends PlanNode
{
private final PlanNode source;
private final Map<Symbol, Aggregation> assignments;
private final List<List<Symbol>> groupingSets;
private final Step step;
private final Optional<Symbol> hashSymbol;
private final Optional<Symbol> groupIdSymbol;
private final List<Symbol> outputs;
public boolean hasEmptyGroupingSet()
{
return groupingSets.stream().anyMatch(List::isEmpty);
}
public boolean hasNonEmptyGroupingSet()
{
return groupingSets.stream().anyMatch(symbols -> !symbols.isEmpty());
}
public enum Step
{
PARTIAL(true, true),
FINAL(false, false),
INTERMEDIATE(false, true),
SINGLE(true, false);
private final boolean inputRaw;
private final boolean outputPartial;
Step(boolean inputRaw, boolean outputPartial)
{
this.inputRaw = inputRaw;
this.outputPartial = outputPartial;
}
public boolean isInputRaw()
{
return inputRaw;
}
public boolean isOutputPartial()
{
return outputPartial;
}
public static Step partialOutput(Step step)
{
if (step.isInputRaw()) {
return Step.PARTIAL;
}
else {
return Step.INTERMEDIATE;
}
}
public static Step partialInput(Step step)
{
if (step.isOutputPartial()) {
return Step.INTERMEDIATE;
}
else {
return Step.FINAL;
}
}
}
@JsonCreator
public AggregationNode(
@JsonProperty("id") PlanNodeId id,
@JsonProperty("source") PlanNode source,
@JsonProperty("assignments") Map<Symbol, Aggregation> assignments,
@JsonProperty("groupingSets") List<List<Symbol>> groupingSets,
@JsonProperty("step") Step step,
@JsonProperty("hashSymbol") Optional<Symbol> hashSymbol,
@JsonProperty("groupIdSymbol") Optional<Symbol> groupIdSymbol)
{
super(id);
this.source = source;
this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "aggregations is null"));
requireNonNull(groupingSets, "groupingSets is null");
checkArgument(!groupingSets.isEmpty(), "grouping sets list cannot be empty");
this.groupingSets = ImmutableList.copyOf(groupingSets);
this.step = step;
this.hashSymbol = hashSymbol;
this.groupIdSymbol = requireNonNull(groupIdSymbol);
ImmutableList.Builder<Symbol> outputs = ImmutableList.builder();
outputs.addAll(getGroupingKeys());
hashSymbol.ifPresent(outputs::add);
outputs.addAll(assignments.keySet());
this.outputs = outputs.build();
}
/**
* @deprecated pass Assignments object instead
*/
@Deprecated
public AggregationNode(
PlanNodeId id,
PlanNode source,
Map<Symbol, FunctionCall> assignments,
Map<Symbol, Signature> functions,
Map<Symbol, Symbol> masks,
List<List<Symbol>> groupingSets,
Step step,
Optional<Symbol> hashSymbol,
Optional<Symbol> groupIdSymbol)
{
this(id, source, makeAssignments(assignments, functions, masks), groupingSets, step, hashSymbol, groupIdSymbol);
}
@Override
public List<PlanNode> getSources()
{
return ImmutableList.of(source);
}
@Override
public List<Symbol> getOutputSymbols()
{
return outputs;
}
@JsonProperty
public Map<Symbol, Aggregation> getAssignments()
{
return assignments;
}
/**
* @deprecated Use getAssignments
*/
@Deprecated
public Map<Symbol, FunctionCall> getAggregations()
{
// use an ImmutableMap.Builder because the output has to preserve
// the iteration order of the original map.
ImmutableMap.Builder<Symbol, FunctionCall> builder = ImmutableMap.builder();
for (Map.Entry<Symbol, Aggregation> entry : assignments.entrySet()) {
builder.put(entry.getKey(), entry.getValue().getCall());
}
return builder.build();
}
/**
* @deprecated Use getAssignments
*/
@Deprecated
public Map<Symbol, Signature> getFunctions()
{
// use an ImmutableMap.Builder because the output has to preserve
// the iteration order of the original map.
ImmutableMap.Builder<Symbol, Signature> builder = ImmutableMap.builder();
for (Map.Entry<Symbol, Aggregation> entry : assignments.entrySet()) {
builder.put(entry.getKey(), entry.getValue().getSignature());
}
return builder.build();
}
/**
* @deprecated Use getAssignments
*/
@Deprecated
public Map<Symbol, Symbol> getMasks()
{
// use an ImmutableMap.Builder because the output has to preserve
// the iteration order of the original map.
ImmutableMap.Builder<Symbol, Symbol> builder = ImmutableMap.builder();
for (Map.Entry<Symbol, Aggregation> entry : assignments.entrySet()) {
entry.getValue()
.getMask()
.ifPresent(symbol -> builder.put(entry.getKey(), symbol));
}
return builder.build();
}
public List<Symbol> getGroupingKeys()
{
List<Symbol> symbols = new ArrayList<>(groupingSets.stream()
.flatMap(Collection::stream)
.distinct()
.collect(Collectors.toList()));
groupIdSymbol.ifPresent(symbols::add);
return symbols;
}
@JsonProperty("groupingSets")
public List<List<Symbol>> getGroupingSets()
{
return groupingSets;
}
/**
* @return whether this node should produce default output in case of no input pages.
* For example for query:
*
* SELECT count(*) FROM nation WHERE nationkey < 0
*
* A default output of "0" is expected to be produced by FINAL aggregation operator.
*/
public boolean hasDefaultOutput()
{
return hasEmptyGroupingSet() && (step.isOutputPartial() || step.equals(SINGLE));
}
@JsonProperty("source")
public PlanNode getSource()
{
return source;
}
@JsonProperty("step")
public Step getStep()
{
return step;
}
@JsonProperty("hashSymbol")
public Optional<Symbol> getHashSymbol()
{
return hashSymbol;
}
@JsonProperty("groupIdSymbol")
public Optional<Symbol> getGroupIdSymbol()
{
return groupIdSymbol;
}
@Override
public <C, R> R accept(PlanVisitor<C, R> visitor, C context)
{
return visitor.visitAggregation(this, context);
}
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
return new AggregationNode(getId(), Iterables.getOnlyElement(newChildren), assignments, groupingSets, step, hashSymbol, groupIdSymbol);
}
public boolean isDecomposable(FunctionRegistry functionRegistry)
{
return getFunctions().values().stream()
.map(functionRegistry::getAggregateFunctionImplementation)
.allMatch(InternalAggregationFunction::isDecomposable);
}
private static Map<Symbol, Aggregation> makeAssignments(
Map<Symbol, FunctionCall> aggregations,
Map<Symbol, Signature> functions,
Map<Symbol, Symbol> masks)
{
ImmutableMap.Builder<Symbol, Aggregation> builder = ImmutableMap.builder();
for (Map.Entry<Symbol, FunctionCall> entry : aggregations.entrySet()) {
Symbol output = entry.getKey();
builder.put(output, new Aggregation(
entry.getValue(),
functions.get(output),
Optional.ofNullable(masks.get(output))));
}
return builder.build();
}
public static class Aggregation
{
private final FunctionCall call;
private final Signature signature;
private final Optional<Symbol> mask;
@JsonCreator
public Aggregation(
@JsonProperty("call") FunctionCall call,
@JsonProperty("signature") Signature signature,
@JsonProperty("mask") Optional<Symbol> mask)
{
this.call = call;
this.signature = signature;
this.mask = mask;
}
@JsonProperty
public FunctionCall getCall()
{
return call;
}
@JsonProperty
public Signature getSignature()
{
return signature;
}
@JsonProperty
public Optional<Symbol> getMask()
{
return mask;
}
}
}