/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.teradata.tempto.assertions; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.teradata.tempto.configuration.Configuration; import com.teradata.tempto.internal.convention.SqlResultDescriptor; import com.teradata.tempto.internal.query.QueryResultValueComparator; import com.teradata.tempto.query.QueryExecutionException; import com.teradata.tempto.query.QueryExecutor; import com.teradata.tempto.query.QueryResult; import org.assertj.core.api.AbstractAssert; import org.assertj.core.api.Assertions; import org.slf4j.Logger; import java.sql.JDBCType; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; import static com.google.common.collect.Lists.newArrayList; import static com.teradata.tempto.assertions.QueryAssert.Row.row; import static com.teradata.tempto.internal.configuration.TestConfigurationFactory.testConfiguration; import static com.teradata.tempto.query.QueryResult.fromSqlIndex; import static com.teradata.tempto.query.QueryResult.toSqlIndex; import static java.lang.String.format; import static java.sql.JDBCType.INTEGER; import static java.util.Collections.singletonList; import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; import static java.util.Optional.ofNullable; import static org.slf4j.LoggerFactory.getLogger; public class QueryAssert extends AbstractAssert<QueryAssert, QueryResult> { private static final Logger LOGGER = getLogger(QueryExecutor.class); private static final NumberFormat DECIMAL_FORMAT = new DecimalFormat("#0.00000000000"); private final List<Comparator<Object>> columnComparators; private final List<JDBCType> columnTypes; private QueryAssert(QueryResult actual) { super(actual, QueryAssert.class); this.columnComparators = getComparators(actual); this.columnTypes = actual.getColumnTypes(); } public static QueryAssert assertThat(QueryResult queryResult) { return new QueryAssert(queryResult); } public static QueryExecutionAssert assertThat(QueryCallback queryCallback) { QueryExecutionException executionException = null; try { queryCallback.executeQuery(); } catch (QueryExecutionException e) { executionException = e; } return new QueryExecutionAssert(ofNullable(executionException)); } public QueryAssert matches(SqlResultDescriptor sqlResultDescriptor) { if (sqlResultDescriptor.getExpectedTypes().isPresent()) { hasColumns(sqlResultDescriptor.getExpectedTypes().get()); } List<Row> rows = null; try { rows = sqlResultDescriptor.getRows(columnTypes); } catch (Exception e) { failWithMessage("Could not map expected file content to query column types; types=%s; content=<%s>; error=<%s>", columnTypes, sqlResultDescriptor.getOriginalContent(), e.getMessage()); } if (sqlResultDescriptor.isIgnoreOrder()) { contains(rows); } else { containsExactly(rows); } if (!sqlResultDescriptor.isIgnoreExcessRows()) { hasRowsCount(rows.size()); } return this; } public QueryAssert hasRowsCount(int resultCount) { if (actual.getRowsCount() != resultCount) { failWithMessage("Expected row count to be <%s>, but was <%s>; rows=%s", resultCount, actual.getRowsCount(), actual.rows()); } return this; } public QueryAssert hasNoRows() { return hasRowsCount(0); } public QueryAssert hasAnyRows() { if (actual.getRowsCount() == 0) { failWithMessage("Expected some rows to be returned from query"); } return this; } public QueryAssert hasColumnsCount(int columnCount) { if (actual.getColumnsCount() != columnCount) { failWithMessage("Expected column count to be <%s>, but was <%s> - columns <%s>", columnCount, actual.getColumnsCount(), actual.getColumnTypes()); } return this; } public QueryAssert hasColumns(List<JDBCType> expectedTypes) { hasColumnsCount(expectedTypes.size()); for (int i = 0; i < expectedTypes.size(); i++) { JDBCType expectedType = expectedTypes.get(i); JDBCType actualType = actual.getColumnType(toSqlIndex(i)); if (!actualType.equals(expectedType)) { failWithMessage("Expected <%s> column of type <%s>, but was <%s>, actual columns: %s", i, expectedType, actualType, actual.getColumnTypes()); } } return this; } public QueryAssert hasColumns(JDBCType... expectedTypes) { return hasColumns(Arrays.asList(expectedTypes)); } /** * Verifies that the actual result set contains all the given {@code rows} * * @param rows Rows to be matched * @return this */ public QueryAssert contains(List<Row> rows) { List<List<Object>> missingRows = newArrayList(); for (Row row : rows) { List<Object> expectedRow = row.getValues(); if (!containsRow(expectedRow)) { missingRows.add(expectedRow); } } if (!missingRows.isEmpty()) { failWithMessage(buildContainsMessage(missingRows)); } return this; } /** * @param rows Rows to be matched * @return this * @see #contains(java.util.List) */ public QueryAssert contains(Row... rows) { return contains(Arrays.asList(rows)); } /** * Verifies that the actual result set consist of only {@code rows} in any order * * @param rows Rows to be matched * @return this */ public QueryAssert containsOnly(List<Row> rows) { hasRowsCount(rows.size()); contains(rows); return this; } /** * @param rows Rows to be matched * @return this * @see #containsOnly(java.util.List) */ public QueryAssert containsOnly(Row... rows) { return containsOnly(Arrays.asList(rows)); } /** * Verifies that the actual result set equals to {@code rows}. * ResultSet in different order or with any extra rows perceived as not same * * @param rows Rows to be matched * @return this */ public QueryAssert containsExactly(List<Row> rows) { hasRowsCount(rows.size()); List<Integer> unequalRowsIndexes = newArrayList(); for (int rowIndex = 0; rowIndex < rows.size(); rowIndex++) { List<Object> expectedRow = rows.get(rowIndex).getValues(); List<Object> actualRow = actual.row(rowIndex); if (!rowsEqual(expectedRow, actualRow)) { unequalRowsIndexes.add(rowIndex); } } if (!unequalRowsIndexes.isEmpty()) { failWithMessage(buildContainsExactlyErrorMessage(unequalRowsIndexes, rows)); } return this; } /** * @param rows Rows to be matched * @return this * @see #containsExactly(java.util.List) */ public QueryAssert containsExactly(Row... rows) { return containsExactly(Arrays.asList(rows)); } /** * Verifies number of rows updated/inserted by last update query * * @param count Number of rows expected * @return this */ public QueryAssert updatedRowsCountIsEqualTo(int count) { hasRowsCount(1); hasColumnsCount(1); hasColumns(INTEGER); containsExactly(row(count)); return this; } private static List<Comparator<Object>> getComparators(QueryResult queryResult) { Configuration configuration = testConfiguration(); return queryResult.getColumnTypes().stream() .map(it -> QueryResultValueComparator.comparatorForType(it, configuration)) .collect(Collectors.toList()); } private String buildContainsMessage(List<List<Object>> missingRows) { StringBuilder msg = new StringBuilder("Could not find rows:"); appendRows(msg, missingRows); msg.append("\n\nactual rows:"); appendRows(msg, actual.rows()); return msg.toString(); } private void appendRows(StringBuilder msg, List<List<Object>> rows) { rows.stream().forEach(row -> msg.append('\n').append(row)); } private String buildContainsExactlyErrorMessage(List<Integer> unequalRowsIndexes, List<Row> rows) { StringBuilder msg = new StringBuilder("Not equal rows:"); for (Integer unequalRowsIndex : unequalRowsIndexes) { int unequalRowIndex = unequalRowsIndex; msg.append('\n'); msg.append(unequalRowIndex); msg.append(" - expected: "); msg.append(rows.get(unequalRowIndex)); msg.append('\n'); msg.append(unequalRowIndex); msg.append(" - actual: "); msg.append(new Row(actual.row(unequalRowIndex))); } return msg.toString(); } private boolean containsRow(List<Object> expectedRow) { for (int i = 0; i < actual.getRowsCount(); i++) { if (rowsEqual(expectedRow, actual.row(i))) { return true; } } return false; } private boolean rowsEqual(List<Object> expectedRow, List<Object> actualRow) { if (expectedRow.size() != actualRow.size()) { return false; } for (int i = 0; i < expectedRow.size(); ++i) { List<Object> acceptableValues = expectedRow.get(i) instanceof AcceptableValues ? ((AcceptableValues) expectedRow.get(i)).getValues() : singletonList(expectedRow.get(i)); Object actualValue = actualRow.get(i); if (!isAnyValueEqual(i, acceptableValues, actualValue)) { return false; } } return true; } private boolean isAnyValueEqual(int column, List<Object> expectedValues, Object actualValue) { for (Object expectedValue : expectedValues) { if (columnComparators.get(column).compare(actualValue, expectedValue) == 0) { return true; } } return false; } public <T> QueryAssert column(int columnIndex, JDBCType type, ColumnValuesAssert<T> columnValuesAssert) { if (fromSqlIndex(columnIndex) > actual.getColumnsCount()) { failWithMessage("Result contains only <%s> columns, extracting column <%s>", actual.getColumnsCount(), columnIndex); } JDBCType actualColumnType = actual.getColumnType(columnIndex); if (!type.equals(actualColumnType)) { failWithMessage("Expected <%s> column, to be type: <%s>, but was: <%s>", columnIndex, type, actualColumnType); } List<T> columnValues = actual.column(columnIndex); columnValuesAssert.assertColumnValues(Assertions.assertThat(columnValues)); return this; } public <T> QueryAssert column(String columnName, JDBCType type, ColumnValuesAssert<T> columnValuesAssert) { Optional<Integer> index = actual.tryFindColumnIndex(columnName); if (!index.isPresent()) { failWithMessage("No column with name: <%s>", columnName); } return column(index.get(), type, columnValuesAssert); } public static AcceptableValues anyOf(Object... values) { return new AcceptableValues(Arrays.asList(values)); } @FunctionalInterface public static interface QueryCallback { QueryResult executeQuery() throws QueryExecutionException; } public static class QueryExecutionAssert { private Optional<QueryExecutionException> executionExceptionOptional; public QueryExecutionAssert(Optional<QueryExecutionException> executionExceptionOptional) { this.executionExceptionOptional = executionExceptionOptional; } public QueryExecutionAssert failsWithMessage(String... expectedErrorMessages) { if (!executionExceptionOptional.isPresent()) { throw new AssertionError("Query did not fail as expected."); } QueryExecutionException executionException = executionExceptionOptional.get(); String exceptionMessage = executionException.getMessage(); LOGGER.debug("Query failed as expected with message: {}", exceptionMessage); for (String expectedErrorMessage : expectedErrorMessages) { if (!exceptionMessage.contains(expectedErrorMessage)) { throw new AssertionError(format( "Query failed with unexpected error message: '%s' \n Expected error message was '%s'", exceptionMessage, expectedErrorMessage )); } } return this; } } public static class Row { private final List<Object> values; public Row(Object... values) { this(newArrayList(values)); } public Row(List<Object> values) { this.values = requireNonNull(values, "values is null"); } public List<Object> getValues() { return values; } public static Row row(Object... values) { return new Row(values); } @Override public String toString() { StringBuilder msg = new StringBuilder(); for (Object value : values) { if (value instanceof Double || value instanceof Float) { msg.append(DECIMAL_FORMAT.format(value)); } else if (value == null) { msg.append("null"); } else { msg.append(value.toString()); } msg.append('|'); } return msg.toString(); } } public static class AcceptableValues { private final List<Object> values; public AcceptableValues(List<Object> values) { this.values = unmodifiableList(new ArrayList<>(requireNonNull(values, "values can not be null"))); } public List<Object> getValues() { return values; } @Override public String toString() { return "anyOf(" + Joiner.on(", ").join(values) + ")"; } } }