/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.parser; import com.facebook.presto.sql.SqlFormatter; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Statement; import com.google.common.io.Resources; import org.testng.annotations.Test; import java.io.IOException; import java.util.Optional; import static com.facebook.presto.sql.testing.TreeAssertions.assertFormattedSql; import static com.google.common.base.Strings.repeat; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; public class TestStatementBuilder { private static final SqlParser SQL_PARSER = new SqlParser(); @Test public void testStatementBuilder() throws Exception { printStatement("select * from foo"); printStatement("explain select * from foo"); printStatement("explain (type distributed, format graphviz) select * from foo"); printStatement("select * from foo /* end */"); printStatement("/* start */ select * from foo"); printStatement("/* start */ select * /* middle */ from foo /* end */"); printStatement("-- start\nselect * -- junk\n-- hi\nfrom foo -- done"); printStatement("select * from foo a (x, y, z)"); printStatement("select *, 123, * from foo"); printStatement("select show from foo"); printStatement("select extract(day from x), extract(dow from x) from y"); printStatement("select 1 + 13 || '15' from foo"); printStatement("select x is distinct from y from foo where a is not distinct from b"); printStatement("select x[1] from my_table"); printStatement("select x[1][2] from my_table"); printStatement("select x[cast(10 * sin(x) as bigint)] from my_table"); printStatement("select * from unnest(t.my_array)"); printStatement("select * from unnest(array[1, 2, 3])"); printStatement("select x from unnest(array[1, 2, 3]) t(x)"); printStatement("select * from users cross join unnest(friends)"); printStatement("select id, friend from users cross join unnest(friends) t(friend)"); printStatement("select * from unnest(t.my_array) with ordinality"); printStatement("select * from unnest(array[1, 2, 3]) with ordinality"); printStatement("select x from unnest(array[1, 2, 3]) with ordinality t(x)"); printStatement("select * from users cross join unnest(friends) with ordinality"); printStatement("select id, friend from users cross join unnest(friends) with ordinality t(friend)"); printStatement("select count(*) x from src group by k, v"); printStatement("select count(*) x from src group by cube (k, v)"); printStatement("select count(*) x from src group by rollup (k, v)"); printStatement("select count(*) x from src group by grouping sets ((k, v))"); printStatement("select count(*) x from src group by grouping sets ((k, v), (v))"); printStatement("select count(*) x from src group by grouping sets (k, v, k)"); printStatement("select count(*) filter (where x > 4) y from t"); printStatement("select sum(x) filter (where x > 4) y from t"); printStatement("select sum(x) filter (where x > 4) y, sum(x) filter (where x < 2) z from t"); printStatement("select sum(distinct x) filter (where x > 4) y, sum(x) filter (where x < 2) z from t"); printStatement("select sum(x) filter (where x > 4) over (partition by y) z from t"); printStatement("" + "select depname, empno, salary\n" + ", count(*) over ()\n" + ", avg(salary) over (partition by depname)\n" + ", rank() over (partition by depname order by salary desc)\n" + ", sum(salary) over (order by salary rows unbounded preceding)\n" + ", sum(salary) over (partition by depname order by salary rows between current row and 3 following)\n" + ", sum(salary) over (partition by depname range unbounded preceding)\n" + ", sum(salary) over (rows between 2 preceding and unbounded following)\n" + "from emp"); printStatement("" + "with a (id) as (with x as (select 123 from z) select * from x) " + " , b (id) as (select 999 from z) " + "select * from a join b using (id)"); printStatement("with recursive t as (select * from x) select * from t"); printStatement("select * from information_schema.tables"); printStatement("show catalogs"); printStatement("show schemas"); printStatement("show schemas from sys"); printStatement("show tables"); printStatement("show tables from information_schema"); printStatement("show tables like '%'"); printStatement("show tables from information_schema like '%'"); printStatement("show partitions from foo"); printStatement("show partitions from foo where name = 'foo'"); printStatement("show partitions from foo order by x"); printStatement("show partitions from foo limit 10"); printStatement("show partitions from foo limit all"); printStatement("show partitions from foo order by x desc limit 10"); printStatement("show partitions from foo order by x desc limit all"); printStatement("show functions"); printStatement("select cast('123' as bigint), try_cast('foo' as bigint)"); printStatement("select * from a.b.c"); printStatement("select * from a.b.c.e.f.g"); printStatement("select \"TOTALPRICE\" \"my price\" from \"$MY\"\"ORDERS\""); printStatement("select * from foo tablesample system (10+1)"); printStatement("select * from foo tablesample system (10) join bar tablesample bernoulli (30) on a.id = b.id"); printStatement("select * from foo tablesample system (10) join bar tablesample bernoulli (30) on not(a.id > b.id)"); printStatement("create table foo as (select * from abc)"); printStatement("create table if not exists foo as (select * from abc)"); printStatement("create table foo with (a = 'apple', b = 'banana') as select * from abc"); printStatement("create table foo comment 'test' with (a = 'apple') as select * from abc"); printStatement("create table foo as select * from abc WITH NO DATA"); printStatement("create table foo as (with t(x) as (values 1) select x from t)"); printStatement("create table if not exists foo as (with t(x) as (values 1) select x from t)"); printStatement("create table foo as (with t(x) as (values 1) select x from t) WITH DATA"); printStatement("create table if not exists foo as (with t(x) as (values 1) select x from t) WITH DATA"); printStatement("create table foo as (with t(x) as (values 1) select x from t) WITH NO DATA"); printStatement("create table if not exists foo as (with t(x) as (values 1) select x from t) WITH NO DATA"); printStatement("drop table foo"); printStatement("insert into foo select * from abc"); printStatement("delete from foo"); printStatement("delete from foo where a = b"); printStatement("values ('a', 1, 2.2), ('b', 2, 3.3)"); printStatement("table foo"); printStatement("table foo order by x limit 10"); printStatement("(table foo)"); printStatement("(table foo) limit 10"); printStatement("(table foo limit 5) limit 10"); printStatement("select * from a limit all"); printStatement("select * from a order by x limit all"); printStatement("select * from a union select * from b"); printStatement("table a union all table b"); printStatement("(table foo) union select * from foo union (table foo order by x)"); printStatement("table a union table b intersect table c"); printStatement("(table a union table b) intersect table c"); printStatement("table a union table b except table c intersect table d"); printStatement("(table a union table b except table c) intersect table d"); printStatement("((table a union table b) except table c) intersect table d"); printStatement("(table a union (table b except table c)) intersect table d"); printStatement("table a intersect table b union table c"); printStatement("table a intersect (table b union table c)"); printStatement("alter table foo rename to bar"); printStatement("alter table a.b.c rename to d.e.f"); printStatement("alter table a.b.c rename column x to y"); printStatement("alter table a.b.c add column x bigint"); printStatement("create schema test"); printStatement("create schema if not exists test"); printStatement("create schema test with (a = 'apple', b = 123)"); printStatement("drop schema test"); printStatement("drop schema test cascade"); printStatement("drop schema if exists test"); printStatement("drop schema if exists test restrict"); printStatement("alter schema foo rename to bar"); printStatement("alter schema foo.bar rename to baz"); printStatement("create table test (a boolean, b bigint, c double, d varchar, e timestamp)"); printStatement("create table test (a boolean, b bigint comment 'test')"); printStatement("create table if not exists baz (a timestamp, b varchar)"); printStatement("create table test (a boolean, b bigint) with (a = 'apple', b = 'banana')"); printStatement("create table test (a boolean, b bigint) comment 'test' with (a = 'apple')"); printStatement("drop table test"); printStatement("create view foo as with a as (select 123) select * from a"); printStatement("create or replace view foo as select 123 from t"); printStatement("drop view foo"); printStatement("insert into t select * from t"); printStatement("insert into t (c1, c2) select * from t"); printStatement("start transaction"); printStatement("start transaction isolation level read uncommitted"); printStatement("start transaction isolation level read committed"); printStatement("start transaction isolation level repeatable read"); printStatement("start transaction isolation level serializable"); printStatement("start transaction read only"); printStatement("start transaction read write"); printStatement("start transaction isolation level read committed, read only"); printStatement("start transaction read only, isolation level read committed"); printStatement("start transaction read write, isolation level serializable"); printStatement("commit"); printStatement("commit work"); printStatement("rollback"); printStatement("rollback work"); printStatement("call foo()"); printStatement("call foo(123, a => 1, b => 'go', 456)"); printStatement("grant select on foo to alice with grant option"); printStatement("grant all privileges on foo to alice"); printStatement("grant delete, select on foo to public"); printStatement("revoke grant option for select on foo from alice"); printStatement("revoke all privileges on foo from alice"); printStatement("revoke insert, delete on foo from public"); //check support for public printStatement("show grants on table t"); printStatement("show grants on t"); printStatement("show grants"); printStatement("prepare p from select * from (select * from T) \"A B\""); printStatement("SELECT * FROM table1 WHERE a >= ALL (VALUES 2, 3, 4)"); printStatement("SELECT * FROM table1 WHERE a <> ANY (SELECT 2, 3, 4)"); printStatement("SELECT * FROM table1 WHERE a = SOME (SELECT id FROM table2)"); } @Test public void testStringFormatter() throws Exception { assertSqlFormatter("U&'hello\\6d4B\\8Bd5\\+10FFFFworld\\7F16\\7801'", "U&'hello\\6D4B\\8BD5\\+10FFFFworld\\7F16\\7801'"); assertSqlFormatter("'hello world'", "'hello world'"); assertSqlFormatter("U&'!+10FFFF!6d4B!8Bd5ABC!6d4B!8Bd5' UESCAPE '!'", "U&'\\+10FFFF\\6D4B\\8BD5ABC\\6D4B\\8BD5'"); assertSqlFormatter("U&'\\+10FFFF\\6D4B\\8BD5\\0041\\0042\\0043\\6D4B\\8BD5'", "U&'\\+10FFFF\\6D4B\\8BD5ABC\\6D4B\\8BD5'"); assertSqlFormatter("U&'\\\\abc\\6D4B'''", "U&'\\\\abc\\6D4B'''"); } @Test public void testStatementBuilderTpch() throws Exception { printTpchQuery(1, 3); printTpchQuery(2, 33, "part type like", "region name"); printTpchQuery(3, "market segment", "2013-03-05"); printTpchQuery(4, "2013-03-05"); printTpchQuery(5, "region name", "2013-03-05"); printTpchQuery(6, "2013-03-05", 33, 44); printTpchQuery(7, "nation name 1", "nation name 2"); printTpchQuery(8, "nation name", "region name", "part type"); printTpchQuery(9, "part name like"); printTpchQuery(10, "2013-03-05"); printTpchQuery(11, "nation name", 33); printTpchQuery(12, "ship mode 1", "ship mode 2", "2013-03-05"); printTpchQuery(13, "comment like 1", "comment like 2"); printTpchQuery(14, "2013-03-05"); // query 15: views not supported printTpchQuery(16, "part brand", "part type like", 3, 4, 5, 6, 7, 8, 9, 10); printTpchQuery(17, "part brand", "part container"); printTpchQuery(18, 33); printTpchQuery(19, "part brand 1", "part brand 2", "part brand 3", 11, 22, 33); printTpchQuery(20, "part name like", "2013-03-05", "nation name"); printTpchQuery(21, "nation name"); printTpchQuery(22, "phone 1", "phone 2", "phone 3", "phone 4", "phone 5", "phone 6", "phone 7"); } private static void printStatement(String sql) { println(sql.trim()); println(""); Statement statement = SQL_PARSER.createStatement(sql); println(statement.toString()); println(""); println(SqlFormatter.formatSql(statement, Optional.empty())); println(""); assertFormattedSql(SQL_PARSER, statement); println(repeat("=", 60)); println(""); } private static void assertSqlFormatter(String expression, String formatted) { Expression originalExpression = SQL_PARSER.createExpression(expression); String real = SqlFormatter.formatSql(originalExpression, Optional.empty()); assertTrue(real.equals(formatted)); } private static void println(String s) { if (Boolean.parseBoolean(System.getProperty("printParse"))) { System.out.println(s); } } private static String getTpchQuery(int q) throws IOException { return readResource("tpch/queries/" + q + ".sql"); } private static void printTpchQuery(int query, Object... values) throws IOException { String sql = getTpchQuery(query); for (int i = values.length - 1; i >= 0; i--) { sql = sql.replaceAll(format(":%s", i + 1), String.valueOf(values[i])); } assertFalse(sql.matches("(?s).*:[0-9].*"), "Not all bind parameters were replaced: " + sql); sql = fixTpchQuery(sql); printStatement(sql); } private static String readResource(String name) throws IOException { return Resources.toString(Resources.getResource(name), UTF_8); } private static String fixTpchQuery(String s) { s = s.replaceFirst("(?m);$", ""); s = s.replaceAll("(?m)^:[xo]$", ""); s = s.replaceAll("(?m)^:n -1$", ""); s = s.replaceAll("(?m)^:n ([0-9]+)$", "LIMIT $1"); s = s.replace("day (3)", "day"); // for query 1 return s; } }