/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.parser; import com.facebook.presto.sql.parser.SqlBaseParser.TablePropertiesContext; import com.facebook.presto.sql.parser.SqlBaseParser.TablePropertyContext; import com.facebook.presto.sql.tree.AddColumn; import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.AllColumns; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.ArithmeticUnaryExpression; import com.facebook.presto.sql.tree.ArrayConstructor; import com.facebook.presto.sql.tree.AtTimeZone; import com.facebook.presto.sql.tree.BetweenPredicate; import com.facebook.presto.sql.tree.BinaryLiteral; import com.facebook.presto.sql.tree.BindExpression; import com.facebook.presto.sql.tree.BooleanLiteral; import com.facebook.presto.sql.tree.Call; import com.facebook.presto.sql.tree.CallArgument; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CharLiteral; import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.ColumnDefinition; import com.facebook.presto.sql.tree.Commit; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.ComparisonExpressionType; import com.facebook.presto.sql.tree.CreateSchema; import com.facebook.presto.sql.tree.CreateTable; import com.facebook.presto.sql.tree.CreateTableAsSelect; import com.facebook.presto.sql.tree.CreateView; import com.facebook.presto.sql.tree.Cube; import com.facebook.presto.sql.tree.CurrentTime; import com.facebook.presto.sql.tree.Deallocate; import com.facebook.presto.sql.tree.DecimalLiteral; import com.facebook.presto.sql.tree.Delete; import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.DescribeInput; import com.facebook.presto.sql.tree.DescribeOutput; import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; import com.facebook.presto.sql.tree.DropView; import com.facebook.presto.sql.tree.Except; import com.facebook.presto.sql.tree.Execute; import com.facebook.presto.sql.tree.ExistsPredicate; import com.facebook.presto.sql.tree.Explain; import com.facebook.presto.sql.tree.ExplainFormat; import com.facebook.presto.sql.tree.ExplainOption; import com.facebook.presto.sql.tree.ExplainType; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Extract; import com.facebook.presto.sql.tree.FrameBound; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GenericLiteral; import com.facebook.presto.sql.tree.Grant; import com.facebook.presto.sql.tree.GroupBy; import com.facebook.presto.sql.tree.GroupingElement; import com.facebook.presto.sql.tree.GroupingSets; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.IfExpression; import com.facebook.presto.sql.tree.InListExpression; import com.facebook.presto.sql.tree.InPredicate; import com.facebook.presto.sql.tree.Insert; import com.facebook.presto.sql.tree.Intersect; import com.facebook.presto.sql.tree.IntervalLiteral; import com.facebook.presto.sql.tree.IsNotNullPredicate; import com.facebook.presto.sql.tree.IsNullPredicate; import com.facebook.presto.sql.tree.Isolation; import com.facebook.presto.sql.tree.Join; import com.facebook.presto.sql.tree.JoinCriteria; import com.facebook.presto.sql.tree.JoinOn; import com.facebook.presto.sql.tree.JoinUsing; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; import com.facebook.presto.sql.tree.LikeClause; import com.facebook.presto.sql.tree.LikePredicate; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.NaturalJoin; import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.NodeLocation; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.OrderBy; import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.sql.tree.Prepare; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.QueryBody; import com.facebook.presto.sql.tree.QuerySpecification; import com.facebook.presto.sql.tree.Relation; import com.facebook.presto.sql.tree.RenameColumn; import com.facebook.presto.sql.tree.RenameSchema; import com.facebook.presto.sql.tree.RenameTable; import com.facebook.presto.sql.tree.ResetSession; import com.facebook.presto.sql.tree.Revoke; import com.facebook.presto.sql.tree.Rollback; import com.facebook.presto.sql.tree.Rollup; import com.facebook.presto.sql.tree.Row; import com.facebook.presto.sql.tree.SampledRelation; import com.facebook.presto.sql.tree.SearchedCaseExpression; import com.facebook.presto.sql.tree.Select; import com.facebook.presto.sql.tree.SelectItem; import com.facebook.presto.sql.tree.SetSession; import com.facebook.presto.sql.tree.ShowCatalogs; import com.facebook.presto.sql.tree.ShowColumns; import com.facebook.presto.sql.tree.ShowCreate; import com.facebook.presto.sql.tree.ShowFunctions; import com.facebook.presto.sql.tree.ShowGrants; import com.facebook.presto.sql.tree.ShowPartitions; import com.facebook.presto.sql.tree.ShowSchemas; import com.facebook.presto.sql.tree.ShowSession; import com.facebook.presto.sql.tree.ShowTables; import com.facebook.presto.sql.tree.SimpleCaseExpression; import com.facebook.presto.sql.tree.SimpleGroupBy; import com.facebook.presto.sql.tree.SingleColumn; import com.facebook.presto.sql.tree.SortItem; import com.facebook.presto.sql.tree.StartTransaction; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.sql.tree.StringLiteral; import com.facebook.presto.sql.tree.SubqueryExpression; import com.facebook.presto.sql.tree.SubscriptExpression; import com.facebook.presto.sql.tree.Table; import com.facebook.presto.sql.tree.TableElement; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.TimeLiteral; import com.facebook.presto.sql.tree.TimestampLiteral; import com.facebook.presto.sql.tree.TransactionAccessMode; import com.facebook.presto.sql.tree.TransactionMode; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.tree.Union; import com.facebook.presto.sql.tree.Unnest; import com.facebook.presto.sql.tree.Use; import com.facebook.presto.sql.tree.Values; import com.facebook.presto.sql.tree.WhenClause; import com.facebook.presto.sql.tree.Window; import com.facebook.presto.sql.tree.WindowFrame; import com.facebook.presto.sql.tree.With; import com.facebook.presto.sql.tree.WithQuery; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import org.antlr.v4.runtime.ParserRuleContext; import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; import org.antlr.v4.runtime.tree.TerminalNode; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; class AstBuilder extends SqlBaseBaseVisitor<Node> { private int parameterPosition = 0; @Override public Node visitSingleStatement(SqlBaseParser.SingleStatementContext context) { return visit(context.statement()); } @Override public Node visitSingleExpression(SqlBaseParser.SingleExpressionContext context) { return visit(context.expression()); } // ******************* statements ********************** @Override public Node visitUse(SqlBaseParser.UseContext context) { return new Use(getLocation(context), getTextIfPresent(context.catalog), context.schema.getText()); } @Override public Node visitCreateSchema(SqlBaseParser.CreateSchemaContext context) { return new CreateSchema( getLocation(context), getQualifiedName(context.qualifiedName()), context.EXISTS() != null, processTableProperties(context.tableProperties())); } @Override public Node visitDropSchema(SqlBaseParser.DropSchemaContext context) { return new DropSchema( getLocation(context), getQualifiedName(context.qualifiedName()), context.EXISTS() != null, context.CASCADE() != null); } @Override public Node visitRenameSchema(SqlBaseParser.RenameSchemaContext context) { return new RenameSchema( getLocation(context), getQualifiedName(context.qualifiedName()), context.identifier().getText()); } @Override public Node visitCreateTableAsSelect(SqlBaseParser.CreateTableAsSelectContext context) { Optional<String> comment = Optional.empty(); if (context.COMMENT() != null) { comment = Optional.of(((StringLiteral) visit(context.string())).getValue()); } return new CreateTableAsSelect(getLocation(context), getQualifiedName(context.qualifiedName()), (Query) visit(context.query()), context.EXISTS() != null, processTableProperties(context.tableProperties()), context.NO() == null, comment); } @Override public Node visitCreateTable(SqlBaseParser.CreateTableContext context) { Optional<String> comment = Optional.empty(); if (context.COMMENT() != null) { comment = Optional.of(((StringLiteral) visit(context.string())).getValue()); } return new CreateTable(getLocation(context), getQualifiedName(context.qualifiedName()), visit(context.tableElement(), TableElement.class), context.EXISTS() != null, processTableProperties(context.tableProperties()), comment); } private Map<String, Expression> processTableProperties(TablePropertiesContext tablePropertiesContext) { ImmutableMap.Builder<String, Expression> properties = ImmutableMap.builder(); if (tablePropertiesContext != null) { for (TablePropertyContext tablePropertyContext : tablePropertiesContext.tableProperty()) { properties.put(tablePropertyContext.identifier().getText(), (Expression) visit(tablePropertyContext.expression())); } } return properties.build(); } @Override public Node visitShowCreateTable(SqlBaseParser.ShowCreateTableContext context) { return new ShowCreate(getLocation(context), ShowCreate.Type.TABLE, getQualifiedName(context.qualifiedName())); } @Override public Node visitDropTable(SqlBaseParser.DropTableContext context) { return new DropTable(getLocation(context), getQualifiedName(context.qualifiedName()), context.EXISTS() != null); } @Override public Node visitDropView(SqlBaseParser.DropViewContext context) { return new DropView(getLocation(context), getQualifiedName(context.qualifiedName()), context.EXISTS() != null); } @Override public Node visitInsertInto(SqlBaseParser.InsertIntoContext context) { return new Insert( getQualifiedName(context.qualifiedName()), Optional.ofNullable(getColumnAliases(context.columnAliases())), (Query) visit(context.query())); } @Override public Node visitDelete(SqlBaseParser.DeleteContext context) { return new Delete( getLocation(context), new Table(getLocation(context), getQualifiedName(context.qualifiedName())), visitIfPresent(context.booleanExpression(), Expression.class)); } @Override public Node visitRenameTable(SqlBaseParser.RenameTableContext context) { return new RenameTable(getLocation(context), getQualifiedName(context.from), getQualifiedName(context.to)); } @Override public Node visitRenameColumn(SqlBaseParser.RenameColumnContext context) { return new RenameColumn(getLocation(context), getQualifiedName(context.tableName), context.from.getText(), context.to.getText()); } @Override public Node visitAddColumn(SqlBaseParser.AddColumnContext context) { return new AddColumn(getLocation(context), getQualifiedName(context.qualifiedName()), (ColumnDefinition) visit(context.columnDefinition())); } @Override public Node visitCreateView(SqlBaseParser.CreateViewContext context) { return new CreateView( getLocation(context), getQualifiedName(context.qualifiedName()), (Query) visit(context.query()), context.REPLACE() != null); } @Override public Node visitStartTransaction(SqlBaseParser.StartTransactionContext context) { return new StartTransaction(visit(context.transactionMode(), TransactionMode.class)); } @Override public Node visitCommit(SqlBaseParser.CommitContext context) { return new Commit(getLocation(context)); } @Override public Node visitRollback(SqlBaseParser.RollbackContext context) { return new Rollback(getLocation(context)); } @Override public Node visitTransactionAccessMode(SqlBaseParser.TransactionAccessModeContext context) { return new TransactionAccessMode(getLocation(context), context.accessMode.getType() == SqlBaseLexer.ONLY); } @Override public Node visitIsolationLevel(SqlBaseParser.IsolationLevelContext context) { return visit(context.levelOfIsolation()); } @Override public Node visitReadUncommitted(SqlBaseParser.ReadUncommittedContext context) { return new Isolation(getLocation(context), Isolation.Level.READ_UNCOMMITTED); } @Override public Node visitReadCommitted(SqlBaseParser.ReadCommittedContext context) { return new Isolation(getLocation(context), Isolation.Level.READ_COMMITTED); } @Override public Node visitRepeatableRead(SqlBaseParser.RepeatableReadContext context) { return new Isolation(getLocation(context), Isolation.Level.REPEATABLE_READ); } @Override public Node visitSerializable(SqlBaseParser.SerializableContext context) { return new Isolation(getLocation(context), Isolation.Level.SERIALIZABLE); } @Override public Node visitCall(SqlBaseParser.CallContext context) { return new Call( getLocation(context), getQualifiedName(context.qualifiedName()), visit(context.callArgument(), CallArgument.class)); } @Override public Node visitPrepare(SqlBaseParser.PrepareContext context) { String name = context.identifier().getText(); return new Prepare(getLocation(context), name, (Statement) visit(context.statement())); } @Override public Node visitDeallocate(SqlBaseParser.DeallocateContext context) { String name = context.identifier().getText(); return new Deallocate(getLocation(context), name); } @Override public Node visitExecute(SqlBaseParser.ExecuteContext context) { String name = context.identifier().getText(); return new Execute(getLocation(context), name, visit(context.expression(), Expression.class)); } @Override public Node visitDescribeOutput(SqlBaseParser.DescribeOutputContext context) { String name = context.identifier().getText(); return new DescribeOutput(getLocation(context), name); } @Override public Node visitDescribeInput(SqlBaseParser.DescribeInputContext context) { String name = context.identifier().getText(); return new DescribeInput(getLocation(context), name); } // ********************** query expressions ******************** @Override public Node visitQuery(SqlBaseParser.QueryContext context) { Query body = (Query) visit(context.queryNoWith()); return new Query( getLocation(context), visitIfPresent(context.with(), With.class), body.getQueryBody(), body.getOrderBy(), body.getLimit()); } @Override public Node visitWith(SqlBaseParser.WithContext context) { return new With(getLocation(context), context.RECURSIVE() != null, visit(context.namedQuery(), WithQuery.class)); } @Override public Node visitNamedQuery(SqlBaseParser.NamedQueryContext context) { return new WithQuery(getLocation(context), context.name.getText(), (Query) visit(context.query()), Optional.ofNullable(getColumnAliases(context.columnAliases()))); } @Override public Node visitQueryNoWith(SqlBaseParser.QueryNoWithContext context) { QueryBody term = (QueryBody) visit(context.queryTerm()); Optional<OrderBy> orderBy = Optional.empty(); if (context.ORDER() != null) { orderBy = Optional.of(new OrderBy(getLocation(context.ORDER()), visit(context.sortItem(), SortItem.class))); } if (term instanceof QuerySpecification) { // When we have a simple query specification // followed by order by limit, fold the order by and limit // clauses into the query specification (analyzer/planner // expects this structure to resolve references with respect // to columns defined in the query specification) QuerySpecification query = (QuerySpecification) term; return new Query( getLocation(context), Optional.empty(), new QuerySpecification( getLocation(context), query.getSelect(), query.getFrom(), query.getWhere(), query.getGroupBy(), query.getHaving(), orderBy, getTextIfPresent(context.limit)), Optional.empty(), Optional.empty()); } return new Query( getLocation(context), Optional.empty(), term, orderBy, getTextIfPresent(context.limit)); } @Override public Node visitQuerySpecification(SqlBaseParser.QuerySpecificationContext context) { Optional<Relation> from = Optional.empty(); List<SelectItem> selectItems = visit(context.selectItem(), SelectItem.class); List<Relation> relations = visit(context.relation(), Relation.class); if (!relations.isEmpty()) { // synthesize implicit join nodes Iterator<Relation> iterator = relations.iterator(); Relation relation = iterator.next(); while (iterator.hasNext()) { relation = new Join(getLocation(context), Join.Type.IMPLICIT, relation, iterator.next(), Optional.empty()); } from = Optional.of(relation); } return new QuerySpecification( getLocation(context), new Select(getLocation(context.SELECT()), isDistinct(context.setQuantifier()), selectItems), from, visitIfPresent(context.where, Expression.class), visitIfPresent(context.groupBy(), GroupBy.class), visitIfPresent(context.having, Expression.class), Optional.empty(), Optional.empty()); } @Override public Node visitGroupBy(SqlBaseParser.GroupByContext context) { return new GroupBy(getLocation(context), isDistinct(context.setQuantifier()), visit(context.groupingElement(), GroupingElement.class)); } @Override public Node visitSingleGroupingSet(SqlBaseParser.SingleGroupingSetContext context) { return new SimpleGroupBy(getLocation(context), visit(context.groupingExpressions().expression(), Expression.class)); } @Override public Node visitRollup(SqlBaseParser.RollupContext context) { return new Rollup(getLocation(context), context.qualifiedName().stream() .map(AstBuilder::getQualifiedName) .collect(toList())); } @Override public Node visitCube(SqlBaseParser.CubeContext context) { return new Cube(getLocation(context), context.qualifiedName().stream() .map(AstBuilder::getQualifiedName) .collect(toList())); } @Override public Node visitMultipleGroupingSets(SqlBaseParser.MultipleGroupingSetsContext context) { return new GroupingSets(getLocation(context), context.groupingSet().stream() .map(groupingSet -> groupingSet.qualifiedName().stream() .map(AstBuilder::getQualifiedName) .collect(toList())) .collect(toList())); } @Override public Node visitSetOperation(SqlBaseParser.SetOperationContext context) { QueryBody left = (QueryBody) visit(context.left); QueryBody right = (QueryBody) visit(context.right); boolean distinct = context.setQuantifier() == null || context.setQuantifier().DISTINCT() != null; switch (context.operator.getType()) { case SqlBaseLexer.UNION: return new Union(getLocation(context.UNION()), ImmutableList.of(left, right), distinct); case SqlBaseLexer.INTERSECT: return new Intersect(getLocation(context.INTERSECT()), ImmutableList.of(left, right), distinct); case SqlBaseLexer.EXCEPT: return new Except(getLocation(context.EXCEPT()), left, right, distinct); } throw new IllegalArgumentException("Unsupported set operation: " + context.operator.getText()); } @Override public Node visitSelectAll(SqlBaseParser.SelectAllContext context) { if (context.qualifiedName() != null) { return new AllColumns(getLocation(context), getQualifiedName(context.qualifiedName())); } return new AllColumns(getLocation(context)); } @Override public Node visitSelectSingle(SqlBaseParser.SelectSingleContext context) { Optional<String> alias = getTextIfPresent(context.identifier()); return new SingleColumn(getLocation(context), (Expression) visit(context.expression()), alias); } @Override public Node visitTable(SqlBaseParser.TableContext context) { return new Table(getLocation(context), getQualifiedName(context.qualifiedName())); } @Override public Node visitSubquery(SqlBaseParser.SubqueryContext context) { return new TableSubquery(getLocation(context), (Query) visit(context.queryNoWith())); } @Override public Node visitInlineTable(SqlBaseParser.InlineTableContext context) { return new Values(getLocation(context), visit(context.expression(), Expression.class)); } @Override public Node visitExplain(SqlBaseParser.ExplainContext context) { return new Explain(getLocation(context), context.ANALYZE() != null, (Statement) visit(context.statement()), visit(context.explainOption(), ExplainOption.class)); } @Override public Node visitExplainFormat(SqlBaseParser.ExplainFormatContext context) { switch (context.value.getType()) { case SqlBaseLexer.GRAPHVIZ: return new ExplainFormat(getLocation(context), ExplainFormat.Type.GRAPHVIZ); case SqlBaseLexer.TEXT: return new ExplainFormat(getLocation(context), ExplainFormat.Type.TEXT); } throw new IllegalArgumentException("Unsupported EXPLAIN format: " + context.value.getText()); } @Override public Node visitExplainType(SqlBaseParser.ExplainTypeContext context) { switch (context.value.getType()) { case SqlBaseLexer.LOGICAL: return new ExplainType(getLocation(context), ExplainType.Type.LOGICAL); case SqlBaseLexer.DISTRIBUTED: return new ExplainType(getLocation(context), ExplainType.Type.DISTRIBUTED); case SqlBaseLexer.VALIDATE: return new ExplainType(getLocation(context), ExplainType.Type.VALIDATE); } throw new IllegalArgumentException("Unsupported EXPLAIN type: " + context.value.getText()); } @Override public Node visitShowTables(SqlBaseParser.ShowTablesContext context) { return new ShowTables( getLocation(context), Optional.ofNullable(context.qualifiedName()) .map(AstBuilder::getQualifiedName), getTextIfPresent(context.pattern) .map(AstBuilder::unquote)); } @Override public Node visitShowSchemas(SqlBaseParser.ShowSchemasContext context) { return new ShowSchemas( getLocation(context), getTextIfPresent(context.identifier()), getTextIfPresent(context.pattern) .map(AstBuilder::unquote)); } @Override public Node visitShowCatalogs(SqlBaseParser.ShowCatalogsContext context) { return new ShowCatalogs(getLocation(context), getTextIfPresent(context.pattern) .map(AstBuilder::unquote)); } @Override public Node visitShowColumns(SqlBaseParser.ShowColumnsContext context) { return new ShowColumns(getLocation(context), getQualifiedName(context.qualifiedName())); } @Override public Node visitShowPartitions(SqlBaseParser.ShowPartitionsContext context) { return new ShowPartitions( getLocation(context), getQualifiedName(context.qualifiedName()), visitIfPresent(context.booleanExpression(), Expression.class), visit(context.sortItem(), SortItem.class), getTextIfPresent(context.limit)); } @Override public Node visitShowCreateView(SqlBaseParser.ShowCreateViewContext context) { return new ShowCreate(getLocation(context), ShowCreate.Type.VIEW, getQualifiedName(context.qualifiedName())); } @Override public Node visitShowFunctions(SqlBaseParser.ShowFunctionsContext context) { return new ShowFunctions(getLocation(context)); } @Override public Node visitShowSession(SqlBaseParser.ShowSessionContext context) { return new ShowSession(getLocation(context)); } @Override public Node visitSetSession(SqlBaseParser.SetSessionContext context) { return new SetSession(getLocation(context), getQualifiedName(context.qualifiedName()), (Expression) visit(context.expression())); } @Override public Node visitResetSession(SqlBaseParser.ResetSessionContext context) { return new ResetSession(getLocation(context), getQualifiedName(context.qualifiedName())); } @Override public Node visitGrant(SqlBaseParser.GrantContext context) { String grantee = context.grantee.getText(); Optional<List<String>> privileges; if (context.ALL() != null) { privileges = Optional.empty(); } else { privileges = Optional.of(context.privilege().stream() .map(SqlBaseParser.PrivilegeContext::getText) .collect(toList())); } return new Grant( getLocation(context), privileges, context.TABLE() != null, getQualifiedName(context.qualifiedName()), grantee, context.OPTION() != null); } @Override public Node visitRevoke(SqlBaseParser.RevokeContext context) { Optional<List<String>> privileges; if (context.ALL() != null) { privileges = Optional.empty(); } else { privileges = Optional.of(context.privilege().stream() .map(SqlBaseParser.PrivilegeContext::getText) .collect(toList())); } return new Revoke( getLocation(context), context.OPTION() != null, privileges, context.TABLE() != null, getQualifiedName(context.qualifiedName()), context.grantee.getText()); } @Override public Node visitShowGrants(SqlBaseParser.ShowGrantsContext context) { Optional<QualifiedName> tableName = Optional.empty(); if (context.qualifiedName() != null) { tableName = Optional.of(getQualifiedName(context.qualifiedName())); } return new ShowGrants( getLocation(context), context.TABLE() != null, tableName); } // ***************** boolean expressions ****************** @Override public Node visitLogicalNot(SqlBaseParser.LogicalNotContext context) { return new NotExpression(getLocation(context), (Expression) visit(context.booleanExpression())); } @Override public Node visitLogicalBinary(SqlBaseParser.LogicalBinaryContext context) { return new LogicalBinaryExpression( getLocation(context.operator), getLogicalBinaryOperator(context.operator), (Expression) visit(context.left), (Expression) visit(context.right)); } // *************** from clause ***************** @Override public Node visitJoinRelation(SqlBaseParser.JoinRelationContext context) { Relation left = (Relation) visit(context.left); Relation right; if (context.CROSS() != null) { right = (Relation) visit(context.right); return new Join(getLocation(context), Join.Type.CROSS, left, right, Optional.empty()); } JoinCriteria criteria; if (context.NATURAL() != null) { right = (Relation) visit(context.right); criteria = new NaturalJoin(); } else { right = (Relation) visit(context.rightRelation); if (context.joinCriteria().ON() != null) { criteria = new JoinOn((Expression) visit(context.joinCriteria().booleanExpression())); } else if (context.joinCriteria().USING() != null) { List<String> columns = context.joinCriteria() .identifier().stream() .map(ParseTree::getText) .collect(toList()); criteria = new JoinUsing(columns); } else { throw new IllegalArgumentException("Unsupported join criteria"); } } Join.Type joinType; if (context.joinType().LEFT() != null) { joinType = Join.Type.LEFT; } else if (context.joinType().RIGHT() != null) { joinType = Join.Type.RIGHT; } else if (context.joinType().FULL() != null) { joinType = Join.Type.FULL; } else { joinType = Join.Type.INNER; } return new Join(getLocation(context), joinType, left, right, Optional.of(criteria)); } @Override public Node visitSampledRelation(SqlBaseParser.SampledRelationContext context) { Relation child = (Relation) visit(context.aliasedRelation()); if (context.TABLESAMPLE() == null) { return child; } return new SampledRelation( getLocation(context), child, getSamplingMethod((Token) context.sampleType().getChild(0).getPayload()), (Expression) visit(context.percentage)); } @Override public Node visitAliasedRelation(SqlBaseParser.AliasedRelationContext context) { Relation child = (Relation) visit(context.relationPrimary()); if (context.identifier() == null) { return child; } return new AliasedRelation(getLocation(context), child, context.identifier().getText(), getColumnAliases(context.columnAliases())); } @Override public Node visitTableName(SqlBaseParser.TableNameContext context) { return new Table(getLocation(context), getQualifiedName(context.qualifiedName())); } @Override public Node visitSubqueryRelation(SqlBaseParser.SubqueryRelationContext context) { return new TableSubquery(getLocation(context), (Query) visit(context.query())); } @Override public Node visitUnnest(SqlBaseParser.UnnestContext context) { return new Unnest(getLocation(context), visit(context.expression(), Expression.class), context.ORDINALITY() != null); } @Override public Node visitParenthesizedRelation(SqlBaseParser.ParenthesizedRelationContext context) { return visit(context.relation()); } // ********************* predicates ******************* @Override public Node visitPredicated(SqlBaseParser.PredicatedContext context) { if (context.predicate() != null) { return visit(context.predicate()); } return visit(context.valueExpression); } @Override public Node visitComparison(SqlBaseParser.ComparisonContext context) { return new ComparisonExpression( getLocation(context.comparisonOperator()), getComparisonOperator(((TerminalNode) context.comparisonOperator().getChild(0)).getSymbol()), (Expression) visit(context.value), (Expression) visit(context.right)); } @Override public Node visitDistinctFrom(SqlBaseParser.DistinctFromContext context) { Expression expression = new ComparisonExpression( getLocation(context), ComparisonExpressionType.IS_DISTINCT_FROM, (Expression) visit(context.value), (Expression) visit(context.right)); if (context.NOT() != null) { expression = new NotExpression(getLocation(context), expression); } return expression; } @Override public Node visitBetween(SqlBaseParser.BetweenContext context) { Expression expression = new BetweenPredicate( getLocation(context), (Expression) visit(context.value), (Expression) visit(context.lower), (Expression) visit(context.upper)); if (context.NOT() != null) { expression = new NotExpression(getLocation(context), expression); } return expression; } @Override public Node visitNullPredicate(SqlBaseParser.NullPredicateContext context) { Expression child = (Expression) visit(context.value); if (context.NOT() == null) { return new IsNullPredicate(getLocation(context), child); } return new IsNotNullPredicate(getLocation(context), child); } @Override public Node visitLike(SqlBaseParser.LikeContext context) { Expression escape = null; if (context.escape != null) { escape = (Expression) visit(context.escape); } Expression result = new LikePredicate(getLocation(context), (Expression) visit(context.value), (Expression) visit(context.pattern), escape); if (context.NOT() != null) { result = new NotExpression(getLocation(context), result); } return result; } @Override public Node visitInList(SqlBaseParser.InListContext context) { Expression result = new InPredicate( getLocation(context), (Expression) visit(context.value), new InListExpression(getLocation(context), visit(context.expression(), Expression.class))); if (context.NOT() != null) { result = new NotExpression(getLocation(context), result); } return result; } @Override public Node visitInSubquery(SqlBaseParser.InSubqueryContext context) { Expression result = new InPredicate( getLocation(context), (Expression) visit(context.value), new SubqueryExpression(getLocation(context), (Query) visit(context.query()))); if (context.NOT() != null) { result = new NotExpression(getLocation(context), result); } return result; } @Override public Node visitExists(SqlBaseParser.ExistsContext context) { return new ExistsPredicate(getLocation(context), new SubqueryExpression(getLocation(context), (Query) visit(context.query()))); } @Override public Node visitQuantifiedComparison(SqlBaseParser.QuantifiedComparisonContext context) { return new QuantifiedComparisonExpression( getLocation(context.comparisonOperator()), getComparisonOperator(((TerminalNode) context.comparisonOperator().getChild(0)).getSymbol()), getComparisonQuantifier(((TerminalNode) context.comparisonQuantifier().getChild(0)).getSymbol()), (Expression) visit(context.value), new SubqueryExpression(getLocation(context.query()), (Query) visit(context.query()))); } // ************** value expressions ************** @Override public Node visitArithmeticUnary(SqlBaseParser.ArithmeticUnaryContext context) { Expression child = (Expression) visit(context.valueExpression()); switch (context.operator.getType()) { case SqlBaseLexer.MINUS: return ArithmeticUnaryExpression.negative(getLocation(context), child); case SqlBaseLexer.PLUS: return ArithmeticUnaryExpression.positive(getLocation(context), child); default: throw new UnsupportedOperationException("Unsupported sign: " + context.operator.getText()); } } @Override public Node visitArithmeticBinary(SqlBaseParser.ArithmeticBinaryContext context) { return new ArithmeticBinaryExpression( getLocation(context.operator), getArithmeticBinaryOperator(context.operator), (Expression) visit(context.left), (Expression) visit(context.right)); } @Override public Node visitConcatenation(SqlBaseParser.ConcatenationContext context) { return new FunctionCall( getLocation(context.CONCAT()), QualifiedName.of("concat"), ImmutableList.of( (Expression) visit(context.left), (Expression) visit(context.right))); } @Override public Node visitAtTimeZone(SqlBaseParser.AtTimeZoneContext context) { return new AtTimeZone( getLocation(context.AT()), (Expression) visit(context.valueExpression()), (Expression) visit(context.timeZoneSpecifier())); } @Override public Node visitTimeZoneInterval(SqlBaseParser.TimeZoneIntervalContext context) { return visit(context.interval()); } @Override public Node visitTimeZoneString(SqlBaseParser.TimeZoneStringContext context) { return visit(context.string()); } // ********************* primary expressions ********************** @Override public Node visitParenthesizedExpression(SqlBaseParser.ParenthesizedExpressionContext context) { return visit(context.expression()); } @Override public Node visitRowConstructor(SqlBaseParser.RowConstructorContext context) { return new Row(getLocation(context), visit(context.expression(), Expression.class)); } @Override public Node visitArrayConstructor(SqlBaseParser.ArrayConstructorContext context) { return new ArrayConstructor(getLocation(context), visit(context.expression(), Expression.class)); } @Override public Node visitCast(SqlBaseParser.CastContext context) { boolean isTryCast = context.TRY_CAST() != null; return new Cast(getLocation(context), (Expression) visit(context.expression()), getType(context.type()), isTryCast); } @Override public Node visitSpecialDateTimeFunction(SqlBaseParser.SpecialDateTimeFunctionContext context) { CurrentTime.Type type = getDateTimeFunctionType(context.name); if (context.precision != null) { return new CurrentTime(getLocation(context), type, Integer.parseInt(context.precision.getText())); } return new CurrentTime(getLocation(context), type); } @Override public Node visitExtract(SqlBaseParser.ExtractContext context) { String fieldString = context.identifier().getText(); Extract.Field field; try { field = Extract.Field.valueOf(fieldString.toUpperCase()); } catch (IllegalArgumentException e) { throw parseError("Invalid EXTRACT field: " + fieldString, context); } return new Extract(getLocation(context), (Expression) visit(context.valueExpression()), field); } @Override public Node visitSubstring(SqlBaseParser.SubstringContext context) { return new FunctionCall(getLocation(context), QualifiedName.of("substr"), visit(context.valueExpression(), Expression.class)); } @Override public Node visitPosition(SqlBaseParser.PositionContext context) { List<Expression> arguments = Lists.reverse(visit(context.valueExpression(), Expression.class)); return new FunctionCall(getLocation(context), QualifiedName.of("strpos"), arguments); } @Override public Node visitNormalize(SqlBaseParser.NormalizeContext context) { Expression str = (Expression) visit(context.valueExpression()); String normalForm = Optional.ofNullable(context.normalForm()).map(ParserRuleContext::getText).orElse("NFC"); return new FunctionCall(getLocation(context), QualifiedName.of("normalize"), ImmutableList.of(str, new StringLiteral(getLocation(context), normalForm))); } @Override public Node visitSubscript(SqlBaseParser.SubscriptContext context) { return new SubscriptExpression(getLocation(context), (Expression) visit(context.value), (Expression) visit(context.index)); } @Override public Node visitSubqueryExpression(SqlBaseParser.SubqueryExpressionContext context) { return new SubqueryExpression(getLocation(context), (Query) visit(context.query())); } @Override public Node visitDereference(SqlBaseParser.DereferenceContext context) { return new DereferenceExpression(getLocation(context), (Expression) visit(context.base), context.fieldName.getText()); } @Override public Node visitColumnReference(SqlBaseParser.ColumnReferenceContext context) { return new Identifier(getLocation(context), context.getText()); } @Override public Node visitSimpleCase(SqlBaseParser.SimpleCaseContext context) { return new SimpleCaseExpression( getLocation(context), (Expression) visit(context.valueExpression()), visit(context.whenClause(), WhenClause.class), visitIfPresent(context.elseExpression, Expression.class)); } @Override public Node visitSearchedCase(SqlBaseParser.SearchedCaseContext context) { return new SearchedCaseExpression( getLocation(context), visit(context.whenClause(), WhenClause.class), visitIfPresent(context.elseExpression, Expression.class)); } @Override public Node visitWhenClause(SqlBaseParser.WhenClauseContext context) { return new WhenClause(getLocation(context), (Expression) visit(context.condition), (Expression) visit(context.result)); } @Override public Node visitFunctionCall(SqlBaseParser.FunctionCallContext context) { Optional<Expression> filter = visitIfPresent(context.filter(), Expression.class); Optional<Window> window = visitIfPresent(context.over(), Window.class); QualifiedName name = getQualifiedName(context.qualifiedName()); boolean distinct = isDistinct(context.setQuantifier()); if (name.toString().equalsIgnoreCase("if")) { check(context.expression().size() == 2 || context.expression().size() == 3, "Invalid number of arguments for 'if' function", context); check(!window.isPresent(), "OVER clause not valid for 'if' function", context); check(!distinct, "DISTINCT not valid for 'if' function", context); Expression elseExpression = null; if (context.expression().size() == 3) { elseExpression = (Expression) visit(context.expression(2)); } return new IfExpression( getLocation(context), (Expression) visit(context.expression(0)), (Expression) visit(context.expression(1)), elseExpression); } if (name.toString().equalsIgnoreCase("nullif")) { check(context.expression().size() == 2, "Invalid number of arguments for 'nullif' function", context); check(!window.isPresent(), "OVER clause not valid for 'nullif' function", context); check(!distinct, "DISTINCT not valid for 'nullif' function", context); return new NullIfExpression( getLocation(context), (Expression) visit(context.expression(0)), (Expression) visit(context.expression(1))); } if (name.toString().equalsIgnoreCase("coalesce")) { check(!window.isPresent(), "OVER clause not valid for 'coalesce' function", context); check(!distinct, "DISTINCT not valid for 'coalesce' function", context); return new CoalesceExpression(getLocation(context), visit(context.expression(), Expression.class)); } if (name.toString().equalsIgnoreCase("try")) { check(context.expression().size() == 1, "The 'try' function must have exactly one argument", context); check(!window.isPresent(), "OVER clause not valid for 'try' function", context); check(!distinct, "DISTINCT not valid for 'try' function", context); return new TryExpression(getLocation(context), (Expression) visit(getOnlyElement(context.expression()))); } if (name.toString().equalsIgnoreCase("$internal$bind")) { check(context.expression().size() == 2, "The '$internal$bind' function must have exactly two arguments", context); check(!window.isPresent(), "OVER clause not valid for '$internal$bind' function", context); check(!distinct, "DISTINCT not valid for '$internal$bind' function", context); return new BindExpression( getLocation(context), (Expression) visit(context.expression(0)), (Expression) visit(context.expression(1))); } return new FunctionCall( getLocation(context), getQualifiedName(context.qualifiedName()), window, filter, distinct, visit(context.expression(), Expression.class)); } @Override public Node visitLambda(SqlBaseParser.LambdaContext context) { List<LambdaArgumentDeclaration> arguments = context.identifier().stream() .map(SqlBaseParser.IdentifierContext::getText) .map(LambdaArgumentDeclaration::new) .collect(toList()); Expression body = (Expression) visit(context.expression()); return new LambdaExpression(arguments, body); } @Override public Node visitFilter(SqlBaseParser.FilterContext context) { return visit(context.booleanExpression()); } @Override public Node visitOver(SqlBaseParser.OverContext context) { Optional<OrderBy> orderBy = Optional.empty(); if (context.ORDER() != null) { orderBy = Optional.of(new OrderBy(getLocation(context.ORDER()), visit(context.sortItem(), SortItem.class))); } return new Window( getLocation(context), visit(context.partition, Expression.class), orderBy, visitIfPresent(context.windowFrame(), WindowFrame.class)); } @Override public Node visitColumnDefinition(SqlBaseParser.ColumnDefinitionContext context) { Optional<String> comment = Optional.empty(); if (context.COMMENT() != null) { comment = Optional.of(((StringLiteral) visit(context.string())).getValue()); } return new ColumnDefinition(getLocation(context), context.identifier().getText(), getType(context.type()), comment); } @Override public Node visitLikeClause(SqlBaseParser.LikeClauseContext context) { return new LikeClause( getLocation(context), getQualifiedName(context.qualifiedName()), Optional.ofNullable(context.optionType) .map(AstBuilder::getPropertiesOption)); } @Override public Node visitSortItem(SqlBaseParser.SortItemContext context) { return new SortItem( getLocation(context), (Expression) visit(context.expression()), Optional.ofNullable(context.ordering) .map(AstBuilder::getOrderingType) .orElse(SortItem.Ordering.ASCENDING), Optional.ofNullable(context.nullOrdering) .map(AstBuilder::getNullOrderingType) .orElse(SortItem.NullOrdering.UNDEFINED)); } @Override public Node visitWindowFrame(SqlBaseParser.WindowFrameContext context) { return new WindowFrame( getLocation(context), getFrameType(context.frameType), (FrameBound) visit(context.start), visitIfPresent(context.end, FrameBound.class)); } @Override public Node visitUnboundedFrame(SqlBaseParser.UnboundedFrameContext context) { return new FrameBound(getLocation(context), getUnboundedFrameBoundType(context.boundType)); } @Override public Node visitBoundedFrame(SqlBaseParser.BoundedFrameContext context) { return new FrameBound(getLocation(context), getBoundedFrameBoundType(context.boundType), (Expression) visit(context.expression())); } @Override public Node visitCurrentRowBound(SqlBaseParser.CurrentRowBoundContext context) { return new FrameBound(getLocation(context), FrameBound.Type.CURRENT_ROW); } // ************** literals ************** @Override public Node visitNullLiteral(SqlBaseParser.NullLiteralContext context) { return new NullLiteral(getLocation(context)); } @Override public Node visitBasicStringLiteral(SqlBaseParser.BasicStringLiteralContext context) { return new StringLiteral(getLocation(context), unquote(context.STRING().getText())); } @Override public Node visitUnicodeStringLiteral(SqlBaseParser.UnicodeStringLiteralContext context) { return new StringLiteral(getLocation(context), decodeUnicodeLiteral(context)); } @Override public Node visitBinaryLiteral(SqlBaseParser.BinaryLiteralContext context) { String raw = context.BINARY_LITERAL().getText(); return new BinaryLiteral(getLocation(context), unquote(raw.substring(1))); } @Override public Node visitTypeConstructor(SqlBaseParser.TypeConstructorContext context) { String value = ((StringLiteral) visit(context.string())).getValue(); if (context.DOUBLE_PRECISION() != null) { // TODO: Temporary hack that should be removed with new planner. return new GenericLiteral(getLocation(context), "DOUBLE", value); } String type = context.identifier().getText(); if (type.equalsIgnoreCase("time")) { return new TimeLiteral(getLocation(context), value); } if (type.equalsIgnoreCase("timestamp")) { return new TimestampLiteral(getLocation(context), value); } if (type.equalsIgnoreCase("decimal")) { return new DecimalLiteral(getLocation(context), value); } if (type.equalsIgnoreCase("char")) { return new CharLiteral(getLocation(context), value); } return new GenericLiteral(getLocation(context), type, value); } @Override public Node visitIntegerLiteral(SqlBaseParser.IntegerLiteralContext context) { return new LongLiteral(getLocation(context), context.getText()); } @Override public Node visitDecimalLiteral(SqlBaseParser.DecimalLiteralContext context) { return new DoubleLiteral(getLocation(context), context.getText()); } @Override public Node visitBooleanValue(SqlBaseParser.BooleanValueContext context) { return new BooleanLiteral(getLocation(context), context.getText()); } @Override public Node visitInterval(SqlBaseParser.IntervalContext context) { return new IntervalLiteral( getLocation(context), ((StringLiteral) visit(context.string())).getValue(), Optional.ofNullable(context.sign) .map(AstBuilder::getIntervalSign) .orElse(IntervalLiteral.Sign.POSITIVE), getIntervalFieldType((Token) context.from.getChild(0).getPayload()), Optional.ofNullable(context.to) .map((x) -> x.getChild(0).getPayload()) .map(Token.class::cast) .map(AstBuilder::getIntervalFieldType)); } @Override public Node visitParameter(SqlBaseParser.ParameterContext context) { Parameter parameter = new Parameter(getLocation(context), parameterPosition); parameterPosition++; return parameter; } // ***************** arguments ***************** @Override public Node visitPositionalArgument(SqlBaseParser.PositionalArgumentContext context) { return new CallArgument(getLocation(context), (Expression) visit(context.expression())); } @Override public Node visitNamedArgument(SqlBaseParser.NamedArgumentContext context) { return new CallArgument(getLocation(context), context.identifier().getText(), (Expression) visit(context.expression())); } // ***************** helpers ***************** @Override protected Node defaultResult() { return null; } @Override protected Node aggregateResult(Node aggregate, Node nextResult) { if (nextResult == null) { throw new UnsupportedOperationException("not yet implemented"); } if (aggregate == null) { return nextResult; } throw new UnsupportedOperationException("not yet implemented"); } private enum UnicodeDecodeState { EMPTY, ESCAPED, UNICODE_SEQUENCE } private static String decodeUnicodeLiteral(SqlBaseParser.UnicodeStringLiteralContext context) { char escape; if (context.UESCAPE() != null) { String escapeString = unquote(context.STRING().getText()); check(!escapeString.isEmpty(), "Empty Unicode escape character", context); check(escapeString.length() == 1, "Invalid Unicode escape character: " + escapeString, context); escape = escapeString.charAt(0); check(isValidUnicodeEscape(escape), "Invalid Unicode escape character: " + escapeString, context); } else { escape = '\\'; } String rawContent = unquote(context.UNICODE_STRING().getText().substring(2)); StringBuilder unicodeStringBuilder = new StringBuilder(); StringBuilder escapedCharacterBuilder = new StringBuilder(); int charactersNeeded = 0; UnicodeDecodeState state = UnicodeDecodeState.EMPTY; for (int i = 0; i < rawContent.length(); i++) { char ch = rawContent.charAt(i); switch (state) { case EMPTY: if (ch == escape) { state = UnicodeDecodeState.ESCAPED; } else { unicodeStringBuilder.append(ch); } break; case ESCAPED: if (ch == escape) { unicodeStringBuilder.append(escape); state = UnicodeDecodeState.EMPTY; } else if (ch == '+') { state = UnicodeDecodeState.UNICODE_SEQUENCE; charactersNeeded = 6; } else if (isHexDigit(ch)) { state = UnicodeDecodeState.UNICODE_SEQUENCE; charactersNeeded = 4; escapedCharacterBuilder.append(ch); } else { throw parseError("Invalid hexadecimal digit: " + ch, context); } break; case UNICODE_SEQUENCE: check(isHexDigit(ch), "Incomplete escape sequence: " + escapedCharacterBuilder.toString(), context); escapedCharacterBuilder.append(ch); if (charactersNeeded == escapedCharacterBuilder.length()) { String currentEscapedCode = escapedCharacterBuilder.toString(); escapedCharacterBuilder.setLength(0); int codePoint = Integer.parseInt(currentEscapedCode, 16); check(Character.isValidCodePoint(codePoint), "Invalid escaped character: " + currentEscapedCode, context); if (Character.isSupplementaryCodePoint(codePoint)) { unicodeStringBuilder.appendCodePoint(codePoint); } else { char currentCodePoint = (char) codePoint; check(!Character.isSurrogate(currentCodePoint), format("Invalid escaped character: %s. Escaped character is a surrogate. Use '\\+123456' instead.", currentEscapedCode), context); unicodeStringBuilder.append(currentCodePoint); } state = UnicodeDecodeState.EMPTY; charactersNeeded = -1; } else { check(charactersNeeded > escapedCharacterBuilder.length(), "Unexpected escape sequence length: " + escapedCharacterBuilder.length(), context); } break; default: throw new UnsupportedOperationException(); } } check(state == UnicodeDecodeState.EMPTY, "Incomplete escape sequence: " + escapedCharacterBuilder.toString(), context); return unicodeStringBuilder.toString(); } private <T> Optional<T> visitIfPresent(ParserRuleContext context, Class<T> clazz) { return Optional.ofNullable(context) .map(this::visit) .map(clazz::cast); } private <T> List<T> visit(List<? extends ParserRuleContext> contexts, Class<T> clazz) { return contexts.stream() .map(this::visit) .map(clazz::cast) .collect(toList()); } private static String unquote(String value) { return value.substring(1, value.length() - 1) .replace("''", "'"); } private static LikeClause.PropertiesOption getPropertiesOption(Token token) { switch (token.getType()) { case SqlBaseLexer.INCLUDING: return LikeClause.PropertiesOption.INCLUDING; case SqlBaseLexer.EXCLUDING: return LikeClause.PropertiesOption.EXCLUDING; } throw new IllegalArgumentException("Unsupported LIKE option type: " + token.getText()); } private static QualifiedName getQualifiedName(SqlBaseParser.QualifiedNameContext context) { List<String> parts = context .identifier().stream() .map(ParseTree::getText) .collect(toList()); return QualifiedName.of(parts); } private static boolean isDistinct(SqlBaseParser.SetQuantifierContext setQuantifier) { return setQuantifier != null && setQuantifier.DISTINCT() != null; } private static boolean isHexDigit(char c) { return ((c >= '0') && (c <= '9')) || ((c >= 'A') && (c <= 'F')) || ((c >= 'a') && (c <= 'f')); } private static boolean isValidUnicodeEscape(char c) { return c < 0x7F && c > 0x20 && !isHexDigit(c) && c != '"' && c != '+' && c != '\''; } private static Optional<String> getTextIfPresent(ParserRuleContext context) { return Optional.ofNullable(context) .map(ParseTree::getText); } private static Optional<String> getTextIfPresent(Token token) { return Optional.ofNullable(token) .map(Token::getText); } private static List<String> getColumnAliases(SqlBaseParser.ColumnAliasesContext columnAliasesContext) { if (columnAliasesContext == null) { return null; } return columnAliasesContext .identifier().stream() .map(ParseTree::getText) .collect(toList()); } private static ArithmeticBinaryExpression.Type getArithmeticBinaryOperator(Token operator) { switch (operator.getType()) { case SqlBaseLexer.PLUS: return ArithmeticBinaryExpression.Type.ADD; case SqlBaseLexer.MINUS: return ArithmeticBinaryExpression.Type.SUBTRACT; case SqlBaseLexer.ASTERISK: return ArithmeticBinaryExpression.Type.MULTIPLY; case SqlBaseLexer.SLASH: return ArithmeticBinaryExpression.Type.DIVIDE; case SqlBaseLexer.PERCENT: return ArithmeticBinaryExpression.Type.MODULUS; } throw new UnsupportedOperationException("Unsupported operator: " + operator.getText()); } private static ComparisonExpressionType getComparisonOperator(Token symbol) { switch (symbol.getType()) { case SqlBaseLexer.EQ: return ComparisonExpressionType.EQUAL; case SqlBaseLexer.NEQ: return ComparisonExpressionType.NOT_EQUAL; case SqlBaseLexer.LT: return ComparisonExpressionType.LESS_THAN; case SqlBaseLexer.LTE: return ComparisonExpressionType.LESS_THAN_OR_EQUAL; case SqlBaseLexer.GT: return ComparisonExpressionType.GREATER_THAN; case SqlBaseLexer.GTE: return ComparisonExpressionType.GREATER_THAN_OR_EQUAL; } throw new IllegalArgumentException("Unsupported operator: " + symbol.getText()); } private static CurrentTime.Type getDateTimeFunctionType(Token token) { switch (token.getType()) { case SqlBaseLexer.CURRENT_DATE: return CurrentTime.Type.DATE; case SqlBaseLexer.CURRENT_TIME: return CurrentTime.Type.TIME; case SqlBaseLexer.CURRENT_TIMESTAMP: return CurrentTime.Type.TIMESTAMP; case SqlBaseLexer.LOCALTIME: return CurrentTime.Type.LOCALTIME; case SqlBaseLexer.LOCALTIMESTAMP: return CurrentTime.Type.LOCALTIMESTAMP; } throw new IllegalArgumentException("Unsupported special function: " + token.getText()); } private static IntervalLiteral.IntervalField getIntervalFieldType(Token token) { switch (token.getType()) { case SqlBaseLexer.YEAR: return IntervalLiteral.IntervalField.YEAR; case SqlBaseLexer.MONTH: return IntervalLiteral.IntervalField.MONTH; case SqlBaseLexer.DAY: return IntervalLiteral.IntervalField.DAY; case SqlBaseLexer.HOUR: return IntervalLiteral.IntervalField.HOUR; case SqlBaseLexer.MINUTE: return IntervalLiteral.IntervalField.MINUTE; case SqlBaseLexer.SECOND: return IntervalLiteral.IntervalField.SECOND; } throw new IllegalArgumentException("Unsupported interval field: " + token.getText()); } private static IntervalLiteral.Sign getIntervalSign(Token token) { switch (token.getType()) { case SqlBaseLexer.MINUS: return IntervalLiteral.Sign.NEGATIVE; case SqlBaseLexer.PLUS: return IntervalLiteral.Sign.POSITIVE; } throw new IllegalArgumentException("Unsupported sign: " + token.getText()); } private static WindowFrame.Type getFrameType(Token type) { switch (type.getType()) { case SqlBaseLexer.RANGE: return WindowFrame.Type.RANGE; case SqlBaseLexer.ROWS: return WindowFrame.Type.ROWS; } throw new IllegalArgumentException("Unsupported frame type: " + type.getText()); } private static FrameBound.Type getBoundedFrameBoundType(Token token) { switch (token.getType()) { case SqlBaseLexer.PRECEDING: return FrameBound.Type.PRECEDING; case SqlBaseLexer.FOLLOWING: return FrameBound.Type.FOLLOWING; } throw new IllegalArgumentException("Unsupported bound type: " + token.getText()); } private static FrameBound.Type getUnboundedFrameBoundType(Token token) { switch (token.getType()) { case SqlBaseLexer.PRECEDING: return FrameBound.Type.UNBOUNDED_PRECEDING; case SqlBaseLexer.FOLLOWING: return FrameBound.Type.UNBOUNDED_FOLLOWING; } throw new IllegalArgumentException("Unsupported bound type: " + token.getText()); } private static SampledRelation.Type getSamplingMethod(Token token) { switch (token.getType()) { case SqlBaseLexer.BERNOULLI: return SampledRelation.Type.BERNOULLI; case SqlBaseLexer.SYSTEM: return SampledRelation.Type.SYSTEM; } throw new IllegalArgumentException("Unsupported sampling method: " + token.getText()); } private static LogicalBinaryExpression.Type getLogicalBinaryOperator(Token token) { switch (token.getType()) { case SqlBaseLexer.AND: return LogicalBinaryExpression.Type.AND; case SqlBaseLexer.OR: return LogicalBinaryExpression.Type.OR; } throw new IllegalArgumentException("Unsupported operator: " + token.getText()); } private static SortItem.NullOrdering getNullOrderingType(Token token) { switch (token.getType()) { case SqlBaseLexer.FIRST: return SortItem.NullOrdering.FIRST; case SqlBaseLexer.LAST: return SortItem.NullOrdering.LAST; } throw new IllegalArgumentException("Unsupported ordering: " + token.getText()); } private static SortItem.Ordering getOrderingType(Token token) { switch (token.getType()) { case SqlBaseLexer.ASC: return SortItem.Ordering.ASCENDING; case SqlBaseLexer.DESC: return SortItem.Ordering.DESCENDING; } throw new IllegalArgumentException("Unsupported ordering: " + token.getText()); } private static QuantifiedComparisonExpression.Quantifier getComparisonQuantifier(Token symbol) { switch (symbol.getType()) { case SqlBaseLexer.ALL: return QuantifiedComparisonExpression.Quantifier.ALL; case SqlBaseLexer.ANY: return QuantifiedComparisonExpression.Quantifier.ANY; case SqlBaseLexer.SOME: return QuantifiedComparisonExpression.Quantifier.SOME; } throw new IllegalArgumentException("Unsupported quantifier: " + symbol.getText()); } private static String getType(SqlBaseParser.TypeContext type) { if (type.baseType() != null) { String signature = type.baseType().getText(); if (type.baseType().DOUBLE_PRECISION() != null) { // TODO: Temporary hack that should be removed with new planner. signature = "DOUBLE"; } if (!type.typeParameter().isEmpty()) { String typeParameterSignature = type .typeParameter() .stream() .map(AstBuilder::typeParameterToString) .collect(Collectors.joining(",")); signature += "(" + typeParameterSignature + ")"; } return signature; } if (type.ARRAY() != null) { return "ARRAY(" + getType(type.type(0)) + ")"; } if (type.MAP() != null) { return "MAP(" + getType(type.type(0)) + "," + getType(type.type(1)) + ")"; } if (type.ROW() != null) { StringBuilder builder = new StringBuilder("("); for (int i = 0; i < type.identifier().size(); i++) { if (i != 0) { builder.append(","); } builder.append(type.identifier(i).getText()) .append(" ") .append(getType(type.type(i))); } builder.append(")"); return "ROW" + builder.toString(); } throw new IllegalArgumentException("Unsupported type specification: " + type.getText()); } private static String typeParameterToString(SqlBaseParser.TypeParameterContext typeParameter) { if (typeParameter.INTEGER_VALUE() != null) { return typeParameter.INTEGER_VALUE().toString(); } if (typeParameter.type() != null) { return getType(typeParameter.type()); } throw new IllegalArgumentException("Unsupported typeParameter: " + typeParameter.getText()); } private static void check(boolean condition, String message, ParserRuleContext context) { if (!condition) { throw parseError(message, context); } } public static NodeLocation getLocation(TerminalNode terminalNode) { requireNonNull(terminalNode, "terminalNode is null"); return getLocation(terminalNode.getSymbol()); } public static NodeLocation getLocation(ParserRuleContext parserRuleContext) { requireNonNull(parserRuleContext, "parserRuleContext is null"); return getLocation(parserRuleContext.getStart()); } public static NodeLocation getLocation(Token token) { requireNonNull(token, "token is null"); return new NodeLocation(token.getLine(), token.getCharPositionInLine()); } private static ParsingException parseError(String message, ParserRuleContext context) { return new ParsingException(message, null, context.getStart().getLine(), context.getStart().getCharPositionInLine()); } }