package org.vertexium.cypher.ast; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import org.antlr.v4.runtime.tree.ErrorNode; import org.antlr.v4.runtime.tree.ParseTree; import org.antlr.v4.runtime.tree.TerminalNode; import org.vertexium.VertexiumException; import org.vertexium.cypher.CypherBaseVisitor; import org.vertexium.cypher.CypherParser; import org.vertexium.cypher.ast.model.*; import org.vertexium.cypher.exceptions.VertexiumCypherNotImplemented; import org.vertexium.cypher.exceptions.VertexiumCypherSyntaxErrorException; import org.vertexium.cypher.functions.CypherFunction; import org.vertexium.util.StreamUtils; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; import static org.vertexium.util.StreamUtils.stream; public class CypherCstToAstVisitor extends CypherBaseVisitor<CypherAstBase> { private final CypherCompilerContext compilerContext; public CypherCstToAstVisitor() { this(new CypherCompilerContext()); } public CypherCstToAstVisitor(CypherCompilerContext compilerContext) { this.compilerContext = compilerContext; } @Override public CypherStatement visitStatement(CypherParser.StatementContext ctx) { return new CypherStatement(visitQuery(ctx.query())); } @Override public CypherAstBase visitQuery(CypherParser.QueryContext ctx) { return visitRegularQuery(ctx.regularQuery()); } @Override public CypherAstBase visitRegularQuery(CypherParser.RegularQueryContext ctx) { CypherQuery left = visitSingleQuery(ctx.singleQuery()); if (ctx.union().size() > 0) { return visitUnions(left, ctx.union()); } return left; } @Override public CypherQuery visitSingleQuery(CypherParser.SingleQueryContext ctx) { return new CypherQuery( ctx.clause().stream() .map(this::visitClause) .collect(StreamUtils.toImmutableList()) ); } @Override public CypherClause visitClause(CypherParser.ClauseContext ctx) { Object o = super.visitClause(ctx); if (!(o instanceof CypherClause)) { throw new VertexiumException("clause not supported: " + ctx.getText()); } return (CypherClause) o; } @Override public CypherCreateClause visitCreate(CypherParser.CreateContext ctx) { ImmutableList<CypherPatternPart> patternParts = ctx.pattern().patternPart().stream() .map(this::visitPatternPart) .collect(StreamUtils.toImmutableList()); return new CypherCreateClause(patternParts); } @Override public CypherReturnClause visitReturnClause(CypherParser.ReturnClauseContext ctx) { boolean distinct = ctx.DISTINCT() != null; return new CypherReturnClause(distinct, visitReturnBody(ctx.returnBody())); } @Override public CypherReturnBody visitReturnBody(CypherParser.ReturnBodyContext ctx) { CypherParser.OrderContext order = ctx.order(); CypherParser.LimitContext limit = ctx.limit(); CypherParser.SkipContext skip = ctx.skip(); return new CypherReturnBody( visitReturnItems(ctx.returnItems()), (order == null) ? null : visitOrder(order), (limit == null) ? null : visitLimit(limit), (skip == null) ? null : visitSkip(skip) ); } @Override public CypherMatchClause visitMatch(CypherParser.MatchContext ctx) { boolean optional = ctx.OPTIONAL() != null; CypherListLiteral<CypherPatternPart> patternParts = visitPattern(ctx.pattern()); CypherAstBase whereExpression = visitWhere(ctx.where()); return new CypherMatchClause(optional, patternParts, whereExpression); } @Override public CypherListLiteral<CypherPatternPart> visitPattern(CypherParser.PatternContext ctx) { return ctx.patternPart().stream() .map(this::visitPatternPart) .collect(CypherListLiteral.collect()); } @Override public CypherPatternPart visitPatternPart(CypherParser.PatternPartContext ctx) { String name = visitVariableString(ctx.variable()); CypherListLiteral<CypherElementPattern> elementPatterns = visitAnonymousPatternPart(ctx.anonymousPatternPart()); return new CypherPatternPart(name, elementPatterns); } @Override public CypherListLiteral<CypherElementPattern> visitAnonymousPatternPart(CypherParser.AnonymousPatternPartContext ctx) { return visitPatternElement(ctx.patternElement()); } @Override public CypherListLiteral<CypherElementPattern> visitPatternElement(CypherParser.PatternElementContext ctx) { // unwind parenthesis if (ctx.patternElement() != null) { return visitPatternElement(ctx.patternElement()); } List<CypherElementPattern> list = new ArrayList<>(); CypherNodePattern nodePattern = visitNodePattern(ctx.nodePattern()); list.add(nodePattern); list.addAll(visitPatternElementChainList(nodePattern, ctx.patternElementChain())); return new CypherListLiteral<>(list); } private List<CypherElementPattern> visitPatternElementChainList( CypherNodePattern previousNodePattern, List<CypherParser.PatternElementChainContext> patternElementChainList ) { List<CypherElementPattern> list = new ArrayList<>(); for (CypherParser.PatternElementChainContext chainContext : patternElementChainList) { CypherRelationshipPattern relationshipPattern = visitRelationshipPattern(chainContext.relationshipPattern()); relationshipPattern.setPreviousNodePattern(previousNodePattern); list.add(relationshipPattern); CypherNodePattern nodePattern = visitNodePattern(chainContext.nodePattern()); relationshipPattern.setNextNodePattern(nodePattern); list.add(nodePattern); previousNodePattern = nodePattern; } return list; } @Override public CypherNodePattern visitNodePattern(CypherParser.NodePatternContext ctx) { return new CypherNodePattern( visitVariableString(ctx.variable()), visitProperties(ctx.properties()), visitNodeLabels(ctx.nodeLabels()) ); } @Override public CypherRelationshipPattern visitRelationshipPattern(CypherParser.RelationshipPatternContext ctx) { CypherParser.RelationshipDetailContext relationshipDetail = ctx.relationshipDetail(); String name; CypherListLiteral<CypherRelTypeName> relTypeNames; CypherMapLiteral<String, CypherAstBase> properties; CypherRangeLiteral range; if (relationshipDetail == null) { name = null; relTypeNames = null; properties = null; range = null; } else { if (relationshipDetail.rangeLiteral() != null) { range = visitRangeLiteral(relationshipDetail.rangeLiteral()); } else { range = null; } name = visitVariableString(relationshipDetail.variable()); if (relationshipDetail.relationshipTypes() == null) { relTypeNames = null; } else { relTypeNames = visitRelationshipTypes(relationshipDetail.relationshipTypes()); } properties = visitProperties(relationshipDetail.properties()); } CypherDirection direction = getDirectionFromRelationshipPattern(ctx); return new CypherRelationshipPattern(name, relTypeNames, properties, range, direction); } private static CypherDirection getDirectionFromRelationshipPattern(CypherParser.RelationshipPatternContext relationshipPatternContext) { if (relationshipPatternContext.leftArrowHead() != null && relationshipPatternContext.rightArrowHead() != null) { return CypherDirection.BOTH; } if (relationshipPatternContext.leftArrowHead() != null) { return CypherDirection.IN; } if (relationshipPatternContext.rightArrowHead() != null) { return CypherDirection.OUT; } return CypherDirection.UNSPECIFIED; } @Override public CypherMapLiteral<String, CypherAstBase> visitProperties(CypherParser.PropertiesContext ctx) { if (ctx == null) { return null; } //noinspection unchecked return (CypherMapLiteral<String, CypherAstBase>) super.visitProperties(ctx); } @Override public CypherMapLiteral<String, CypherAstBase> visitMapLiteral(CypherParser.MapLiteralContext ctx) { List<CypherParser.PropertyKeyNameContext> keys = ctx.propertyKeyName(); List<CypherParser.ExpressionContext> values = ctx.expression(); Map<String, CypherAstBase> result = new HashMap<>(); for (int i = 0, keysSize = keys.size(); i < keysSize; i++) { String key = visitPropertyKeyName(keys.get(i)).getValue(); CypherAstBase value = visitExpression(values.get(i)); result.put(key, value); } return new CypherMapLiteral<>(result); } @Override public CypherString visitPropertyKeyName(CypherParser.PropertyKeyNameContext ctx) { return visitSymbolicName(ctx.symbolicName()); } @Override public CypherListLiteral<CypherLabelName> visitNodeLabels(CypherParser.NodeLabelsContext ctx) { if (ctx == null) { return new CypherListLiteral<>(); } return ctx.nodeLabel().stream() .map(nl -> visitLabelName(nl.labelName())) .collect(CypherListLiteral.collect()); } @Override public CypherLabelName visitLabelName(CypherParser.LabelNameContext ctx) { return new CypherLabelName(visitSymbolicName(ctx.symbolicName()).getValue()); } @Override public CypherAstBase visitPatternElementChain(CypherParser.PatternElementChainContext ctx) { throw new VertexiumException("should not be called, see visitPatternElementChainList"); } @Override public CypherUnwindClause visitUnwind(CypherParser.UnwindContext ctx) { String name = visitVariableString(ctx.variable()); CypherAstBase expression = visitExpression(ctx.expression()); return new CypherUnwindClause(name, expression); } @Override public CypherWithClause visitWith(CypherParser.WithContext ctx) { boolean distinct = ctx.DISTINCT() != null; CypherReturnBody returnBody = visitReturnBody(ctx.returnBody()); CypherAstBase where = visitWhere(ctx.where()); return new CypherWithClause(distinct, returnBody, where); } @Override public CypherMergeClause visitMerge(CypherParser.MergeContext ctx) { CypherPatternPart patternPart = visitPatternPart(ctx.patternPart()); List<CypherMergeAction> mergeActions = ctx.mergeAction().stream() .map(this::visitMergeAction) .collect(Collectors.toList()); return new CypherMergeClause( patternPart, mergeActions ); } @Override public CypherAstBase visitWhere(CypherParser.WhereContext ctx) { if (ctx == null) { return null; } return visitExpression(ctx.expression()); } public CypherListLiteral<CypherAstBase> visitExpressions(Iterable<CypherParser.ExpressionContext> expressionContexts) { return stream(expressionContexts) .map(this::visitExpression) .collect(CypherListLiteral.collect()); } @Override public CypherAstBase visitExpression(CypherParser.ExpressionContext ctx) { return visitExpression12(ctx.expression12()); } // OR @Override public CypherAstBase visitExpression12(CypherParser.Expression12Context ctx) { List<CypherParser.Expression11Context> children = ctx.expression11(); if (children.size() == 1) { return visitExpression11(children.get(0)); } return toBinaryExpressions(ctx.children, this::visitExpression11); } // XOR @Override public CypherAstBase visitExpression11(CypherParser.Expression11Context ctx) { List<CypherParser.Expression10Context> children = ctx.expression10(); if (children.size() == 1) { return visitExpression10(children.get(0)); } return toBinaryExpressions(ctx.children, this::visitExpression10); } // AND @Override public CypherAstBase visitExpression10(CypherParser.Expression10Context ctx) { List<CypherParser.Expression9Context> children = ctx.expression9(); if (children.size() == 1) { return visitExpression9(children.get(0)); } return toBinaryExpressions(ctx.children, this::visitExpression9); } private <T extends ParseTree> CypherBinaryExpression toBinaryExpressions(List<ParseTree> children, Function<T, CypherAstBase> itemTransform) { CypherAstBase left = null; CypherBinaryExpression.Op op = null; for (int i = 0; i < children.size(); i++) { ParseTree child = children.get(i); if (child instanceof TerminalNode) { CypherBinaryExpression.Op newOp = CypherBinaryExpression.Op.parseOrNull(child.getText()); if (newOp != null) { if (op == null) { op = newOp; } else { throw new VertexiumException("unexpected op, found too many ops in a row"); } } } else { //noinspection unchecked CypherAstBase childObj = itemTransform.apply((T) child); if (left == null) { left = childObj; } else { if (op == null) { throw new VertexiumException("unexpected binary expression. expected an op between expressions"); } left = new CypherBinaryExpression(left, op, childObj); } op = null; } } return (CypherBinaryExpression) left; } // NOT @Override public CypherAstBase visitExpression9(CypherParser.Expression9Context ctx) { if (ctx.NOT().size() % 2 == 0) { return visitExpression8(ctx.expression8()); } else { return new CypherUnaryExpression(CypherUnaryExpression.Op.NOT, visitExpression8(ctx.expression8())); } } // comparison @Override public CypherAstBase visitExpression8(CypherParser.Expression8Context ctx) { List<CypherParser.PartialComparisonExpressionContext> partialComparisonExpressions = ctx.partialComparisonExpression(); if (partialComparisonExpressions.size() == 0) { return visitExpression7(ctx.expression7()); } CypherAstBase left = visitExpression7(ctx.expression7()); String op = partialComparisonExpressions.get(0).children.get(0).getText(); CypherAstBase right = visitExpression7(partialComparisonExpressions.get(0).expression7()); return new CypherBinaryExpression( new CypherComparisonExpression(left, op, right), CypherBinaryExpression.Op.AND, visitPartialComparisonExpression(right, 1, partialComparisonExpressions) ); } private CypherExpression visitPartialComparisonExpression( CypherAstBase left, int partialComparisonExpressionIndex, List<CypherParser.PartialComparisonExpressionContext> partialComparisonExpressions ) { if (partialComparisonExpressionIndex >= partialComparisonExpressions.size()) { return new CypherTrueExpression(); } String op = partialComparisonExpressions.get(partialComparisonExpressionIndex).children.get(0).getText(); CypherAstBase right = visitExpression7(partialComparisonExpressions.get(partialComparisonExpressionIndex).expression7()); CypherComparisonExpression binLeft = new CypherComparisonExpression(left, op, right); CypherExpression binRight = visitPartialComparisonExpression(right, partialComparisonExpressionIndex + 1, partialComparisonExpressions); if (binRight instanceof CypherTrueExpression) { return binLeft; } return new CypherBinaryExpression(binLeft, CypherBinaryExpression.Op.AND, binRight); } // + - @Override public CypherAstBase visitExpression7(CypherParser.Expression7Context ctx) { List<CypherParser.Expression6Context> expression6s = ctx.expression6(); if (expression6s.size() == 1) { return visitExpression6(expression6s.get(0)); } return toBinaryExpressions(ctx.children, this::visitExpression6); } // * / % @Override public CypherAstBase visitExpression6(CypherParser.Expression6Context ctx) { List<CypherParser.Expression5Context> expression5s = ctx.expression5(); if (expression5s.size() == 1) { return visitExpression5(expression5s.get(0)); } return toBinaryExpressions(ctx.children, this::visitExpression5); } // ^ @Override public CypherAstBase visitExpression5(CypherParser.Expression5Context ctx) { List<CypherParser.Expression4Context> expression4s = ctx.expression4(); if (expression4s.size() == 1) { return visitExpression4(expression4s.get(0)); } return toBinaryExpressions(ctx.children, this::visitExpression4); } // + - prefix @Override public CypherAstBase visitExpression4(CypherParser.Expression4Context ctx) { int neg = 0; for (ParseTree child : ctx.children) { if (child instanceof TerminalNode && child.getText().equals("-")) { neg++; } } CypherAstBase expr = visitExpression3(ctx.expression3()); if (neg % 2 == 1) { return new CypherNegateExpression(expr); } else { return expr; } } @Override public CypherAstBase visitExpression3(CypherParser.Expression3Context ctx) { if (ctx.children.size() == 1) { return visitExpression2(ctx.expression2(0)); } return visitExpression3(filterSpaces(ctx.children).collect(Collectors.toList())); } private Stream<ParseTree> filterSpaces(List<ParseTree> items) { return items.stream() .filter(item -> item.getText().trim().length() > 0); } private CypherAstBase visitExpression3(List<ParseTree> children) { // array slice - v[1..3] if (children.size() == 6 && children.get(1).getText().equals("[") && children.get(3).getText().equals("..") && children.get(5).getText().equals("]")) { CypherAstBase arrayExpression = visitExpression2((CypherParser.Expression2Context) children.get(0)); CypherAstBase sliceFrom = visitExpression((CypherParser.ExpressionContext) children.get(2)); CypherAstBase sliceTo = visitExpression((CypherParser.ExpressionContext) children.get(4)); return new CypherArraySlice(arrayExpression, sliceFrom, sliceTo); } // item in list - 'a' IN [ 1, 2, 3 ] else if (children.size() > 2 && children.get(1).getText().equalsIgnoreCase("IN")) { CypherAstBase valueExpression = visitExpression2((CypherParser.Expression2Context) children.get(0)); List<ParseTree> remainingChildren = children.stream().skip(2).collect(Collectors.toList()); CypherAstBase arrExpression; if (remainingChildren.size() == 1) { arrExpression = visitExpression2((CypherParser.Expression2Context) remainingChildren.get(0)); } else { arrExpression = visitExpression3(remainingChildren); } return new CypherIn(valueExpression, arrExpression); } // is null - a IS NULL else if (children.size() == 3 && children.get(1).getText().equalsIgnoreCase("IS") && children.get(2).getText().equalsIgnoreCase("NULL")) { CypherAstBase valueExpression = visitExpression2((CypherParser.Expression2Context) children.get(0)); return new CypherIsNull(valueExpression); } // is not null - a IS NOT NULL else if (children.size() == 4 && children.get(1).getText().equalsIgnoreCase("IS") && children.get(2).getText().equalsIgnoreCase("NOT") && children.get(3).getText().equalsIgnoreCase("NULL")) { CypherAstBase valueExpression = visitExpression2((CypherParser.Expression2Context) children.get(0)); return new CypherIsNotNull(valueExpression); } // starts with - 'abc' STARTS WITH 'a' else if (children.size() == 4 && children.get(1).getText().equalsIgnoreCase("STARTS") && children.get(2).getText().equalsIgnoreCase("WITH")) { CypherAstBase valueExpression = visitExpression2((CypherParser.Expression2Context) children.get(0)); CypherAstBase stringExpression = visitExpression2((CypherParser.Expression2Context) children.get(3)); return new CypherStringMatch(valueExpression, stringExpression, CypherStringMatch.Op.STARTS_WITH); } // ends with - 'abc' ENDS WITH 'a' else if (children.size() == 4 && children.get(1).getText().equalsIgnoreCase("ENDS") && children.get(2).getText().equalsIgnoreCase("WITH")) { CypherAstBase valueExpression = visitExpression2((CypherParser.Expression2Context) children.get(0)); CypherAstBase stringExpression = visitExpression2((CypherParser.Expression2Context) children.get(3)); return new CypherStringMatch(valueExpression, stringExpression, CypherStringMatch.Op.ENDS_WITH); } // contains - 'abc' CONTAINS 'a' else if (children.size() == 3 && children.get(1).getText().equalsIgnoreCase("CONTAINS")) { CypherAstBase valueExpression = visitExpression2((CypherParser.Expression2Context) children.get(0)); CypherAstBase stringExpression = visitExpression2((CypherParser.Expression2Context) children.get(2)); return new CypherStringMatch(valueExpression, stringExpression, CypherStringMatch.Op.CONTAINS); } // array index - a[0] or a[0][1] else if (children.size() >= 4 && children.get(1).getText().equals("[") && children.get(3).getText().equals("]")) { CypherAstBase arrayExpression = visitExpression2((CypherParser.Expression2Context) children.get(0)); CypherAstBase indexExpression = visitExpression((CypherParser.ExpressionContext) children.get(2)); CypherArrayAccess arrayAccess = new CypherArrayAccess(arrayExpression, indexExpression); children = children.subList(4, children.size()); while (children.size() > 0) { indexExpression = visitExpression((CypherParser.ExpressionContext) children.get(1)); arrayAccess = new CypherArrayAccess(arrayAccess, indexExpression); children = children.subList(3, children.size()); } return arrayAccess; } throw new VertexiumCypherNotImplemented("" + children.stream().map(ParseTree::getText).collect(Collectors.joining(", "))); } @Override public CypherAstBase visitExpression2(CypherParser.Expression2Context ctx) { CypherParser.AtomContext atom = ctx.atom(); List<CypherParser.PropertyLookupContext> propertyLookups = ctx.propertyLookup(); List<CypherParser.NodeLabelsContext> nodeLabels = ctx.nodeLabels(); if ((propertyLookups == null || propertyLookups.size() == 0) && (nodeLabels == null || nodeLabels.size() == 0)) { if (ctx.children.size() != 1) { throw new VertexiumCypherSyntaxErrorException("invalid expression \"" + ctx.getText() + "\""); } return visitAtom(atom); } return createLookup(atom, propertyLookups, nodeLabels); } private CypherLookup createLookup( CypherParser.AtomContext atomCtx, List<CypherParser.PropertyLookupContext> propertyLookups, List<CypherParser.NodeLabelsContext> nodeLabels ) { CypherAstBase atom = visitAtom(atomCtx); if (propertyLookups.size() == 0 && nodeLabels.size() == 0) { return new CypherLookup(atom, null, null); } else { String property = propertyLookups.stream() .map(pl -> visitPropertyLookup(pl).getValue()) .collect(Collectors.joining(".")); if (property.length() == 0) { property = null; } List<CypherLabelName> labels; if (nodeLabels == null) { labels = null; } else { labels = nodeLabels.stream() .flatMap(l -> visitNodeLabels(l).getValue().stream()) .collect(Collectors.toList()); } return new CypherLookup(atom, property, labels); } } @Override public CypherString visitPropertyLookup(CypherParser.PropertyLookupContext ctx) { return visitPropertyKeyName(ctx.propertyKeyName()); } @Override public CypherAstBase visitAtom(CypherParser.AtomContext ctx) { if (ctx.COUNT() != null) { return new CypherFunctionInvocation("count", false, new CypherMatchAll()); } return super.visitAtom(ctx); } @Override public CypherLiteral visitLiteral(CypherParser.LiteralContext ctx) { if (ctx.StringLiteral() != null) { String text = ctx.StringLiteral().getText(); return new CypherString(text.substring(1, text.length() - 1)); } return (CypherLiteral) super.visitLiteral(ctx); } @Override public CypherVariable visitVariable(CypherParser.VariableContext ctx) { if (ctx == null) { return null; } String name = visitSymbolicName(ctx.symbolicName()).getValue(); if (name == null) { return null; } return new CypherVariable(name); } public String visitVariableString(CypherParser.VariableContext ctx) { CypherVariable v = visitVariable(ctx); if (v == null) { return null; } return v.getName(); } @Override public CypherString visitSymbolicName(CypherParser.SymbolicNameContext ctx) { if (ctx.EscapedSymbolicName() != null) { return visitEscapedSymbolicName(ctx.EscapedSymbolicName()); } return new CypherString(ctx.getText()); } @Override public CypherListLiteral<CypherReturnItem> visitReturnItems(CypherParser.ReturnItemsContext ctx) { if (ctx.children.get(0).getText().equals("*")) { return CypherListLiteral.of(new CypherReturnItem("*", new CypherAllLiteral(), null)); } return ctx.returnItem().stream() .map(this::visitReturnItem) .collect(CypherListLiteral.collect()); } @Override public CypherReturnItem visitReturnItem(CypherParser.ReturnItemContext ctx) { return new CypherReturnItem( ctx.getText(), visitExpression(ctx.expression()), visitVariableString(ctx.variable()) ); } @Override public CypherAstBase visitPartialComparisonExpression(CypherParser.PartialComparisonExpressionContext ctx) { throw new VertexiumCypherNotImplemented("PartialComparisonExpression"); } @Override public CypherAstBase visitParenthesizedExpression(CypherParser.ParenthesizedExpressionContext ctx) { return visitExpression(ctx.expression()); } @Override public CypherPatternComprehension visitPatternComprehension(CypherParser.PatternComprehensionContext ctx) { CypherVariable variable = ctx.variable() == null ? null : visitVariable(ctx.variable()); CypherRelationshipsPattern relationshipsPattern = visitRelationshipsPattern(ctx.relationshipsPattern()); List<CypherParser.ExpressionContext> expressions = ctx.expression(); CypherAstBase whereExpression = expressions.size() > 1 ? visitExpression(expressions.get(0)) : null; CypherAstBase expression = visitExpression(expressions.get(expressions.size() - 1)); ArrayList<CypherElementPattern> patternPartPatterns = Lists.newArrayList(relationshipsPattern.getNodePattern()); for (CypherElementPattern elementPattern : relationshipsPattern.getPatternElementChains()) { patternPartPatterns.add(elementPattern); } CypherPatternPart patternPart = new CypherPatternPart(variable == null ? null : variable.getName(), new CypherListLiteral<>(patternPartPatterns)); CypherMatchClause matchClause = new CypherMatchClause(false, CypherListLiteral.of(patternPart), whereExpression); return new CypherPatternComprehension(matchClause, expression); } @Override public CypherLimit visitLimit(CypherParser.LimitContext ctx) { String expressionText = ctx.expression().getText(); Integer i = tryParseInteger(expressionText); if (i != null && i < 0) { throw new VertexiumCypherSyntaxErrorException("NegativeIntegerArgument: limit should only have positive arguments: " + expressionText); } CypherAstBase expression = visitExpression(ctx.expression()); return new CypherLimit(expression); } private Integer tryParseInteger(String expressionText) { try { return Integer.parseInt(expressionText); } catch (Exception ex) { return null; } } @Override public CypherBoolean visitBooleanLiteral(CypherParser.BooleanLiteralContext ctx) { if (ctx.TRUE() != null) { return new CypherBoolean(true); } if (ctx.FALSE() != null) { return new CypherBoolean(false); } throw new VertexiumException("unexpected boolean: " + ctx.getText()); } @Override public CypherOrderBy visitOrder(CypherParser.OrderContext ctx) { List<CypherSortItem> sortItems = ctx.sortItem().stream() .map(this::visitSortItem) .collect(Collectors.toList()); return new CypherOrderBy(sortItems); } @Override public CypherIdInColl visitIdInColl(CypherParser.IdInCollContext ctx) { CypherVariable variable = visitVariable(ctx.variable()); CypherAstBase expression = visitExpression(ctx.expression()); return new CypherIdInColl(variable, expression); } @Override public CypherRelTypeName visitRelTypeName(CypherParser.RelTypeNameContext ctx) { return new CypherRelTypeName(visitSymbolicName(ctx.symbolicName()).getValue()); } @Override public CypherDouble visitDoubleLiteral(CypherParser.DoubleLiteralContext ctx) { return new CypherDouble(Double.parseDouble(ctx.getText())); } @Override public CypherAstBase visitDash(CypherParser.DashContext ctx) { throw new VertexiumCypherNotImplemented("Dash"); } @Override public CypherAstBase visitNodeLabel(CypherParser.NodeLabelContext ctx) { throw new VertexiumCypherNotImplemented("NodeLabel"); } @Override public CypherAstBase visitRightArrowHead(CypherParser.RightArrowHeadContext ctx) { throw new VertexiumCypherNotImplemented("RightArrowHead"); } @Override public CypherAstBase visitPropertyExpression(CypherParser.PropertyExpressionContext ctx) { if (ctx.propertyLookup() != null) { return createLookup(ctx.atom(), ctx.propertyLookup(), null); } return visitAtom(ctx.atom()); } @Override public CypherRemoveItem visitRemoveItem(CypherParser.RemoveItemContext ctx) { if (ctx.propertyExpression() != null) { return new CypherRemovePropertyExpressionItem(visitPropertyExpression(ctx.propertyExpression())); } else { return new CypherRemoveLabelItem( visitVariable(ctx.variable()), visitNodeLabels(ctx.nodeLabels()) ); } } @Override public CypherListLiteral<CypherAstBase> visitListLiteral(CypherParser.ListLiteralContext ctx) { return visitExpressions(ctx.expression()); } @Override public CypherSkip visitSkip(CypherParser.SkipContext ctx) { CypherAstBase expression = visitExpression(ctx.expression()); return new CypherSkip(expression); } @Override public CypherAstBase visitLeftArrowHead(CypherParser.LeftArrowHeadContext ctx) { throw new VertexiumCypherNotImplemented("LeftArrowHead"); } @Override public CypherAstBase visitDelete(CypherParser.DeleteContext ctx) { boolean detach = ctx.DETACH() != null; CypherListLiteral<CypherAstBase> expressions = visitExpressions(ctx.expression()); return new CypherDeleteClause(expressions, detach); } @Override public CypherAstBase visitRemove(CypherParser.RemoveContext ctx) { List<CypherRemoveItem> removeItems = ctx.removeItem().stream() .map(this::visitRemoveItem) .collect(Collectors.toList()); return new CypherRemoveClause(removeItems); } @Override public CypherAstBase visitFunctionInvocation(CypherParser.FunctionInvocationContext ctx) { String functionName = visitFunctionName(ctx.functionName()).getValue(); CypherFunction fn = compilerContext.getFunction(functionName); if (fn == null) { throw new VertexiumCypherSyntaxErrorException("UnknownFunction: Could not find function with name \"" + functionName + "\""); } boolean distinct = ctx.DISTINCT() != null; CypherListLiteral<CypherAstBase> argumentsList = visitExpressions(ctx.expression()); CypherAstBase[] arguments = argumentsList.toArray(new CypherAstBase[argumentsList.size()]); fn.compile(compilerContext, arguments); return new CypherFunctionInvocation(functionName, distinct, arguments); } @Override public CypherAstBase visitListComprehension(CypherParser.ListComprehensionContext ctx) { CypherFilterExpression filterExpression = visitFilterExpression(ctx.filterExpression()); CypherAstBase expression = ctx.expression() == null ? null : visitExpression(ctx.expression()); return new CypherListComprehension(filterExpression, expression); } @Override public CypherStatement visitCypher(CypherParser.CypherContext ctx) { return visitStatement(ctx.statement()); } @Override public CypherAstBase visitParameter(CypherParser.ParameterContext ctx) { if (ctx.symbolicName() != null) { String parameterName = visitSymbolicName(ctx.symbolicName()).getValue(); return new CypherNameParameter(parameterName); } if (ctx.DecimalInteger() != null) { return new CypherIndexedParameter(Integer.parseInt(ctx.DecimalInteger().getText())); } throw new VertexiumCypherNotImplemented("Parameter"); } @Override public CypherMergeAction visitMergeAction(CypherParser.MergeActionContext ctx) { CypherSetClause set = visitSet(ctx.set()); if (ctx.CREATE() != null) { return new CypherMergeActionCreate(set); } else if (ctx.MATCH() != null) { return new CypherMergeActionMatch(set); } else { throw new VertexiumCypherSyntaxErrorException("Expected ON CREATE or ON MATCH"); } } @Override public CypherSortItem visitSortItem(CypherParser.SortItemContext ctx) { CypherAstBase expr = visitExpression(ctx.expression()); CypherSortItem.Direction direction; if (ctx.DESC() != null || ctx.DESCENDING() != null) { direction = CypherSortItem.Direction.DESCENDING; } else { direction = CypherSortItem.Direction.ASCENDING; } return new CypherSortItem(expr, direction); } @Override public CypherSetItem visitSetItem(CypherParser.SetItemContext ctx) { if (ctx.propertyExpression() != null) { CypherAstBase lookup = visitPropertyExpression(ctx.propertyExpression()); if (!(lookup instanceof CypherLookup)) { throw new VertexiumException("expected " + CypherLookup.class.getName() + " found " + lookup.getClass().getName()); } return new CypherSetProperty( (CypherLookup) lookup, visitExpression(ctx.expression()) ); } if (ctx.nodeLabels() != null) { return new CypherSetNodeLabels( visitVariable(ctx.variable()), visitNodeLabels(ctx.nodeLabels()) ); } CypherSetItem.Op op = getSetItemOp(ctx); return new CypherSetVariable( visitVariable(ctx.variable()), op, visitExpression(ctx.expression()) ); } private CypherSetItem.Op getSetItemOp(CypherParser.SetItemContext ctx) { for (ParseTree child : ctx.children) { if (child instanceof TerminalNode) { String text = child.getText(); if (text.equals("+=")) { return CypherSetItem.Op.PLUS_EQUAL; } else if (text.equals("=")) { return CypherSetItem.Op.EQUAL; } } } throw new VertexiumException("Could not find set item op: " + ctx.getText()); } @Override public CypherSetClause visitSet(CypherParser.SetContext ctx) { return new CypherSetClause(ctx.setItem().stream().map(this::visitSetItem).collect(Collectors.toList())); } @Override public CypherString visitFunctionName(CypherParser.FunctionNameContext ctx) { if (ctx.UnescapedSymbolicName() != null) { return visitUnescapedSymbolicName(ctx.UnescapedSymbolicName()); } else if (ctx.EscapedSymbolicName() != null) { return visitEscapedSymbolicName(ctx.EscapedSymbolicName()); } else if (ctx.COUNT() != null) { return new CypherString("count"); } else { throw new VertexiumException("unexpected function name: " + ctx.getText()); } } private CypherString visitEscapedSymbolicName(TerminalNode escapedSymbolicName) { String text = escapedSymbolicName.getText(); text = text.substring(1, text.length() - 1); return new CypherString(text); } private CypherString visitUnescapedSymbolicName(TerminalNode unescapedSymbolicName) { return new CypherString(unescapedSymbolicName.getText()); } @Override public CypherRelationshipsPattern visitRelationshipsPattern(CypherParser.RelationshipsPatternContext ctx) { CypherNodePattern nodePattern = visitNodePattern(ctx.nodePattern()); List<CypherElementPattern> patternElementChains = visitPatternElementChainList(nodePattern, ctx.patternElementChain()); return new CypherRelationshipsPattern(nodePattern, patternElementChains); } private CypherAstBase visitUnions(CypherQuery left, List<CypherParser.UnionContext> unions) { if (unions.size() == 0) { return left; } CypherParser.UnionContext firstUnion = unions.get(0); boolean all = firstUnion.ALL() != null; CypherQuery right = visitSingleQuery(firstUnion.singleQuery()); return new CypherUnion(left, visitUnions(right, unions.subList(1, unions.size())), all); } @Override public CypherUnion visitUnion(CypherParser.UnionContext ctx) { throw new VertexiumCypherNotImplemented("Union"); } @Override public CypherAstBase visitRelationshipDetail(CypherParser.RelationshipDetailContext ctx) { throw new VertexiumCypherNotImplemented("RelationshipDetail"); } @Override public CypherRangeLiteral visitRangeLiteral(CypherParser.RangeLiteralContext ctx) { Integer from = null; Integer to = null; boolean seenDotDot = false; for (ParseTree child : ctx.children) { if (child instanceof CypherParser.IntegerLiteralContext) { int i = visitIntegerLiteral((CypherParser.IntegerLiteralContext) child).getIntValue(); if (seenDotDot) { to = i; } else { from = i; } continue; } String text = child.getText(); if (text.equals("*")) { continue; } if (text.equals("..")) { seenDotDot = true; continue; } } if (!seenDotDot) { to = from; } return new CypherRangeLiteral(from, to); } @Override public CypherFilterExpression visitFilterExpression(CypherParser.FilterExpressionContext ctx) { CypherIdInColl idInCol = visitIdInColl(ctx.idInColl()); CypherAstBase where = ctx.where() == null ? null : visitWhere(ctx.where()); return new CypherFilterExpression(idInCol, where); } @Override public CypherInteger visitIntegerLiteral(CypherParser.IntegerLiteralContext ctx) { try { return new CypherInteger(Long.decode(ctx.getText())); } catch (NumberFormatException ex) { throw new VertexiumException("could not parse \"" + ctx.getText() + "\" into integer"); } } @Override public CypherListLiteral<CypherRelTypeName> visitRelationshipTypes(CypherParser.RelationshipTypesContext ctx) { return ctx.relTypeName().stream() .map(this::visitRelTypeName) .collect(CypherListLiteral.collect()); } @Override public CypherLiteral visitNumberLiteral(CypherParser.NumberLiteralContext ctx) { return (CypherLiteral) super.visitNumberLiteral(ctx); } @Override public CypherAstBase visitErrorNode(ErrorNode node) { throw new VertexiumException(String.format( "Could not parse: invalid value \"%s\" (line: %d, pos: %d)", node.getText(), node.getSymbol().getLine(), node.getSymbol().getCharPositionInLine() )); } }