/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.iterative.rule.test; import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TestingTableHandle; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.TableWriterNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.function.Consumer; import java.util.stream.Stream; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; public class PlanBuilder { private final PlanNodeIdAllocator idAllocator; private final Metadata metadata; private final Map<Symbol, Type> symbols = new HashMap<>(); public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata) { this.idAllocator = idAllocator; this.metadata = metadata; } public ValuesNode values(Symbol... columns) { return new ValuesNode( idAllocator.getNextId(), ImmutableList.copyOf(columns), ImmutableList.of()); } public ValuesNode values(List<Symbol> columns, List<List<Expression>> rows) { return new ValuesNode(idAllocator.getNextId(), columns, rows); } public LimitNode limit(long limit, PlanNode source) { return new LimitNode(idAllocator.getNextId(), source, limit, false); } public SampleNode sample(double sampleRatio, SampleNode.Type type, PlanNode source) { return new SampleNode(idAllocator.getNextId(), source, sampleRatio, type); } public ProjectNode project(Assignments assignments, PlanNode source) { return new ProjectNode(idAllocator.getNextId(), source, assignments); } public FilterNode filter(Expression predicate, PlanNode source) { return new FilterNode(idAllocator.getNextId(), source, predicate); } public AggregationNode aggregation(Consumer<AggregationBuilder> aggregationBuilderConsumer) { AggregationBuilder aggregationBuilder = new AggregationBuilder(); aggregationBuilderConsumer.accept(aggregationBuilder); return aggregationBuilder.build(); } public class AggregationBuilder { private PlanNode source; private Map<Symbol, Aggregation> assignments = new HashMap<>(); private List<List<Symbol>> groupingSets; private Step step; private Optional<Symbol> hashSymbol = Optional.empty(); private Optional<Symbol> groupIdSymbol = Optional.empty(); public AggregationBuilder source(PlanNode source) { this.source = source; return this; } public AggregationBuilder addAggregation(Symbol output, Expression expression, List<Type> inputTypes) { checkArgument(expression instanceof FunctionCall); FunctionCall aggregation = (FunctionCall) expression; Signature signature = metadata.getFunctionRegistry().resolveFunction(aggregation.getName(), TypeSignatureProvider.fromTypes(inputTypes)); return addAggregation(output, new Aggregation(aggregation, signature, Optional.empty())); } public AggregationBuilder addAggregation(Symbol output, Aggregation aggregation) { assignments.put(output, aggregation); return this; } public AggregationBuilder globalGrouping() { return groupingSets(ImmutableList.of(ImmutableList.of())); } public AggregationBuilder groupingSets(List<List<Symbol>> groupingSets) { this.groupingSets = ImmutableList.copyOf(groupingSets); return this; } public AggregationBuilder step(Step step) { this.step = step; return this; } public AggregationBuilder hashSymbol(Symbol hashSymbol) { this.hashSymbol = Optional.of(hashSymbol); return this; } public AggregationBuilder groupIdSymbol(Symbol groupIdSymbol) { this.groupIdSymbol = Optional.of(groupIdSymbol); return this; } protected AggregationNode build() { return new AggregationNode( idAllocator.getNextId(), source, assignments, groupingSets, step, hashSymbol, groupIdSymbol); } } public ApplyNode apply(Assignments subqueryAssignments, List<Symbol> correlation, PlanNode input, PlanNode subquery) { return new ApplyNode(idAllocator.getNextId(), input, subquery, subqueryAssignments, correlation); } public TableScanNode tableScan(List<Symbol> symbols, Map<Symbol, ColumnHandle> assignments) { Expression originalConstraint = null; return new TableScanNode( idAllocator.getNextId(), new TableHandle( new ConnectorId("testConnector"), new TestingTableHandle()), symbols, assignments, Optional.empty(), TupleDomain.all(), originalConstraint ); } public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode deleteSource, Symbol deleteRowId) { TableWriterNode.DeleteHandle deleteHandle = new TableWriterNode.DeleteHandle( new TableHandle( new ConnectorId("testConnector"), new TestingTableHandle()), schemaTableName ); return new TableFinishNode( idAllocator.getNextId(), exchange(e -> e .addSource(new DeleteNode( idAllocator.getNextId(), deleteSource, deleteHandle, deleteRowId, ImmutableList.of(deleteRowId) )) .addInputsSet(deleteRowId) .singleDistributionPartitioningScheme(deleteRowId) ), deleteHandle, ImmutableList.of(deleteRowId) ); } public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child) { return exchange(builder -> builder.type(ExchangeNode.Type.GATHER) .scope(scope) .singleDistributionPartitioningScheme(child.getOutputSymbols()) .addSource(child) .addInputsSet(child.getOutputSymbols())); } public ExchangeNode exchange(Consumer<ExchangeBuilder> exchangeBuilderConsumer) { ExchangeBuilder exchangeBuilder = new ExchangeBuilder(); exchangeBuilderConsumer.accept(exchangeBuilder); return exchangeBuilder.build(); } public class ExchangeBuilder { private ExchangeNode.Type type = ExchangeNode.Type.GATHER; private ExchangeNode.Scope scope = ExchangeNode.Scope.REMOTE; private PartitioningScheme partitioningScheme; private List<PlanNode> sources = new ArrayList<>(); private List<List<Symbol>> inputs = new ArrayList<>(); public ExchangeBuilder type(ExchangeNode.Type type) { this.type = type; return this; } public ExchangeBuilder scope(ExchangeNode.Scope scope) { this.scope = scope; return this; } public ExchangeBuilder singleDistributionPartitioningScheme(Symbol... outputSymbols) { return singleDistributionPartitioningScheme(Arrays.asList(outputSymbols)); } public ExchangeBuilder singleDistributionPartitioningScheme(List<Symbol> outputSymbols) { return partitioningScheme(new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), outputSymbols)); } public ExchangeBuilder fixedHashDistributionParitioningScheme(List<Symbol> outputSymbols, List<Symbol> partitioningSymbols) { return partitioningScheme(new PartitioningScheme(Partitioning.create(FIXED_HASH_DISTRIBUTION, ImmutableList.copyOf(partitioningSymbols)), ImmutableList.copyOf(outputSymbols))); } public ExchangeBuilder fixedHashDistributionParitioningScheme(List<Symbol> outputSymbols, List<Symbol> partitioningSymbols, Symbol hashSymbol) { return partitioningScheme(new PartitioningScheme(Partitioning.create(FIXED_HASH_DISTRIBUTION, ImmutableList.copyOf(partitioningSymbols)), ImmutableList.copyOf(outputSymbols), Optional.of(hashSymbol))); } public ExchangeBuilder partitioningScheme(PartitioningScheme partitioningScheme) { this.partitioningScheme = partitioningScheme; return this; } public ExchangeBuilder addSource(PlanNode source) { this.sources.add(source); return this; } public ExchangeBuilder addInputsSet(Symbol... inputs) { return addInputsSet(Arrays.asList(inputs)); } public ExchangeBuilder addInputsSet(List<Symbol> inputs) { this.inputs.add(inputs); return this; } protected ExchangeNode build() { return new ExchangeNode(idAllocator.getNextId(), type, scope, partitioningScheme, sources, inputs); } } public Symbol symbol(String name, Type type) { Symbol symbol = new Symbol(name); Type old = symbols.get(symbol); if (old != null && !old.equals(type)) { throw new IllegalArgumentException(format("Symbol '%s' already registered with type '%s'", name, old)); } if (old == null) { symbols.put(symbol, type); } return symbol; } public WindowNode window(WindowNode.Specification specification, Map<Symbol, WindowNode.Function> functions, PlanNode source) { return new WindowNode( idAllocator.getNextId(), source, specification, ImmutableMap.copyOf(functions), Optional.empty(), ImmutableSet.of(), 0); } public static Expression expression(String sql) { return ExpressionUtils.rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(sql)); } public static List<Expression> expressions(String... expressions) { return Stream.of(expressions) .map(PlanBuilder::expression) .collect(toImmutableList()); } public Map<Symbol, Type> getSymbols() { return Collections.unmodifiableMap(symbols); } }