/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.tree.DefaultTraversalVisitor; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Predicate; import java.util.stream.Collectors; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.sql.planner.optimizations.Predicates.isInstanceOfAny; import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; /** * Scalar aggregation is aggregation with GROUP BY 'a constant' (or empty GROUP BY). * It always returns single row in Presto. * <p> * This optimizer can rewrite correlated scalar aggregation subquery to left outer join in a way described here: * https://github.com/prestodb/presto/wiki/Correlated-subqueries * <p> * From: * <pre> * - Apply (with correlation list: [C]) * - (input) plan which produces symbols: [A, B, C] * - (subquery) Aggregation(GROUP BY (); functions: [sum(F), count(), ...] * - Filter(D = C AND E > 5) * - plan which produces symbols: [D, E, F] * </pre> * to: * <pre> * - Aggregation(GROUP BY A, B, C, U; functions: [sum(F), count(non_null), ...] * - Join(LEFT_OUTER, D = C) * - AssignUniqueId(adds symbol U) * - (input) plan which produces symbols: [A, B, C] * - Filter(E > 5) * - projection which adds no null symbol used for count() function * - plan which produces symbols: [D, E, F] * </pre> * <p> * Note only conjunction predicates in FilterNode are supported */ public class TransformCorrelatedScalarAggregationToJoin implements PlanOptimizer { private final Metadata metadata; public TransformCorrelatedScalarAggregationToJoin(Metadata metadata) { this.metadata = requireNonNull(metadata, "metadata is null"); } @Override public PlanNode optimize( PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { return rewriteWith(new Rewriter(idAllocator, symbolAllocator, metadata), plan, null); } private static class Rewriter extends SimplePlanRewriter<PlanNode> { private final PlanNodeIdAllocator idAllocator; private final SymbolAllocator symbolAllocator; private final Metadata metadata; public Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Metadata metadata) { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); } @Override public PlanNode visitApply(ApplyNode node, RewriteContext<PlanNode> context) { ApplyNode rewrittenNode = (ApplyNode) context.defaultRewrite(node, context.get()); if (!rewrittenNode.getCorrelation().isEmpty() && rewrittenNode.isResolvedScalarSubquery()) { Optional<AggregationNode> aggregation = searchFrom(rewrittenNode.getSubquery()) .where(AggregationNode.class::isInstance) .skipOnlyWhen(isInstanceOfAny(ProjectNode.class, EnforceSingleRowNode.class)) .findFirst(); if (aggregation.isPresent() && aggregation.get().getGroupingKeys().isEmpty()) { return rewriteScalarAggregation(rewrittenNode, aggregation.get()); } } return rewrittenNode; } private PlanNode rewriteScalarAggregation(ApplyNode apply, AggregationNode aggregation) { List<Symbol> correlation = apply.getCorrelation(); Optional<DecorrelatedNode> source = decorrelateFilters(aggregation.getSource(), correlation); if (!source.isPresent()) { return apply; } Symbol nonNull = symbolAllocator.newSymbol("non_null", BooleanType.BOOLEAN); Assignments scalarAggregationSourceAssignments = Assignments.builder() .putAll(Assignments.identity(source.get().getNode().getOutputSymbols())) .put(nonNull, TRUE_LITERAL) .build(); ProjectNode scalarAggregationSourceWithNonNullableSymbol = new ProjectNode( idAllocator.getNextId(), source.get().getNode(), scalarAggregationSourceAssignments); return rewriteScalarAggregation( apply, aggregation, scalarAggregationSourceWithNonNullableSymbol, source.get().getCorrelatedPredicates(), nonNull); } private PlanNode rewriteScalarAggregation( ApplyNode applyNode, AggregationNode scalarAggregation, PlanNode scalarAggregationSource, Optional<Expression> joinExpression, Symbol nonNull) { AssignUniqueId inputWithUniqueColumns = new AssignUniqueId( idAllocator.getNextId(), applyNode.getInput(), symbolAllocator.newSymbol("unique", BigintType.BIGINT)); JoinNode leftOuterJoin = new JoinNode( idAllocator.getNextId(), JoinNode.Type.LEFT, inputWithUniqueColumns, scalarAggregationSource, ImmutableList.of(), ImmutableList.<Symbol>builder() .addAll(inputWithUniqueColumns.getOutputSymbols()) .addAll(scalarAggregationSource.getOutputSymbols()) .build(), joinExpression, Optional.empty(), Optional.empty(), Optional.empty()); Optional<AggregationNode> aggregationNode = createAggregationNode( scalarAggregation, leftOuterJoin, nonNull); if (!aggregationNode.isPresent()) { return applyNode; } Optional<ProjectNode> subqueryProjection = searchFrom(applyNode.getSubquery()) .where(ProjectNode.class::isInstance) .skipOnlyWhen(EnforceSingleRowNode.class::isInstance) .findFirst(); if (subqueryProjection.isPresent()) { Assignments assignments = Assignments.builder() .putAll(Assignments.identity(aggregationNode.get().getOutputSymbols())) .putAll(subqueryProjection.get().getAssignments()) .build(); return new ProjectNode( idAllocator.getNextId(), aggregationNode.get(), assignments); } else { return aggregationNode.get(); } } private Optional<AggregationNode> createAggregationNode( AggregationNode scalarAggregation, JoinNode leftOuterJoin, Symbol nonNullableAggregationSourceSymbol) { ImmutableMap.Builder<Symbol, FunctionCall> aggregations = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder(); FunctionRegistry functionRegistry = metadata.getFunctionRegistry(); for (Map.Entry<Symbol, FunctionCall> entry : scalarAggregation.getAggregations().entrySet()) { FunctionCall call = entry.getValue(); QualifiedName count = QualifiedName.of("count"); Symbol symbol = entry.getKey(); if (call.getName().equals(count)) { aggregations.put(symbol, new FunctionCall( count, ImmutableList.of(nonNullableAggregationSourceSymbol.toSymbolReference()))); List<TypeSignature> scalarAggregationSourceTypeSignatures = ImmutableList.of( symbolAllocator.getTypes().get(nonNullableAggregationSourceSymbol).getTypeSignature()); functions.put(symbol, functionRegistry.resolveFunction( count, fromTypeSignatures(scalarAggregationSourceTypeSignatures))); } else { aggregations.put(symbol, entry.getValue()); functions.put(symbol, scalarAggregation.getFunctions().get(symbol)); } } List<Symbol> groupBySymbols = leftOuterJoin.getLeft().getOutputSymbols(); return Optional.of(new AggregationNode( idAllocator.getNextId(), leftOuterJoin, aggregations.build(), functions.build(), scalarAggregation.getMasks(), ImmutableList.of(groupBySymbols), scalarAggregation.getStep(), scalarAggregation.getHashSymbol(), Optional.empty())); } private Optional<DecorrelatedNode> decorrelateFilters(PlanNode node, List<Symbol> correlation) { PlanNodeSearcher filterNodeSearcher = searchFrom(node) .where(FilterNode.class::isInstance) .skipOnlyWhen(isInstanceOfAny(ProjectNode.class)); List<FilterNode> filterNodes = filterNodeSearcher.findAll(); if (filterNodes.isEmpty()) { return decorrelatedNode(ImmutableList.of(), node, correlation); } if (filterNodes.size() > 1) { return Optional.empty(); } FilterNode filterNode = filterNodes.get(0); Expression predicate = filterNode.getPredicate(); if (!isSupportedPredicate(predicate)) { return Optional.empty(); } if (!DependencyExtractor.extractUnique(predicate).containsAll(correlation)) { return Optional.empty(); } Map<Boolean, List<Expression>> predicates = ExpressionUtils.extractConjuncts(predicate).stream() .collect(Collectors.partitioningBy(isUsingPredicate(correlation))); List<Expression> correlatedPredicates = ImmutableList.copyOf(predicates.get(true)); List<Expression> uncorrelatedPredicates = ImmutableList.copyOf(predicates.get(false)); node = updateFilterNode(filterNodeSearcher, uncorrelatedPredicates); node = ensureJoinSymbolsAreReturned(node, correlatedPredicates); return decorrelatedNode(correlatedPredicates, node, correlation); } private static Optional<DecorrelatedNode> decorrelatedNode( List<Expression> correlatedPredicates, PlanNode node, List<Symbol> correlation) { if (DependencyExtractor.extractUnique(node).stream().anyMatch(correlation::contains)) { // node is still correlated ; / return Optional.empty(); } return Optional.of(new DecorrelatedNode(correlatedPredicates, node)); } private static Predicate<Expression> isUsingPredicate(List<Symbol> symbols) { return expression -> symbols.stream().anyMatch(DependencyExtractor.extractUnique(expression)::contains); } private PlanNode updateFilterNode(PlanNodeSearcher filterNodeSearcher, List<Expression> newPredicates) { if (newPredicates.isEmpty()) { return filterNodeSearcher.removeAll(); } FilterNode oldFilterNode = Iterables.getOnlyElement(filterNodeSearcher.findAll()); FilterNode newFilterNode = new FilterNode( idAllocator.getNextId(), oldFilterNode.getSource(), ExpressionUtils.combineConjuncts(newPredicates)); return filterNodeSearcher.replaceAll(newFilterNode); } private PlanNode ensureJoinSymbolsAreReturned(PlanNode scalarAggregationSource, List<Expression> joinPredicate) { Set<Symbol> joinExpressionSymbols = DependencyExtractor.extractUnique(joinPredicate); ExtendProjectionRewriter extendProjectionRewriter = new ExtendProjectionRewriter( idAllocator, joinExpressionSymbols); return rewriteWith(extendProjectionRewriter, scalarAggregationSource); } private static boolean isSupportedPredicate(Expression predicate) { AtomicBoolean isSupported = new AtomicBoolean(true); new DefaultTraversalVisitor<Void, AtomicBoolean>() { @Override protected Void visitLogicalBinaryExpression(LogicalBinaryExpression node, AtomicBoolean context) { if (node.getType() != LogicalBinaryExpression.Type.AND) { context.set(false); } return null; } }.process(predicate, isSupported); return isSupported.get(); } } private static class DecorrelatedNode { private final List<Expression> correlatedPredicates; private final PlanNode node; public DecorrelatedNode(List<Expression> correlatedPredicates, PlanNode node) { requireNonNull(correlatedPredicates, "correlatedPredicates is null"); this.correlatedPredicates = ImmutableList.copyOf(correlatedPredicates); this.node = requireNonNull(node, "node is null"); } Optional<Expression> getCorrelatedPredicates() { if (correlatedPredicates.isEmpty()) { return Optional.empty(); } return Optional.of(ExpressionUtils.and(correlatedPredicates)); } public PlanNode getNode() { return node; } } private static class ExtendProjectionRewriter extends SimplePlanRewriter<PlanNode> { private final PlanNodeIdAllocator idAllocator; private final Set<Symbol> symbols; ExtendProjectionRewriter(PlanNodeIdAllocator idAllocator, Set<Symbol> symbols) { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.symbols = requireNonNull(symbols, "symbols is null"); } @Override public PlanNode visitProject(ProjectNode node, RewriteContext<PlanNode> context) { ProjectNode rewrittenNode = (ProjectNode) context.defaultRewrite(node, context.get()); List<Symbol> symbolsToAdd = symbols.stream() .filter(rewrittenNode.getSource().getOutputSymbols()::contains) .filter(symbol -> !rewrittenNode.getOutputSymbols().contains(symbol)) .collect(toImmutableList()); Assignments assignments = Assignments.builder() .putAll(rewrittenNode.getAssignments()) .putAll(Assignments.identity(symbolsToAdd)) .build(); return new ProjectNode(idAllocator.getNextId(), rewrittenNode.getSource(), assignments); } } }