/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.predicate.Domain;
import com.facebook.presto.spi.predicate.Range;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.predicate.ValueSet;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import static com.facebook.presto.metadata.FunctionKind.WINDOW;
import static com.facebook.presto.spi.predicate.Marker.Bound.BELOW;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.sql.planner.DomainTranslator.ExtractionResult;
import static com.facebook.presto.sql.planner.DomainTranslator.fromPredicate;
import static com.facebook.presto.sql.planner.DomainTranslator.toPredicate;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toMap;
public class WindowFilterPushDown
implements PlanOptimizer
{
private static final Signature ROW_NUMBER_SIGNATURE = new Signature("row_number", WINDOW, parseTypeSignature(StandardTypes.BIGINT), ImmutableList.of());
private final Metadata metadata;
public WindowFilterPushDown(Metadata metadata)
{
this.metadata = requireNonNull(metadata, "metadata is null");
}
@Override
public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator)
{
requireNonNull(plan, "plan is null");
requireNonNull(session, "session is null");
requireNonNull(types, "types is null");
requireNonNull(symbolAllocator, "symbolAllocator is null");
requireNonNull(idAllocator, "idAllocator is null");
return SimplePlanRewriter.rewriteWith(new Rewriter(idAllocator, metadata, session, types), plan, null);
}
private static class Rewriter
extends SimplePlanRewriter<Void>
{
private final PlanNodeIdAllocator idAllocator;
private final Metadata metadata;
private final Session session;
private final Map<Symbol, Type> types;
private Rewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session, Map<Symbol, Type> types)
{
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.session = requireNonNull(session, "session is null");
this.types = ImmutableMap.copyOf(requireNonNull(types, "types is null"));
}
@Override
public PlanNode visitWindow(WindowNode node, RewriteContext<Void> context)
{
checkState(node.getWindowFunctions().size() == 1, "WindowFilterPushdown requires that WindowNodes contain exactly one window function");
PlanNode rewrittenSource = context.rewrite(node.getSource());
if (canReplaceWithRowNumber(node)) {
return new RowNumberNode(idAllocator.getNextId(),
rewrittenSource,
node.getPartitionBy(),
getOnlyElement(node.getWindowFunctions().keySet()),
Optional.empty(),
Optional.empty());
}
return replaceChildren(node, ImmutableList.of(rewrittenSource));
}
@Override
public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context)
{
// Operators can handle MAX_VALUE rows per page, so do not optimize if count is greater than this value
if (node.getCount() > Integer.MAX_VALUE) {
return context.defaultRewrite(node);
}
PlanNode source = context.rewrite(node.getSource());
int limit = toIntExact(node.getCount());
if (source instanceof RowNumberNode) {
RowNumberNode rowNumberNode = mergeLimit(((RowNumberNode) source), limit);
if (rowNumberNode.getPartitionBy().isEmpty()) {
return rowNumberNode;
}
source = rowNumberNode;
}
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source)) {
WindowNode windowNode = (WindowNode) source;
// verify that unordered row_number window functions are replaced by RowNumberNode
verify(!windowNode.getOrderBy().isEmpty());
TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit);
if (windowNode.getPartitionBy().isEmpty()) {
return topNRowNumberNode;
}
source = topNRowNumberNode;
}
return replaceChildren(node, ImmutableList.of(source));
}
@Override
public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
{
PlanNode source = context.rewrite(node.getSource());
TupleDomain<Symbol> tupleDomain = fromPredicate(metadata, session, node.getPredicate(), types).getTupleDomain();
if (source instanceof RowNumberNode) {
Symbol rowNumberSymbol = ((RowNumberNode) source).getRowNumberSymbol();
OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberSymbol);
if (upperBound.isPresent()) {
source = mergeLimit(((RowNumberNode) source), upperBound.getAsInt());
return rewriteFilterSource(node, source, rowNumberSymbol, upperBound.getAsInt());
}
}
else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source)) {
WindowNode windowNode = (WindowNode) source;
Symbol rowNumberSymbol = getOnlyElement(windowNode.getWindowFunctions().entrySet()).getKey();
OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberSymbol);
if (upperBound.isPresent()) {
source = convertToTopNRowNumber(windowNode, upperBound.getAsInt());
return rewriteFilterSource(node, source, rowNumberSymbol, upperBound.getAsInt());
}
}
return replaceChildren(node, ImmutableList.of(source));
}
private PlanNode rewriteFilterSource(FilterNode filterNode, PlanNode source, Symbol rowNumberSymbol, int upperBound)
{
ExtractionResult extractionResult = fromPredicate(metadata, session, filterNode.getPredicate(), types);
TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
if (!isEqualRange(tupleDomain, rowNumberSymbol, upperBound)) {
return new FilterNode(filterNode.getId(), source, filterNode.getPredicate());
}
// Remove the row number domain because it is absorbed into the node
Map<Symbol, Domain> newDomains = tupleDomain.getDomains().get().entrySet().stream()
.filter(entry -> !entry.getKey().equals(rowNumberSymbol))
.collect(toMap(Map.Entry::getKey, Map.Entry::getValue));
// Construct a new predicate
TupleDomain<Symbol> newTupleDomain = TupleDomain.withColumnDomains(newDomains);
Expression newPredicate = ExpressionUtils.combineConjuncts(
extractionResult.getRemainingExpression(),
toPredicate(newTupleDomain));
if (newPredicate.equals(BooleanLiteral.TRUE_LITERAL)) {
return source;
}
return new FilterNode(filterNode.getId(), source, newPredicate);
}
private static boolean isEqualRange(TupleDomain<Symbol> tupleDomain, Symbol symbol, long upperBound)
{
if (tupleDomain.isNone()) {
return false;
}
Domain domain = tupleDomain.getDomains().get().get(symbol);
return domain.getValues().equals(ValueSet.ofRanges(Range.lessThanOrEqual(domain.getType(), upperBound)));
}
private static OptionalInt extractUpperBound(TupleDomain<Symbol> tupleDomain, Symbol symbol)
{
if (tupleDomain.isNone()) {
return OptionalInt.empty();
}
Domain rowNumberDomain = tupleDomain.getDomains().get().get(symbol);
if (rowNumberDomain == null) {
return OptionalInt.empty();
}
ValueSet values = rowNumberDomain.getValues();
if (values.isAll() || values.isNone() || values.getRanges().getRangeCount() <= 0) {
return OptionalInt.empty();
}
Range span = values.getRanges().getSpan();
if (span.getHigh().isUpperUnbounded()) {
return OptionalInt.empty();
}
verify(rowNumberDomain.getType().equals(BIGINT));
long upperBound = (Long) span.getHigh().getValue();
if (span.getHigh().getBound() == BELOW) {
upperBound--;
}
if (upperBound > Integer.MAX_VALUE) {
return OptionalInt.empty();
}
return OptionalInt.of(toIntExact(upperBound));
}
private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPartition)
{
if (node.getMaxRowCountPerPartition().isPresent()) {
newRowCountPerPartition = Math.min(node.getMaxRowCountPerPartition().get(), newRowCountPerPartition);
}
return new RowNumberNode(node.getId(), node.getSource(), node.getPartitionBy(), node.getRowNumberSymbol(), Optional.of(newRowCountPerPartition), node.getHashSymbol());
}
private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit)
{
return new TopNRowNumberNode(idAllocator.getNextId(),
windowNode.getSource(),
windowNode.getSpecification(),
getOnlyElement(windowNode.getWindowFunctions().keySet()),
limit,
false,
Optional.empty());
}
private static boolean canReplaceWithRowNumber(WindowNode node)
{
return canOptimizeWindowFunction(node) && node.getOrderBy().isEmpty();
}
private static boolean canOptimizeWindowFunction(WindowNode node)
{
if (node.getWindowFunctions().size() != 1) {
return false;
}
Symbol rowNumberSymbol = getOnlyElement(node.getWindowFunctions().entrySet()).getKey();
return isRowNumberSignature(node.getWindowFunctions().get(rowNumberSymbol).getSignature());
}
private static boolean isRowNumberSignature(Signature signature)
{
return signature.equals(ROW_NUMBER_SIGNATURE);
}
}
}