/* * Licensed to Crate under one or more contributor license agreements. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. Crate licenses this file * to you under the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. * * However, if you have executed another commercial license agreement * with Crate these terms will supersede the license and you may use the * software solely pursuant to the terms of the relevant commercial * agreement. */ package io.crate.protocols.postgres; import com.google.common.annotations.VisibleForTesting; import io.crate.action.sql.ResultReceiver; import io.crate.action.sql.SQLOperations; import io.crate.action.sql.SessionContext; import io.crate.analyze.symbol.Field; import io.crate.analyze.symbol.Symbols; import io.crate.operation.auth.Authentication; import io.crate.operation.auth.AuthenticationMethod; import io.crate.operation.auth.HbaProtocol; import io.crate.protocols.postgres.types.PGType; import io.crate.protocols.postgres.types.PGTypes; import io.crate.types.DataType; import org.apache.logging.log4j.Logger; import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.network.InetAddresses; import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.buffer.ChannelBuffers; import org.jboss.netty.channel.*; import org.jboss.netty.handler.codec.frame.FrameDecoder; import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.net.InetAddress; import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; import java.util.*; import java.util.function.BiConsumer; import static io.crate.protocols.postgres.ConnectionContext.State.STARTUP_HEADER; import static io.crate.protocols.postgres.FormatCodes.getFormatCode; /** * ConnectionContext for the Postgres wire protocol.<br /> * This class handles the message flow and dispatching * <p> * <p> * <pre> * Client Server * * (optional ssl negotiation) * * * | SSLRequest | * |--------------------------------->| * | | * | 'S' | 'N' | error | (always N - ssl not supported) * |<---------------------------------| * * * startup: * The authentication flow is handled by implementations of {@link AuthenticationMethod}. * * | | * | StartupMessage | * |--------------------------------->| * | | * | Authentication<Method> | * | or | * | AuthenticationOK | * | or | * | ErrorResponse | * |<---------------------------------| * | | * | ParameterStatus | * |<---------------------------------| * | | * | ReadyForQuery | * |<---------------------------------| * * * Simple Query: * * + + * | Q (query) | * |--------------------------------->| * | | * | RowDescription | * |<---------------------------------| * | | * | DataRow | * |<---------------------------------| * | DataRow | * |<---------------------------------| * | CommandComplete | * |<---------------------------------| * | ReadyForQuery | * |<---------------------------------| * * Extended Query * * + + * | Parse | * |--------------------------------->| * | | * | ParseComplete or ErrorResponse | * |<---------------------------------| * | | * | Bind | * |--------------------------------->| * | | * | BindComplete or ErrorResponse | * |<---------------------------------| * | | * | Describe (optional) | * |--------------------------------->| * | | * | RowDescription (optional) | * |<-------------------------------- | * | | * | Execute | * |--------------------------------->| * | | * | DataRow | | * | CommandComplete | | * | EmptyQueryResponse | | * | ErrorResponse | * |<---------------------------------| * | | * | Sync | * |--------------------------------->| * | | * | ReadyForQuery | * |<---------------------------------| * </pre> * <p> * Take a look at {@link Messages} to see how the messages are structured. * <p> * See https://www.postgresql.org/docs/current/static/protocol-flow.html for a more detailed description of the message flow */ class ConnectionContext { private static final Logger LOGGER = Loggers.getLogger(ConnectionContext.class); final MessageDecoder decoder; final MessageHandler handler; private final SQLOperations sqlOperations; private final Authentication authService; private int msgLength; private byte msgType; private SQLOperations.Session session; private boolean ignoreTillSync = false; private SessionContext sessionContext; enum State { SSL_NEG, STARTUP_HEADER, STARTUP_BODY, MSG_HEADER, MSG_BODY } private State state = STARTUP_HEADER; ConnectionContext(SQLOperations sqlOperations, Authentication authService) { this.sqlOperations = sqlOperations; this.authService = authService; decoder = new MessageDecoder(); handler = new MessageHandler(); } private static void traceLogProtocol(int protocol) { if (LOGGER.isTraceEnabled()) { int major = protocol >> 16; int minor = protocol & 0x0000FFFF; LOGGER.trace("protocol {}.{}", major, minor); } } private static String readCString(ChannelBuffer buffer) { byte[] bytes = new byte[buffer.bytesBefore((byte) 0) + 1]; if (bytes.length == 0) { return null; } buffer.readBytes(bytes); return new String(bytes, 0, bytes.length - 1, StandardCharsets.UTF_8); } private SessionContext readStartupMessage(ChannelBuffer buffer) { Properties properties = new Properties(); ChannelBuffer channelBuffer = buffer.readBytes(msgLength); while (true) { String key = readCString(channelBuffer); if (key == null) { break; } String value = readCString(channelBuffer); LOGGER.trace("payload: key={} value={}", key, value); if (!"".equals(key) && !"".equals(value)) { properties.setProperty(key, value); } } return new SessionContext(properties); } private static class ReadyForQueryCallback implements BiConsumer<Object, Throwable> { private final Channel channel; private ReadyForQueryCallback(Channel channel) { this.channel = channel; } @Override public void accept(Object o, Throwable t) { if (t == null) { onSuccess(o); } else { onFailure(t); } } private void onSuccess(@Nullable Object result) { if (result == null || result.equals(Boolean.FALSE)) { // only send ReadyForQuery if query was not interrupted Messages.sendReadyForQuery(channel); } } private void onFailure(@Nonnull Throwable t) { Messages.sendReadyForQuery(channel); } } private class MessageHandler extends SimpleChannelUpstreamHandler { @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { Object m = e.getMessage(); if (!(m instanceof ChannelBuffer)) { ctx.sendUpstream(e); return; } ChannelBuffer buffer = (ChannelBuffer) m; final Channel channel = ctx.getChannel(); try { dispatchState(buffer, channel); } catch (Throwable t) { ignoreTillSync = true; try { Messages.sendErrorResponse(channel, t); } catch (Throwable ti) { LOGGER.error("Error trying to send error to client: {}", t, ti); } } } private void dispatchState(ChannelBuffer buffer, Channel channel) { switch (state) { case SSL_NEG: state = STARTUP_HEADER; handleStartupHeader(buffer, channel); return; case STARTUP_HEADER: case MSG_HEADER: throw new IllegalStateException("Decoder should've processed the headers"); case STARTUP_BODY: state = State.MSG_HEADER; handleStartupBody(buffer, channel); return; case MSG_BODY: state = State.MSG_HEADER; LOGGER.trace("msg={} msgLength={} readableBytes={}", ((char) msgType), msgLength, buffer.readableBytes()); if (ignoreTillSync && msgType != 'S') { buffer.readBytes(msgLength); return; } dispatchMessage(buffer, channel); return; } throw new IllegalStateException("Illegal state: " + state); } private void dispatchMessage(ChannelBuffer buffer, Channel channel) { switch (msgType) { case 'Q': // Query (simple) handleSimpleQuery(buffer, channel); return; case 'P': handleParseMessage(buffer, channel); return; case 'B': handleBindMessage(buffer, channel); return; case 'D': handleDescribeMessage(buffer, channel); return; case 'E': handleExecute(buffer, channel); return; case 'H': handleFlush(channel); return; case 'S': handleSync(channel); return; case 'C': handleClose(buffer, channel); return; case 'X': // Terminate (called when jdbc connection is closed) closeSession(); channel.close(); return; default: Messages.sendErrorResponse(channel, new UnsupportedOperationException("Unsupported messageType: " + msgType)); } } private void closeSession() { if (session != null) { session.close(); session = null; } } @Override public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception { LOGGER.error("Uncaught exception: ", e.getCause()); } @Override public void channelDisconnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { LOGGER.trace("channelDisconnected"); closeSession(); super.channelDisconnected(ctx, e); } } private void handleStartupBody(ChannelBuffer buffer, Channel channel) { sessionContext = readStartupMessage(buffer); session = sqlOperations.createSession(sessionContext); authenticate(channel); } private void authenticate(Channel channel) { InetAddress address = getRemoteAddress(channel); AuthenticationMethod authMethod = authService.resolveAuthenticationType(sessionContext.userName(), address, HbaProtocol.POSTGRES); if (authMethod == null) { String errorMessage = String.format( Locale.ENGLISH, "No valid auth.host_based entry found for host \"%s\", user \"%s\", schema \"%s\"", address.getHostAddress(), sessionContext.userName(), sessionContext.defaultSchema() ); Messages.sendAuthenticationError(channel, errorMessage); } else { authMethod.pgAuthenticate(channel, sessionContext) .whenComplete((success, throwable) -> { if (throwable == null) { if (success) { if (LOGGER.isTraceEnabled()) { LOGGER.trace("Authentication succeeded user \"{}\" and method \"{}\".", sessionContext.userName(), authMethod.name()); } sendReadyForQuery(channel); } else { LOGGER.warn("Authentication failed for user \"{}\" and method \"{}\".", sessionContext.userName(), authMethod.name()); } } else { Messages.sendAuthenticationError(channel, throwable.getMessage()); } }); } } private static InetAddress getRemoteAddress(Channel channel) { if (channel.getRemoteAddress() instanceof InetSocketAddress) { return ((InetSocketAddress) channel.getRemoteAddress()).getAddress(); } // In certain cases the channel is an EmbeddedChannel (e.g. in tests) // and this type of channel has an EmbeddedSocketAddress instance as remoteAddress // which does not have an address. // An embedded socket address is handled like a local connection via loopback. return InetAddresses.forString("127.0.0.1"); } private void sendReadyForQuery(Channel channel) { Messages.sendParameterStatus(channel, "server_version", "9.5.0"); Messages.sendParameterStatus(channel, "server_encoding", "UTF8"); Messages.sendParameterStatus(channel, "client_encoding", "UTF8"); Messages.sendParameterStatus(channel, "datestyle", "ISO"); Messages.sendReadyForQuery(channel); } private void handleStartupHeader(ChannelBuffer buffer, Channel channel) { buffer.readInt(); // sslCode ChannelBuffer channelBuffer = ChannelBuffers.buffer(1); channelBuffer.writeByte('N'); ChannelFuture channelFuture = channel.write(channelBuffer); if (LOGGER.isTraceEnabled()) { channelFuture.addListener(future -> LOGGER.trace("sent SSL neg: N")); } } /** * Flush Message * | 'H' | int32 len * <p> * Flush forces the backend to deliver any data pending in it's output buffers. */ private void handleFlush(Channel channel) { /* * Currently we don't buffer data. It is always send to the client immediately. * So flush would be a no-op except that we delay execution until sync to be able to execute bulk operations * more efficiently. * If a Client sends flush we also need to trigger execution because a Client is expecting to receive data after * a Flush. * * Note that there is no ReadyForQueryCallback here because handleSync will still be called and it is done there. */ try { session.sync(); } catch (Throwable t) { Messages.sendErrorResponse(channel, t); } } /** * Parse Message * header: * | 'P' | int32 len * <p> * body: * | string statementName | string query | int16 numParamTypes | * foreach param: * | int32 type_oid (zero = unspecified) */ private void handleParseMessage(ChannelBuffer buffer, final Channel channel) { String statementName = readCString(buffer); final String query = readCString(buffer); short numParams = buffer.readShort(); List<DataType> paramTypes = new ArrayList<>(numParams); for (int i = 0; i < numParams; i++) { int oid = buffer.readInt(); DataType dataType = PGTypes.fromOID(oid); if (dataType == null) { throw new IllegalArgumentException( String.format(Locale.ENGLISH, "Can't map PGType with oid=%d to Crate type", oid)); } paramTypes.add(dataType); } session.parse(statementName, query, paramTypes); Messages.sendParseComplete(channel); } /** * Bind Message * Header: * | 'B' | int32 len * <p> * Body: * <pre> * | string portalName | string statementName * | int16 numFormatCodes * foreach * | int16 formatCode * | int16 numParams * foreach * | int32 valueLength * | byteN value * | int16 numResultColumnFormatCodes * foreach * | int16 formatCode * </pre> */ private void handleBindMessage(ChannelBuffer buffer, Channel channel) { String portalName = readCString(buffer); String statementName = readCString(buffer); FormatCodes.FormatCode[] formatCodes = FormatCodes.fromBuffer(buffer); short numParams = buffer.readShort(); List<Object> params = createList(numParams); for (int i = 0; i < numParams; i++) { int valueLength = buffer.readInt(); if (valueLength == -1) { params.add(null); } else { DataType paramType = session.getParamType(statementName, i); PGType pgType = PGTypes.get(paramType); FormatCodes.FormatCode formatCode = getFormatCode(formatCodes, i); switch (formatCode) { case TEXT: params.add(pgType.readTextValue(buffer, valueLength)); break; case BINARY: params.add(pgType.readBinaryValue(buffer, valueLength)); break; default: Messages.sendErrorResponse(channel, new UnsupportedOperationException( String.format(Locale.ENGLISH, "Unsupported format code '%d' for param '%s'", formatCode.ordinal(), paramType.getName()))); return; } } } FormatCodes.FormatCode[] resultFormatCodes = FormatCodes.fromBuffer(buffer); session.bind(portalName, statementName, params, resultFormatCodes); Messages.sendBindComplete(channel); } private <T> List<T> createList(short size) { return size == 0 ? Collections.<T>emptyList() : new ArrayList<T>(size); } /** * Describe Message * Header: * | 'D' | int32 len * <p> * Body: * | 'S' = prepared statement or 'P' = portal * | string nameOfPortalOrStatement */ private void handleDescribeMessage(ChannelBuffer buffer, Channel channel) { byte type = buffer.readByte(); String portalOrStatement = readCString(buffer); Collection<Field> fields = session.describe((char) type, portalOrStatement); if (fields == null) { Messages.sendNoData(channel); } else { Messages.sendRowDescription(channel, fields, session.getResultFormatCodes(portalOrStatement)); } } /** * Execute Message * Header: * | 'E' | int32 len * <p> * Body: * | string portalName * | int32 maxRows (0 = unlimited) */ private void handleExecute(ChannelBuffer buffer, Channel channel) { String portalName = readCString(buffer); int maxRows = buffer.readInt(); String query = session.getQuery(portalName); if (query.isEmpty()) { // remove portal so that it doesn't stick around and no attempt to batch it with follow up statement is made session.close((byte) 'P', portalName); Messages.sendEmptyQueryResponse(channel); return; } List<? extends DataType> outputTypes = session.getOutputTypes(portalName); ResultReceiver resultReceiver; if (outputTypes == null) { // this is a DML query maxRows = 0; resultReceiver = new RowCountReceiver(query, channel); } else { // query with resultSet resultReceiver = new ResultSetReceiver(query, channel, outputTypes, session.getResultFormatCodes(portalName)); } session.execute(portalName, maxRows, resultReceiver); } private void handleSync(final Channel channel) { if (ignoreTillSync) { ignoreTillSync = false; session.clearState(); Messages.sendReadyForQuery(channel); return; } try { ReadyForQueryCallback readyForQueryCallback = new ReadyForQueryCallback(channel); session.sync().whenComplete(readyForQueryCallback); } catch (Throwable t) { Messages.sendErrorResponse(channel, t); Messages.sendReadyForQuery(channel); } } /** * | 'C' | int32 len | byte portalOrStatement | string portalOrStatementName | */ private void handleClose(ChannelBuffer buffer, Channel channel) { byte b = buffer.readByte(); String portalOrStatementName = readCString(buffer); session.close(b, portalOrStatementName); Messages.sendCloseComplete(channel); } @VisibleForTesting void handleSimpleQuery(ChannelBuffer buffer, final Channel channel) { String query = readCString(buffer); assert query != null : "query must not be nulL"; if (query.isEmpty() || ";".equals(query)) { Messages.sendEmptyQueryResponse(channel); Messages.sendReadyForQuery(channel); return; } try { session.parse("", query, Collections.<DataType>emptyList()); session.bind("", "", Collections.emptyList(), null); List<Field> fields = session.describe('P', ""); if (fields == null) { RowCountReceiver rowCountReceiver = new RowCountReceiver(query, channel); session.execute("", 0, rowCountReceiver); } else { Messages.sendRowDescription(channel, fields, null); ResultSetReceiver resultSetReceiver = new ResultSetReceiver(query, channel, Symbols.extractTypes(fields), null); session.execute("", 0, resultSetReceiver); } ReadyForQueryCallback readyForQueryCallback = new ReadyForQueryCallback(channel); session.sync().whenComplete(readyForQueryCallback); } catch (Throwable t) { session.clearState(); Messages.sendErrorResponse(channel, t); Messages.sendReadyForQuery(channel); } } /** * FrameDecoder that makes sure that a full message is in the buffer before delegating work to the MessageHandler */ private class MessageDecoder extends FrameDecoder { @Override protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer) throws Exception { switch (state) { /* * StartupMessage: * | int32 length | int32 protocol | [ string paramKey | string paramValue , ... ] */ case STARTUP_HEADER: if (buffer.readableBytes() < 8) { return null; } buffer.markReaderIndex(); msgLength = buffer.readInt() - 8; // exclude length itself and protocol if (msgLength == 0) { // SSL negotiation pkg LOGGER.trace("Received SSL negotiation pkg"); state = State.SSL_NEG; return buffer; } LOGGER.trace("Header pkgLength: {}", msgLength); int protocol = buffer.readInt(); traceLogProtocol(protocol); return nullOrBuffer(buffer, State.STARTUP_BODY); /* * Regular Data Packet: * | char tag | int32 len | payload */ case MSG_HEADER: if (buffer.readableBytes() < 5) { return null; } buffer.markReaderIndex(); msgType = buffer.readByte(); msgLength = buffer.readInt() - 4; // exclude length itself return nullOrBuffer(buffer, State.MSG_BODY); case MSG_BODY: case STARTUP_BODY: return nullOrBuffer(buffer, state); } throw new IllegalStateException("Invalid state " + state); } /** * return null if there aren't enough bytes to read the whole message. Otherwise returns the buffer. * <p> * If null is returned the decoder will be called again, otherwise the MessageHandler will be called next. */ private ChannelBuffer nullOrBuffer(ChannelBuffer buffer, State nextState) { if (buffer.readableBytes() < msgLength) { buffer.resetReaderIndex(); return null; } state = nextState; return buffer; } } }