/* * Licensed to CRATE Technology GmbH ("Crate") under one or more contributor * license agreements. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. Crate licenses * this file to you under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * However, if you have executed another commercial license agreement * with Crate these terms will supersede the license and you may use the * software solely pursuant to the terms of the relevant commercial agreement. */ package io.crate.testing; import com.carrotsearch.randomizedtesting.RandomizedContext; import com.google.common.base.MoreObjects; import com.google.common.base.Throwables; import io.crate.action.sql.*; import io.crate.analyze.symbol.Field; import io.crate.data.Row; import io.crate.exceptions.SQLExceptions; import io.crate.protocols.postgres.types.PGType; import io.crate.protocols.postgres.types.PGTypes; import io.crate.shade.org.postgresql.util.PGobject; import io.crate.shade.org.postgresql.util.PSQLException; import io.crate.shade.org.postgresql.util.ServerErrorMessage; import io.crate.types.DataType; import io.crate.types.DataTypes; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.health.ClusterHealthResponse; import org.elasticsearch.action.support.AdapterActionFuture; import org.elasticsearch.client.Client; import org.elasticsearch.client.Requests; import org.elasticsearch.cluster.health.ClusterHealthStatus; import org.elasticsearch.common.Priority; import org.elasticsearch.common.io.stream.NotSerializableExceptionWrapper; import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.rest.RestStatus; import org.hamcrest.Matchers; import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.io.IOException; import java.lang.reflect.Method; import java.nio.charset.StandardCharsets; import java.sql.*; import java.util.*; import java.util.concurrent.TimeUnit; import static io.crate.action.sql.SQLOperations.Session.UNNAMED; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; public class SQLTransportExecutor { private static final String SQL_REQUEST_TIMEOUT = "CRATE_TESTS_SQL_REQUEST_TIMEOUT"; public final static int DEFAULT_SOFT_LIMIT = 10_000; public static final TimeValue REQUEST_TIMEOUT = new TimeValue(Long.parseLong( MoreObjects.firstNonNull(System.getenv(SQL_REQUEST_TIMEOUT), "5")), TimeUnit.SECONDS); private static final Logger LOGGER = Loggers.getLogger(SQLTransportExecutor.class); private final ClientProvider clientProvider; public SQLTransportExecutor(ClientProvider clientProvider) { this.clientProvider = clientProvider; } public SQLResponse exec(String statement) { return exec(statement, null, REQUEST_TIMEOUT); } public SQLResponse exec(String statement, Object... params) { return exec(statement, params, REQUEST_TIMEOUT); } public SQLBulkResponse execBulk(String statement, @Nullable Object[][] bulkArgs) { return executeBulk(statement, bulkArgs, REQUEST_TIMEOUT); } public SQLBulkResponse execBulk(String statement, @Nullable Object[][] bulkArgs, TimeValue timeout) { return executeBulk(statement, bulkArgs, timeout); } private SQLResponse exec(String stmt, @Nullable Object[] args, TimeValue timeout) { String pgUrl = clientProvider.pgUrl(); Random random = RandomizedContext.current().getRandom(); if (pgUrl != null && isJdbcEnabled()) { LOGGER.trace("Executing with pgJDBC: {}", stmt); return executeWithPg(stmt, args, pgUrl, random); } try { return execute(stmt, args).actionGet(timeout); } catch (ElasticsearchTimeoutException e) { LOGGER.error("Timeout on SQL statement: {}", e, stmt); throw e; } } /** * @return true if a class or method in the stacktrace contains a @UseJdbc annotation * and based on the ration provided * <p> * Method annotations have higher priority than class annotations. */ private boolean isJdbcEnabled() { StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); for (StackTraceElement element : stackTrace) { try { Class<?> ar = Class.forName(element.getClassName()); Method method = ar.getMethod(element.getMethodName()); UseJdbc annotation = method.getAnnotation(UseJdbc.class); if (annotation == null) { annotation = ar.getAnnotation(UseJdbc.class); if (annotation == null) { continue; } } double ratio = annotation.value(); assert ratio >= 0.0 && ratio <= 1.0; if (ratio == 0) { return false; } if (ratio == 1) { return true; } return RandomizedContext.current().getRandom().nextDouble() < ratio; } catch (NoSuchMethodException | ClassNotFoundException ignored) { } } return false; } public String jdbcUrl() { assertTrue("It seems like JDBC is not enabled for this test. Did you forget to annotate the test with @UseJdbc(value = 1)?", isJdbcEnabled()); return clientProvider.pgUrl(); } public ActionFuture<SQLResponse> execute(String stmt, @Nullable Object[] args) { return execute(stmt, args, clientProvider.sqlOperations().createSession( null, null, Option.NONE, DEFAULT_SOFT_LIMIT )); } public static ActionFuture<SQLResponse> execute(String stmt, @Nullable Object[] args, SQLOperations.Session session) { final AdapterActionFuture<SQLResponse, SQLResponse> actionFuture = new TestTransportActionFuture<>(); execute(stmt, args, actionFuture, session); return actionFuture; } private static void execute(String stmt, @Nullable Object[] args, ActionListener<SQLResponse> listener, SQLOperations.Session session) { try { session.parse(UNNAMED, stmt, Collections.<DataType>emptyList()); List<Object> argsList = args == null ? Collections.emptyList() : Arrays.asList(args); session.bind(UNNAMED, UNNAMED, argsList, null); List<Field> outputFields = session.describe('P', UNNAMED); if (outputFields == null) { ResultReceiver resultReceiver = new RowCountReceiver(listener); session.execute(UNNAMED, 0, resultReceiver); } else { ResultReceiver resultReceiver = new ResultSetReceiver(listener, outputFields); session.execute(UNNAMED, 0, resultReceiver); } session.sync(); } catch (Throwable t) { listener.onFailure(SQLExceptions.createSQLActionException(t)); } } private void execute(String stmt, @Nullable Object[][] bulkArgs, final ActionListener<SQLBulkResponse> listener) { SQLOperations.Session session = clientProvider.sqlOperations().createSession( null, null, Option.NONE, DEFAULT_SOFT_LIMIT ); try { session.parse(UNNAMED, stmt, Collections.<DataType>emptyList()); if (bulkArgs == null) { bulkArgs = new Object[0][]; } final SQLBulkResponse.Result[] results = new SQLBulkResponse.Result[bulkArgs.length]; if (results.length == 0) { session.bind(UNNAMED, UNNAMED, Collections.emptyList(), null); session.execute(UNNAMED, 0, new BaseResultReceiver()); } else { for (int i = 0; i < bulkArgs.length; i++) { session.bind(UNNAMED, UNNAMED, Arrays.asList(bulkArgs[i]), null); ResultReceiver resultReceiver = new BulkRowCountReceiver(results, i); session.execute(UNNAMED, 0, resultReceiver); } } List<Field> outputColumns = session.describe('P', UNNAMED); if (outputColumns != null) { throw new UnsupportedOperationException( "Bulk operations for statements that return result sets is not supported"); } session.sync().whenComplete((Object result, Throwable t) -> { if (t == null) { listener.onResponse(new SQLBulkResponse(results)); } else { listener.onFailure(SQLExceptions.createSQLActionException(t)); } }); } catch (Throwable t) { listener.onFailure(SQLExceptions.createSQLActionException(t)); } } private SQLResponse executeWithPg(String stmt, @Nullable Object[] args, String pgUrl, Random random) { try { Properties properties = new Properties(); if (random.nextBoolean()) { properties.setProperty("prepareThreshold", "-1"); // disable prepared statements } try (Connection conn = DriverManager.getConnection(pgUrl, properties)) { conn.setAutoCommit(true); PreparedStatement preparedStatement = conn.prepareStatement(stmt); if (args != null) { for (int i = 0; i < args.length; i++) { preparedStatement.setObject(i + 1, toJdbcCompatObject(conn, args[i])); } } return executeAndConvertResult(preparedStatement); } } catch (PSQLException e) { ServerErrorMessage serverErrorMessage = e.getServerErrorMessage(); StackTraceElement[] stacktrace; if (serverErrorMessage != null) { StackTraceElement stackTraceElement = new StackTraceElement( serverErrorMessage.getFile(), serverErrorMessage.getRoutine(), serverErrorMessage.getFile(), serverErrorMessage.getLine()); stacktrace = new StackTraceElement[]{stackTraceElement}; } else { stacktrace = new StackTraceElement[]{}; } throw new SQLActionException( e.getMessage(), 0, RestStatus.BAD_REQUEST, stacktrace); } catch (SQLException e) { throw new SQLActionException(e.getMessage(), 0, RestStatus.BAD_REQUEST); } } private static Object toJdbcCompatObject(Connection connection, Object arg) { if (arg == null) { return null; } if (arg instanceof Map) { // setObject with a Map would use hstore. But that only supports text values try { return toPGObjectJson(toJsonString(((Map) arg))); } catch (SQLException | IOException e) { throw Throwables.propagate(e); } } if (arg.getClass().isArray()) { arg = Arrays.asList((Object[]) arg); } if (arg instanceof Collection) { Collection values = (Collection) arg; if (values.isEmpty()) { return null; // TODO: can't insert empty list without knowing the type } if (values.iterator().next() instanceof Map) { try { return toPGObjectJson(toJsonString(values)); } catch (SQLException | IOException e) { throw Throwables.propagate(e); } } List<Object> convertedValues = new ArrayList<>(values.size()); PGType pgType = null; for (Object value : values) { convertedValues.add(toJdbcCompatObject(connection, value)); if (pgType == null && value != null) { pgType = PGTypes.get(DataTypes.guessType(value)); } } try { return connection.createArrayOf(pgType.typName(), convertedValues.toArray(new Object[0])); } catch (SQLException e) { /* * pg error message doesn't include a stacktrace. * Set a breakpoint in {@link io.crate.protocols.postgres.Messages#sendErrorResponse(Channel, Throwable)} * to inspect the error */ throw Throwables.propagate(e); } } return arg; } private static String toJsonString(Map value) throws IOException { XContentBuilder builder = JsonXContent.contentBuilder(); builder.map(value); builder.close(); return builder.bytes().utf8ToString(); } private static String toJsonString(Collection values) throws IOException { XContentBuilder builder; builder = JsonXContent.contentBuilder(); builder.startArray(); for (Object value : values) { builder.value(value); } builder.endArray(); builder.close(); return builder.bytes().utf8ToString(); } private static PGobject toPGObjectJson(String json) throws SQLException { PGobject pGobject = new PGobject(); pGobject.setType("json"); pGobject.setValue(json); return pGobject; } private SQLResponse executeAndConvertResult(PreparedStatement preparedStatement) throws SQLException { if (preparedStatement.execute()) { ResultSetMetaData metaData = preparedStatement.getMetaData(); ResultSet resultSet = preparedStatement.getResultSet(); List<Object[]> rows = new ArrayList<>(); List<String> columnNames = new ArrayList<>(metaData.getColumnCount()); DataType[] dataTypes = new DataType[metaData.getColumnCount()]; for (int i = 0; i < metaData.getColumnCount(); i++) { columnNames.add(metaData.getColumnName(i + 1)); } while (resultSet.next()) { Object[] row = new Object[metaData.getColumnCount()]; for (int i = 0; i < row.length; i++) { Object value; String typeName = metaData.getColumnTypeName(i + 1); value = getObject(resultSet, i, typeName); row[i] = value; } rows.add(row); } return new SQLResponse( columnNames.toArray(new String[0]), rows.toArray(new Object[0][]), dataTypes, rows.size() ); } else { int updateCount = preparedStatement.getUpdateCount(); if (updateCount < 0) { /* * In Crate -1 means row-count unknown, and -2 means error. In JDBC -2 means row-count unknown and -3 means error. * See {@link java.sql.Statement#EXECUTE_FAILED} */ updateCount += 1; } return new SQLResponse( new String[0], new Object[0][], new DataType[0], updateCount ); } } /** * retrieve the same type of object from the resultSet as the CrateClient would return */ private Object getObject(ResultSet resultSet, int i, String typeName) throws SQLException { Object value; switch (typeName) { // need to use explicit `get<Type>` for some because getObject would return a wrong type. // E.g. int2 would return Integer instead of short. case "int2": Integer intValue = (Integer) resultSet.getObject(i + 1); if (intValue == null) { return null; } value = intValue.shortValue(); break; case "byte": value = resultSet.getByte(i + 1); break; case "_json": List<Object> jsonObjects = new ArrayList<>(); for (Object json : (Object[]) resultSet.getArray(i + 1).getArray()) { jsonObjects.add(jsonToObject(((PGobject) json).getValue())); } value = jsonObjects.toArray(); break; case "json": String json = resultSet.getString(i + 1); value = jsonToObject(json); break; default: value = resultSet.getObject(i + 1); break; } if (value instanceof Timestamp) { value = ((Timestamp) value).getTime(); } else if (value instanceof Array) { value = ((Array) value).getArray(); } return value; } private Object jsonToObject(String json) { try { if (json != null) { byte[] bytes = json.getBytes(StandardCharsets.UTF_8); XContentParser parser = JsonXContent.jsonXContent.createParser(bytes); if (bytes.length >= 1 && bytes[0] == '[') { parser.nextToken(); return recursiveListToArray(parser.list()); } else { return parser.mapOrdered(); } } else { return null; } } catch (IOException e) { throw Throwables.propagate(e); } } private Object recursiveListToArray(Object value) { if (value instanceof List) { List list = (List) value; Object[] arr = list.toArray(new Object[0]); for (int i = 0; i < list.size(); i++) { arr[i] = recursiveListToArray(list.get(i)); } return arr; } return value; } private SQLBulkResponse executeBulk(String stmt, Object[][] bulkArgs, TimeValue timeout) { try { AdapterActionFuture<SQLBulkResponse, SQLBulkResponse> actionFuture = new TestTransportActionFuture<>(); execute(stmt, bulkArgs, actionFuture); return actionFuture.actionGet(timeout); } catch (ElasticsearchTimeoutException e) { LOGGER.error("Timeout on SQL statement: {}", e, stmt); throw e; } } public ClusterHealthStatus ensureGreen() { return ensureState(ClusterHealthStatus.GREEN); } public ClusterHealthStatus ensureYellowOrGreen() { return ensureState(ClusterHealthStatus.YELLOW); } private ClusterHealthStatus ensureState(ClusterHealthStatus state) { Client client = clientProvider.client(); ClusterHealthResponse actionGet = client.admin().cluster().health( Requests.clusterHealthRequest() .waitForStatus(state) .waitForEvents(Priority.LANGUID).waitForNoRelocatingShards(false) ).actionGet(); if (actionGet.isTimedOut()) { LOGGER.info("ensure state timed out, cluster state:\n{}\n{}", client.admin().cluster().prepareState().get().getState().prettyPrint(), client.admin().cluster().preparePendingClusterTasks().get().prettyPrint()); assertThat("timed out waiting for state", actionGet.isTimedOut(), equalTo(false)); } if (state == ClusterHealthStatus.YELLOW) { assertThat(actionGet.getStatus(), Matchers.anyOf(equalTo(state), equalTo(ClusterHealthStatus.GREEN))); } else { assertThat(actionGet.getStatus(), equalTo(state)); } return actionGet.getStatus(); } public interface ClientProvider { Client client(); @Nullable String pgUrl(); SQLOperations sqlOperations(); } private static class TestTransportActionFuture<R> extends AdapterActionFuture<R, R> { @Override protected R convert(R response) { return response; } @Override public void onFailure(Exception e) { Throwable cause = ExceptionsHelper.unwrapCause(e); if (cause instanceof NotSerializableExceptionWrapper) { NotSerializableExceptionWrapper wrapper = ((NotSerializableExceptionWrapper) cause); SQLActionException sae = SQLActionException.fromSerializationWrapper(wrapper); if (sae != null) { e = sae; } } super.onFailure(e); } } private static final DataType[] EMPTY_TYPES = new DataType[0]; private static final String[] EMPTY_NAMES = new String[0]; private static final Object[][] EMPTY_ROWS = new Object[0][]; /** * Wrapper for testing issues. Creates a {@link SQLResponse} from * query results. */ private static class ResultSetReceiver extends BaseResultReceiver { private final List<Object[]> rows = new ArrayList<>(); private final ActionListener<SQLResponse> listener; private final List<Field> outputFields; ResultSetReceiver(ActionListener<SQLResponse> listener, List<Field> outputFields) { this.listener = listener; this.outputFields = outputFields; } @Override public void setNextRow(Row row) { rows.add(row.materialize()); } @Override public void allFinished(boolean interrupted) { listener.onResponse(createSqlResponse()); super.allFinished(interrupted); } @Override public void fail(@Nonnull Throwable t) { listener.onFailure(SQLExceptions.createSQLActionException(t)); super.fail(t); } private SQLResponse createSqlResponse() { String[] outputNames = new String[outputFields.size()]; DataType[] outputTypes = new DataType[outputFields.size()]; for (int i = 0, outputFieldsSize = outputFields.size(); i < outputFieldsSize; i++) { Field field = outputFields.get(i); outputNames[i] = field.path().outputName(); outputTypes[i] = field.valueType(); } Object[][] rowsArr = rows.toArray(new Object[0][]); BytesRefUtils.ensureStringTypesAreStrings(outputTypes, rowsArr); return new SQLResponse( outputNames, rowsArr, outputTypes, rowsArr.length ); } } /** * Wrapper for testing issues. Creates a {@link SQLResponse} with * rowCount and duration of query execution. */ private static class RowCountReceiver extends BaseResultReceiver { private final ActionListener<SQLResponse> listener; private long rowCount; RowCountReceiver(ActionListener<SQLResponse> listener) { this.listener = listener; } @Override public void setNextRow(Row row) { rowCount = (long) row.get(0); } @Override public void allFinished(boolean interrupted) { SQLResponse sqlResponse = new SQLResponse( EMPTY_NAMES, EMPTY_ROWS, EMPTY_TYPES, rowCount ); listener.onResponse(sqlResponse); super.allFinished(interrupted); } @Override public void fail(@Nonnull Throwable t) { listener.onFailure(SQLExceptions.createSQLActionException(t)); super.fail(t); } } /** * Wraps results of bulk requests for testing. */ private static class BulkRowCountReceiver extends BaseResultReceiver { private final SQLBulkResponse.Result[] results; private final int resultIdx; private long rowCount; BulkRowCountReceiver(SQLBulkResponse.Result[] results, int resultIdx) { this.results = results; this.resultIdx = resultIdx; } @Override public void setNextRow(Row row) { rowCount = ((long) row.get(0)); } @Override public void allFinished(boolean interrupted) { results[resultIdx] = new SQLBulkResponse.Result(null, rowCount); super.allFinished(interrupted); } @Override public void fail(@Nonnull Throwable t) { results[resultIdx] = new SQLBulkResponse.Result(SQLExceptions.messageOf(t), rowCount); super.fail(t); } } }