/* * Licensed to Crate.io Inc. or its affiliates ("Crate.io") under one or * more contributor license agreements. See the NOTICE file distributed * with this work for additional information regarding copyright ownership. * Crate.io licenses this file to you under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with the * License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * However, if you have executed another commercial license agreement with * Crate.io these terms will supersede the license and you may use the * software solely pursuant to the terms of the relevant commercial * agreement. */ package io.crate.sql.parser; import com.google.common.collect.ImmutableList; import com.google.common.collect.LinkedListMultimap; import com.google.common.collect.Multimap; import io.crate.sql.parser.antlr.v4.SqlBaseBaseVisitor; import io.crate.sql.parser.antlr.v4.SqlBaseLexer; import io.crate.sql.parser.antlr.v4.SqlBaseParser; import io.crate.sql.tree.*; import org.antlr.v4.runtime.ParserRuleContext; import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; import org.antlr.v4.runtime.tree.TerminalNode; import java.util.List; import java.util.Locale; import java.util.Optional; import static java.util.stream.Collectors.toList; class AstBuilder extends SqlBaseBaseVisitor<Node> { private int parameterPosition = 1; @Override public Node visitSingleStatement(SqlBaseParser.SingleStatementContext context) { return visit(context.statement()); } @Override public Node visitSingleExpression(SqlBaseParser.SingleExpressionContext context) { return visit(context.expr()); } // Statements @Override public Node visitBegin(SqlBaseParser.BeginContext context) { return new BeginStatement(); } @Override public Node visitOptimize(SqlBaseParser.OptimizeContext context) { return new OptimizeStatement( visit(context.tableWithPartitions().tableWithPartition(), Table.class), visitIfPresent(context.withProperties(), GenericProperties.class)); } @Override public Node visitCreateTable(SqlBaseParser.CreateTableContext context) { boolean notExists = context.EXISTS() != null; return new CreateTable( (Table) visit(context.table()), visit(context.tableElement(), TableElement.class), visit(context.crateTableOption(), CrateTableOption.class), visitIfPresent(context.withProperties(), GenericProperties.class), notExists); } @Override public Node visitCreateBlobTable(SqlBaseParser.CreateBlobTableContext context) { return new CreateBlobTable( (Table) visit(context.table()), visitIfPresent(context.numShards, ClusteredBy.class), visitIfPresent(context.withProperties(), GenericProperties.class)); } @Override public Node visitCreateRepository(SqlBaseParser.CreateRepositoryContext context) { return new CreateRepository( getIdentText(context.name), getIdentText(context.type), visitIfPresent(context.withProperties(), GenericProperties.class)); } @Override public Node visitCreateSnapshot(SqlBaseParser.CreateSnapshotContext context) { if (context.ALL() != null) { return new CreateSnapshot( getQualifiedName(context.qname()), visitIfPresent(context.withProperties(), GenericProperties.class)); } return new CreateSnapshot( getQualifiedName(context.qname()), visit(context.tableWithPartitions().tableWithPartition(), Table.class), visitIfPresent(context.withProperties(), GenericProperties.class)); } @Override public Node visitCreateAnalyzer(SqlBaseParser.CreateAnalyzerContext context) { return new CreateAnalyzer( getIdentText(context.name), getIdentTextIfPresent(context.extendedName), visit(context.analyzerElement(), AnalyzerElement.class) ); } @Override public Node visitCharFilters(SqlBaseParser.CharFiltersContext context) { return new CharFilters(visit(context.namedProperties(), NamedProperties.class)); } @Override public Node visitTokenFilters(SqlBaseParser.TokenFiltersContext context) { return new TokenFilters(visit(context.namedProperties(), NamedProperties.class)); } @Override public Node visitTokenizer(SqlBaseParser.TokenizerContext context) { return new Tokenizer((NamedProperties) visit(context.namedProperties())); } @Override public Node visitNamedProperties(SqlBaseParser.NamedPropertiesContext context) { return new NamedProperties( getIdentText(context.ident()), visitIfPresent(context.withProperties(), GenericProperties.class)); } @Override public Node visitRestore(SqlBaseParser.RestoreContext context) { if (context.ALL() != null) { return new RestoreSnapshot( getQualifiedName(context.qname()), visitIfPresent(context.withProperties(), GenericProperties.class)); } return new RestoreSnapshot(getQualifiedName(context.qname()), visit(context.tableWithPartitions().tableWithPartition(), Table.class), visitIfPresent(context.withProperties(), GenericProperties.class)); } @Override public Node visitShowCreateTable(SqlBaseParser.ShowCreateTableContext context) { return new ShowCreateTable((Table) visit(context.table())); } @Override public Node visitShowTransaction(SqlBaseParser.ShowTransactionContext context) { return new ShowTransaction(); } @Override public Node visitDropTable(SqlBaseParser.DropTableContext context) { return new DropTable((Table) visit(context.table()), context.EXISTS() != null); } @Override public Node visitDropRepository(SqlBaseParser.DropRepositoryContext context) { return new DropRepository(getIdentText(context.ident())); } @Override public Node visitDropBlobTable(SqlBaseParser.DropBlobTableContext context) { return new DropBlobTable((Table) visit(context.table()), context.EXISTS() != null); } @Override public Node visitDropSnapshot(SqlBaseParser.DropSnapshotContext context) { return new DropSnapshot(getQualifiedName(context.qname())); } @Override public Node visitCopyFrom(SqlBaseParser.CopyFromContext context) { return new CopyFrom( (Table) visit(context.tableWithPartition()), (Expression) visit(context.path), visitIfPresent(context.withProperties(), GenericProperties.class)); } @Override public Node visitCopyTo(SqlBaseParser.CopyToContext context) { List<Expression> columns = Optional.ofNullable(context.columns()) .map(list -> visit(list.primaryExpression(), Expression.class)) .orElse(null); return new CopyTo( (Table) visit(context.tableWithPartition()), columns, visitIfPresent(context.where(), Expression.class), context.DIRECTORY() != null, (Expression) visit(context.path), visitIfPresent(context.withProperties(), GenericProperties.class)); } @Override public Node visitInsert(SqlBaseParser.InsertContext context) { List<String> columns = identsToStrings(context.ident()); if (context.insertSource().VALUES() != null) { return new InsertFromValues( (Table) visit(context.table()), visit(context.insertSource().values(), ValuesList.class), columns, visit(context.assignment(), Assignment.class)); } return new InsertFromSubquery( (Table) visit(context.table()), (Query) visit(context.insertSource().query()), columns, visit(context.assignment(), Assignment.class)); } @Override public Node visitValues(SqlBaseParser.ValuesContext context) { return new ValuesList(visit(context.expr(), Expression.class)); } @Override public Node visitDelete(SqlBaseParser.DeleteContext context) { return new Delete( (Relation) visit(context.aliasedRelation()), visitIfPresent(context.where(), Expression.class)); } @Override public Node visitUpdate(SqlBaseParser.UpdateContext context) { return new Update( (Relation) visit(context.aliasedRelation()), visit(context.assignment(), Assignment.class), visitIfPresent(context.where(), Expression.class)); } @Override public Node visitSet(SqlBaseParser.SetContext context) { Assignment setAssignment = prepareSetAssignment(context); if (context.LOCAL() != null) { return new SetStatement(SetStatement.Scope.LOCAL, setAssignment); } return new SetStatement(SetStatement.Scope.SESSION, setAssignment); } private Assignment prepareSetAssignment(SqlBaseParser.SetContext context) { Expression settingName = new QualifiedNameReference(getQualifiedName(context.qname())); if (context.DEFAULT() != null) { return new Assignment(settingName, ImmutableList.of()); } return new Assignment(settingName, visit(context.setExpr(), Expression.class)); } @Override public Node visitSetGlobal(SqlBaseParser.SetGlobalContext context) { if (context.PERSISTENT() != null) { return new SetStatement(SetStatement.Scope.GLOBAL, SetStatement.SettingType.PERSISTENT, visit(context.setGlobalAssignment(), Assignment.class)); } return new SetStatement(SetStatement.Scope.GLOBAL, visit(context.setGlobalAssignment(), Assignment.class)); } @Override public Node visitResetGlobal(SqlBaseParser.ResetGlobalContext context) { return new ResetStatement(visit(context.primaryExpression(), Expression.class)); } @Override public Node visitKill(SqlBaseParser.KillContext context) { if (context.ALL() != null) { return new KillStatement(); } return new KillStatement((Expression) visit(context.jobId)); } @Override public Node visitExplain(SqlBaseParser.ExplainContext context) { return new Explain((Statement) visit(context.statement())); } @Override public Node visitShowTables(SqlBaseParser.ShowTablesContext context) { return new ShowTables( Optional.ofNullable(context.qname()).map(this::getQualifiedName), getTextIfPresent(context.pattern).map(AstBuilder::unquote), visitIfPresent(context.where(), Expression.class)); } @Override public Node visitShowSchemas(SqlBaseParser.ShowSchemasContext context) { return new ShowSchemas( getTextIfPresent(context.pattern).map(AstBuilder::unquote), visitIfPresent(context.where(), Expression.class)); } @Override public Node visitShowColumns(SqlBaseParser.ShowColumnsContext context) { return new ShowColumns( getQualifiedName(context.tableName), Optional.ofNullable(context.schema).map(this::getQualifiedName), visitIfPresent(context.where(), Expression.class), getTextIfPresent(context.pattern).map(AstBuilder::unquote)); } @Override public Node visitRefreshTable(SqlBaseParser.RefreshTableContext context) { return new RefreshStatement(visit(context.tableWithPartitions().tableWithPartition(), Table.class)); } @Override public Node visitTableOnly(SqlBaseParser.TableOnlyContext context) { return new Table(getQualifiedName(context.qname())); } @Override public Node visitTableWithPartition(SqlBaseParser.TableWithPartitionContext context) { return new Table(getQualifiedName(context.qname()), visit(context.assignment(), Assignment.class)); } @Override public Node visitCreateFunction(SqlBaseParser.CreateFunctionContext context) { QualifiedName functionName = getQualifiedName(context.name); validateFunctionName(functionName); return new CreateFunction( functionName, context.REPLACE() != null, visit(context.functionArgument(), FunctionArgument.class), (ColumnType) visit(context.returnType), (Expression) visit(context.language), (Expression) visit(context.body)); } @Override public Node visitDropFunction(SqlBaseParser.DropFunctionContext context) { QualifiedName functionName = getQualifiedName(context.name); validateFunctionName(functionName); return new DropFunction( functionName, context.EXISTS() != null, visit(context.functionArgument(), FunctionArgument.class)); } // Column / Table definition @Override public Node visitColumnDefinition(SqlBaseParser.ColumnDefinitionContext context) { if (context.generatedColumnDefinition() != null) { return visit(context.generatedColumnDefinition()); } return new ColumnDefinition( getIdentText(context.ident()), null, visitIfPresent(context.dataType(), ColumnType.class).orElse(null), visit(context.columnConstraint(), ColumnConstraint.class)); } @Override public Node visitGeneratedColumnDefinition(SqlBaseParser.GeneratedColumnDefinitionContext context) { return new ColumnDefinition( getIdentText(context.ident()), visitIfPresent(context.generatedExpr, Expression.class).orElse(null), visitIfPresent(context.dataType(), ColumnType.class).orElse(null), visit(context.columnConstraint(), ColumnConstraint.class)); } @Override public Node visitColumnConstraintPrimaryKey(SqlBaseParser.ColumnConstraintPrimaryKeyContext context) { return new PrimaryKeyColumnConstraint(); } @Override public Node visitColumnConstraintNotNull(SqlBaseParser.ColumnConstraintNotNullContext context) { return new NotNullColumnConstraint(); } @Override public Node visitPrimaryKeyConstraint(SqlBaseParser.PrimaryKeyConstraintContext context) { return new PrimaryKeyConstraint(visit(context.columns().primaryExpression(), Expression.class)); } @Override public Node visitColumnIndexOff(SqlBaseParser.ColumnIndexOffContext context) { return IndexColumnConstraint.OFF; } @Override public Node visitColumnIndexConstraint(SqlBaseParser.ColumnIndexConstraintContext context) { return new IndexColumnConstraint( getIdentText(context.method), visitIfPresent(context.withProperties(), GenericProperties.class) .orElse(GenericProperties.EMPTY)); } @Override public Node visitIndexDefinition(SqlBaseParser.IndexDefinitionContext context) { return new IndexDefinition( getIdentText(context.name), getIdentText(context.method), visit(context.columns().primaryExpression(), Expression.class), visitIfPresent(context.withProperties(), GenericProperties.class) .orElse(GenericProperties.EMPTY)); } @Override public Node visitPartitionedBy(SqlBaseParser.PartitionedByContext context) { return new PartitionedBy(visit(context.columns().primaryExpression(), Expression.class)); } @Override public Node visitClusteredBy(SqlBaseParser.ClusteredByContext context) { return new ClusteredBy( visitIfPresent(context.routing, Expression.class), visitIfPresent(context.numShards, Expression.class)); } @Override public Node visitClusteredInto(SqlBaseParser.ClusteredIntoContext context) { return new ClusteredBy(null, visitIfPresent(context.numShards, Expression.class)); } @Override public Node visitFunctionArgument(SqlBaseParser.FunctionArgumentContext context) { return new FunctionArgument(getIdentTextIfPresent(context.ident()), (ColumnType) visit(context.dataType())); } // Properties @Override public Node visitWithGenericProperties(SqlBaseParser.WithGenericPropertiesContext context) { return visitGenericProperties(context.genericProperties()); } @Override public Node visitGenericProperties(SqlBaseParser.GenericPropertiesContext context) { GenericProperties properties = new GenericProperties(); context.genericProperty().forEach(p -> properties.add((GenericProperty) visit(p))); return properties; } @Override public Node visitGenericProperty(SqlBaseParser.GenericPropertyContext context) { return new GenericProperty(getIdentText(context.ident()), (Expression) visit(context.expr())); } // Amending tables @Override public Node visitAlterTableProperties(SqlBaseParser.AlterTablePropertiesContext context) { Table name = (Table) visit(context.alterTableDefinition()); if (context.SET() != null) { return new AlterTable(name, (GenericProperties) visit(context.genericProperties())); } return new AlterTable(name, identsToStrings(context.ident())); } @Override public Node visitAlterBlobTableProperties(SqlBaseParser.AlterBlobTablePropertiesContext context) { Table name = (Table) visit(context.alterTableDefinition()); if (context.SET() != null) { return new AlterBlobTable(name, (GenericProperties) visit(context.genericProperties())); } return new AlterBlobTable(name, identsToStrings(context.ident())); } @Override public Node visitAddColumn(SqlBaseParser.AddColumnContext context) { return new AlterTableAddColumn( (Table) visit(context.alterTableDefinition()), (AddColumnDefinition) visit(context.addColumnDefinition())); } @Override public Node visitAddColumnDefinition(SqlBaseParser.AddColumnDefinitionContext context) { if (context.addGeneratedColumnDefinition() != null) { return visit(context.addGeneratedColumnDefinition()); } return new AddColumnDefinition( (Expression) visit(context.subscriptSafe()), null, visitIfPresent(context.dataType(), ColumnType.class).orElse(null), visit(context.columnConstraint(), ColumnConstraint.class)); } @Override public Node visitAddGeneratedColumnDefinition(SqlBaseParser.AddGeneratedColumnDefinitionContext context) { return new AddColumnDefinition( (Expression) visit(context.subscriptSafe()), visitIfPresent(context.generatedExpr, Expression.class).orElse(null), visitIfPresent(context.dataType(), ColumnType.class).orElse(null), visit(context.columnConstraint(), ColumnConstraint.class)); } // Assignments @Override public Node visitSetGlobalAssignment(SqlBaseParser.SetGlobalAssignmentContext context) { return new Assignment((Expression) visit(context.primaryExpression()), (Expression) visit(context.expr())); } @Override public Node visitAssignment(SqlBaseParser.AssignmentContext context) { Expression column = (Expression) visit(context.primaryExpression()); // such as it is currently hard to restrict a left side of an assignment to subscript and // qname in the grammar, because of our current grammar structure which causes the // indirect left-side recursion when attempting to do so. We restrict it before initializing // an Assignment. if (column instanceof SubscriptExpression || column instanceof QualifiedNameReference) { return new Assignment(column, (Expression) visit(context.expr())); } throw new IllegalArgumentException( String.format(Locale.ENGLISH, "cannot use expression %s as a left side of an assignment", column)); } // Query specification @Override public Node visitQuery(SqlBaseParser.QueryContext context) { Query body = (Query) visit(context.queryNoWith()); return new Query( visitIfPresent(context.with(), With.class), body.getQueryBody(), body.getOrderBy(), body.getLimit(), body.getOffset()); } @Override public Node visitWith(SqlBaseParser.WithContext context) { return new With(context.RECURSIVE() != null, visit(context.namedQuery(), WithQuery.class)); } @Override public Node visitNamedQuery(SqlBaseParser.NamedQueryContext context) { return new WithQuery( getIdentText(context.name), (Query) visit(context.query()), Optional.ofNullable(getColumnAliases(context.aliasedColumns())).orElse(ImmutableList.of())); } @Override public Node visitQueryNoWith(SqlBaseParser.QueryNoWithContext context) { QueryBody term = (QueryBody) visit(context.queryTerm()); if (term instanceof QuerySpecification) { // When we have a simple query specification // followed by order by limit, fold the order by and limit // clauses into the query specification (analyzer/planner // expects this structure to resolve references with respect // to columns defined in the query specification) QuerySpecification query = (QuerySpecification) term; return new Query( Optional.empty(), new QuerySpecification( query.getSelect(), query.getFrom(), query.getWhere(), query.getGroupBy(), query.getHaving(), visit(context.sortItem(), SortItem.class), visitIfPresent(context.limit, Expression.class), visitIfPresent(context.offset, Expression.class)), ImmutableList.of(), Optional.empty(), Optional.empty()); } return new Query( Optional.empty(), term, visit(context.sortItem(), SortItem.class), visitIfPresent(context.limit, Expression.class), visitIfPresent(context.offset, Expression.class)); } @Override public Node visitQuerySpecification(SqlBaseParser.QuerySpecificationContext context) { List<SelectItem> selectItems = visit(context.selectItem(), SelectItem.class); List<Relation> relations = null; if (context.FROM() != null) { relations = visit(context.relation(), Relation.class); } return new QuerySpecification( new Select(isDistinct(context.setQuant()), selectItems), relations, visitIfPresent(context.where(), Expression.class), visit(context.expr(), Expression.class), visitIfPresent(context.having, Expression.class), ImmutableList.of(), Optional.empty(), Optional.empty()); } @Override public Node visitWhere(SqlBaseParser.WhereContext context) { return visit(context.condition); } @Override public Node visitSortItem(SqlBaseParser.SortItemContext context) { return new SortItem( (Expression) visit(context.expr()), Optional.ofNullable(context.ordering) .map(AstBuilder::getOrderingType) .orElse(SortItem.Ordering.ASCENDING), Optional.ofNullable(context.nullOrdering) .map(AstBuilder::getNullOrderingType) .orElse(SortItem.NullOrdering.UNDEFINED)); } @Override public Node visitSetOperation(SqlBaseParser.SetOperationContext context) { QueryBody left = (QueryBody) visit(context.left); QueryBody right = (QueryBody) visit(context.right); boolean distinct = context.setQuant() == null || context.setQuant().DISTINCT() != null; switch (context.operator.getType()) { case SqlBaseLexer.UNION: return new Union(ImmutableList.of(left, right), distinct); case SqlBaseLexer.INTERSECT: return new Intersect(ImmutableList.of(left, right), distinct); case SqlBaseLexer.EXCEPT: return new Except(left, right, distinct); } throw new IllegalArgumentException("Unsupported set operation: " + context.operator.getText()); } @Override public Node visitSelectAll(SqlBaseParser.SelectAllContext context) { if (context.qname() != null) { return new AllColumns(getQualifiedName(context.qname())); } return new AllColumns(); } @Override public Node visitSelectSingle(SqlBaseParser.SelectSingleContext context) { return new SingleColumn((Expression) visit(context.expr()), getIdentTextIfPresent(context.ident())); } /* * case sensitivity like it is in postgres * see also http://www.thenextage.com/wordpress/postgresql-case-sensitivity-part-1-the-ddl/ * * unfortunately this has to be done in the parser because afterwards the * knowledge of the IDENT / QUOTED_IDENT difference is lost */ @Override public Node visitUnquotedIdentifier(SqlBaseParser.UnquotedIdentifierContext context) { return new StringLiteral(context.IDENTIFIER().getText().replace("``", "`").toLowerCase(Locale.ENGLISH)); } @Override public Node visitQuotedIdentifierAlternative(SqlBaseParser.QuotedIdentifierAlternativeContext context) { return new StringLiteral(context.getText().replace("\"\"", "\"")); } private String getIdentText(SqlBaseParser.IdentContext ident) { StringLiteral literal = (StringLiteral) visit(ident); return literal.getValue(); } private Optional<String> getIdentTextIfPresent(SqlBaseParser.IdentContext ident) { return Optional.ofNullable(ident).map(this::getIdentText); } @Override public Node visitTable(SqlBaseParser.TableContext context) { if (context.qname() != null) { return new Table(getQualifiedName(context.qname()), visit(context.parameterOrLiteral(), Assignment.class)); } FunctionCall fc = new FunctionCall( getQualifiedName(context.ident()), visit(context.parameterOrLiteral(), Expression.class)); return new TableFunction(fc); } // Boolean expressions @Override public Node visitLogicalNot(SqlBaseParser.LogicalNotContext context) { return new NotExpression((Expression) visit(context.booleanExpression())); } @Override public Node visitLogicalBinary(SqlBaseParser.LogicalBinaryContext context) { return new LogicalBinaryExpression( getLogicalBinaryOperator(context.operator), (Expression) visit(context.left), (Expression) visit(context.right)); } // From clause @Override public Node visitJoinRelation(SqlBaseParser.JoinRelationContext context) { Relation left = (Relation) visit(context.left); Relation right; if (context.CROSS() != null) { right = (Relation) visit(context.right); return new Join(Join.Type.CROSS, left, right, Optional.empty()); } JoinCriteria criteria; if (context.NATURAL() != null) { right = (Relation) visit(context.right); criteria = new NaturalJoin(); } else { right = (Relation) visit(context.rightRelation); if (context.joinCriteria().ON() != null) { criteria = new JoinOn((Expression) visit(context.joinCriteria().booleanExpression())); } else if (context.joinCriteria().USING() != null) { List<String> columns = identsToStrings(context.joinCriteria().ident()); criteria = new JoinUsing(columns); } else { throw new IllegalArgumentException("Unsupported join criteria"); } } return new Join(getJoinType(context.joinType()), left, right, Optional.of(criteria)); } private static Join.Type getJoinType(SqlBaseParser.JoinTypeContext joinTypeContext) { Join.Type joinType; if (joinTypeContext.LEFT() != null) { joinType = Join.Type.LEFT; } else if (joinTypeContext.RIGHT() != null) { joinType = Join.Type.RIGHT; } else if (joinTypeContext.FULL() != null) { joinType = Join.Type.FULL; } else { joinType = Join.Type.INNER; } return joinType; } @Override public Node visitAliasedRelation(SqlBaseParser.AliasedRelationContext context) { Relation child = (Relation) visit(context.relationPrimary()); if (context.ident() == null) { return child; } return new AliasedRelation(child, getIdentText(context.ident()), getColumnAliases(context.aliasedColumns())); } @Override public Node visitSubqueryRelation(SqlBaseParser.SubqueryRelationContext context) { return new TableSubquery((Query) visit(context.query())); } @Override public Node visitParenthesizedRelation(SqlBaseParser.ParenthesizedRelationContext context) { return visit(context.relation()); } // Predicates @Override public Node visitPredicated(SqlBaseParser.PredicatedContext context) { if (context.predicate() != null) { return visit(context.predicate()); } return visit(context.valueExpression); } @Override public Node visitComparison(SqlBaseParser.ComparisonContext context) { return new ComparisonExpression( getComparisonOperator(((TerminalNode) context.cmpOp().getChild(0)).getSymbol()), (Expression) visit(context.value), (Expression) visit(context.right)); } @Override public Node visitDistinctFrom(SqlBaseParser.DistinctFromContext context) { Expression expression = new ComparisonExpression( ComparisonExpression.Type.IS_DISTINCT_FROM, (Expression) visit(context.value), (Expression) visit(context.right)); if (context.NOT() != null) { expression = new NotExpression(expression); } return expression; } @Override public Node visitBetween(SqlBaseParser.BetweenContext context) { Expression expression = new BetweenPredicate( (Expression) visit(context.value), (Expression) visit(context.lower), (Expression) visit(context.upper)); if (context.NOT() != null) { expression = new NotExpression(expression); } return expression; } @Override public Node visitNullPredicate(SqlBaseParser.NullPredicateContext context) { Expression child = (Expression) visit(context.value); if (context.NOT() == null) { return new IsNullPredicate(child); } return new IsNotNullPredicate(child); } @Override public Node visitLike(SqlBaseParser.LikeContext context) { Expression escape = null; if (context.escape != null) { escape = (Expression) visit(context.escape); } Expression result = new LikePredicate( (Expression) visit(context.value), (Expression) visit(context.pattern), escape); if (context.NOT() != null) { result = new NotExpression(result); } return result; } @Override public Node visitArrayLike(SqlBaseParser.ArrayLikeContext context) { boolean inverse = context.NOT() != null; return new ArrayLikePredicate( getComparisonQuantifier(((TerminalNode) context.setCmpQuantifier().getChild(0)).getSymbol()), (Expression) visit(context.value), (Expression) visit(context.v), visitIfPresent(context.escape, Expression.class).orElse(null), inverse); } @Override public Node visitInList(SqlBaseParser.InListContext context) { Expression result = new InPredicate( (Expression) visit(context.value), new InListExpression(visit(context.expr(), Expression.class))); if (context.NOT() != null) { result = new NotExpression(result); } return result; } @Override public Node visitInSubquery(SqlBaseParser.InSubqueryContext context) { Expression result = new InPredicate( (Expression) visit(context.value), (Expression) visit(context.subqueryExpression())); if (context.NOT() != null) { result = new NotExpression(result); } return result; } @Override public Node visitExists(SqlBaseParser.ExistsContext context) { return new ExistsPredicate((Query) visit(context.query())); } @Override public Node visitQuantifiedComparison(SqlBaseParser.QuantifiedComparisonContext context) { return new ArrayComparisonExpression( getComparisonOperator(((TerminalNode) context.cmpOp().getChild(0)).getSymbol()), getComparisonQuantifier(((TerminalNode) context.setCmpQuantifier().getChild(0)).getSymbol()), (Expression) visit(context.value), (Expression) visit(context.parenthesizedPrimaryExpressionOrSubquery())); } @Override public Node visitMatch(SqlBaseParser.MatchContext context) { SqlBaseParser.MatchPredicateIdentsContext predicateIdents = context.matchPredicateIdents(); List<MatchPredicateColumnIdent> idents; if (predicateIdents.matchPred != null) { idents = ImmutableList.of((MatchPredicateColumnIdent) visit(predicateIdents.matchPred)); } else { idents = visit(predicateIdents.matchPredicateIdent(), MatchPredicateColumnIdent.class); } return new MatchPredicate( idents, (Expression) visit(context.term), getIdentTextIfPresent(context.matchType).orElse(null), visitIfPresent(context.withProperties(), GenericProperties.class).orElse(null)); } @Override public Node visitMatchPredicateIdent(SqlBaseParser.MatchPredicateIdentContext context) { return new MatchPredicateColumnIdent( (Expression) visit(context.subscriptSafe()), visitIfPresent(context.boost, Expression.class).orElse(null)); } // Value expressions @Override public Node visitArithmeticUnary(SqlBaseParser.ArithmeticUnaryContext context) { switch (context.operator.getType()) { case SqlBaseLexer.MINUS: return new NegativeExpression((Expression) visit(context.valueExpression())); case SqlBaseLexer.PLUS: return visit(context.valueExpression()); default: throw new UnsupportedOperationException("Unsupported sign: " + context.operator.getText()); } } @Override public Node visitArithmeticBinary(SqlBaseParser.ArithmeticBinaryContext context) { return new ArithmeticExpression( getArithmeticBinaryOperator(context.operator), (Expression) visit(context.left), (Expression) visit(context.right)); } @Override public Node visitConcatenation(SqlBaseParser.ConcatenationContext context) { return new FunctionCall( QualifiedName.of("concat"), ImmutableList.of( (Expression) visit(context.left), (Expression) visit(context.right))); } @Override public Node visitDoubleColonCast(SqlBaseParser.DoubleColonCastContext context) { return new Cast((Expression) visit(context.valueExpression()), (ColumnType) visit(context.dataType())); } // Primary expressions @Override public Node visitCast(SqlBaseParser.CastContext context) { if (context.TRY_CAST() != null) { return new TryCast((Expression) visit(context.expr()), (ColumnType) visit(context.dataType())); } else { return new Cast((Expression) visit(context.expr()), (ColumnType) visit(context.dataType())); } } @Override public Node visitSpecialDateTimeFunction(SqlBaseParser.SpecialDateTimeFunctionContext context) { CurrentTime.Type type = getDateTimeFunctionType(context.name); if (context.precision != null) { return new CurrentTime(type, Integer.parseInt(context.precision.getText())); } return new CurrentTime(type); } @Override public Node visitExtract(SqlBaseParser.ExtractContext context) { return new Extract((Expression) visit(context.expr()), (Expression) visit(context.identExpr())); } @Override public Node visitSubstring(SqlBaseParser.SubstringContext context) { return new FunctionCall(QualifiedName.of("substr"), visit(context.expr(), Expression.class)); } @Override public Node visitCurrentSchema(SqlBaseParser.CurrentSchemaContext context) { return new FunctionCall(QualifiedName.of("current_schema"), ImmutableList.of()); } @Override public Node visitNestedExpression(SqlBaseParser.NestedExpressionContext context) { return visit(context.expr()); } @Override public Node visitSubqueryExpression(SqlBaseParser.SubqueryExpressionContext context) { return new SubqueryExpression((Query) visit(context.query())); } @Override public Node visitParenthesizedPrimaryExpression(SqlBaseParser.ParenthesizedPrimaryExpressionContext context) { return visit(context.primaryExpression()); } @Override public Node visitDereference(SqlBaseParser.DereferenceContext context) { return new QualifiedNameReference( QualifiedName.of(identsToStrings(context.ident())) ); } @Override public Node visitColumnReference(SqlBaseParser.ColumnReferenceContext context) { return new QualifiedNameReference(QualifiedName.of(getIdentText(context.ident()))); } @Override public Node visitSubscript(SqlBaseParser.SubscriptContext context) { return new SubscriptExpression((Expression) visit(context.value), (Expression) visit(context.index)); } @Override public Node visitSubscriptSafe(SqlBaseParser.SubscriptSafeContext context) { if (context.qname() != null) { return new QualifiedNameReference(getQualifiedName(context.qname())); } return new SubscriptExpression((Expression) visit(context.value), (Expression) visit(context.index)); } @Override public Node visitQname(SqlBaseParser.QnameContext context) { return new QualifiedNameReference(getQualifiedName(context)); } @Override public Node visitSimpleCase(SqlBaseParser.SimpleCaseContext context) { return new SimpleCaseExpression( (Expression) visit(context.valueExpression()), visit(context.whenClause(), WhenClause.class), visitIfPresent(context.elseExpr, Expression.class).orElse(null)); } @Override public Node visitSearchedCase(SqlBaseParser.SearchedCaseContext context) { return new SearchedCaseExpression( visit(context.whenClause(), WhenClause.class), visitIfPresent(context.elseExpr, Expression.class).orElse(null)); } @Override public Node visitIfCase(SqlBaseParser.IfCaseContext context) { return new IfExpression( (Expression) visit(context.condition), (Expression) visit(context.trueValue), visitIfPresent(context.falseValue, Expression.class)); } @Override public Node visitWhenClause(SqlBaseParser.WhenClauseContext context) { return new WhenClause((Expression) visit(context.condition), (Expression) visit(context.result)); } @Override public Node visitFunctionCall(SqlBaseParser.FunctionCallContext context) { return new FunctionCall( getQualifiedName(context.qname()), isDistinct(context.setQuant()), visit(context.expr(), Expression.class)); } // Literals @Override public Node visitNullLiteral(SqlBaseParser.NullLiteralContext context) { return NullLiteral.INSTANCE; } @Override public Node visitStringLiteral(SqlBaseParser.StringLiteralContext context) { return new StringLiteral(unquote(context.STRING().getText())); } @Override public Node visitIntegerLiteral(SqlBaseParser.IntegerLiteralContext context) { return new LongLiteral(context.getText()); } @Override public Node visitDecimalLiteral(SqlBaseParser.DecimalLiteralContext context) { return new DoubleLiteral(context.getText()); } @Override public Node visitBooleanLiteral(SqlBaseParser.BooleanLiteralContext context) { return context.TRUE() != null ? BooleanLiteral.TRUE_LITERAL : BooleanLiteral.FALSE_LITERAL; } @Override public Node visitArrayLiteral(SqlBaseParser.ArrayLiteralContext context) { return new ArrayLiteral(visit(context.expr(), Expression.class)); } @Override public Node visitDateLiteral(SqlBaseParser.DateLiteralContext context) { return new DateLiteral(unquote(context.STRING().getText())); } @Override public Node visitTimeLiteral(SqlBaseParser.TimeLiteralContext context) { return new TimeLiteral(unquote(context.STRING().getText())); } @Override public Node visitTimestampLiteral(SqlBaseParser.TimestampLiteralContext context) { return new TimestampLiteral(unquote(context.STRING().getText())); } @Override public Node visitObjectLiteral(SqlBaseParser.ObjectLiteralContext context) { Multimap<String, Expression> objAttributes = LinkedListMultimap.create(); context.objectKeyValue().forEach(attr -> objAttributes.put(getIdentText(attr.key), (Expression) visit(attr.value)) ); return new ObjectLiteral(objAttributes); } @Override public Node visitParameterPlaceholder(SqlBaseParser.ParameterPlaceholderContext context) { return new ParameterExpression(parameterPosition++); } @Override public Node visitPositionalParameter(SqlBaseParser.PositionalParameterContext context) { return new ParameterExpression(Integer.valueOf(context.integerLiteral().getText())); } @Override public Node visitOn(SqlBaseParser.OnContext context) { return BooleanLiteral.TRUE_LITERAL; } // Data types @Override public Node visitDataType(SqlBaseParser.DataTypeContext context) { if (context.objectTypeDefinition() != null) { return new ObjectColumnType( getObjectType(context.objectTypeDefinition().type), visit(context.objectTypeDefinition().columnDefinition(), ColumnDefinition.class)); } else if (context.arrayTypeDefinition() != null) { return CollectionColumnType.array((ColumnType) visit(context.arrayTypeDefinition().dataType())); } else if (context.setTypeDefinition() != null) { return CollectionColumnType.set((ColumnType) visit(context.setTypeDefinition().dataType())); } return new ColumnType(context.getText().toLowerCase(Locale.ENGLISH)); } private String getObjectType(Token type) { if (type == null) return null; switch (type.getType()) { case SqlBaseLexer.DYNAMIC: return type.getText().toLowerCase(Locale.ENGLISH); case SqlBaseLexer.STRICT: return type.getText().toLowerCase(Locale.ENGLISH); case SqlBaseLexer.IGNORED: return type.getText().toLowerCase(Locale.ENGLISH); } throw new UnsupportedOperationException("Unsupported object type: " + type.getText()); } // Helpers @Override protected Node defaultResult() { return null; } @Override protected Node aggregateResult(Node aggregate, Node nextResult) { if (nextResult == null) { throw new UnsupportedOperationException("not yet implemented"); } if (aggregate == null) { return nextResult; } throw new UnsupportedOperationException("not yet implemented"); } private <T> Optional<T> visitIfPresent(ParserRuleContext context, Class<T> clazz) { return Optional.ofNullable(context) .map(this::visit) .map(clazz::cast); } private <T> List<T> visit(List<? extends ParserRuleContext> contexts, Class<T> clazz) { return contexts.stream() .map(this::visit) .map(clazz::cast) .collect(toList()); } private static String unquote(String value) { return value.substring(1, value.length() - 1) .replace("''", "'"); } private QualifiedName getQualifiedName(SqlBaseParser.QnameContext context) { return QualifiedName.of(identsToStrings(context.ident())); } private QualifiedName getQualifiedName(SqlBaseParser.IdentContext context) { return QualifiedName.of(getIdentText(context)); } private List<String> identsToStrings(List<SqlBaseParser.IdentContext> idents) { return idents.stream() .map(this::getIdentText) .collect(toList()); } private static boolean isDistinct(SqlBaseParser.SetQuantContext setQuantifier) { return setQuantifier != null && setQuantifier.DISTINCT() != null; } private static Optional<String> getTextIfPresent(ParserRuleContext context) { return Optional.ofNullable(context).map(ParseTree::getText); } private List<String> getColumnAliases(SqlBaseParser.AliasedColumnsContext columnAliasesContext) { if (columnAliasesContext == null) { return null; } return identsToStrings(columnAliasesContext.ident()); } private static ArithmeticExpression.Type getArithmeticBinaryOperator(Token operator) { switch (operator.getType()) { case SqlBaseLexer.PLUS: return ArithmeticExpression.Type.ADD; case SqlBaseLexer.MINUS: return ArithmeticExpression.Type.SUBTRACT; case SqlBaseLexer.ASTERISK: return ArithmeticExpression.Type.MULTIPLY; case SqlBaseLexer.SLASH: return ArithmeticExpression.Type.DIVIDE; case SqlBaseLexer.PERCENT: return ArithmeticExpression.Type.MODULUS; } throw new UnsupportedOperationException("Unsupported operator: " + operator.getText()); } private static ComparisonExpression.Type getComparisonOperator(Token symbol) { switch (symbol.getType()) { case SqlBaseLexer.EQ: return ComparisonExpression.Type.EQUAL; case SqlBaseLexer.NEQ: return ComparisonExpression.Type.NOT_EQUAL; case SqlBaseLexer.LT: return ComparisonExpression.Type.LESS_THAN; case SqlBaseLexer.LTE: return ComparisonExpression.Type.LESS_THAN_OR_EQUAL; case SqlBaseLexer.GT: return ComparisonExpression.Type.GREATER_THAN; case SqlBaseLexer.GTE: return ComparisonExpression.Type.GREATER_THAN_OR_EQUAL; case SqlBaseLexer.REGEX_MATCH: return ComparisonExpression.Type.REGEX_MATCH; case SqlBaseLexer.REGEX_NO_MATCH: return ComparisonExpression.Type.REGEX_NO_MATCH; case SqlBaseLexer.REGEX_MATCH_CI: return ComparisonExpression.Type.REGEX_MATCH_CI; case SqlBaseLexer.REGEX_NO_MATCH_CI: return ComparisonExpression.Type.REGEX_NO_MATCH_CI; } throw new IllegalArgumentException("Unsupported operator: " + symbol.getText()); } private static CurrentTime.Type getDateTimeFunctionType(Token token) { switch (token.getType()) { case SqlBaseLexer.CURRENT_DATE: return CurrentTime.Type.DATE; case SqlBaseLexer.CURRENT_TIME: return CurrentTime.Type.TIME; case SqlBaseLexer.CURRENT_TIMESTAMP: return CurrentTime.Type.TIMESTAMP; } throw new IllegalArgumentException("Unsupported special function: " + token.getText()); } private static LogicalBinaryExpression.Type getLogicalBinaryOperator(Token token) { switch (token.getType()) { case SqlBaseLexer.AND: return LogicalBinaryExpression.Type.AND; case SqlBaseLexer.OR: return LogicalBinaryExpression.Type.OR; } throw new IllegalArgumentException("Unsupported operator: " + token.getText()); } private static SortItem.NullOrdering getNullOrderingType(Token token) { switch (token.getType()) { case SqlBaseLexer.FIRST: return SortItem.NullOrdering.FIRST; case SqlBaseLexer.LAST: return SortItem.NullOrdering.LAST; } throw new IllegalArgumentException("Unsupported ordering: " + token.getText()); } private static SortItem.Ordering getOrderingType(Token token) { switch (token.getType()) { case SqlBaseLexer.ASC: return SortItem.Ordering.ASCENDING; case SqlBaseLexer.DESC: return SortItem.Ordering.DESCENDING; } throw new IllegalArgumentException("Unsupported ordering: " + token.getText()); } private static ArrayComparisonExpression.Quantifier getComparisonQuantifier(Token symbol) { switch (symbol.getType()) { case SqlBaseLexer.ALL: return ArrayComparisonExpression.Quantifier.ALL; case SqlBaseLexer.ANY: return ArrayComparisonExpression.Quantifier.ANY; case SqlBaseLexer.SOME: return ArrayComparisonExpression.Quantifier.ANY; } throw new IllegalArgumentException("Unsupported quantifier: " + symbol.getText()); } private static void validateFunctionName(QualifiedName functionName) { if (functionName.getParts().size() > 2) { throw new IllegalArgumentException(String.format(Locale.ENGLISH, "The function name is not correct! " + "name [%s] does not conform the [[schema_name .] function_name] format.", functionName)); } } }