/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner; import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.execution.QueryPerformanceFetcher; import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.execution.buffer.OutputBuffer; import com.facebook.presto.execution.buffer.PagesSerdeFactory; import com.facebook.presto.index.IndexManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.AggregationOperator.AggregationOperatorFactory; import com.facebook.presto.operator.AssignUniqueIdOperator; import com.facebook.presto.operator.CursorProcessor; import com.facebook.presto.operator.DeleteOperator.DeleteOperatorFactory; import com.facebook.presto.operator.DriverFactory; import com.facebook.presto.operator.EnforceSingleRowOperator; import com.facebook.presto.operator.ExchangeClientSupplier; import com.facebook.presto.operator.ExchangeOperator.ExchangeOperatorFactory; import com.facebook.presto.operator.ExplainAnalyzeOperator.ExplainAnalyzeOperatorFactory; import com.facebook.presto.operator.FilterAndProjectOperator; import com.facebook.presto.operator.GroupIdOperator; import com.facebook.presto.operator.HashAggregationOperator.HashAggregationOperatorFactory; import com.facebook.presto.operator.HashBuilderOperator.HashBuilderOperatorFactory; import com.facebook.presto.operator.HashSemiJoinOperator.HashSemiJoinOperatorFactory; import com.facebook.presto.operator.JoinOperatorFactory; import com.facebook.presto.operator.LimitOperator.LimitOperatorFactory; import com.facebook.presto.operator.LocalPlannerAware; import com.facebook.presto.operator.LookupJoinOperators; import com.facebook.presto.operator.LookupSourceFactory; import com.facebook.presto.operator.MarkDistinctOperator.MarkDistinctOperatorFactory; import com.facebook.presto.operator.MetadataDeleteOperator.MetadataDeleteOperatorFactory; import com.facebook.presto.operator.NestedLoopJoinPagesSupplier; import com.facebook.presto.operator.OperatorFactory; import com.facebook.presto.operator.OrderByOperator.OrderByOperatorFactory; import com.facebook.presto.operator.OutputFactory; import com.facebook.presto.operator.PagesIndex; import com.facebook.presto.operator.PartitionFunction; import com.facebook.presto.operator.PartitionedOutputOperator.PartitionedOutputFactory; import com.facebook.presto.operator.RowNumberOperator; import com.facebook.presto.operator.ScanFilterAndProjectOperator; import com.facebook.presto.operator.SetBuilderOperator.SetBuilderOperatorFactory; import com.facebook.presto.operator.SetBuilderOperator.SetSupplier; import com.facebook.presto.operator.SourceOperatorFactory; import com.facebook.presto.operator.TableScanOperator.TableScanOperatorFactory; import com.facebook.presto.operator.TaskOutputOperator.TaskOutputFactory; import com.facebook.presto.operator.TopNOperator.TopNOperatorFactory; import com.facebook.presto.operator.TopNRowNumberOperator; import com.facebook.presto.operator.ValuesOperator.ValuesOperatorFactory; import com.facebook.presto.operator.WindowFunctionDefinition; import com.facebook.presto.operator.WindowOperator.WindowOperatorFactory; import com.facebook.presto.operator.aggregation.AccumulatorFactory; import com.facebook.presto.operator.exchange.LocalExchange; import com.facebook.presto.operator.exchange.LocalExchangeSinkOperator.LocalExchangeSinkOperatorFactory; import com.facebook.presto.operator.exchange.LocalExchangeSourceOperator.LocalExchangeSourceOperatorFactory; import com.facebook.presto.operator.exchange.PageChannelSelector; import com.facebook.presto.operator.index.DynamicTupleFilterFactory; import com.facebook.presto.operator.index.FieldSetFilteringRecordSet; import com.facebook.presto.operator.index.IndexBuildDriverFactoryProvider; import com.facebook.presto.operator.index.IndexJoinLookupStats; import com.facebook.presto.operator.index.IndexLookupSourceFactory; import com.facebook.presto.operator.index.IndexSourceOperator; import com.facebook.presto.operator.project.InterpretedCursorProcessor; import com.facebook.presto.operator.project.InterpretedPageFilter; import com.facebook.presto.operator.project.InterpretedPageProjection; import com.facebook.presto.operator.project.PageFilter; import com.facebook.presto.operator.project.PageProcessor; import com.facebook.presto.operator.project.PageProjection; import com.facebook.presto.operator.window.FrameInfo; import com.facebook.presto.operator.window.WindowFunctionSupplier; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorIndex; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.RecordSet; import com.facebook.presto.spi.block.BlockEncodingSerde; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spiller.SpillerFactory; import com.facebook.presto.split.MappedRecordSet; import com.facebook.presto.split.PageSinkManager; import com.facebook.presto.split.PageSourceProvider; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.DistinctLimitNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.MetadataDeleteNode; import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.PlanVisitor; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.planner.plan.SortNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; import com.facebook.presto.sql.planner.plan.TableWriterNode.DeleteHandle; import com.facebook.presto.sql.planner.plan.TopNNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UnionNode; import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.planner.plan.WindowNode.Frame; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.util.maps.IdentityLinkedHashMap; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap.Builder; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSetMultimap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Multimap; import com.google.common.collect.SetMultimap; import com.google.common.primitives.Ints; import io.airlift.log.Logger; import io.airlift.units.DataSize; import javax.annotation.Nullable; import javax.inject.Inject; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import static com.facebook.presto.SystemSessionProperties.getOperatorMemoryLimitBeforeSpill; import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency; import static com.facebook.presto.SystemSessionProperties.getTaskWriterCount; import static com.facebook.presto.SystemSessionProperties.isExchangeCompressionEnabled; import static com.facebook.presto.SystemSessionProperties.isSpillEnabled; import static com.facebook.presto.metadata.FunctionKind.SCALAR; import static com.facebook.presto.operator.DistinctLimitOperator.DistinctLimitOperatorFactory; import static com.facebook.presto.operator.NestedLoopBuildOperator.NestedLoopBuildOperatorFactory; import static com.facebook.presto.operator.NestedLoopJoinOperator.NestedLoopJoinOperatorFactory; import static com.facebook.presto.operator.TableFinishOperator.TableFinishOperatorFactory; import static com.facebook.presto.operator.TableFinishOperator.TableFinisher; import static com.facebook.presto.operator.TableWriterOperator.TableWriterOperatorFactory; import static com.facebook.presto.operator.UnnestOperator.UnnestOperatorFactory; import static com.facebook.presto.operator.WindowFunctionDefinition.window; import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.TypeUtils.writeNativeValue; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; import static com.facebook.presto.sql.planner.plan.TableWriterNode.CreateHandle; import static com.facebook.presto.sql.planner.plan.TableWriterNode.InsertHandle; import static com.facebook.presto.sql.planner.plan.TableWriterNode.WriterTarget; import static com.google.common.base.Functions.forMap; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; public class LocalExecutionPlanner { private static final Logger log = Logger.get(LocalExecutionPlanner.class); private final Metadata metadata; private final SqlParser sqlParser; private final Optional<QueryPerformanceFetcher> queryPerformanceFetcher; private final PageSourceProvider pageSourceProvider; private final IndexManager indexManager; private final NodePartitioningManager nodePartitioningManager; private final PageSinkManager pageSinkManager; private final ExchangeClientSupplier exchangeClientSupplier; private final ExpressionCompiler expressionCompiler; private final JoinFilterFunctionCompiler joinFilterFunctionCompiler; private final boolean interpreterEnabled; private final DataSize maxIndexMemorySize; private final IndexJoinLookupStats indexJoinLookupStats; private final DataSize maxPartialAggregationMemorySize; private final DataSize maxPagePartitioningBufferSize; private final SpillerFactory spillerFactory; private final BlockEncodingSerde blockEncodingSerde; private final PagesIndex.Factory pagesIndexFactory; private final JoinCompiler joinCompiler; private final LookupJoinOperators lookupJoinOperators; @Inject public LocalExecutionPlanner( Metadata metadata, SqlParser sqlParser, Optional<QueryPerformanceFetcher> queryPerformanceFetcher, PageSourceProvider pageSourceProvider, IndexManager indexManager, NodePartitioningManager nodePartitioningManager, PageSinkManager pageSinkManager, ExchangeClientSupplier exchangeClientSupplier, ExpressionCompiler expressionCompiler, JoinFilterFunctionCompiler joinFilterFunctionCompiler, IndexJoinLookupStats indexJoinLookupStats, CompilerConfig compilerConfig, TaskManagerConfig taskManagerConfig, SpillerFactory spillerFactory, BlockEncodingSerde blockEncodingSerde, PagesIndex.Factory pagesIndexFactory, JoinCompiler joinCompiler, LookupJoinOperators lookupJoinOperators) { requireNonNull(compilerConfig, "compilerConfig is null"); this.queryPerformanceFetcher = requireNonNull(queryPerformanceFetcher, "queryPerformanceFetcher is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); this.indexManager = requireNonNull(indexManager, "indexManager is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.exchangeClientSupplier = exchangeClientSupplier; this.metadata = requireNonNull(metadata, "metadata is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); this.expressionCompiler = requireNonNull(expressionCompiler, "compiler is null"); this.joinFilterFunctionCompiler = requireNonNull(joinFilterFunctionCompiler, "compiler is null"); this.indexJoinLookupStats = requireNonNull(indexJoinLookupStats, "indexJoinLookupStats is null"); this.maxIndexMemorySize = requireNonNull(taskManagerConfig, "taskManagerConfig is null").getMaxIndexMemoryUsage(); this.spillerFactory = requireNonNull(spillerFactory, "spillerFactory is null"); this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); this.maxPartialAggregationMemorySize = taskManagerConfig.getMaxPartialAggregationMemoryUsage(); this.maxPagePartitioningBufferSize = taskManagerConfig.getMaxPagePartitioningBufferSize(); this.pagesIndexFactory = requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); this.joinCompiler = requireNonNull(joinCompiler, "joinCompiler is null"); this.lookupJoinOperators = requireNonNull(lookupJoinOperators, "lookupJoinOperators is null"); interpreterEnabled = compilerConfig.isInterpreterEnabled(); } public LocalExecutionPlan plan( Session session, PlanNode plan, Map<Symbol, Type> types, PartitioningScheme partitioningScheme, OutputBuffer outputBuffer) { List<Symbol> outputLayout = partitioningScheme.getOutputLayout(); if (partitioningScheme.getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION) || partitioningScheme.getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioningScheme.getPartitioning().getHandle().equals(SINGLE_DISTRIBUTION) || partitioningScheme.getPartitioning().getHandle().equals(COORDINATOR_DISTRIBUTION)) { return plan(session, plan, outputLayout, types, new TaskOutputFactory(outputBuffer)); } // We can convert the symbols directly into channels, because the root must be a sink and therefore the layout is fixed List<Integer> partitionChannels; List<Optional<NullableValue>> partitionConstants; List<Type> partitionChannelTypes; if (partitioningScheme.getHashColumn().isPresent()) { partitionChannels = ImmutableList.of(outputLayout.indexOf(partitioningScheme.getHashColumn().get())); partitionConstants = ImmutableList.of(Optional.empty()); partitionChannelTypes = ImmutableList.of(BIGINT); } else { partitionChannels = partitioningScheme.getPartitioning().getArguments().stream() .map(ArgumentBinding::getColumn) .map(outputLayout::indexOf) .collect(toImmutableList()); partitionConstants = partitioningScheme.getPartitioning().getArguments().stream() .map(argument -> { if (argument.isConstant()) { return Optional.of(argument.getConstant()); } return Optional.<NullableValue>empty(); }) .collect(toImmutableList()); partitionChannelTypes = partitioningScheme.getPartitioning().getArguments().stream() .map(argument -> { if (argument.isConstant()) { return argument.getConstant().getType(); } return types.get(argument.getColumn()); }) .collect(toImmutableList()); } PartitionFunction partitionFunction = nodePartitioningManager.getPartitionFunction(session, partitioningScheme, partitionChannelTypes); OptionalInt nullChannel = OptionalInt.empty(); Set<Symbol> partitioningColumns = partitioningScheme.getPartitioning().getColumns(); // partitioningColumns expected to have one column in the normal case, and zero columns when partitioning on a constant checkArgument(!partitioningScheme.isReplicateNulls() || partitioningColumns.size() <= 1); if (partitioningScheme.isReplicateNulls() && partitioningColumns.size() == 1) { nullChannel = OptionalInt.of(outputLayout.indexOf(getOnlyElement(partitioningColumns))); } return plan( session, plan, outputLayout, types, new PartitionedOutputFactory(partitionFunction, partitionChannels, partitionConstants, nullChannel, outputBuffer, maxPagePartitioningBufferSize)); } public LocalExecutionPlan plan(Session session, PlanNode plan, List<Symbol> outputLayout, Map<Symbol, Type> types, OutputFactory outputOperatorFactory) { LocalExecutionPlanContext context = new LocalExecutionPlanContext(session, types); PhysicalOperation physicalOperation = plan.accept(new Visitor(session), context); Function<Page, Page> pagePreprocessor = enforceLayoutProcessor(outputLayout, physicalOperation.getLayout()); List<Type> outputTypes = outputLayout.stream() .map(types::get) .collect(toImmutableList()); context.addDriverFactory(context.isInputDriver(), true, ImmutableList.<OperatorFactory>builder() .addAll(physicalOperation.getOperatorFactories()) .add(outputOperatorFactory.createOutputOperator( context.getNextOperatorId(), plan.getId(), outputTypes, pagePreprocessor, new PagesSerdeFactory(blockEncodingSerde, isExchangeCompressionEnabled(session)))) .build(), context.getDriverInstanceCount()); addLookupOuterDrivers(context); // notify operator factories that planning has completed context.getDriverFactories().stream() .map(DriverFactory::getOperatorFactories) .flatMap(List::stream) .filter(LocalPlannerAware.class::isInstance) .map(LocalPlannerAware.class::cast) .forEach(LocalPlannerAware::localPlannerComplete); return new LocalExecutionPlan(context.getDriverFactories()); } private static void addLookupOuterDrivers(LocalExecutionPlanContext context) { // For an outer join on the lookup side (RIGHT or FULL) add an additional // driver to output the unused rows in the lookup source for (DriverFactory factory : context.getDriverFactories()) { List<OperatorFactory> operatorFactories = factory.getOperatorFactories(); for (int i = 0; i < operatorFactories.size(); i++) { OperatorFactory operatorFactory = operatorFactories.get(i); if (!(operatorFactory instanceof JoinOperatorFactory)) { continue; } JoinOperatorFactory lookupJoin = (JoinOperatorFactory) operatorFactory; Optional<OperatorFactory> outerOperatorFactory = lookupJoin.createOuterOperatorFactory(); if (outerOperatorFactory.isPresent()) { // Add a new driver to output the unmatched rows in an outer join. // We duplicate all of the factories above the JoinOperator (the ones reading from the joins), // and replace the JoinOperator with the OuterOperator (the one that produces unmatched rows). ImmutableList.Builder<OperatorFactory> newOperators = ImmutableList.builder(); newOperators.add(outerOperatorFactory.get()); operatorFactories.subList(i + 1, operatorFactories.size()).stream() .map(OperatorFactory::duplicate) .forEach(newOperators::add); context.addDriverFactory(false, factory.isOutputDriver(), newOperators.build(), OptionalInt.of(1)); } } } } private static class LocalExecutionPlanContext { private final Session session; private final Map<Symbol, Type> types; private final List<DriverFactory> driverFactories; private final Optional<IndexSourceContext> indexSourceContext; // this is shared with all subContexts private AtomicInteger nextPipelineId; private int nextOperatorId; private boolean inputDriver = true; private OptionalInt driverInstanceCount = OptionalInt.empty(); public LocalExecutionPlanContext(Session session, Map<Symbol, Type> types) { this(session, types, new ArrayList<>(), Optional.empty(), new AtomicInteger(0)); } private LocalExecutionPlanContext( Session session, Map<Symbol, Type> types, List<DriverFactory> driverFactories, Optional<IndexSourceContext> indexSourceContext, AtomicInteger nextPipelineId) { this.session = session; this.types = types; this.driverFactories = driverFactories; this.indexSourceContext = indexSourceContext; this.nextPipelineId = nextPipelineId; } public void addDriverFactory(boolean inputDriver, boolean outputDriver, List<OperatorFactory> operatorFactories, OptionalInt driverInstances) { driverFactories.add(new DriverFactory(getNextPipelineId(), inputDriver, outputDriver, operatorFactories, driverInstances)); } private List<DriverFactory> getDriverFactories() { return ImmutableList.copyOf(driverFactories); } public Session getSession() { return session; } public Map<Symbol, Type> getTypes() { return types; } public Optional<IndexSourceContext> getIndexSourceContext() { return indexSourceContext; } private int getNextPipelineId() { return nextPipelineId.getAndIncrement(); } private int getNextOperatorId() { return nextOperatorId++; } private boolean isInputDriver() { return inputDriver; } private void setInputDriver(boolean inputDriver) { this.inputDriver = inputDriver; } public LocalExecutionPlanContext createSubContext() { checkState(!indexSourceContext.isPresent(), "index build plan can not have sub-contexts"); return new LocalExecutionPlanContext(session, types, driverFactories, indexSourceContext, nextPipelineId); } public LocalExecutionPlanContext createIndexSourceSubContext(IndexSourceContext indexSourceContext) { return new LocalExecutionPlanContext(session, types, driverFactories, Optional.of(indexSourceContext), nextPipelineId); } public OptionalInt getDriverInstanceCount() { return driverInstanceCount; } public void setDriverInstanceCount(int driverInstanceCount) { checkArgument(driverInstanceCount > 0, "driverInstanceCount must be > 0"); if (this.driverInstanceCount.isPresent()) { checkState(this.driverInstanceCount.getAsInt() == driverInstanceCount, "driverInstance count already set to " + this.driverInstanceCount.getAsInt()); } this.driverInstanceCount = OptionalInt.of(driverInstanceCount); } } private static class IndexSourceContext { private final SetMultimap<Symbol, Integer> indexLookupToProbeInput; public IndexSourceContext(SetMultimap<Symbol, Integer> indexLookupToProbeInput) { this.indexLookupToProbeInput = ImmutableSetMultimap.copyOf(requireNonNull(indexLookupToProbeInput, "indexLookupToProbeInput is null")); } private SetMultimap<Symbol, Integer> getIndexLookupToProbeInput() { return indexLookupToProbeInput; } } public static class LocalExecutionPlan { private final List<DriverFactory> driverFactories; public LocalExecutionPlan(List<DriverFactory> driverFactories) { this.driverFactories = ImmutableList.copyOf(requireNonNull(driverFactories, "driverFactories is null")); } public List<DriverFactory> getDriverFactories() { return driverFactories; } } private class Visitor extends PlanVisitor<LocalExecutionPlanContext, PhysicalOperation> { private final Session session; private Visitor(Session session) { this.session = session; } @Override public PhysicalOperation visitRemoteSource(RemoteSourceNode node, LocalExecutionPlanContext context) { List<Type> types = getSourceOperatorTypes(node, context.getTypes()); if (!context.getDriverInstanceCount().isPresent()) { context.setDriverInstanceCount(getTaskConcurrency(session)); } OperatorFactory operatorFactory = new ExchangeOperatorFactory( context.getNextOperatorId(), node.getId(), exchangeClientSupplier, new PagesSerdeFactory(blockEncodingSerde, isExchangeCompressionEnabled(session)), types); return new PhysicalOperation(operatorFactory, makeLayout(node)); } @Override public PhysicalOperation visitExplainAnalyze(ExplainAnalyzeNode node, LocalExecutionPlanContext context) { checkState(queryPerformanceFetcher.isPresent(), "ExplainAnalyze can only run on coordinator"); PhysicalOperation source = node.getSource().accept(this, context); OperatorFactory operatorFactory = new ExplainAnalyzeOperatorFactory(context.getNextOperatorId(), node.getId(), queryPerformanceFetcher.get(), metadata); return new PhysicalOperation(operatorFactory, makeLayout(node), source); } @Override public PhysicalOperation visitOutput(OutputNode node, LocalExecutionPlanContext context) { return node.getSource().accept(this, context); } @Override public PhysicalOperation visitRowNumber(RowNumberNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); List<Symbol> partitionBySymbols = node.getPartitionBy(); List<Integer> partitionChannels = getChannelsForSymbols(partitionBySymbols, source.getLayout()); List<Type> partitionTypes = partitionChannels.stream() .map(channel -> source.getTypes().get(channel)) .collect(toImmutableList()); ImmutableList.Builder<Integer> outputChannels = ImmutableList.builder(); for (int i = 0; i < source.getTypes().size(); i++) { outputChannels.add(i); } // compute the layout of the output from the window operator ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); outputMappings.putAll(source.getLayout()); // row number function goes in the last channel int channel = source.getTypes().size(); outputMappings.put(node.getRowNumberSymbol(), channel); Optional<Integer> hashChannel = node.getHashSymbol().map(channelGetter(source)); OperatorFactory operatorFactory = new RowNumberOperator.RowNumberOperatorFactory( context.getNextOperatorId(), node.getId(), source.getTypes(), outputChannels.build(), partitionChannels, partitionTypes, node.getMaxRowCountPerPartition(), hashChannel, 10_000, joinCompiler); return new PhysicalOperation(operatorFactory, outputMappings.build(), source); } @Override public PhysicalOperation visitTopNRowNumber(TopNRowNumberNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); List<Symbol> partitionBySymbols = node.getPartitionBy(); List<Integer> partitionChannels = getChannelsForSymbols(partitionBySymbols, source.getLayout()); List<Type> partitionTypes = partitionChannels.stream() .map(channel -> source.getTypes().get(channel)) .collect(toImmutableList()); List<Symbol> orderBySymbols = node.getOrderBy(); List<Integer> sortChannels = getChannelsForSymbols(orderBySymbols, source.getLayout()); List<SortOrder> sortOrder = orderBySymbols.stream() .map(symbol -> node.getOrderings().get(symbol)) .collect(toImmutableList()); ImmutableList.Builder<Integer> outputChannels = ImmutableList.builder(); for (int i = 0; i < source.getTypes().size(); i++) { outputChannels.add(i); } // compute the layout of the output from the window operator ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); outputMappings.putAll(source.getLayout()); if (!node.isPartial() || !partitionChannels.isEmpty()) { // row number function goes in the last channel int channel = source.getTypes().size(); outputMappings.put(node.getRowNumberSymbol(), channel); } Optional<Integer> hashChannel = node.getHashSymbol().map(channelGetter(source)); OperatorFactory operatorFactory = new TopNRowNumberOperator.TopNRowNumberOperatorFactory( context.getNextOperatorId(), node.getId(), source.getTypes(), outputChannels.build(), partitionChannels, partitionTypes, sortChannels, sortOrder, node.getMaxRowCountPerPartition(), node.isPartial(), hashChannel, 1000, joinCompiler); return new PhysicalOperation(operatorFactory, makeLayout(node), source); } @Override public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); List<Symbol> partitionBySymbols = node.getPartitionBy(); List<Symbol> orderBySymbols = node.getOrderBy(); List<Integer> partitionChannels = ImmutableList.copyOf(getChannelsForSymbols(partitionBySymbols, source.getLayout())); List<Integer> preGroupedChannels = ImmutableList.copyOf(getChannelsForSymbols(ImmutableList.copyOf(node.getPrePartitionedInputs()), source.getLayout())); List<Integer> sortChannels = getChannelsForSymbols(orderBySymbols, source.getLayout()); List<SortOrder> sortOrder = orderBySymbols.stream() .map(symbol -> node.getOrderings().get(symbol)) .collect(toImmutableList()); ImmutableList.Builder<Integer> outputChannels = ImmutableList.builder(); for (int i = 0; i < source.getTypes().size(); i++) { outputChannels.add(i); } ImmutableList.Builder<WindowFunctionDefinition> windowFunctionsBuilder = ImmutableList.builder(); ImmutableList.Builder<Symbol> windowFunctionOutputSymbolsBuilder = ImmutableList.builder(); for (Map.Entry<Symbol, WindowNode.Function> entry : node.getWindowFunctions().entrySet()) { Optional<Integer> frameStartChannel = Optional.empty(); Optional<Integer> frameEndChannel = Optional.empty(); Frame frame = entry.getValue().getFrame(); if (frame.getStartValue().isPresent()) { frameStartChannel = Optional.of(source.getLayout().get(frame.getStartValue().get())); } if (frame.getEndValue().isPresent()) { frameEndChannel = Optional.of(source.getLayout().get(frame.getEndValue().get())); } FrameInfo frameInfo = new FrameInfo(frame.getType(), frame.getStartType(), frameStartChannel, frame.getEndType(), frameEndChannel); FunctionCall functionCall = entry.getValue().getFunctionCall(); Signature signature = entry.getValue().getSignature(); ImmutableList.Builder<Integer> arguments = ImmutableList.builder(); for (Expression argument : functionCall.getArguments()) { Symbol argumentSymbol = Symbol.from(argument); arguments.add(source.getLayout().get(argumentSymbol)); } Symbol symbol = entry.getKey(); WindowFunctionSupplier windowFunctionSupplier = metadata.getFunctionRegistry().getWindowFunctionImplementation(signature); Type type = metadata.getType(signature.getReturnType()); windowFunctionsBuilder.add(window(windowFunctionSupplier, type, frameInfo, arguments.build())); windowFunctionOutputSymbolsBuilder.add(symbol); } List<Symbol> windowFunctionOutputSymbols = windowFunctionOutputSymbolsBuilder.build(); // compute the layout of the output from the window operator ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); for (Symbol symbol : node.getSource().getOutputSymbols()) { outputMappings.put(symbol, source.getLayout().get(symbol)); } // window functions go in remaining channels starting after the last channel from the source operator, one per channel int channel = source.getTypes().size(); for (Symbol symbol : windowFunctionOutputSymbols) { outputMappings.put(symbol, channel); channel++; } OperatorFactory operatorFactory = new WindowOperatorFactory( context.getNextOperatorId(), node.getId(), source.getTypes(), outputChannels.build(), windowFunctionsBuilder.build(), partitionChannels, preGroupedChannels, sortChannels, sortOrder, node.getPreSortedOrderPrefix(), 10_000, pagesIndexFactory); return new PhysicalOperation(operatorFactory, outputMappings.build(), source); } @Override public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); List<Symbol> orderBySymbols = node.getOrderBy(); List<Integer> sortChannels = new ArrayList<>(); List<SortOrder> sortOrders = new ArrayList<>(); for (Symbol symbol : orderBySymbols) { sortChannels.add(source.getLayout().get(symbol)); sortOrders.add(node.getOrderings().get(symbol)); } OperatorFactory operator = new TopNOperatorFactory( context.getNextOperatorId(), node.getId(), source.getTypes(), (int) node.getCount(), sortChannels, sortOrders, node.isPartial(), maxPartialAggregationMemorySize); return new PhysicalOperation(operator, source.getLayout(), source); } @Override public PhysicalOperation visitSort(SortNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); List<Symbol> orderBySymbols = node.getOrderBy(); List<Integer> orderByChannels = getChannelsForSymbols(orderBySymbols, source.getLayout()); ImmutableList.Builder<SortOrder> sortOrder = ImmutableList.builder(); for (Symbol symbol : orderBySymbols) { sortOrder.add(node.getOrderings().get(symbol)); } ImmutableList.Builder<Integer> outputChannels = ImmutableList.builder(); for (int i = 0; i < source.getTypes().size(); i++) { outputChannels.add(i); } OperatorFactory operator = new OrderByOperatorFactory( context.getNextOperatorId(), node.getId(), source.getTypes(), outputChannels.build(), 10_000, orderByChannels, sortOrder.build(), pagesIndexFactory); return new PhysicalOperation(operator, source.getLayout(), source); } @Override public PhysicalOperation visitLimit(LimitNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); OperatorFactory operatorFactory = new LimitOperatorFactory(context.getNextOperatorId(), node.getId(), source.getTypes(), node.getCount()); return new PhysicalOperation(operatorFactory, source.getLayout(), source); } @Override public PhysicalOperation visitDistinctLimit(DistinctLimitNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); Optional<Integer> hashChannel = node.getHashSymbol().map(channelGetter(source)); List<Integer> distinctChannels = getChannelsForSymbols(node.getDistinctSymbols(), source.getLayout()); OperatorFactory operatorFactory = new DistinctLimitOperatorFactory( context.getNextOperatorId(), node.getId(), source.getTypes(), distinctChannels, node.getLimit(), hashChannel, joinCompiler); return new PhysicalOperation(operatorFactory, source.getLayout(), source); } @Override public PhysicalOperation visitGroupId(GroupIdNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); Map<Symbol, Integer> newLayout = new HashMap<>(); ImmutableList.Builder<Type> outputTypes = ImmutableList.builder(); int outputChannel = 0; for (Symbol output : node.getGroupingSets().stream().flatMap(Collection::stream).collect(Collectors.toSet())) { newLayout.put(output, outputChannel++); outputTypes.add(source.getTypes().get(source.getLayout().get(node.getGroupingSetMappings().get(output)))); } Map<Symbol, Integer> argumentMappings = new HashMap<>(); for (Symbol output : node.getArgumentMappings().keySet()) { int inputChannel = source.getLayout().get(node.getArgumentMappings().get(output)); newLayout.put(output, outputChannel++); outputTypes.add(source.getTypes().get(inputChannel)); argumentMappings.put(output, inputChannel); } // for every grouping set, create a mapping of all output to input channels (including arguments) ImmutableList.Builder<Map<Integer, Integer>> mappings = ImmutableList.builder(); for (List<Symbol> groupingSet : node.getGroupingSets()) { ImmutableMap.Builder<Integer, Integer> setMapping = ImmutableMap.builder(); for (Symbol output : groupingSet) { setMapping.put(newLayout.get(output), source.getLayout().get(node.getGroupingSetMappings().get(output))); } for (Symbol output : argumentMappings.keySet()) { setMapping.put(newLayout.get(output), argumentMappings.get(output)); } mappings.add(setMapping.build()); } newLayout.put(node.getGroupIdSymbol(), outputChannel); outputTypes.add(BIGINT); OperatorFactory groupIdOperatorFactory = new GroupIdOperator.GroupIdOperatorFactory(context.getNextOperatorId(), node.getId(), outputTypes.build(), mappings.build()); return new PhysicalOperation(groupIdOperatorFactory, newLayout, source); } @Override public PhysicalOperation visitAggregation(AggregationNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); if (node.getGroupingKeys().isEmpty()) { return planGlobalAggregation(context.getNextOperatorId(), node, source); } boolean spillEnabled = isSpillEnabled(context.getSession()); DataSize memoryLimitBeforeSpill = getOperatorMemoryLimitBeforeSpill(context.getSession()); return planGroupByAggregation(node, source, context.getNextOperatorId(), spillEnabled, memoryLimitBeforeSpill); } @Override public PhysicalOperation visitMarkDistinct(MarkDistinctNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); List<Integer> channels = getChannelsForSymbols(node.getDistinctSymbols(), source.getLayout()); Optional<Integer> hashChannel = node.getHashSymbol().map(channelGetter(source)); MarkDistinctOperatorFactory operator = new MarkDistinctOperatorFactory(context.getNextOperatorId(), node.getId(), source.getTypes(), channels, hashChannel, joinCompiler); return new PhysicalOperation(operator, makeLayout(node), source); } @Override public PhysicalOperation visitSample(SampleNode node, LocalExecutionPlanContext context) { // For system sample, the splits are already filtered out, so no specific action needs to be taken here if (node.getSampleType() == SampleNode.Type.SYSTEM) { return node.getSource().accept(this, context); } throw new UnsupportedOperationException("not yet implemented: " + node); } @Override public PhysicalOperation visitFilter(FilterNode node, LocalExecutionPlanContext context) { PlanNode sourceNode = node.getSource(); Expression filterExpression = node.getPredicate(); List<Symbol> outputSymbols = node.getOutputSymbols(); return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), Assignments.identity(outputSymbols), outputSymbols); } @Override public PhysicalOperation visitProject(ProjectNode node, LocalExecutionPlanContext context) { PlanNode sourceNode; Optional<Expression> filterExpression = Optional.empty(); if (node.getSource() instanceof FilterNode) { FilterNode filterNode = (FilterNode) node.getSource(); sourceNode = filterNode.getSource(); filterExpression = Optional.of(filterNode.getPredicate()); } else { sourceNode = node.getSource(); } List<Symbol> outputSymbols = node.getOutputSymbols(); return visitScanFilterAndProject(context, node.getId(), sourceNode, filterExpression, node.getAssignments(), outputSymbols); } // TODO: This should be refactored, so that there's an optimizer that merges scan-filter-project into a single PlanNode private PhysicalOperation visitScanFilterAndProject( LocalExecutionPlanContext context, PlanNodeId planNodeId, PlanNode sourceNode, Optional<Expression> filterExpression, Assignments assignments, List<Symbol> outputSymbols) { // if source is a table scan we fold it directly into the filter and project // otherwise we plan it as a normal operator Map<Symbol, Integer> sourceLayout; Map<Integer, Type> sourceTypes; List<ColumnHandle> columns = null; PhysicalOperation source = null; if (sourceNode instanceof TableScanNode) { TableScanNode tableScanNode = (TableScanNode) sourceNode; // extract the column handles and channel to type mapping sourceLayout = new LinkedHashMap<>(); sourceTypes = new LinkedHashMap<>(); columns = new ArrayList<>(); int channel = 0; for (Symbol symbol : tableScanNode.getOutputSymbols()) { columns.add(tableScanNode.getAssignments().get(symbol)); Integer input = channel; sourceLayout.put(symbol, input); Type type = requireNonNull(context.getTypes().get(symbol), format("No type for symbol %s", symbol)); sourceTypes.put(input, type); channel++; } } else { // plan source source = sourceNode.accept(this, context); sourceLayout = source.getLayout(); sourceTypes = getInputTypes(source.getLayout(), source.getTypes()); } // build output mapping ImmutableMap.Builder<Symbol, Integer> outputMappingsBuilder = ImmutableMap.builder(); for (int i = 0; i < outputSymbols.size(); i++) { Symbol symbol = outputSymbols.get(i); outputMappingsBuilder.put(symbol, i); } Map<Symbol, Integer> outputMappings = outputMappingsBuilder.build(); // compiler uses inputs instead of symbols, so rewrite the expressions first SymbolToInputRewriter symbolToInputRewriter = new SymbolToInputRewriter(sourceLayout); Optional<Expression> rewrittenFilter = filterExpression.map(symbolToInputRewriter::rewrite); List<Expression> rewrittenProjections = new ArrayList<>(); for (Symbol symbol : outputSymbols) { rewrittenProjections.add(symbolToInputRewriter.rewrite(assignments.get(symbol))); } IdentityLinkedHashMap<Expression, Type> expressionTypes = getExpressionTypesFromInput( context.getSession(), metadata, sqlParser, sourceTypes, concat(rewrittenFilter.map(ImmutableList::of).orElse(ImmutableList.of()), rewrittenProjections), emptyList()); Optional<RowExpression> translatedFilter = rewrittenFilter.map(filter -> toRowExpression(filter, expressionTypes)); List<RowExpression> translatedProjections = rewrittenProjections.stream() .map(expression -> toRowExpression(expression, expressionTypes)) .collect(toImmutableList()); try { if (columns != null) { Supplier<CursorProcessor> cursorProcessor = expressionCompiler.compileCursorProcessor(translatedFilter, translatedProjections, sourceNode.getId()); Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(translatedFilter, translatedProjections); SourceOperatorFactory operatorFactory = new ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory( context.getNextOperatorId(), planNodeId, sourceNode.getId(), pageSourceProvider, cursorProcessor, pageProcessor, columns, Lists.transform(rewrittenProjections, forMap(expressionTypes))); return new PhysicalOperation(operatorFactory, outputMappings); } else { Supplier<PageProcessor> pageProcessor = expressionCompiler.compilePageProcessor(translatedFilter, translatedProjections); OperatorFactory operatorFactory = new FilterAndProjectOperator.FilterAndProjectOperatorFactory( context.getNextOperatorId(), planNodeId, pageProcessor, Lists.transform(rewrittenProjections, forMap(expressionTypes))); return new PhysicalOperation(operatorFactory, outputMappings, source); } } catch (RuntimeException e) { if (!interpreterEnabled) { throw new PrestoException(COMPILER_ERROR, "Compiler failed and interpreter is disabled", e); } // compilation failed, use interpreter log.error(e, "Compile failed for filter=%s projections=%s sourceTypes=%s error=%s", filterExpression, assignments, sourceTypes, e); } PageProcessor pageProcessor = createInterpretedColumnarPageProcessor( filterExpression, outputSymbols.stream() .map(assignments::get) .collect(toImmutableList()), context.getTypes(), sourceLayout, context.getSession()); if (columns != null) { InterpretedCursorProcessor cursorProcessor = new InterpretedCursorProcessor( filterExpression, outputSymbols.stream() .map(assignments::get) .collect(toImmutableList()), context.getTypes(), sourceLayout, metadata, sqlParser, context.getSession()); OperatorFactory operatorFactory = new ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory( context.getNextOperatorId(), planNodeId, sourceNode.getId(), pageSourceProvider, () -> cursorProcessor, () -> pageProcessor, columns, Lists.transform(rewrittenProjections, forMap(expressionTypes))); return new PhysicalOperation(operatorFactory, outputMappings); } else { OperatorFactory operatorFactory = new FilterAndProjectOperator.FilterAndProjectOperatorFactory( context.getNextOperatorId(), planNodeId, () -> pageProcessor, Lists.transform(rewrittenProjections, forMap(expressionTypes))); return new PhysicalOperation(operatorFactory, outputMappings, source); } } private PageProcessor createInterpretedColumnarPageProcessor( Optional<Expression> filter, List<Expression> projections, Map<Symbol, Type> symbolTypes, Map<Symbol, Integer> symbolToInputMappings, Session session) { Optional<PageFilter> pageFilter = filter .map(expression -> new InterpretedPageFilter(expression, symbolTypes, symbolToInputMappings, metadata, sqlParser, session)); List<PageProjection> pageProjections = projections.stream() .map(expression -> new InterpretedPageProjection(expression, symbolTypes, symbolToInputMappings, metadata, sqlParser, session)) .collect(toImmutableList()); return new PageProcessor(pageFilter, pageProjections); } private RowExpression toRowExpression(Expression expression, IdentityLinkedHashMap<Expression, Type> types) { return SqlToRowExpressionTranslator.translate(expression, SCALAR, types, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, true); } private Map<Integer, Type> getInputTypes(Map<Symbol, Integer> layout, List<Type> types) { Builder<Integer, Type> inputTypes = ImmutableMap.builder(); for (Integer input : ImmutableSet.copyOf(layout.values())) { Type type = types.get(input); inputTypes.put(input, type); } return inputTypes.build(); } @Override public PhysicalOperation visitTableScan(TableScanNode node, LocalExecutionPlanContext context) { List<ColumnHandle> columns = new ArrayList<>(); for (Symbol symbol : node.getOutputSymbols()) { columns.add(node.getAssignments().get(symbol)); } List<Type> types = getSourceOperatorTypes(node, context.getTypes()); OperatorFactory operatorFactory = new TableScanOperatorFactory(context.getNextOperatorId(), node.getId(), pageSourceProvider, types, columns); return new PhysicalOperation(operatorFactory, makeLayout(node)); } @Override public PhysicalOperation visitValues(ValuesNode node, LocalExecutionPlanContext context) { // a values node must have a single driver context.setDriverInstanceCount(1); List<Type> outputTypes = new ArrayList<>(); for (Symbol symbol : node.getOutputSymbols()) { Type type = requireNonNull(context.getTypes().get(symbol), format("No type for symbol %s", symbol)); outputTypes.add(type); } if (node.getRows().isEmpty()) { OperatorFactory operatorFactory = new ValuesOperatorFactory(context.getNextOperatorId(), node.getId(), outputTypes, ImmutableList.of()); return new PhysicalOperation(operatorFactory, makeLayout(node)); } PageBuilder pageBuilder = new PageBuilder(outputTypes); for (List<Expression> row : node.getRows()) { pageBuilder.declarePosition(); IdentityLinkedHashMap<Expression, Type> expressionTypes = getExpressionTypes( context.getSession(), metadata, sqlParser, ImmutableMap.of(), ImmutableList.copyOf(row), emptyList(), false); for (int i = 0; i < row.size(); i++) { // evaluate the literal value Object result = ExpressionInterpreter.expressionInterpreter(row.get(i), metadata, context.getSession(), expressionTypes).evaluate(0); writeNativeValue(outputTypes.get(i), pageBuilder.getBlockBuilder(i), result); } } OperatorFactory operatorFactory = new ValuesOperatorFactory(context.getNextOperatorId(), node.getId(), outputTypes, ImmutableList.of(pageBuilder.build())); return new PhysicalOperation(operatorFactory, makeLayout(node)); } @Override public PhysicalOperation visitUnnest(UnnestNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); ImmutableList.Builder<Type> replicateTypes = ImmutableList.builder(); for (Symbol symbol : node.getReplicateSymbols()) { replicateTypes.add(context.getTypes().get(symbol)); } List<Symbol> unnestSymbols = ImmutableList.copyOf(node.getUnnestSymbols().keySet()); ImmutableList.Builder<Type> unnestTypes = ImmutableList.builder(); for (Symbol symbol : unnestSymbols) { unnestTypes.add(context.getTypes().get(symbol)); } Optional<Symbol> ordinalitySymbol = node.getOrdinalitySymbol(); Optional<Type> ordinalityType = ordinalitySymbol.map(context.getTypes()::get); ordinalityType.ifPresent(type -> checkState(type.equals(BIGINT), "Type of ordinalitySymbol must always be BIGINT.")); List<Integer> replicateChannels = getChannelsForSymbols(node.getReplicateSymbols(), source.getLayout()); List<Integer> unnestChannels = getChannelsForSymbols(unnestSymbols, source.getLayout()); // Source channels are always laid out first, followed by the unnested symbols ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); int channel = 0; for (Symbol symbol : node.getReplicateSymbols()) { outputMappings.put(symbol, channel); channel++; } for (Symbol symbol : unnestSymbols) { for (Symbol unnestedSymbol : node.getUnnestSymbols().get(symbol)) { outputMappings.put(unnestedSymbol, channel); channel++; } } if (ordinalitySymbol.isPresent()) { outputMappings.put(ordinalitySymbol.get(), channel); channel++; } OperatorFactory operatorFactory = new UnnestOperatorFactory( context.getNextOperatorId(), node.getId(), replicateChannels, replicateTypes.build(), unnestChannels, unnestTypes.build(), ordinalityType.isPresent()); return new PhysicalOperation(operatorFactory, outputMappings.build(), source); } private ImmutableMap<Symbol, Integer> makeLayout(PlanNode node) { Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); int channel = 0; for (Symbol symbol : node.getOutputSymbols()) { outputMappings.put(symbol, channel); channel++; } return outputMappings.build(); } @Override public PhysicalOperation visitIndexSource(IndexSourceNode node, LocalExecutionPlanContext context) { checkState(context.getIndexSourceContext().isPresent(), "Must be in an index source context"); IndexSourceContext indexSourceContext = context.getIndexSourceContext().get(); SetMultimap<Symbol, Integer> indexLookupToProbeInput = indexSourceContext.getIndexLookupToProbeInput(); checkState(indexLookupToProbeInput.keySet().equals(node.getLookupSymbols())); // Finalize the symbol lookup layout for the index source List<Symbol> lookupSymbolSchema = ImmutableList.copyOf(node.getLookupSymbols()); // Identify how to remap the probe key Input to match the source index lookup layout ImmutableList.Builder<Integer> remappedProbeKeyChannelsBuilder = ImmutableList.builder(); // Identify overlapping fields that can produce the same lookup symbol. // We will filter incoming keys to ensure that overlapping fields will have the same value. ImmutableList.Builder<Set<Integer>> overlappingFieldSetsBuilder = ImmutableList.builder(); for (Symbol lookupSymbol : lookupSymbolSchema) { Set<Integer> potentialProbeInputs = indexLookupToProbeInput.get(lookupSymbol); checkState(!potentialProbeInputs.isEmpty(), "Must have at least one source from the probe input"); if (potentialProbeInputs.size() > 1) { overlappingFieldSetsBuilder.add(potentialProbeInputs.stream().collect(toImmutableSet())); } remappedProbeKeyChannelsBuilder.add(Iterables.getFirst(potentialProbeInputs, null)); } List<Set<Integer>> overlappingFieldSets = overlappingFieldSetsBuilder.build(); List<Integer> remappedProbeKeyChannels = remappedProbeKeyChannelsBuilder.build(); Function<RecordSet, RecordSet> probeKeyNormalizer = recordSet -> { if (!overlappingFieldSets.isEmpty()) { recordSet = new FieldSetFilteringRecordSet(metadata.getFunctionRegistry(), recordSet, overlappingFieldSets); } return new MappedRecordSet(recordSet, remappedProbeKeyChannels); }; // Declare the input and output schemas for the index and acquire the actual Index List<ColumnHandle> lookupSchema = Lists.transform(lookupSymbolSchema, forMap(node.getAssignments())); List<ColumnHandle> outputSchema = Lists.transform(node.getOutputSymbols(), forMap(node.getAssignments())); ConnectorIndex index = indexManager.getIndex(session, node.getIndexHandle(), lookupSchema, outputSchema); List<Type> types = getSourceOperatorTypes(node, context.getTypes()); OperatorFactory operatorFactory = new IndexSourceOperator.IndexSourceOperatorFactory(context.getNextOperatorId(), node.getId(), index, types, probeKeyNormalizer); return new PhysicalOperation(operatorFactory, makeLayout(node)); } /** * This method creates a mapping from each index source lookup symbol (directly applied to the index) * to the corresponding probe key Input */ private SetMultimap<Symbol, Integer> mapIndexSourceLookupSymbolToProbeKeyInput(IndexJoinNode node, Map<Symbol, Integer> probeKeyLayout) { Set<Symbol> indexJoinSymbols = node.getCriteria().stream() .map(IndexJoinNode.EquiJoinClause::getIndex) .collect(toImmutableSet()); // Trace the index join symbols to the index source lookup symbols // Map: Index join symbol => Index source lookup symbol Map<Symbol, Symbol> indexKeyTrace = IndexJoinOptimizer.IndexKeyTracer.trace(node.getIndexSource(), indexJoinSymbols); // Map the index join symbols to the probe key Input Multimap<Symbol, Integer> indexToProbeKeyInput = HashMultimap.create(); for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) { indexToProbeKeyInput.put(clause.getIndex(), probeKeyLayout.get(clause.getProbe())); } // Create the mapping from index source look up symbol to probe key Input ImmutableSetMultimap.Builder<Symbol, Integer> builder = ImmutableSetMultimap.builder(); for (Map.Entry<Symbol, Symbol> entry : indexKeyTrace.entrySet()) { Symbol indexJoinSymbol = entry.getKey(); Symbol indexLookupSymbol = entry.getValue(); builder.putAll(indexLookupSymbol, indexToProbeKeyInput.get(indexJoinSymbol)); } return builder.build(); } @Override public PhysicalOperation visitIndexJoin(IndexJoinNode node, LocalExecutionPlanContext context) { List<IndexJoinNode.EquiJoinClause> clauses = node.getCriteria(); List<Symbol> probeSymbols = Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getProbe); List<Symbol> indexSymbols = Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getIndex); // Plan probe side PhysicalOperation probeSource = node.getProbeSource().accept(this, context); List<Integer> probeChannels = getChannelsForSymbols(probeSymbols, probeSource.getLayout()); Optional<Integer> probeHashChannel = node.getProbeHashSymbol().map(channelGetter(probeSource)); // The probe key channels will be handed to the index according to probeSymbol order Map<Symbol, Integer> probeKeyLayout = new HashMap<>(); for (int i = 0; i < probeSymbols.size(); i++) { // Duplicate symbols can appear and we only need to take take one of the Inputs probeKeyLayout.put(probeSymbols.get(i), i); } // Plan the index source side SetMultimap<Symbol, Integer> indexLookupToProbeInput = mapIndexSourceLookupSymbolToProbeKeyInput(node, probeKeyLayout); LocalExecutionPlanContext indexContext = context.createIndexSourceSubContext(new IndexSourceContext(indexLookupToProbeInput)); PhysicalOperation indexSource = node.getIndexSource().accept(this, indexContext); List<Integer> indexOutputChannels = getChannelsForSymbols(indexSymbols, indexSource.getLayout()); Optional<Integer> indexHashChannel = node.getIndexHashSymbol().map(channelGetter(indexSource)); // Identify just the join keys/channels needed for lookup by the index source (does not have to use all of them). Set<Symbol> indexSymbolsNeededBySource = IndexJoinOptimizer.IndexKeyTracer.trace(node.getIndexSource(), ImmutableSet.copyOf(indexSymbols)).keySet(); Set<Integer> lookupSourceInputChannels = node.getCriteria().stream() .filter(equiJoinClause -> indexSymbolsNeededBySource.contains(equiJoinClause.getIndex())) .map(IndexJoinNode.EquiJoinClause::getProbe) .map(probeKeyLayout::get) .collect(toImmutableSet()); Optional<DynamicTupleFilterFactory> dynamicTupleFilterFactory = Optional.empty(); if (lookupSourceInputChannels.size() < probeKeyLayout.values().size()) { int[] nonLookupInputChannels = Ints.toArray(node.getCriteria().stream() .filter(equiJoinClause -> !indexSymbolsNeededBySource.contains(equiJoinClause.getIndex())) .map(IndexJoinNode.EquiJoinClause::getProbe) .map(probeKeyLayout::get) .collect(toImmutableList())); int[] nonLookupOutputChannels = Ints.toArray(node.getCriteria().stream() .filter(equiJoinClause -> !indexSymbolsNeededBySource.contains(equiJoinClause.getIndex())) .map(IndexJoinNode.EquiJoinClause::getIndex) .map(indexSource.getLayout()::get) .collect(toImmutableList())); int filterOperatorId = indexContext.getNextOperatorId(); dynamicTupleFilterFactory = Optional.of(new DynamicTupleFilterFactory( filterOperatorId, node.getId(), nonLookupInputChannels, nonLookupOutputChannels, indexSource.getTypes(), metadata)); } IndexBuildDriverFactoryProvider indexBuildDriverFactoryProvider = new IndexBuildDriverFactoryProvider( indexContext.getNextPipelineId(), indexContext.getNextOperatorId(), node.getId(), indexContext.isInputDriver(), indexSource.getOperatorFactories(), dynamicTupleFilterFactory); IndexLookupSourceFactory indexLookupSourceFactory = new IndexLookupSourceFactory( lookupSourceInputChannels, indexOutputChannels, indexHashChannel, indexSource.getTypes(), indexSource.getLayout(), indexBuildDriverFactoryProvider, maxIndexMemorySize, indexJoinLookupStats, SystemSessionProperties.isShareIndexLoading(session), pagesIndexFactory, joinCompiler); ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); outputMappings.putAll(probeSource.getLayout()); // inputs from index side of the join are laid out following the input from the probe side, // so adjust the channel ids but keep the field layouts intact int offset = probeSource.getTypes().size(); for (Map.Entry<Symbol, Integer> entry : indexSource.getLayout().entrySet()) { Integer input = entry.getValue(); outputMappings.put(entry.getKey(), offset + input); } OperatorFactory lookupJoinOperatorFactory; switch (node.getType()) { case INNER: lookupJoinOperatorFactory = lookupJoinOperators.innerJoin(context.getNextOperatorId(), node.getId(), indexLookupSourceFactory, probeSource.getTypes(), probeChannels, probeHashChannel, Optional.empty()); break; case SOURCE_OUTER: lookupJoinOperatorFactory = lookupJoinOperators.probeOuterJoin(context.getNextOperatorId(), node.getId(), indexLookupSourceFactory, probeSource.getTypes(), probeChannels, probeHashChannel, Optional.empty()); break; default: throw new AssertionError("Unknown type: " + node.getType()); } return new PhysicalOperation(lookupJoinOperatorFactory, outputMappings.build(), probeSource); } @Override public PhysicalOperation visitJoin(JoinNode node, LocalExecutionPlanContext context) { List<JoinNode.EquiJoinClause> clauses = node.getCriteria(); if (node.isCrossJoin()) { return createNestedLoopJoin(node, context); } List<Symbol> leftSymbols = Lists.transform(clauses, JoinNode.EquiJoinClause::getLeft); List<Symbol> rightSymbols = Lists.transform(clauses, JoinNode.EquiJoinClause::getRight); switch (node.getType()) { case INNER: case LEFT: case RIGHT: case FULL: return createLookupJoin(node, node.getLeft(), leftSymbols, node.getLeftHashSymbol(), node.getRight(), rightSymbols, node.getRightHashSymbol(), context); default: throw new UnsupportedOperationException("Unsupported join type: " + node.getType()); } } private PhysicalOperation createNestedLoopJoin(JoinNode node, LocalExecutionPlanContext context) { PhysicalOperation probeSource = node.getLeft().accept(this, context); LocalExecutionPlanContext buildContext = context.createSubContext(); PhysicalOperation buildSource = node.getRight().accept(this, buildContext); NestedLoopBuildOperatorFactory nestedLoopBuildOperatorFactory = new NestedLoopBuildOperatorFactory( buildContext.getNextOperatorId(), node.getId(), buildSource.getTypes()); checkArgument(buildContext.getDriverInstanceCount().orElse(1) == 1, "Expected local execution to not be parallel"); context.addDriverFactory( buildContext.isInputDriver(), false, ImmutableList.<OperatorFactory>builder() .addAll(buildSource.getOperatorFactories()) .add(nestedLoopBuildOperatorFactory) .build(), buildContext.getDriverInstanceCount()); NestedLoopJoinPagesSupplier nestedLoopJoinPagesSupplier = nestedLoopBuildOperatorFactory.getNestedLoopJoinPagesSupplier(); ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); outputMappings.putAll(probeSource.getLayout()); // inputs from build side of the join are laid out following the input from the probe side, // so adjust the channel ids but keep the field layouts intact int offset = probeSource.getTypes().size(); for (Map.Entry<Symbol, Integer> entry : buildSource.getLayout().entrySet()) { outputMappings.put(entry.getKey(), offset + entry.getValue()); } OperatorFactory operatorFactory = new NestedLoopJoinOperatorFactory(context.getNextOperatorId(), node.getId(), nestedLoopJoinPagesSupplier, probeSource.getTypes()); PhysicalOperation operation = new PhysicalOperation(operatorFactory, outputMappings.build(), probeSource); return operation; } private PhysicalOperation createLookupJoin(JoinNode node, PlanNode probeNode, List<Symbol> probeSymbols, Optional<Symbol> probeHashSymbol, PlanNode buildNode, List<Symbol> buildSymbols, Optional<Symbol> buildHashSymbol, LocalExecutionPlanContext context) { // Plan probe PhysicalOperation probeSource = probeNode.accept(this, context); // Plan build LookupSourceFactory lookupSourceFactory = createLookupSourceFactory(node, buildNode, buildSymbols, buildHashSymbol, probeSource.getLayout(), context); OperatorFactory operator = createLookupJoin(node, probeSource, probeSymbols, probeHashSymbol, lookupSourceFactory, context); ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); List<Symbol> outputSymbols = node.getOutputSymbols(); for (int i = 0; i < outputSymbols.size(); i++) { Symbol symbol = outputSymbols.get(i); outputMappings.put(symbol, i); } return new PhysicalOperation(operator, outputMappings.build(), probeSource); } private LookupSourceFactory createLookupSourceFactory( JoinNode node, PlanNode buildNode, List<Symbol> buildSymbols, Optional<Symbol> buildHashSymbol, Map<Symbol, Integer> probeLayout, LocalExecutionPlanContext context) { LocalExecutionPlanContext buildContext = context.createSubContext(); PhysicalOperation buildSource = buildNode.accept(this, buildContext); List<Symbol> buildOutputSymbols = node.getOutputSymbols().stream() .filter(symbol -> node.getRight().getOutputSymbols().contains(symbol)) .collect(toImmutableList()); List<Integer> buildOutputChannels = ImmutableList.copyOf(getChannelsForSymbols(buildOutputSymbols, buildSource.getLayout())); List<Integer> buildChannels = ImmutableList.copyOf(getChannelsForSymbols(buildSymbols, buildSource.getLayout())); Optional<Integer> buildHashChannel = buildHashSymbol.map(channelGetter(buildSource)); Optional<JoinFilterFunctionFactory> filterFunctionFactory = node.getFilter() .map(filterExpression -> compileJoinFilterFunction( filterExpression, node.getSortExpression(), probeLayout, buildSource.getLayout(), context.getTypes(), context.getSession())); HashBuilderOperatorFactory hashBuilderOperatorFactory = new HashBuilderOperatorFactory( buildContext.getNextOperatorId(), node.getId(), buildSource.getTypes(), buildOutputChannels, buildSource.getLayout(), buildChannels, buildHashChannel, node.getType() == RIGHT || node.getType() == FULL, filterFunctionFactory, 10_000, buildContext.getDriverInstanceCount().orElse(1), pagesIndexFactory); context.addDriverFactory( buildContext.isInputDriver(), false, ImmutableList.<OperatorFactory>builder() .addAll(buildSource.getOperatorFactories()) .add(hashBuilderOperatorFactory) .build(), buildContext.getDriverInstanceCount()); return hashBuilderOperatorFactory.getLookupSourceFactory(); } private JoinFilterFunctionFactory compileJoinFilterFunction( Expression filterExpression, Optional<Expression> sortExpression, Map<Symbol, Integer> probeLayout, Map<Symbol, Integer> buildLayout, Map<Symbol, Type> types, Session session) { Map<Symbol, Integer> joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); Map<Integer, Type> sourceTypes = joinSourcesLayout.entrySet().stream() .collect(toImmutableMap(Map.Entry::getValue, entry -> types.get(entry.getKey()))); Expression rewrittenFilter = new SymbolToInputRewriter(joinSourcesLayout).rewrite(filterExpression); Optional<Expression> rewrittenSortExpression = sortExpression.map( expression -> new SymbolToInputRewriter(buildLayout).rewrite(expression)); Optional<SortExpression> sortChannel = rewrittenSortExpression.map(SortExpression::fromExpression); IdentityLinkedHashMap<Expression, Type> expressionTypes = getExpressionTypesFromInput( session, metadata, sqlParser, sourceTypes, rewrittenFilter, emptyList() /* parameters have already been replaced */); RowExpression translatedFilter = toRowExpression(rewrittenFilter, expressionTypes); return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size(), sortChannel); } private OperatorFactory createLookupJoin( JoinNode node, PhysicalOperation probeSource, List<Symbol> probeSymbols, Optional<Symbol> probeHashSymbol, LookupSourceFactory lookupSourceFactory, LocalExecutionPlanContext context) { List<Type> probeTypes = probeSource.getTypes(); List<Symbol> probeOutputSymbols = node.getOutputSymbols().stream() .filter(symbol -> node.getLeft().getOutputSymbols().contains(symbol)) .collect(toImmutableList()); List<Integer> probeOutputChannels = ImmutableList.copyOf(getChannelsForSymbols(probeOutputSymbols, probeSource.getLayout())); List<Integer> probeJoinChannels = ImmutableList.copyOf(getChannelsForSymbols(probeSymbols, probeSource.getLayout())); Optional<Integer> probeHashChannel = probeHashSymbol.map(channelGetter(probeSource)); switch (node.getType()) { case INNER: return lookupJoinOperators.innerJoin(context.getNextOperatorId(), node.getId(), lookupSourceFactory, probeTypes, probeJoinChannels, probeHashChannel, Optional.of(probeOutputChannels)); case LEFT: return lookupJoinOperators.probeOuterJoin(context.getNextOperatorId(), node.getId(), lookupSourceFactory, probeTypes, probeJoinChannels, probeHashChannel, Optional.of(probeOutputChannels)); case RIGHT: return lookupJoinOperators.lookupOuterJoin(context.getNextOperatorId(), node.getId(), lookupSourceFactory, probeTypes, probeJoinChannels, probeHashChannel, Optional.of(probeOutputChannels)); case FULL: return lookupJoinOperators.fullOuterJoin(context.getNextOperatorId(), node.getId(), lookupSourceFactory, probeTypes, probeJoinChannels, probeHashChannel, Optional.of(probeOutputChannels)); default: throw new UnsupportedOperationException("Unsupported join type: " + node.getType()); } } private Map<Symbol, Integer> createJoinSourcesLayout(Map<Symbol, Integer> lookupSourceLayout, Map<Symbol, Integer> probeSourceLayout) { Builder<Symbol, Integer> joinSourcesLayout = ImmutableMap.builder(); joinSourcesLayout.putAll(lookupSourceLayout); for (Map.Entry<Symbol, Integer> probeLayoutEntry : probeSourceLayout.entrySet()) { joinSourcesLayout.put(probeLayoutEntry.getKey(), probeLayoutEntry.getValue() + lookupSourceLayout.size()); } return joinSourcesLayout.build(); } @Override public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanContext context) { // Plan probe PhysicalOperation probeSource = node.getSource().accept(this, context); // Plan build LocalExecutionPlanContext buildContext = context.createSubContext(); PhysicalOperation buildSource = node.getFilteringSource().accept(this, buildContext); checkArgument(buildContext.getDriverInstanceCount().orElse(1) == 1, "Expected local execution to not be parallel"); int probeChannel = probeSource.getLayout().get(node.getSourceJoinSymbol()); int buildChannel = buildSource.getLayout().get(node.getFilteringSourceJoinSymbol()); Optional<Integer> buildHashChannel = node.getFilteringSourceHashSymbol().map(channelGetter(buildSource)); SetBuilderOperatorFactory setBuilderOperatorFactory = new SetBuilderOperatorFactory( buildContext.getNextOperatorId(), node.getId(), buildSource.getTypes().get(buildChannel), buildChannel, buildHashChannel, 10_000, joinCompiler); SetSupplier setProvider = setBuilderOperatorFactory.getSetProvider(); context.addDriverFactory(buildContext.isInputDriver(), false, ImmutableList.<OperatorFactory>builder() .addAll(buildSource.getOperatorFactories()) .add(setBuilderOperatorFactory) .build(), buildContext.getDriverInstanceCount()); // Source channels are always laid out first, followed by the boolean output symbol Map<Symbol, Integer> outputMappings = ImmutableMap.<Symbol, Integer>builder() .putAll(probeSource.getLayout()) .put(node.getSemiJoinOutput(), probeSource.getLayout().size()) .build(); HashSemiJoinOperatorFactory operator = new HashSemiJoinOperatorFactory(context.getNextOperatorId(), node.getId(), setProvider, probeSource.getTypes(), probeChannel); return new PhysicalOperation(operator, outputMappings, probeSource); } @Override public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPlanContext context) { // Set table writer count if (node.getPartitioningScheme().isPresent()) { context.setDriverInstanceCount(1); } else { context.setDriverInstanceCount(getTaskWriterCount(session)); } // serialize writes by forcing data through a single writer PhysicalOperation source = node.getSource().accept(this, context); List<Integer> inputChannels = node.getColumns().stream() .map(source::symbolToChannel) .collect(toImmutableList()); OperatorFactory operatorFactory = new TableWriterOperatorFactory( context.getNextOperatorId(), node.getId(), pageSinkManager, node.getTarget(), inputChannels, session); Map<Symbol, Integer> layout = ImmutableMap.<Symbol, Integer>builder() .put(node.getOutputSymbols().get(0), 0) .put(node.getOutputSymbols().get(1), 1) .build(); return new PhysicalOperation(operatorFactory, layout, source); } @Override public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); OperatorFactory operatorFactory = new TableFinishOperatorFactory(context.getNextOperatorId(), node.getId(), createTableFinisher(session, node, metadata)); Map<Symbol, Integer> layout = ImmutableMap.of(node.getOutputSymbols().get(0), 0); return new PhysicalOperation(operatorFactory, layout, source); } @Override public PhysicalOperation visitDelete(DeleteNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); OperatorFactory operatorFactory = new DeleteOperatorFactory(context.getNextOperatorId(), node.getId(), source.getLayout().get(node.getRowId())); Map<Symbol, Integer> layout = ImmutableMap.<Symbol, Integer>builder() .put(node.getOutputSymbols().get(0), 0) .put(node.getOutputSymbols().get(1), 1) .build(); return new PhysicalOperation(operatorFactory, layout, source); } @Override public PhysicalOperation visitMetadataDelete(MetadataDeleteNode node, LocalExecutionPlanContext context) { OperatorFactory operatorFactory = new MetadataDeleteOperatorFactory(context.getNextOperatorId(), node.getId(), node.getTableLayout(), metadata, session, node.getTarget().getHandle()); return new PhysicalOperation(operatorFactory, makeLayout(node)); } @Override public PhysicalOperation visitUnion(UnionNode node, LocalExecutionPlanContext context) { throw new UnsupportedOperationException("Union node should not be present in a local execution plan"); } @Override public PhysicalOperation visitEnforceSingleRow(EnforceSingleRowNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); List<Type> types = getSourceOperatorTypes(node, context.getTypes()); OperatorFactory operatorFactory = new EnforceSingleRowOperator.EnforceSingleRowOperatorFactory(context.getNextOperatorId(), node.getId(), types); return new PhysicalOperation(operatorFactory, makeLayout(node), source); } @Override public PhysicalOperation visitAssignUniqueId(AssignUniqueId node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); List<Type> types = getSourceOperatorTypes(node, context.getTypes()); OperatorFactory operatorFactory = new AssignUniqueIdOperator.AssignUniqueIdOperatorFactory( context.getNextOperatorId(), node.getId(), types); return new PhysicalOperation(operatorFactory, makeLayout(node), source); } @Override public PhysicalOperation visitExchange(ExchangeNode node, LocalExecutionPlanContext context) { checkArgument(node.getScope() == LOCAL, "Only local exchanges are supported in the local planner"); int driverInstanceCount; if (node.getType() == ExchangeNode.Type.GATHER) { driverInstanceCount = 1; context.setDriverInstanceCount(1); } else if (context.getDriverInstanceCount().isPresent()) { driverInstanceCount = context.getDriverInstanceCount().getAsInt(); } else { driverInstanceCount = getTaskConcurrency(session); context.setDriverInstanceCount(driverInstanceCount); } List<Type> types = getSourceOperatorTypes(node, context.getTypes()); List<Integer> channels = node.getPartitioningScheme().getPartitioning().getArguments().stream() .map(argument -> node.getOutputSymbols().indexOf(argument.getColumn())) .collect(toImmutableList()); Optional<Integer> hashChannel = node.getPartitioningScheme().getHashColumn() .map(symbol -> node.getOutputSymbols().indexOf(symbol)); LocalExchange localExchange = new LocalExchange(node.getPartitioningScheme().getPartitioning().getHandle(), driverInstanceCount, types, channels, hashChannel); for (int i = 0; i < node.getSources().size(); i++) { PlanNode sourceNode = node.getSources().get(i); List<Symbol> expectedLayout = node.getInputs().get(i); LocalExecutionPlanContext subContext = context.createSubContext(); PhysicalOperation source = sourceNode.accept(this, subContext); List<OperatorFactory> operatorFactories = new ArrayList<>(source.getOperatorFactories()); Function<Page, Page> pagePreprocessor = enforceLayoutProcessor(expectedLayout, source.getLayout()); operatorFactories.add(new LocalExchangeSinkOperatorFactory(subContext.getNextOperatorId(), node.getId(), localExchange.createSinkFactory(), pagePreprocessor)); context.addDriverFactory(subContext.isInputDriver(), false, operatorFactories, subContext.getDriverInstanceCount()); } // the main driver is not an input... the exchange sources are the input for the plan context.setInputDriver(false); // instance count must match the number of partitions in the exchange verify(context.getDriverInstanceCount().getAsInt() == localExchange.getBufferCount(), "driver instance count must match the number of exchange partitions"); return new PhysicalOperation(new LocalExchangeSourceOperatorFactory(context.getNextOperatorId(), node.getId(), localExchange), makeLayout(node)); } @Override protected PhysicalOperation visitPlan(PlanNode node, LocalExecutionPlanContext context) { throw new UnsupportedOperationException("not yet implemented"); } private List<Type> getSourceOperatorTypes(PlanNode node, Map<Symbol, Type> types) { return getSymbolTypes(node.getOutputSymbols(), types); } private List<Type> getSymbolTypes(List<Symbol> symbols, Map<Symbol, Type> types) { return symbols.stream() .map(types::get) .collect(toImmutableList()); } private AccumulatorFactory buildAccumulatorFactory( PhysicalOperation source, Signature function, FunctionCall call, @Nullable Symbol mask) { List<Integer> arguments = new ArrayList<>(); for (Expression argument : call.getArguments()) { Symbol argumentSymbol = Symbol.from(argument); arguments.add(source.getLayout().get(argumentSymbol)); } Optional<Integer> maskChannel = Optional.empty(); if (mask != null) { maskChannel = Optional.of(source.getLayout().get(mask)); } return metadata.getFunctionRegistry().getAggregateFunctionImplementation(function).bind(arguments, maskChannel); } private PhysicalOperation planGlobalAggregation(int operatorId, AggregationNode node, PhysicalOperation source) { int outputChannel = 0; ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); List<AccumulatorFactory> accumulatorFactories = new ArrayList<>(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); accumulatorFactories.add(buildAccumulatorFactory(source, node.getFunctions().get(symbol), entry.getValue(), node.getMasks().get(entry.getKey()))); outputMappings.put(symbol, outputChannel); // one aggregation per channel outputChannel++; } OperatorFactory operatorFactory = new AggregationOperatorFactory(operatorId, node.getId(), node.getStep(), accumulatorFactories); return new PhysicalOperation(operatorFactory, outputMappings.build(), source); } private PhysicalOperation planGroupByAggregation( AggregationNode node, PhysicalOperation source, int operatorId, boolean spillEnabled, DataSize memoryLimitBeforeSpill) { List<Symbol> groupBySymbols = node.getGroupingKeys(); List<Symbol> aggregationOutputSymbols = new ArrayList<>(); List<AccumulatorFactory> accumulatorFactories = new ArrayList<>(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); accumulatorFactories.add(buildAccumulatorFactory( source, node.getFunctions().get(symbol), entry.getValue(), node.getMasks().get(entry.getKey()))); aggregationOutputSymbols.add(symbol); } ImmutableList.Builder<Integer> globalAggregationGroupIds = ImmutableList.builder(); for (int i = 0; i < node.getGroupingSets().size(); i++) { if (node.getGroupingSets().get(i).isEmpty()) { globalAggregationGroupIds.add(i); } } ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); // add group-by key fields each in a separate channel int channel = 0; for (Symbol symbol : groupBySymbols) { outputMappings.put(symbol, channel); channel++; } // hashChannel follows the group by channels if (node.getHashSymbol().isPresent()) { outputMappings.put(node.getHashSymbol().get(), channel++); } // aggregations go in following channels for (Symbol symbol : aggregationOutputSymbols) { outputMappings.put(symbol, channel); channel++; } List<Integer> groupByChannels = getChannelsForSymbols(groupBySymbols, source.getLayout()); List<Type> groupByTypes = groupByChannels.stream() .map(entry -> source.getTypes().get(entry)) .collect(toImmutableList()); Optional<Integer> hashChannel = node.getHashSymbol().map(channelGetter(source)); Map<Symbol, Integer> mappings = outputMappings.build(); OperatorFactory operatorFactory = new HashAggregationOperatorFactory( operatorId, node.getId(), groupByTypes, groupByChannels, globalAggregationGroupIds.build(), node.getStep(), node.hasDefaultOutput(), accumulatorFactories, hashChannel, node.getGroupIdSymbol().map(mappings::get), 10_000, maxPartialAggregationMemorySize, spillEnabled, memoryLimitBeforeSpill, spillerFactory, joinCompiler); return new PhysicalOperation(operatorFactory, mappings, source); } } private static TableFinisher createTableFinisher(Session session, TableFinishNode node, Metadata metadata) { WriterTarget target = node.getTarget(); return fragments -> { if (target instanceof CreateHandle) { return metadata.finishCreateTable(session, ((CreateHandle) target).getHandle(), fragments); } else if (target instanceof InsertHandle) { return metadata.finishInsert(session, ((InsertHandle) target).getHandle(), fragments); } else if (target instanceof DeleteHandle) { metadata.finishDelete(session, ((DeleteHandle) target).getHandle(), fragments); return Optional.empty(); } else { throw new AssertionError("Unhandled target type: " + target.getClass().getName()); } }; } private static Function<Page, Page> enforceLayoutProcessor(List<Symbol> expectedLayout, Map<Symbol, Integer> inputLayout) { int[] channels = expectedLayout.stream() .mapToInt(inputLayout::get) .toArray(); if (Arrays.equals(channels, range(0, inputLayout.size()).toArray())) { // this is an identity mapping return Function.identity(); } return new PageChannelSelector(channels); } private static List<Integer> getChannelsForSymbols(List<Symbol> symbols, Map<Symbol, Integer> layout) { ImmutableList.Builder<Integer> builder = ImmutableList.builder(); for (Symbol symbol : symbols) { builder.add(layout.get(symbol)); } return builder.build(); } private static Function<Symbol, Integer> channelGetter(PhysicalOperation source) { return input -> { checkArgument(source.getLayout().containsKey(input)); return source.getLayout().get(input); }; } /** * Encapsulates an physical operator plus the mapping of logical symbols to channel/field */ private static class PhysicalOperation { private final List<OperatorFactory> operatorFactories; private final Map<Symbol, Integer> layout; private final List<Type> types; public PhysicalOperation(OperatorFactory operatorFactory, Map<Symbol, Integer> layout) { requireNonNull(operatorFactory, "operatorFactory is null"); requireNonNull(layout, "layout is null"); this.operatorFactories = ImmutableList.of(operatorFactory); this.layout = ImmutableMap.copyOf(layout); this.types = operatorFactory.getTypes(); } public PhysicalOperation(OperatorFactory operatorFactory, Map<Symbol, Integer> layout, PhysicalOperation source) { requireNonNull(operatorFactory, "operatorFactory is null"); requireNonNull(layout, "layout is null"); requireNonNull(source, "source is null"); this.operatorFactories = ImmutableList.<OperatorFactory>builder().addAll(source.getOperatorFactories()).add(operatorFactory).build(); this.layout = ImmutableMap.copyOf(layout); this.types = operatorFactory.getTypes(); } public int symbolToChannel(Symbol input) { checkArgument(layout.containsKey(input)); return layout.get(input); } public List<Type> getTypes() { return types; } public Map<Symbol, Integer> getLayout() { return layout; } private List<OperatorFactory> getOperatorFactories() { return operatorFactories; } } }