/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.ExceptNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.IntersectNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SetOperationNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static com.facebook.presto.metadata.FunctionKind.AGGREGATE;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step;
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL;
import static com.facebook.presto.sql.tree.ComparisonExpressionType.GREATER_THAN_OR_EQUAL;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.concat;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
/**
* Converts INTERSECT and EXCEPT queries into UNION ALL..GROUP BY...WHERE
* Eg: SELECT a FROM foo INTERSECT SELECT x FROM bar
* <p/>
* =>
* <p/>
* SELECT a
* FROM
* (SELECT a,
* COUNT(foo_marker) AS foo_cnt,
* COUNT(bar_marker) AS bar_cnt
* FROM
* (
* SELECT a, true as foo_marker, null as bar_marker FROM foo
* UNION ALL
* SELECT x, null as foo_marker, true as bar_marker FROM bar
* ) T1
* GROUP BY a) T2
* WHERE foo_cnt >= 1 AND bar_cnt >= 1;
*
* Eg: SELECT a FROM foo EXCEPT SELECT x FROM bar
* <p/>
* =>
* <p/>
* SELECT a
* FROM
* (SELECT a,
* COUNT(foo_marker) AS foo_cnt,
* COUNT(bar_marker) AS bar_cnt
* FROM
* (
* SELECT a, true as foo_marker, null as bar_marker FROM foo
* UNION ALL
* SELECT x, null as foo_marker, true as bar_marker FROM bar
* ) T1
* GROUP BY a) T2
* WHERE foo_cnt >= 1 AND bar_cnt = 0;
*/
public class ImplementIntersectAndExceptAsUnion
implements PlanOptimizer
{
@Override
public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator)
{
requireNonNull(plan, "plan is null");
requireNonNull(session, "session is null");
requireNonNull(types, "types is null");
requireNonNull(symbolAllocator, "symbolAllocator is null");
requireNonNull(idAllocator, "idAllocator is null");
return SimplePlanRewriter.rewriteWith(new Rewriter(idAllocator, symbolAllocator), plan);
}
private static class Rewriter
extends SimplePlanRewriter<Void>
{
private static final String MARKER = "marker";
private static final Signature COUNT_AGGREGATION = new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BOOLEAN));
private final PlanNodeIdAllocator idAllocator;
private final SymbolAllocator symbolAllocator;
private Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator)
{
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null");
}
@Override
public PlanNode visitIntersect(IntersectNode node, RewriteContext<Void> rewriteContext)
{
List<PlanNode> sources = node.getSources().stream()
.map(rewriteContext::rewrite)
.collect(toList());
List<Symbol> markers = allocateSymbols(sources.size(), MARKER, BOOLEAN);
// identity projection for all the fields in each of the sources plus marker columns
List<PlanNode> withMarkers = appendMarkers(markers, sources, node);
// add a union over all the rewritten sources. The outputs of the union have the same name as the
// original intersect node
List<Symbol> outputs = node.getOutputSymbols();
UnionNode union = union(withMarkers, ImmutableList.copyOf(concat(outputs, markers)));
// add count aggregations and filter rows where any of the counts is >= 1
List<Symbol> aggregationOutputs = allocateSymbols(markers.size(), "count", BIGINT);
AggregationNode aggregation = computeCounts(union, outputs, markers, aggregationOutputs);
FilterNode filterNode = addFilterForIntersect(aggregation);
return project(filterNode, outputs);
}
@Override
public PlanNode visitExcept(ExceptNode node, RewriteContext<Void> rewriteContext)
{
List<PlanNode> sources = node.getSources().stream()
.map(rewriteContext::rewrite)
.collect(toList());
List<Symbol> markers = allocateSymbols(sources.size(), MARKER, BOOLEAN);
// identity projection for all the fields in each of the sources plus marker columns
List<PlanNode> withMarkers = appendMarkers(markers, sources, node);
// add a union over all the rewritten sources. The outputs of the union have the same name as the
// original except node
List<Symbol> outputs = node.getOutputSymbols();
UnionNode union = union(withMarkers, ImmutableList.copyOf(concat(outputs, markers)));
// add count aggregations and filter rows where count for the first source is >= 1 and all others are 0
List<Symbol> aggregationOutputs = allocateSymbols(markers.size(), "count", BIGINT);
AggregationNode aggregation = computeCounts(union, outputs, markers, aggregationOutputs);
FilterNode filterNode = addFilterForExcept(aggregation, aggregationOutputs.get(0), aggregationOutputs.subList(1, aggregationOutputs.size()));
return project(filterNode, outputs);
}
private List<Symbol> allocateSymbols(int count, String nameHint, Type type)
{
ImmutableList.Builder<Symbol> symbolsBuilder = ImmutableList.builder();
for (int i = 0; i < count; i++) {
symbolsBuilder.add(symbolAllocator.newSymbol(nameHint, type));
}
return symbolsBuilder.build();
}
private List<PlanNode> appendMarkers(List<Symbol> markers, List<PlanNode> nodes, SetOperationNode node)
{
ImmutableList.Builder<PlanNode> result = ImmutableList.builder();
for (int i = 0; i < nodes.size(); i++) {
result.add(appendMarkers(nodes.get(i), i, markers, node.sourceSymbolMap(i)));
}
return result.build();
}
private PlanNode appendMarkers(PlanNode source, int markerIndex, List<Symbol> markers, Map<Symbol, SymbolReference> projections)
{
Assignments.Builder assignments = Assignments.builder();
// add existing intersect symbols to projection
for (Map.Entry<Symbol, SymbolReference> entry : projections.entrySet()) {
Symbol symbol = symbolAllocator.newSymbol(entry.getKey().getName(), symbolAllocator.getTypes().get(entry.getKey()));
assignments.put(symbol, entry.getValue());
}
// add extra marker fields to the projection
for (int i = 0; i < markers.size(); ++i) {
Expression expression = (i == markerIndex) ? TRUE_LITERAL : new Cast(new NullLiteral(), StandardTypes.BOOLEAN);
assignments.put(symbolAllocator.newSymbol(markers.get(i).getName(), BOOLEAN), expression);
}
return new ProjectNode(idAllocator.getNextId(), source, assignments.build());
}
private UnionNode union(List<PlanNode> nodes, List<Symbol> outputs)
{
ImmutableListMultimap.Builder<Symbol, Symbol> outputsToInputs = ImmutableListMultimap.builder();
for (PlanNode source : nodes) {
for (int i = 0; i < source.getOutputSymbols().size(); i++) {
outputsToInputs.put(outputs.get(i), source.getOutputSymbols().get(i));
}
}
return new UnionNode(idAllocator.getNextId(), nodes, outputsToInputs.build(), outputs);
}
private AggregationNode computeCounts(UnionNode sourceNode, List<Symbol> originalColumns, List<Symbol> markers, List<Symbol> aggregationOutputs)
{
ImmutableMap.Builder<Symbol, Signature> signatures = ImmutableMap.builder();
ImmutableMap.Builder<Symbol, FunctionCall> aggregations = ImmutableMap.builder();
for (int i = 0; i < markers.size(); i++) {
Symbol output = aggregationOutputs.get(i);
aggregations.put(output, new FunctionCall(QualifiedName.of("count"), ImmutableList.of(markers.get(i).toSymbolReference())));
signatures.put(output, COUNT_AGGREGATION);
}
return new AggregationNode(idAllocator.getNextId(),
sourceNode,
aggregations.build(),
signatures.build(),
ImmutableMap.of(),
ImmutableList.of(originalColumns),
Step.SINGLE,
Optional.empty(),
Optional.empty());
}
private FilterNode addFilterForIntersect(AggregationNode aggregation)
{
ImmutableList<Expression> predicates = aggregation.getAggregations().keySet().stream()
.map(column -> new ComparisonExpression(GREATER_THAN_OR_EQUAL, column.toSymbolReference(), new GenericLiteral("BIGINT", "1")))
.collect(toImmutableList());
return new FilterNode(idAllocator.getNextId(), aggregation, ExpressionUtils.and(predicates));
}
private FilterNode addFilterForExcept(AggregationNode aggregation, Symbol firstSource, List<Symbol> remainingSources)
{
ImmutableList.Builder<Expression> predicatesBuilder = ImmutableList.builder();
predicatesBuilder.add(new ComparisonExpression(GREATER_THAN_OR_EQUAL, firstSource.toSymbolReference(), new GenericLiteral("BIGINT", "1")));
for (Symbol symbol : remainingSources) {
predicatesBuilder.add(new ComparisonExpression(EQUAL, symbol.toSymbolReference(), new GenericLiteral("BIGINT", "0")));
}
return new FilterNode(idAllocator.getNextId(), aggregation, ExpressionUtils.and(predicatesBuilder.build()));
}
private ProjectNode project(PlanNode node, List<Symbol> columns)
{
return new ProjectNode(
idAllocator.getNextId(),
node,
Assignments.identity(columns));
}
}
}