/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.iterative.rule.test; import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; import com.google.common.collect.ImmutableSet; import java.util.Map; import java.util.Optional; import java.util.function.Function; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan; import static com.google.common.base.Preconditions.checkArgument; import static org.testng.Assert.fail; public class RuleAssert { private final Metadata metadata; private Session session; private final Rule rule; private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); private Map<Symbol, Type> symbols; private PlanNode plan; public RuleAssert(Metadata metadata, Session session, Rule rule) { this.metadata = metadata; this.session = session; this.rule = rule; } public RuleAssert setSystemProperty(String key, String value) { return withSession(Session.builder(session) .setSystemProperty(key, value) .build()); } public RuleAssert withSession(Session session) { this.session = session; return this; } public RuleAssert on(Function<PlanBuilder, PlanNode> planProvider) { checkArgument(plan == null, "plan has already been set"); PlanBuilder builder = new PlanBuilder(idAllocator, metadata); plan = planProvider.apply(builder); symbols = builder.getSymbols(); return this; } public void doesNotFire() { SymbolAllocator symbolAllocator = new SymbolAllocator(symbols); Optional<PlanNode> result = rule.apply(plan, x -> x, idAllocator, symbolAllocator, session); if (result.isPresent()) { fail(String.format( "Expected %s to not fire for:\n%s", rule.getClass().getName(), PlanPrinter.textLogicalPlan(plan, symbolAllocator.getTypes(), metadata, session, 2))); } } public void matches(PlanMatchPattern pattern) { SymbolAllocator symbolAllocator = new SymbolAllocator(symbols); Optional<PlanNode> result = rule.apply(plan, x -> x, idAllocator, symbolAllocator, session); Map<Symbol, Type> types = symbolAllocator.getTypes(); if (!result.isPresent()) { fail(String.format( "%s did not fire for:\n%s", rule.getClass().getName(), PlanPrinter.textLogicalPlan(plan, types, metadata, session, 2))); } PlanNode actual = result.get(); if (actual == plan) { // plans are not comparable, so we can only ensure they are not the same instance fail(String.format( "%s: rule fired but return the original plan:\n%s", rule.getClass().getName(), PlanPrinter.textLogicalPlan(plan, types, metadata, session, 2))); } if (!ImmutableSet.copyOf(plan.getOutputSymbols()).equals(ImmutableSet.copyOf(actual.getOutputSymbols()))) { fail(String.format( "%s: output schema of transformed and original plans are not equivalent\n" + "\texpected: %s\n" + "\tactual: %s", rule.getClass().getName(), plan.getOutputSymbols(), actual.getOutputSymbols())); } assertPlan(session, metadata, new Plan(actual, types), pattern); } }