/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.assertions;
import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.block.SortOrder;
import com.facebook.presto.spi.predicate.Domain;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.AggregationNode.Step;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.ExceptNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.IntersectNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.OutputNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SortNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.WindowFrame;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences;
import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH;
import static com.facebook.presto.sql.planner.assertions.MatchResult.match;
import static com.facebook.presto.sql.planner.assertions.StrictAssignedSymbolsMatcher.actualAssignments;
import static com.facebook.presto.sql.planner.assertions.StrictSymbolsMatcher.actualOutputs;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
public final class PlanMatchPattern
{
private final List<Matcher> matchers = new ArrayList<>();
private final List<PlanMatchPattern> sourcePatterns;
private boolean anyTree;
public static PlanMatchPattern node(Class<? extends PlanNode> nodeClass, PlanMatchPattern... sources)
{
return any(sources).with(new PlanNodeMatcher(nodeClass));
}
public static PlanMatchPattern any(PlanMatchPattern... sources)
{
return new PlanMatchPattern(ImmutableList.copyOf(sources));
}
/**
* Matches to any tree of nodes with children matching to given source matchers.
* anyNodeTree(tableScanNode("nation")) - will match to any plan which all leafs contain
* any node containing table scan from nation table.
*
* @note anyTree does not match zero nodes. E.g. output(anyTree(tableScan)) will NOT match TableScan node followed by OutputNode.
*/
public static PlanMatchPattern anyTree(PlanMatchPattern... sources)
{
return any(sources).matchToAnyNodeTree();
}
public static PlanMatchPattern anyNot(Class<? extends PlanNode> excludeNodeClass, PlanMatchPattern... sources)
{
return any(sources).with(new NotPlanNodeMatcher(excludeNodeClass));
}
public static PlanMatchPattern tableScan(String expectedTableName)
{
return node(TableScanNode.class).with(new TableScanMatcher(expectedTableName));
}
public static PlanMatchPattern tableScan(String expectedTableName, Map<String, String> columnReferences)
{
PlanMatchPattern result = tableScan(expectedTableName);
return result.addColumnReferences(expectedTableName, columnReferences);
}
public static PlanMatchPattern strictTableScan(String expectedTableName, Map<String, String> columnReferences)
{
return tableScan(expectedTableName, columnReferences)
.withExactAssignedOutputs(columnReferences.values().stream()
.map(columnName -> columnReference(expectedTableName, columnName))
.collect(toImmutableList()));
}
public static PlanMatchPattern constrainedTableScan(String expectedTableName, Map<String, Domain> constraint)
{
return node(TableScanNode.class).with(new TableScanMatcher(expectedTableName, constraint));
}
public static PlanMatchPattern constrainedTableScan(String expectedTableName, Map<String, Domain> constraint, Map<String, String> columnReferences)
{
PlanMatchPattern result = constrainedTableScan(expectedTableName, constraint);
return result.addColumnReferences(expectedTableName, columnReferences);
}
private PlanMatchPattern addColumnReferences(String expectedTableName, Map<String, String> columnReferences)
{
columnReferences.entrySet().forEach(
reference -> withAlias(reference.getKey(), columnReference(expectedTableName, reference.getValue())));
return this;
}
public static PlanMatchPattern aggregation(
Map<String, ExpectedValueProvider<FunctionCall>> aggregations,
PlanMatchPattern source)
{
PlanMatchPattern result = node(AggregationNode.class, source);
aggregations.entrySet().forEach(
aggregation -> result.withAlias(aggregation.getKey(), new AggregationFunctionMatcher(aggregation.getValue())));
return result;
}
public static PlanMatchPattern aggregation(
List<List<String>> groupingSets,
Map<Optional<String>, ExpectedValueProvider<FunctionCall>> aggregations,
Map<Symbol, Symbol> masks,
Optional<Symbol> groupId,
Step step,
PlanMatchPattern source)
{
PlanMatchPattern result = node(AggregationNode.class, source).with(new AggregationMatcher(groupingSets, masks, groupId, step));
aggregations.entrySet().forEach(
aggregation -> result.withAlias(aggregation.getKey(), new AggregationFunctionMatcher(aggregation.getValue())));
return result;
}
public static PlanMatchPattern window(
ExpectedValueProvider<WindowNode.Specification> specification,
List<ExpectedValueProvider<FunctionCall>> windowFunctions,
PlanMatchPattern source)
{
PlanMatchPattern result = node(WindowNode.class, source).with(new WindowMatcher(specification));
windowFunctions.forEach(
function -> result.withAlias(Optional.empty(), new WindowFunctionMatcher(function)));
return result;
}
public static PlanMatchPattern window(
ExpectedValueProvider<WindowNode.Specification> specification,
Map<String, ExpectedValueProvider<FunctionCall>> assignments,
PlanMatchPattern source)
{
PlanMatchPattern result = node(WindowNode.class, source).with(new WindowMatcher(specification));
assignments.entrySet().forEach(
assignment -> result.withAlias(assignment.getKey(), new WindowFunctionMatcher(assignment.getValue())));
return result;
}
public static PlanMatchPattern sort(PlanMatchPattern source)
{
return node(SortNode.class, source);
}
public static PlanMatchPattern output(PlanMatchPattern source)
{
return node(OutputNode.class, source);
}
public static PlanMatchPattern output(List<String> outputs, PlanMatchPattern source)
{
PlanMatchPattern result = output(source);
result.withOutputs(outputs);
return result;
}
public static PlanMatchPattern strictOutput(List<String> outputs, PlanMatchPattern source)
{
return output(outputs, source).withExactOutputs(outputs);
}
public static PlanMatchPattern project(PlanMatchPattern source)
{
return node(ProjectNode.class, source);
}
public static PlanMatchPattern project(Map<String, ExpressionMatcher> assignments, PlanMatchPattern source)
{
PlanMatchPattern result = project(source);
assignments.entrySet().forEach(
assignment -> result.withAlias(assignment.getKey(), assignment.getValue()));
return result;
}
public static PlanMatchPattern strictProject(Map<String, ExpressionMatcher> assignments, PlanMatchPattern source)
{
/*
* Under the current implementation of project, all of the outputs are also in the assignment.
* If the implementation changes, this will need to change too.
*/
return project(assignments, source)
.withExactAssignedOutputs(assignments.values())
.withExactAssignments(assignments.values());
}
public static PlanMatchPattern semiJoin(String sourceSymbolAlias, String filteringSymbolAlias, String outputAlias, PlanMatchPattern source, PlanMatchPattern filtering)
{
return node(SemiJoinNode.class, source, filtering).with(new SemiJoinMatcher(sourceSymbolAlias, filteringSymbolAlias, outputAlias));
}
public static PlanMatchPattern join(JoinNode.Type joinType, List<ExpectedValueProvider<JoinNode.EquiJoinClause>> expectedEquiCriteria, PlanMatchPattern left, PlanMatchPattern right)
{
return join(joinType, expectedEquiCriteria, Optional.empty(), left, right);
}
public static PlanMatchPattern join(JoinNode.Type joinType, List<ExpectedValueProvider<JoinNode.EquiJoinClause>> expectedEquiCriteria, Optional<String> expectedFilter, PlanMatchPattern left, PlanMatchPattern right)
{
return node(JoinNode.class, left, right).with(
new JoinMatcher(
joinType,
expectedEquiCriteria,
expectedFilter.map(predicate -> rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(predicate)))));
}
public static PlanMatchPattern exchange(PlanMatchPattern... sources)
{
return node(ExchangeNode.class, sources);
}
public static PlanMatchPattern exchange(ExchangeNode.Scope scope, ExchangeNode.Type type, PlanMatchPattern... sources)
{
return node(ExchangeNode.class, sources)
.with(new ExchangeMatcher(scope, type));
}
public static PlanMatchPattern union(PlanMatchPattern... sources)
{
return node(UnionNode.class, sources);
}
public static PlanMatchPattern intersect(PlanMatchPattern... sources)
{
return node(IntersectNode.class, sources);
}
public static PlanMatchPattern except(PlanMatchPattern... sources)
{
return node(ExceptNode.class, sources);
}
public static ExpectedValueProvider<JoinNode.EquiJoinClause> equiJoinClause(String left, String right)
{
return new EquiJoinClauseProvider(new SymbolAlias(left), new SymbolAlias(right));
}
public static SymbolAlias symbol(String alias)
{
return new SymbolAlias(alias);
}
public static PlanMatchPattern filter(String predicate, PlanMatchPattern source)
{
Expression expectedPredicate = rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(predicate));
return node(FilterNode.class, source).with(new FilterMatcher(expectedPredicate));
}
public static PlanMatchPattern apply(List<String> correlationSymbolAliases, Map<String, ExpressionMatcher> subqueryAssignments, PlanMatchPattern inputPattern, PlanMatchPattern subqueryPattern)
{
PlanMatchPattern result = node(ApplyNode.class, inputPattern, subqueryPattern)
.with(new CorrelationMatcher(correlationSymbolAliases));
subqueryAssignments.entrySet().forEach(
assignment -> result.withAlias(assignment.getKey(), assignment.getValue()));
return result;
}
public static PlanMatchPattern groupingSet(List<List<String>> groups, String groupIdAlias, PlanMatchPattern source)
{
return node(GroupIdNode.class, source).with(new GroupIdMatcher(groups, ImmutableMap.of(), groupIdAlias));
}
public static PlanMatchPattern values(Map<String, Integer> values)
{
PlanMatchPattern result = node(ValuesNode.class);
values.entrySet().forEach(
alias -> result.withAlias(alias.getKey(), new ValuesMatcher(alias.getValue())));
return result;
}
public static PlanMatchPattern limit(long limit, PlanMatchPattern source)
{
return node(LimitNode.class, source).with(new LimitMatcher(limit));
}
public PlanMatchPattern(List<PlanMatchPattern> sourcePatterns)
{
requireNonNull(sourcePatterns, "sourcePatterns are null");
this.sourcePatterns = ImmutableList.copyOf(sourcePatterns);
}
List<PlanMatchingState> shapeMatches(PlanNode node)
{
ImmutableList.Builder<PlanMatchingState> states = ImmutableList.builder();
if (anyTree) {
int sourcesCount = node.getSources().size();
if (sourcesCount > 1) {
states.add(new PlanMatchingState(nCopies(sourcesCount, this)));
}
else {
states.add(new PlanMatchingState(ImmutableList.of(this)));
}
}
if (node.getSources().size() == sourcePatterns.size() && matchers.stream().allMatch(it -> it.shapeMatches(node))) {
states.add(new PlanMatchingState(sourcePatterns));
}
return states.build();
}
MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases)
{
SymbolAliases.Builder newAliases = SymbolAliases.builder();
for (Matcher matcher : matchers) {
MatchResult matchResult = matcher.detailMatches(node, session, metadata, symbolAliases);
if (!matchResult.isMatch()) {
return NO_MATCH;
}
newAliases.putAll(matchResult.getAliases());
}
return match(newAliases.build());
}
public PlanMatchPattern with(Matcher matcher)
{
matchers.add(matcher);
return this;
}
public PlanMatchPattern withAlias(String alias, RvalueMatcher matcher)
{
return withAlias(Optional.of(alias), matcher);
}
public PlanMatchPattern withAlias(Optional<String> alias, RvalueMatcher matcher)
{
matchers.add(new Alias(alias, matcher));
return this;
}
public PlanMatchPattern withNumberOfOutputColumns(int numberOfSymbols)
{
matchers.add(new SymbolCardinalityMatcher(numberOfSymbols));
return this;
}
/*
* This is useful if you already know the bindings for the aliases you expect to find
* in the outputs. This is the case for symbols that are produced by a direct or indirect
* source of the node you're applying this to.
*/
public PlanMatchPattern withExactOutputs(List<String> expectedAliases)
{
matchers.add(new StrictSymbolsMatcher(actualOutputs(), expectedAliases));
return this;
}
/*
* withExactAssignments and withExactAssignedOutputs are needed for matching symbols
* that are produced in the node that you're matching. The name of the symbol bound to
* the alias is *not* known when the Matcher is run, and so you need to match by what
* is being assigned to it.
*/
public PlanMatchPattern withExactAssignedOutputs(Collection<? extends RvalueMatcher> expectedAliases)
{
matchers.add(new StrictAssignedSymbolsMatcher(actualOutputs(), expectedAliases));
return this;
}
public PlanMatchPattern withExactAssignments(Collection<? extends RvalueMatcher> expectedAliases)
{
matchers.add(new StrictAssignedSymbolsMatcher(actualAssignments(), expectedAliases));
return this;
}
public static RvalueMatcher columnReference(String tableName, String columnName)
{
return new ColumnReference(tableName, columnName);
}
public static ExpressionMatcher expression(String expression)
{
return new ExpressionMatcher(expression);
}
public PlanMatchPattern withOutputs(List<String> aliases)
{
matchers.add(new OutputMatcher(aliases));
return this;
}
public PlanMatchPattern matchToAnyNodeTree()
{
anyTree = true;
return this;
}
public boolean isTerminated()
{
return sourcePatterns.isEmpty();
}
public static PlanTestSymbol anySymbol()
{
return new AnySymbol();
}
public static ExpectedValueProvider<FunctionCall> functionCall(String name, List<String> args)
{
return new FunctionCallProvider(QualifiedName.of(name), toSymbolAliases(args));
}
public static ExpectedValueProvider<FunctionCall> functionCall(
String name,
Optional<WindowFrame> frame,
List<String> args)
{
return new FunctionCallProvider(QualifiedName.of(name), frame, false, toSymbolAliases(args));
}
public static ExpectedValueProvider<FunctionCall> functionCall(
String name,
boolean distinct,
List<PlanTestSymbol> args)
{
return new FunctionCallProvider(QualifiedName.of(name), distinct, args);
}
public static List<Expression> toSymbolReferences(List<PlanTestSymbol> aliases, SymbolAliases symbolAliases)
{
return aliases
.stream()
.map(arg -> arg.toSymbol(symbolAliases).toSymbolReference())
.collect(toImmutableList());
}
private static List<PlanTestSymbol> toSymbolAliases(List<String> aliases)
{
return aliases
.stream()
.map(PlanMatchPattern::symbol)
.collect(toImmutableList());
}
public static ExpectedValueProvider<WindowNode.Specification> specification(
List<String> partitionBy,
List<String> orderBy,
Map<String, SortOrder> orderings)
{
return new SpecificationProvider(
partitionBy
.stream()
.map(SymbolAlias::new)
.collect(toImmutableList()),
orderBy
.stream()
.map(SymbolAlias::new)
.collect(toImmutableList()),
orderings
.entrySet()
.stream()
.collect(toImmutableMap(entry -> new SymbolAlias(entry.getKey()), Map.Entry::getValue)));
}
@Override
public String toString()
{
StringBuilder builder = new StringBuilder();
toString(builder, 0);
return builder.toString();
}
private void toString(StringBuilder builder, int indent)
{
checkState(matchers.stream().filter(PlanNodeMatcher.class::isInstance).count() <= 1);
builder.append(indentString(indent)).append("- ");
if (anyTree) {
builder.append("anyTree");
}
else {
builder.append("node");
}
Optional<PlanNodeMatcher> planNodeMatcher = matchers.stream()
.filter(PlanNodeMatcher.class::isInstance)
.map(PlanNodeMatcher.class::cast)
.findFirst();
if (planNodeMatcher.isPresent()) {
builder.append("(").append(planNodeMatcher.get().getNodeClass().getSimpleName()).append(")");
}
builder.append("\n");
List<Matcher> matchersToPrint = matchers.stream()
.filter(matcher -> !(matcher instanceof PlanNodeMatcher))
.collect(toImmutableList());
for (Matcher matcher : matchersToPrint) {
builder.append(indentString(indent + 1)).append(matcher.toString()).append("\n");
}
for (PlanMatchPattern pattern : sourcePatterns) {
pattern.toString(builder, indent + 1);
}
}
private static String indentString(int indent)
{
return Strings.repeat(" ", indent);
}
}