package org.vertexium.cypher.glue; import cucumber.api.DataTable; import cucumber.api.java.en.Given; import cucumber.api.java.en.Then; import cucumber.api.java.en.When; import org.junit.Assume; import org.vertexium.Authorizations; import org.vertexium.VertexiumException; import org.vertexium.cypher.TestVertexiumCypherQueryContext; import org.vertexium.cypher.VertexiumCypherQuery; import org.vertexium.cypher.VertexiumCypherResult; import org.vertexium.cypher.ast.CypherAstParser; import org.vertexium.cypher.ast.CypherCompilerContext; import org.vertexium.cypher.ast.model.CypherAstBase; import org.vertexium.inmemory.InMemoryGraph; import org.vertexium.util.IOUtils; import org.vertexium.util.VertexiumLogger; import org.vertexium.util.VertexiumLoggerFactory; import java.io.InputStream; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; import static junit.framework.TestCase.fail; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; public class GraphGlue { private static final VertexiumLogger LOGGER = VertexiumLoggerFactory.getLogger(GraphGlue.class); public static final Pattern RELATIONSHIP_REGEX = Pattern.compile("^\\[(.*?)(\\{.*\\})\\]$"); public static final Pattern NODE_REGEX = Pattern.compile("^\\((.*?)(\\{.*\\})?\\)$"); private VertexiumCypherQuery query; private VertexiumCypherResult lastResults; private TestVertexiumCypherQueryContext ctx; private Exception lastCompileTimeException; private Exception lastRuntimeException; @Given("^any graph$") public void givenAnyGraph() throws Throwable { createGraph(); } @Given("^an empty graph$") public void givenEmptyGraph() throws Throwable { createGraph(); } @Given("^the binary-tree-(\\d+) graph$") public void givenTheBinaryTreeGraph(int number) throws Throwable { createGraph(); String resourceName = "/org/vertexium/cypher/tck/binary-tree-" + number + ".cyp"; InputStream treeIn = this.getClass().getResourceAsStream(resourceName); if (treeIn == null) { throw new VertexiumException("Could not find '" + resourceName + "'"); } String cyp = IOUtils.toString(treeIn); CypherCompilerContext compilerContext = new CypherCompilerContext(ctx.getFunctions()); VertexiumCypherQuery.parse(compilerContext, cyp).execute(ctx); } private void createGraph() { InMemoryGraph graph = InMemoryGraph.create(); Authorizations authorizations = graph.createAuthorizations(); ctx = new TestVertexiumCypherQueryContext(graph, authorizations); } @Given("^parameters are:$") public void givenParametersAre(DataTable parameters) throws Throwable { for (List<String> parameterRow : parameters.raw()) { String key = parameterRow.get(0); String valueString = parameterRow.get(1); Object value = parseParameterValue(valueString); ctx.setParameter(key, value); } } private Object parseParameterValue(String valueString) { valueString = valueString.trim(); CypherAstBase expression = CypherAstParser.getInstance().parseExpression(valueString); return ctx.getExpressionExecutor().executeExpression(ctx, expression, null); } @When("^executing(.*)query:$") public void whenExecutingQuery(String queryName, String queryString) throws Throwable { ctx.clearCounts(); lastResults = null; lastCompileTimeException = null; lastRuntimeException = null; try { CypherCompilerContext compilerContext = new CypherCompilerContext(ctx.getFunctions()); query = VertexiumCypherQuery.parse(compilerContext, queryString); } catch (Exception ex) { lastCompileTimeException = ex; } try { lastResults = query.execute(ctx); } catch (Exception ex) { lastRuntimeException = ex; } } @Given("^having executed:$") public void givenHavingExecuted(String queryString) throws Throwable { CypherCompilerContext compilerContext = new CypherCompilerContext(ctx.getFunctions()); VertexiumCypherQuery.parse(compilerContext, queryString).execute(ctx); } @Then("^the result should be, in order:$") public void thenTheResultShouldBeInOrder(DataTable expected) throws Throwable { thenTheResultShouldBe(expected); } @Then("^the result should be \\(ignoring element order for lists\\):$") public void thenTheResultShouldBeIgnoringElementOrderForLists(DataTable expected) throws Throwable { thenTheResultShouldBe(expected); } @Then("^the result should be:$") public void thenTheResultShouldBe(DataTable expected) throws Throwable { if (lastCompileTimeException != null) { throw lastCompileTimeException; } if (lastRuntimeException != null) { throw lastRuntimeException; } List<String> columnNames = expected.raw().get(0); if (expected.raw().size() > 0) { List<List<String>> expectedRows = expected.raw().stream() .skip(1) .collect(Collectors.toList()); List<List<String>> foundRows = lastResults.stream() .map(row -> columnNames.stream() .map(columnName -> row.getByName(columnName)) .map(obj -> ctx.getResultWriter().columnValueToString(ctx, obj)) .collect(Collectors.toList()) ) .collect(Collectors.toList()); expectedRows.sort(new RowComparator()); foundRows.sort(new RowComparator()); if (expectedRows.size() > 0) { System.out.println("Expected"); System.out.println(expected.raw().get(0).stream().collect(Collectors.joining(", "))); for (List<String> expectedRow : expectedRows) { System.out.println(expectedRow.stream().collect(Collectors.joining(", "))); } } System.out.println("Found"); System.out.println(columnNames.stream().collect(Collectors.joining(", "))); for (List<String> foundRow : foundRows) { System.out.println(foundRow.stream().collect(Collectors.joining(", "))); } if (expectedRows.size() > 0) { assertEquals("Header count", expected.raw().get(0).size(), lastResults.getColumnNames().size()); for (int colIdx = 0; colIdx < expected.raw().get(0).size(); colIdx++) { String expectedColumnName = expected.raw().get(0).get(colIdx); assertTrue("Header mismatch", lastResults.getColumnNames().contains(expectedColumnName)); } } assertEquals("result count", expectedRows.size(), foundRows.size()); if (expectedRows.size() > 0) { for (int row = 0; row < expectedRows.size(); row++) { List<String> expectedRow = expectedRows.get(row); List<String> foundRow = foundRows.get(row); for (int col = 0; col < expectedRow.size(); col++) { String expectedColumn = expectedRow.get(col); String foundColumn = foundRow.get(col); assertColumnValue(row, col, expectedColumn, foundColumn); } } } } } private final class RowComparator implements Comparator<List<String>> { @Override public int compare(List<String> list1, List<String> list2) { for (int i = 0; i < list1.size(); i++) { String o1 = list1.get(i); String o2 = list2.get(i); int c = o1.compareTo(o2); if (c != 0) { return c; } } return 0; } } private void assertColumnValue(int row, int column, String expected, String found) { if (expected == null && found == null) { return; } if (rowStringIsRelationship(expected) && rowStringIsRelationship(found)) { assertColumnValueRelationship(row, column, expected, found); } else if (rowStringIsNode(expected) && rowStringIsNode(found)) { assertColumnValueNode(row, column, expected, found); } else if (rowStringIsList(expected) && rowStringIsList(found)) { assertColumnValueList(row, column, expected, found); } else if (rowStringIsMap(expected) && rowStringIsMap(found)) { assertColumnValueMap(expected, found); } else { assertEquals(row + ":" + column, expected, found); } } private boolean rowStringIsNode(String string) { return NODE_REGEX.matcher(string).matches(); } private void assertColumnValueNode(int row, int column, String expected, String found) { Matcher expectedMatch = NODE_REGEX.matcher(expected); Matcher foundMatch = NODE_REGEX.matcher(found); assertTrue(expectedMatch.matches()); assertTrue(foundMatch.matches()); String[] expectedLabels = expectedMatch.group(1).split(":"); String[] foundLabels = foundMatch.group(1).split(":"); assertEquals(expectedMatch.group(1) + " does not equal length of " + foundMatch.group(1), expectedLabels.length, foundLabels.length); Arrays.sort(expectedLabels); Arrays.sort(foundLabels); for (int i = 0; i < expectedLabels.length; i++) { assertEquals(expectedMatch.group(1) + " does not equal length of " + foundMatch.group(1), expectedLabels[i], foundLabels[i]); } assertColumnValue(row, column, expectedMatch.group(2), foundMatch.group(2)); } private boolean rowStringIsRelationship(String string) { return RELATIONSHIP_REGEX.matcher(string).matches(); } private void assertColumnValueRelationship(int row, int column, String expected, String found) { Matcher expectedMatch = RELATIONSHIP_REGEX.matcher(expected); Matcher foundMatch = RELATIONSHIP_REGEX.matcher(found); assertTrue(expectedMatch.matches()); assertTrue(foundMatch.matches()); assertColumnValue(row, column, expectedMatch.group(1), foundMatch.group(1)); assertColumnValue(row, column, expectedMatch.group(2), foundMatch.group(2)); } private boolean rowStringIsMap(String columnString) { return columnString.startsWith("{") && columnString.endsWith("}"); } private void assertColumnValueMap(String expected, String found) { try { Map<?, ?> expectedMap = columnValueToMap(expected); Map<?, ?> foundMap = columnValueToMap(found); assertEquals(expectedMap.keySet(), foundMap.keySet()); for (Object key : expectedMap.keySet()) { Object expectedValue = expectedMap.get(key); Object foundValue = foundMap.get(key); assertEquals(expectedValue, foundValue); } } catch (Exception ex) { assertEquals(expected, found); } } private Map<?, ?> columnValueToMap(String columnValue) { CypherAstBase expression = CypherAstParser.getInstance().parseExpression(columnValue); return (Map) ctx.getExpressionExecutor().executeExpression(ctx, expression, null); } private boolean rowStringIsList(String columnString) { return columnString.startsWith("[") && columnString.endsWith("]"); } private void assertColumnValueList(int row, int column, String expected, String found) { List<String> expectedList = columnValueToList(expected); List<String> foundList = columnValueToList(found); assertEquals(expectedList.size(), foundList.size()); for (int i = 0; i < expectedList.size(); i++) { String expectedListItem = expectedList.get(i); String foundListItem = foundList.get(i); assertColumnValue(row, column, expectedListItem, foundListItem); } } private List<String> columnValueToList(String columnValue) { columnValue = columnValue.substring(1, columnValue.length() - 1); return Arrays.stream(columnValue.split(",")) .map(String::trim) .sorted() .collect(Collectors.toList()); } @Then("^a (.*) should be raised at compile time: (.*)$") public void thenASyntaxErrorShouldBeRaisedAtCompileTime(String errorType, String error) throws Throwable { Exception ex = this.lastCompileTimeException; if (ex == null) { if (lastRuntimeException != null) { lastRuntimeException.printStackTrace(); // TODO do we care? LOGGER.warn("statement should have resulted in a compile time exception, but resulted in a runtime exception"); ex = lastRuntimeException; } else { Assume.assumeTrue("statement should have resulted in a compile time exception", false); } } if (!ex.getClass().getName().contains(errorType)) { ex.printStackTrace(); // TODO do we care? LOGGER.warn("exception type should contain \"" + errorType + "\", but only contained \"" + ex.getClass().getName() + "\": " + ex); } if (ex.getMessage() == null || !ex.getMessage().contains(error)) { ex.printStackTrace(); // TODO do we care? LOGGER.warn("exception should contain \"" + error + "\", but only contained \"" + ex.getMessage() + "\""); } } @Then("^a (.*) should be raised at runtime: (.*)$") public void thenATypeErrorShouldBeRaisedAtRuntime(String errorType, String error) throws Throwable { if (lastRuntimeException == null) { if (lastCompileTimeException != null) { fail("statement should have resulted in a runtime exception, but resulted in a compile time exception"); } else { fail("statement should have resulted in a runtime exception"); } } if (!lastRuntimeException.getClass().getName().contains(errorType)) { lastRuntimeException.printStackTrace(); // TODO do we care? LOGGER.warn("exception type should contain \"" + errorType + "\", but only contained \"" + lastRuntimeException.getClass().getName() + "\": " + lastRuntimeException); } if (!lastRuntimeException.getMessage().contains(error)) { lastRuntimeException.printStackTrace(); // TODO do we care? LOGGER.warn("exception should contain \"" + error + "\", but only contained \"" + lastRuntimeException.getMessage() + "\""); } } @Then("^the result should be empty$") public void thenTheResultsShouldBeEmpty() throws Throwable { if (lastCompileTimeException != null) { throw lastCompileTimeException; } if (lastRuntimeException != null) { throw lastRuntimeException; } assertEquals(0, lastResults.size()); } @Then("^no side effects$") public void noSideEffects() throws Throwable { assertEquals("+node", 0, ctx.getPlusNodeCount()); assertEquals("+relationship", 0, ctx.getPlusRelationshipCount()); assertEquals("+label", 0, ctx.getPlusLabelCount()); assertEquals("+property", 0, ctx.getPlusPropertyCount()); assertEquals("-node", 0, ctx.getMinusNodeCount()); assertEquals("-relationship", 0, ctx.getMinusRelationshipCount()); assertEquals("-label", 0, ctx.getMinusLabelCount()); assertEquals("-property", 0, ctx.getMinusPropertyCount()); } @Then("^the side effects should be:$") public void thenTheSideEffectsShouldBe(DataTable table) throws Throwable { for (List<String> tableRow : table.raw()) { if (tableRow.size() == 2 && tableRow.get(0).equals("+nodes")) { int plusNodes = Integer.parseInt(tableRow.get(1)); assertEquals("+nodes", plusNodes, ctx.getPlusNodeCount()); } else if (tableRow.size() == 2 && tableRow.get(0).equals("+relationships")) { int plusRelationships = Integer.parseInt(tableRow.get(1)); assertEquals("+relationships", plusRelationships, ctx.getPlusRelationshipCount()); } else if (tableRow.size() == 2 && tableRow.get(0).equals("+labels")) { int plusLabels = Integer.parseInt(tableRow.get(1)); assertEquals("+labels", plusLabels, ctx.getPlusLabelCount()); } else if (tableRow.size() == 2 && tableRow.get(0).equals("+properties")) { int plusProperties = Integer.parseInt(tableRow.get(1)); assertEquals("+properties", plusProperties, ctx.getPlusPropertyCount()); } else if (tableRow.size() == 2 && tableRow.get(0).equals("-nodes")) { int minusNodes = Integer.parseInt(tableRow.get(1)); assertEquals("-nodes", minusNodes, ctx.getMinusNodeCount()); } else if (tableRow.size() == 2 && tableRow.get(0).equals("-relationships")) { int minusRelationships = Integer.parseInt(tableRow.get(1)); assertEquals("-relationships", minusRelationships, ctx.getMinusRelationshipCount()); } else if (tableRow.size() == 2 && tableRow.get(0).equals("-labels")) { int minusLabels = Integer.parseInt(tableRow.get(1)); assertEquals("-labels", minusLabels, ctx.getMinusLabelCount()); } else if (tableRow.size() == 2 && tableRow.get(0).equals("-properties")) { int minusProperties = Integer.parseInt(tableRow.get(1)); assertEquals("-properties", minusProperties, ctx.getMinusPropertyCount()); } else { fail("Unhandled side effect row: " + tableRow.stream().collect(Collectors.joining(","))); } } } }