/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.tree.Expression; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import org.intellij.lang.annotations.Language; import org.testng.annotations.Test; import java.util.Map; import java.util.Set; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static com.facebook.presto.sql.planner.DependencyExtractor.extractUnique; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; public class TestExpressionEquivalence { private static final SqlParser SQL_PARSER = new SqlParser(); private static final MetadataManager METADATA = MetadataManager.createTestMetadataManager(); private static final ExpressionEquivalence EQUIVALENCE = new ExpressionEquivalence(METADATA, SQL_PARSER); @Test public void testEquivalent() throws Exception { assertEquivalent("a_bigint < b_double", "b_double > a_bigint"); assertEquivalent("true", "true"); assertEquivalent("4", "4"); assertEquivalent("4.4", "4.4"); assertEquivalent("'foo'", "'foo'"); assertEquivalent("4 = 5", "5 = 4"); assertEquivalent("4.4 = 5.5", "5.5 = 4.4"); assertEquivalent("'foo' = 'bar'", "'bar' = 'foo'"); assertEquivalent("4 <> 5", "5 <> 4"); assertEquivalent("4 is distinct from 5", "5 is distinct from 4"); assertEquivalent("4 < 5", "5 > 4"); assertEquivalent("4 <= 5", "5 >= 4"); assertEquivalent("mod(4, 5)", "mod(4, 5)"); assertEquivalent("a_bigint", "a_bigint"); assertEquivalent("a_bigint = b_bigint", "b_bigint = a_bigint"); assertEquivalent("a_bigint < b_bigint", "b_bigint > a_bigint"); assertEquivalent("a_bigint < b_double", "b_double > a_bigint"); assertEquivalent("true and false", "false and true"); assertEquivalent("4 <= 5 and 6 < 7", "7 > 6 and 5 >= 4"); assertEquivalent("4 <= 5 or 6 < 7", "7 > 6 or 5 >= 4"); assertEquivalent("a_bigint <= b_bigint and c_bigint < d_bigint", "d_bigint > c_bigint and b_bigint >= a_bigint"); assertEquivalent("a_bigint <= b_bigint or c_bigint < d_bigint", "d_bigint > c_bigint or b_bigint >= a_bigint"); assertEquivalent("4 <= 5 and 4 <= 5", "4 <= 5"); assertEquivalent("4 <= 5 and 6 < 7", "7 > 6 and 5 >= 4 and 5 >= 4"); assertEquivalent("2 <= 3 and 4 <= 5 and 6 < 7", "7 > 6 and 5 >= 4 and 3 >= 2"); assertEquivalent("4 <= 5 or 4 <= 5", "4 <= 5"); assertEquivalent("4 <= 5 or 6 < 7", "7 > 6 or 5 >= 4 or 5 >= 4"); assertEquivalent("2 <= 3 or 4 <= 5 or 6 < 7", "7 > 6 or 5 >= 4 or 3 >= 2"); assertEquivalent("a_boolean and b_boolean and c_boolean", "c_boolean and b_boolean and a_boolean"); assertEquivalent("(a_boolean and b_boolean) and c_boolean", "(c_boolean and b_boolean) and a_boolean"); assertEquivalent("a_boolean and (b_boolean or c_boolean)", "a_boolean and (c_boolean or b_boolean) and a_boolean"); assertEquivalent( "(a_boolean or b_boolean or c_boolean) and (d_boolean or e_boolean) and (f_boolean or g_boolean or h_boolean)", "(h_boolean or g_boolean or f_boolean) and (b_boolean or a_boolean or c_boolean) and (e_boolean or d_boolean)"); assertEquivalent( "(a_boolean and b_boolean and c_boolean) or (d_boolean and e_boolean) or (f_boolean and g_boolean and h_boolean)", "(h_boolean and g_boolean and f_boolean) or (b_boolean and a_boolean and c_boolean) or (e_boolean and d_boolean)"); } private static void assertEquivalent(@Language("SQL") String left, @Language("SQL") String right) { Expression leftExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(left)); Expression rightExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(right)); Set<Symbol> symbols = extractUnique(ImmutableList.of(leftExpression, rightExpression)); Map<Symbol, Type> types = symbols.stream() .collect(toMap(identity(), TestExpressionEquivalence::generateType)); assertTrue( EQUIVALENCE.areExpressionsEquivalent(TEST_SESSION, leftExpression, rightExpression, types), String.format("Expected (%s) and (%s) to be equivalent", left, right)); assertTrue( EQUIVALENCE.areExpressionsEquivalent(TEST_SESSION, rightExpression, leftExpression, types), String.format("Expected (%s) and (%s) to be equivalent", right, left)); } @Test public void testNotEquivalent() throws Exception { assertNotEquivalent("true", "false"); assertNotEquivalent("4", "5"); assertNotEquivalent("4.4", "5.5"); assertNotEquivalent("'foo'", "'bar'"); assertNotEquivalent("4 = 5", "5 = 6"); assertNotEquivalent("4 <> 5", "5 <> 6"); assertNotEquivalent("4 is distinct from 5", "5 is distinct from 6"); assertNotEquivalent("4 < 5", "5 > 6"); assertNotEquivalent("4 <= 5", "5 >= 6"); assertNotEquivalent("mod(4, 5)", "mod(5, 4)"); assertNotEquivalent("a_bigint", "b_bigint"); assertNotEquivalent("a_bigint = b_bigint", "b_bigint = c_bigint"); assertNotEquivalent("a_bigint < b_bigint", "b_bigint > c_bigint"); assertNotEquivalent("a_bigint < b_double", "b_double > c_bigint"); assertNotEquivalent("4 <= 5 and 6 < 7", "7 > 6 and 5 >= 6"); assertNotEquivalent("4 <= 5 or 6 < 7", "7 > 6 or 5 >= 6"); assertNotEquivalent("a_bigint <= b_bigint and c_bigint < d_bigint", "d_bigint > c_bigint and b_bigint >= c_bigint"); assertNotEquivalent("a_bigint <= b_bigint or c_bigint < d_bigint", "d_bigint > c_bigint or b_bigint >= c_bigint"); } private static void assertNotEquivalent(@Language("SQL") String left, @Language("SQL") String right) { Expression leftExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(left)); Expression rightExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(right)); Set<Symbol> symbols = extractUnique(ImmutableList.of(leftExpression, rightExpression)); Map<Symbol, Type> types = symbols.stream() .collect(toMap(identity(), TestExpressionEquivalence::generateType)); assertFalse( EQUIVALENCE.areExpressionsEquivalent(TEST_SESSION, leftExpression, rightExpression, types), String.format("Expected (%s) and (%s) to not be equivalent", left, right)); assertFalse( EQUIVALENCE.areExpressionsEquivalent(TEST_SESSION, rightExpression, leftExpression, types), String.format("Expected (%s) and (%s) to not be equivalent", right, left)); } private static Type generateType(Symbol symbol) { String typeName = Splitter.on('_').limit(2).splitToList(symbol.getName()).get(1); return METADATA.getType(new TypeSignature(typeName, ImmutableList.of())); } }