/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.ResolvedIndex; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.DomainTranslator; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanVisitor; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.SortNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.BooleanLiteral; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.WindowFrame; import com.google.common.base.Functions; import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.in; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; public class IndexJoinOptimizer implements PlanOptimizer { private final Metadata metadata; public IndexJoinOptimizer(Metadata metadata) { this.metadata = requireNonNull(metadata, "metadata is null"); } @Override public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { requireNonNull(plan, "plan is null"); requireNonNull(session, "session is null"); requireNonNull(types, "types is null"); requireNonNull(symbolAllocator, "symbolAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); return SimplePlanRewriter.rewriteWith(new Rewriter(symbolAllocator, idAllocator, metadata, session), plan, null); } private static class Rewriter extends SimplePlanRewriter<Void> { private final SymbolAllocator symbolAllocator; private final PlanNodeIdAllocator idAllocator; private final Metadata metadata; private final Session session; private Rewriter(SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) { this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.session = requireNonNull(session, "session is null"); } @Override public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context) { PlanNode leftRewritten = context.rewrite(node.getLeft()); PlanNode rightRewritten = context.rewrite(node.getRight()); if (!node.getCriteria().isEmpty()) { // Index join only possible with JOIN criteria List<Symbol> leftJoinSymbols = Lists.transform(node.getCriteria(), JoinNode.EquiJoinClause::getLeft); List<Symbol> rightJoinSymbols = Lists.transform(node.getCriteria(), JoinNode.EquiJoinClause::getRight); Optional<PlanNode> leftIndexCandidate = IndexSourceRewriter.rewriteWithIndex( leftRewritten, ImmutableSet.copyOf(leftJoinSymbols), symbolAllocator, idAllocator, metadata, session); if (leftIndexCandidate.isPresent()) { // Sanity check that we can trace the path for the index lookup key Map<Symbol, Symbol> trace = IndexKeyTracer.trace(leftIndexCandidate.get(), ImmutableSet.copyOf(leftJoinSymbols)); checkState(!trace.isEmpty() && leftJoinSymbols.containsAll(trace.keySet())); } Optional<PlanNode> rightIndexCandidate = IndexSourceRewriter.rewriteWithIndex( rightRewritten, ImmutableSet.copyOf(rightJoinSymbols), symbolAllocator, idAllocator, metadata, session); if (rightIndexCandidate.isPresent()) { // Sanity check that we can trace the path for the index lookup key Map<Symbol, Symbol> trace = IndexKeyTracer.trace(rightIndexCandidate.get(), ImmutableSet.copyOf(rightJoinSymbols)); checkState(!trace.isEmpty() && rightJoinSymbols.containsAll(trace.keySet())); } switch (node.getType()) { case INNER: // Prefer the right candidate over the left candidate PlanNode indexJoinNode = null; if (rightIndexCandidate.isPresent()) { indexJoinNode = new IndexJoinNode(idAllocator.getNextId(), IndexJoinNode.Type.INNER, leftRewritten, rightIndexCandidate.get(), createEquiJoinClause(leftJoinSymbols, rightJoinSymbols), Optional.empty(), Optional.empty()); } else if (leftIndexCandidate.isPresent()) { indexJoinNode = new IndexJoinNode(idAllocator.getNextId(), IndexJoinNode.Type.INNER, rightRewritten, leftIndexCandidate.get(), createEquiJoinClause(rightJoinSymbols, leftJoinSymbols), Optional.empty(), Optional.empty()); } if (indexJoinNode != null) { if (node.getFilter().isPresent()) { indexJoinNode = new FilterNode(idAllocator.getNextId(), indexJoinNode, node.getFilter().get()); } if (!indexJoinNode.getOutputSymbols().equals(node.getOutputSymbols())) { indexJoinNode = new ProjectNode( idAllocator.getNextId(), indexJoinNode, Assignments.identity(node.getOutputSymbols())); } return indexJoinNode; } break; case LEFT: // We cannot use indices for outer joins until index join supports in-line filtering if (!node.getFilter().isPresent() && rightIndexCandidate.isPresent()) { return createIndexJoinWithExpectedOutputs(node.getOutputSymbols(), IndexJoinNode.Type.SOURCE_OUTER, leftRewritten, rightIndexCandidate.get(), createEquiJoinClause(leftJoinSymbols, rightJoinSymbols), idAllocator); } break; case RIGHT: // We cannot use indices for outer joins until index join supports in-line filtering if (!node.getFilter().isPresent() && leftIndexCandidate.isPresent()) { return createIndexJoinWithExpectedOutputs(node.getOutputSymbols(), IndexJoinNode.Type.SOURCE_OUTER, rightRewritten, leftIndexCandidate.get(), createEquiJoinClause(rightJoinSymbols, leftJoinSymbols), idAllocator); } break; case FULL: break; default: throw new IllegalArgumentException("Unknown type: " + node.getType()); } } if (leftRewritten != node.getLeft() || rightRewritten != node.getRight()) { return new JoinNode(node.getId(), node.getType(), leftRewritten, rightRewritten, node.getCriteria(), node.getOutputSymbols(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType()); } return node; } private static PlanNode createIndexJoinWithExpectedOutputs(List<Symbol> expectedOutputs, IndexJoinNode.Type type, PlanNode probe, PlanNode index, List<IndexJoinNode.EquiJoinClause> equiJoinClause, PlanNodeIdAllocator idAllocator) { PlanNode result = new IndexJoinNode(idAllocator.getNextId(), type, probe, index, equiJoinClause, Optional.empty(), Optional.empty()); if (!result.getOutputSymbols().equals(expectedOutputs)) { result = new ProjectNode( idAllocator.getNextId(), result, Assignments.identity(expectedOutputs)); } return result; } private static List<IndexJoinNode.EquiJoinClause> createEquiJoinClause(List<Symbol> probeSymbols, List<Symbol> indexSymbols) { checkArgument(probeSymbols.size() == indexSymbols.size()); ImmutableList.Builder<IndexJoinNode.EquiJoinClause> builder = ImmutableList.builder(); for (int i = 0; i < probeSymbols.size(); i++) { builder.add(new IndexJoinNode.EquiJoinClause(probeSymbols.get(i), indexSymbols.get(i))); } return builder.build(); } } /** * Tries to rewrite a PlanNode tree with an IndexSource instead of a TableScan */ private static class IndexSourceRewriter extends SimplePlanRewriter<IndexSourceRewriter.Context> { private final SymbolAllocator symbolAllocator; private final PlanNodeIdAllocator idAllocator; private final Metadata metadata; private final Session session; private IndexSourceRewriter(SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) { this.metadata = requireNonNull(metadata, "metadata is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.session = requireNonNull(session, "session is null"); } public static Optional<PlanNode> rewriteWithIndex( PlanNode planNode, Set<Symbol> lookupSymbols, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) { AtomicBoolean success = new AtomicBoolean(); IndexSourceRewriter indexSourceRewriter = new IndexSourceRewriter(symbolAllocator, idAllocator, metadata, session); PlanNode rewritten = SimplePlanRewriter.rewriteWith(indexSourceRewriter, planNode, new Context(lookupSymbols, success)); if (success.get()) { return Optional.of(rewritten); } return Optional.empty(); } @Override public PlanNode visitPlan(PlanNode node, RewriteContext<Context> context) { // We don't know how to process this PlanNode in the context of an IndexJoin, so just give up by returning something return node; } @Override public PlanNode visitTableScan(TableScanNode node, RewriteContext<Context> context) { return planTableScan(node, BooleanLiteral.TRUE_LITERAL, context.get()); } private PlanNode planTableScan(TableScanNode node, Expression predicate, Context context) { DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.fromPredicate( metadata, session, predicate, symbolAllocator.getTypes()); TupleDomain<ColumnHandle> simplifiedConstraint = decomposedPredicate.getTupleDomain() .transform(node.getAssignments()::get) .intersect(node.getCurrentConstraint()); checkState(node.getOutputSymbols().containsAll(context.getLookupSymbols())); Set<ColumnHandle> lookupColumns = context.getLookupSymbols().stream() .map(node.getAssignments()::get) .collect(toImmutableSet()); Set<ColumnHandle> outputColumns = node.getOutputSymbols().stream().map(node.getAssignments()::get).collect(toImmutableSet()); Optional<ResolvedIndex> optionalResolvedIndex = metadata.resolveIndex(session, node.getTable(), lookupColumns, outputColumns, simplifiedConstraint); if (!optionalResolvedIndex.isPresent()) { // No index available, so give up by returning something return node; } ResolvedIndex resolvedIndex = optionalResolvedIndex.get(); Map<ColumnHandle, Symbol> inverseAssignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse(); PlanNode source = new IndexSourceNode( idAllocator.getNextId(), resolvedIndex.getIndexHandle(), node.getTable(), node.getLayout(), context.getLookupSymbols(), node.getOutputSymbols(), node.getAssignments(), simplifiedConstraint); Expression resultingPredicate = combineConjuncts( DomainTranslator.toPredicate(resolvedIndex.getUnresolvedTupleDomain().transform(inverseAssignments::get)), decomposedPredicate.getRemainingExpression()); if (!resultingPredicate.equals(TRUE_LITERAL)) { // todo it is likely we end up with redundant filters here because the predicate push down has already been run... the fix is to run predicate push down again source = new FilterNode(idAllocator.getNextId(), source, resultingPredicate); } context.markSuccess(); return source; } @Override public PlanNode visitProject(ProjectNode node, RewriteContext<Context> context) { // Rewrite the lookup symbols in terms of only the pre-projected symbols that have direct translations Set<Symbol> newLookupSymbols = context.get().getLookupSymbols().stream() .map(node.getAssignments()::get) .filter(SymbolReference.class::isInstance) .map(Symbol::from) .collect(toImmutableSet()); if (newLookupSymbols.isEmpty()) { return node; } return context.defaultRewrite(node, new Context(newLookupSymbols, context.get().getSuccess())); } @Override public PlanNode visitFilter(FilterNode node, RewriteContext<Context> context) { if (node.getSource() instanceof TableScanNode) { return planTableScan((TableScanNode) node.getSource(), node.getPredicate(), context.get()); } return context.defaultRewrite(node, new Context(context.get().getLookupSymbols(), context.get().getSuccess())); } @Override public PlanNode visitWindow(WindowNode node, RewriteContext<Context> context) { if (!node.getWindowFunctions().values().stream() .map(function -> function.getFunctionCall().getName()) .allMatch(metadata.getFunctionRegistry()::isAggregationFunction)) { return node; } // Don't need this restriction if we can prove that all order by symbols are deterministically produced if (!node.getOrderBy().isEmpty()) { return node; } // Only RANGE frame type currently supported for aggregation functions because it guarantees the // same value for each peer group. // ROWS frame type requires the ordering to be fully deterministic (e.g. deterministically sorted on all columns) if (node.getFrames().stream().map(WindowNode.Frame::getType).anyMatch(type -> type != WindowFrame.Type.RANGE)) { // TODO: extract frames of type RANGE and allow optimization on them return node; } // Lookup symbols can only be passed through if they are part of the partitioning Set<Symbol> partitionByLookupSymbols = context.get().getLookupSymbols().stream() .filter(node.getPartitionBy()::contains) .collect(toImmutableSet()); if (partitionByLookupSymbols.isEmpty()) { return node; } return context.defaultRewrite(node, new Context(partitionByLookupSymbols, context.get().getSuccess())); } @Override public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext<Context> context) { throw new IllegalStateException("Should not be trying to generate an Index on something that has already been determined to use an Index"); } @Override public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext<Context> context) { // Lookup symbols can only be passed through the probe side of an index join Set<Symbol> probeLookupSymbols = context.get().getLookupSymbols().stream() .filter(node.getProbeSource().getOutputSymbols()::contains) .collect(toImmutableSet()); if (probeLookupSymbols.isEmpty()) { return node; } PlanNode rewrittenProbeSource = context.rewrite(node.getProbeSource(), new Context(probeLookupSymbols, context.get().getSuccess())); PlanNode source = node; if (rewrittenProbeSource != node.getProbeSource()) { source = new IndexJoinNode(node.getId(), node.getType(), rewrittenProbeSource, node.getIndexSource(), node.getCriteria(), node.getProbeHashSymbol(), node.getIndexHashSymbol()); } return source; } @Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Context> context) { // Lookup symbols can only be passed through if they are part of the group by columns Set<Symbol> groupByLookupSymbols = context.get().getLookupSymbols().stream() .filter(node.getGroupingKeys()::contains) .collect(toImmutableSet()); if (groupByLookupSymbols.isEmpty()) { return node; } return context.defaultRewrite(node, new Context(groupByLookupSymbols, context.get().getSuccess())); } @Override public PlanNode visitSort(SortNode node, RewriteContext<Context> context) { // Sort has no bearing when building an index, so just ignore the sort return context.rewrite(node.getSource(), context.get()); } public static class Context { private final Set<Symbol> lookupSymbols; private final AtomicBoolean success; public Context(Set<Symbol> lookupSymbols, AtomicBoolean success) { checkArgument(!lookupSymbols.isEmpty(), "lookupSymbols can not be empty"); this.lookupSymbols = ImmutableSet.copyOf(requireNonNull(lookupSymbols, "lookupSymbols is null")); this.success = requireNonNull(success, "success is null"); } public Set<Symbol> getLookupSymbols() { return lookupSymbols; } public AtomicBoolean getSuccess() { return success; } public void markSuccess() { checkState(success.compareAndSet(false, true), "Can only have one success per context"); } } } /** * Identify the mapping from the lookup symbols used at the top of the index plan to * the actual symbols produced by the IndexSource. Note that multiple top-level lookup symbols may share the same * underlying IndexSource symbol. Also note that lookup symbols that do not correspond to underlying index source symbols * will be omitted from the returned Map. */ public static class IndexKeyTracer { public static Map<Symbol, Symbol> trace(PlanNode node, Set<Symbol> lookupSymbols) { return node.accept(new Visitor(), lookupSymbols); } private static class Visitor extends PlanVisitor<Set<Symbol>, Map<Symbol, Symbol>> { @Override protected Map<Symbol, Symbol> visitPlan(PlanNode node, Set<Symbol> lookupSymbols) { throw new UnsupportedOperationException("Node not expected to be part of Index pipeline: " + node); } @Override public Map<Symbol, Symbol> visitProject(ProjectNode node, Set<Symbol> lookupSymbols) { // Map from output Symbols to source Symbols Map<Symbol, Symbol> directSymbolTranslationOutputMap = Maps.transformValues(Maps.filterValues(node.getAssignments().getMap(), SymbolReference.class::isInstance), Symbol::from); Map<Symbol, Symbol> outputToSourceMap = lookupSymbols.stream() .filter(directSymbolTranslationOutputMap.keySet()::contains) .collect(toImmutableMap(identity(), directSymbolTranslationOutputMap::get)); checkState(!outputToSourceMap.isEmpty(), "No lookup symbols were able to pass through the projection"); // Map from source Symbols to underlying index source Symbols Map<Symbol, Symbol> sourceToIndexMap = node.getSource().accept(this, ImmutableSet.copyOf(outputToSourceMap.values())); // Generate the Map the connects lookup symbols to underlying index source symbols Map<Symbol, Symbol> outputToIndexMap = Maps.transformValues(Maps.filterValues(outputToSourceMap, in(sourceToIndexMap.keySet())), Functions.forMap(sourceToIndexMap)); return ImmutableMap.copyOf(outputToIndexMap); } @Override public Map<Symbol, Symbol> visitFilter(FilterNode node, Set<Symbol> lookupSymbols) { return node.getSource().accept(this, lookupSymbols); } @Override public Map<Symbol, Symbol> visitWindow(WindowNode node, Set<Symbol> lookupSymbols) { Set<Symbol> partitionByLookupSymbols = lookupSymbols.stream() .filter(node.getPartitionBy()::contains) .collect(toImmutableSet()); checkState(!partitionByLookupSymbols.isEmpty(), "No lookup symbols were able to pass through the aggregation group by"); return node.getSource().accept(this, partitionByLookupSymbols); } @Override public Map<Symbol, Symbol> visitIndexJoin(IndexJoinNode node, Set<Symbol> lookupSymbols) { Set<Symbol> probeLookupSymbols = lookupSymbols.stream() .filter(node.getProbeSource().getOutputSymbols()::contains) .collect(toImmutableSet()); checkState(!probeLookupSymbols.isEmpty(), "No lookup symbols were able to pass through the index join probe source"); return node.getProbeSource().accept(this, probeLookupSymbols); } @Override public Map<Symbol, Symbol> visitAggregation(AggregationNode node, Set<Symbol> lookupSymbols) { Set<Symbol> groupByLookupSymbols = lookupSymbols.stream() .filter(node.getGroupingKeys()::contains) .collect(toImmutableSet()); checkState(!groupByLookupSymbols.isEmpty(), "No lookup symbols were able to pass through the aggregation group by"); return node.getSource().accept(this, groupByLookupSymbols); } @Override public Map<Symbol, Symbol> visitSort(SortNode node, Set<Symbol> lookupSymbols) { return node.getSource().accept(this, lookupSymbols); } @Override public Map<Symbol, Symbol> visitIndexSource(IndexSourceNode node, Set<Symbol> lookupSymbols) { checkState(node.getLookupSymbols().equals(lookupSymbols), "lookupSymbols must be the same as IndexSource lookup symbols"); return lookupSymbols.stream().collect(toImmutableMap(identity(), identity())); } } } }