/*
* Licensed to Crate under one or more contributor license agreements.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership. Crate licenses this file
* to you under the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*
* However, if you have executed another commercial license agreement
* with Crate these terms will supersede the license and you may use the
* software solely pursuant to the terms of the relevant commercial
* agreement.
*/
package io.crate.protocols.postgres;
import io.crate.analyze.symbol.Field;
import io.crate.data.Row;
import io.crate.exceptions.SQLExceptions;
import io.crate.protocols.postgres.types.PGType;
import io.crate.protocols.postgres.types.PGTypes;
import io.crate.types.DataType;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.logging.Loggers;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import javax.annotation.Nullable;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
/**
* Regular data packet is in the following format:
* <p>
* +----------+-----------+----------+
* | char tag | int32 len | payload |
* +----------+-----------+----------+
* <p>
* The tag indicates the message type, the second field is the length of the packet
* (excluding the tag, but including the length itself)
* <p>
* <p>
* See https://www.postgresql.org/docs/9.2/static/protocol-message-formats.html
*/
public class Messages {
private final static Logger LOGGER = Loggers.getLogger(Messages.class);
public static ChannelFuture sendAuthenticationOK(Channel channel) {
ChannelBuffer buffer = ChannelBuffers.buffer(9);
buffer.writeByte('R');
buffer.writeInt(8); // size excluding char
buffer.writeInt(0);
ChannelFuture channelFuture = channel.write(buffer);
if (LOGGER.isTraceEnabled()) {
channelFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
LOGGER.trace("sentAuthenticationOK");
}
});
}
return channelFuture;
}
/**
* | 'C' | int32 len | str commandTag
*
* @param query :the query
* @param rowCount : number of rows in the result set or number of rows affected by the DML statement
*/
static void sendCommandComplete(Channel channel, String query, long rowCount) {
query = query.split(" ", 2)[0].toUpperCase(Locale.ENGLISH);
String commandTag;
/*
* from https://www.postgresql.org/docs/current/static/protocol-message-formats.html:
*
* For an INSERT command, the tag is INSERT oid rows, where rows is the number of rows inserted.
* oid is the object ID of the inserted row if rows is 1 and the target table has OIDs; otherwise oid is 0.
*/
if ("BEGIN".equals(query)) {
commandTag = "BEGIN";
} else if ("INSERT".equals(query)) {
commandTag = "INSERT 0 " + rowCount;
} else {
commandTag = query + " " + rowCount;
}
byte[] commandTagBytes = commandTag.getBytes(StandardCharsets.UTF_8);
int length = 4 + commandTagBytes.length + 1;
ChannelBuffer buffer = ChannelBuffers.buffer(length + 1);
buffer.writeByte('C');
buffer.writeInt(length);
writeCString(buffer, commandTagBytes);
ChannelFuture channelFuture = channel.write(buffer);
if (LOGGER.isTraceEnabled()) {
channelFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
LOGGER.trace("sentCommandComplete");
}
});
}
}
/**
* ReadyForQuery (B)
* <p>
* Byte1('Z')
* Identifies the message type. ReadyForQuery is sent whenever the
* backend is ready for a new query cycle.
* <p>
* Int32(5)
* Length of message contents in bytes, including self.
* <p>
* Byte1
* Current backend transaction status indicator. Possible values are
* 'I' if idle (not in a transaction block); 'T' if in a transaction
* block; or 'E' if in a failed transaction block (queries will be
* rejected until block is ended).
*/
static void sendReadyForQuery(Channel channel) {
ChannelBuffer buffer = ChannelBuffers.buffer(6);
buffer.writeByte('Z');
buffer.writeInt(5);
buffer.writeByte('I');
ChannelFuture channelFuture = channel.write(buffer);
if (LOGGER.isTraceEnabled()) {
channelFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
LOGGER.trace("sentReadyForQuery");
}
});
}
}
/**
* | 'S' | int32 len | str name | str value
* <p>
* See https://www.postgresql.org/docs/9.2/static/protocol-flow.html#PROTOCOL-ASYNC
* <p>
* > At present there is a hard-wired set of parameters for which ParameterStatus will be generated: they are
* <p>
* - server_version,
* - server_encoding,
* - client_encoding,
* - application_name,
* - is_superuser,
* - session_authorization,
* - DateStyle,
* - IntervalStyle,
* - TimeZone,
* - integer_datetimes,
* - standard_conforming_string
*/
static void sendParameterStatus(Channel channel, final String name, final String value) {
byte[] nameBytes = name.getBytes(StandardCharsets.UTF_8);
byte[] valueBytes = value.getBytes(StandardCharsets.UTF_8);
int length = 4 + nameBytes.length + 1 + valueBytes.length + 1;
ChannelBuffer buffer = ChannelBuffers.buffer(length + 1);
buffer.writeByte('S');
buffer.writeInt(length);
writeCString(buffer, nameBytes);
writeCString(buffer, valueBytes);
ChannelFuture channelFuture = channel.write(buffer);
if (LOGGER.isTraceEnabled()) {
channelFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
LOGGER.trace("sentParameterStatus {}={}", name, value);
}
});
}
}
public static ChannelFuture sendAuthenticationError(Channel channel, String message) {
byte[] msg = message.getBytes(StandardCharsets.UTF_8);
byte[] severity = "FATAL".getBytes(StandardCharsets.UTF_8);
byte[] errorCode = "28000".getBytes(StandardCharsets.UTF_8);
byte[] method = "ClientAuthentication".getBytes(StandardCharsets.UTF_8);
return sendErrorResponse(channel, message, msg, severity, null, null, method, errorCode);
}
static void sendErrorResponse(Channel channel, Throwable throwable) {
final String message = SQLExceptions.messageOf(throwable);
byte[] msg = message.getBytes(StandardCharsets.UTF_8);
byte[] severity = "ERROR".getBytes(StandardCharsets.UTF_8);
byte[] lineNumber = null;
byte[] fileName = null;
byte[] methodName = null;
StackTraceElement[] stackTrace = throwable.getStackTrace();
if (stackTrace != null && stackTrace.length > 0) {
StackTraceElement stackTraceElement = stackTrace[0];
lineNumber = String.valueOf(stackTraceElement.getLineNumber()).getBytes(StandardCharsets.UTF_8);
if (stackTraceElement.getFileName() != null) {
fileName = stackTraceElement.getFileName().getBytes(StandardCharsets.UTF_8);
}
if (stackTraceElement.getMethodName() != null) {
methodName = stackTraceElement.getMethodName().getBytes(StandardCharsets.UTF_8);
}
}
// See https://www.postgresql.org/docs/9.2/static/errcodes-appendix.html
// need to add a throwable -> error code mapping later on
byte[] errorCode;
if (throwable instanceof IllegalArgumentException || throwable instanceof UnsupportedOperationException) {
// feature_not_supported
errorCode = "0A000".getBytes(StandardCharsets.UTF_8);
} else {
// internal_error
errorCode = "XX000".getBytes(StandardCharsets.UTF_8);
}
sendErrorResponse(channel, message, msg, severity, lineNumber, fileName, methodName, errorCode);
}
/**
* 'E' | int32 len | char code | str value | \0 | char code | str value | \0 | ... | \0
* <p>
* char code / str value -> key-value fields
* example error fields are: message, detail, hint, error position
* <p>
* See https://www.postgresql.org/docs/9.2/static/protocol-error-fields.html for a list of error codes
*/
private static ChannelFuture sendErrorResponse(Channel channel, String message, byte[] msg, byte[] severity, byte[] lineNumber, byte[] fileName, byte[] methodName, byte[] errorCode) {
int length = 4 +
1 + (severity.length + 1) +
1 + (msg.length + 1) +
1 + (errorCode.length + 1) +
(fileName != null ? 1 + (fileName.length + 1) : 0) +
(lineNumber != null ? 1 + (lineNumber.length + 1) : 0) +
(methodName != null ? 1 + (methodName.length + 1) : 0) +
1;
ChannelBuffer buffer = ChannelBuffers.buffer(length + 1);
buffer.writeByte('E');
buffer.writeInt(length);
buffer.writeByte('S');
writeCString(buffer, severity);
buffer.writeByte('M');
writeCString(buffer, msg);
buffer.writeByte(('C'));
writeCString(buffer, errorCode);
if (fileName != null) {
buffer.writeByte('F');
writeCString(buffer, fileName);
}
if (lineNumber != null) {
buffer.writeByte('L');
writeCString(buffer, lineNumber);
}
if (methodName != null) {
buffer.writeByte('R');
writeCString(buffer, methodName);
}
buffer.writeByte(0);
ChannelFuture channelFuture = channel.write(buffer);
if (LOGGER.isTraceEnabled()) {
channelFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
LOGGER.trace("sentErrorResponse msg={}", message);
}
});
}
return channelFuture;
}
/**
* Byte1('D')
* Identifies the message as a data row.
* <p>
* Int32
* Length of message contents in bytes, including self.
* <p>
* Int16
* The number of column values that follow (possibly zero).
* <p>
* Next, the following pair of fields appear for each column:
* <p>
* Int32
* The length of the column value, in bytes (this count does not include itself).
* Can be zero. As a special case, -1 indicates a NULL column value. No value bytes follow in the NULL case.
* <p>
* ByteN
* The value of the column, in the format indicated by the associated format code. n is the above length.
*/
static void sendDataRow(Channel channel, Row row, List<? extends DataType> columnTypes, @Nullable FormatCodes.FormatCode[] formatCodes) {
int length = 4 + 2;
assert columnTypes.size() == row.numColumns()
: "Number of columns in the row must match number of columnTypes. Row: " + row + " types: " + columnTypes;
ChannelBuffer buffer = ChannelBuffers.dynamicBuffer();
buffer.writeByte('D');
buffer.writeInt(0); // will be set at the end
buffer.writeShort(row.numColumns());
for (int i = 0; i < row.numColumns(); i++) {
DataType dataType = columnTypes.get(i);
PGType pgType = PGTypes.get(dataType);
Object value = row.get(i);
if (value == null) {
buffer.writeInt(-1);
length += 4;
} else {
FormatCodes.FormatCode formatCode = FormatCodes.getFormatCode(formatCodes, i);
switch (formatCode) {
case TEXT:
length += pgType.writeAsText(buffer, value);
break;
case BINARY:
length += pgType.writeAsBinary(buffer, value);
break;
default:
throw new AssertionError("Unrecognized formatCode: " + formatCode);
}
}
}
buffer.setInt(1, length);
channel.write(buffer);
}
static void writeCString(ChannelBuffer buffer, byte[] valBytes) {
buffer.writeBytes(valBytes);
buffer.writeByte(0);
}
/**
* RowDescription (B)
* <p>
* | 'T' | int32 len | int16 numCols
* <p>
* For each field:
* <p>
* | string name | int32 table_oid | int16 attr_num | int32 oid | int16 typlen | int32 type_modifier | int16 format_code
* <p>
* See https://www.postgresql.org/docs/current/static/protocol-message-formats.html
*/
static void sendRowDescription(Channel channel, Collection<Field> columns, @Nullable FormatCodes.FormatCode[] formatCodes) {
int length = 4 + 2;
int columnSize = 4 + 2 + 4 + 2 + 4 + 2;
ChannelBuffer buffer = ChannelBuffers.dynamicBuffer(
length + (columns.size() * (10 + columnSize))); // use 10 as an estimate for columnName length
buffer.writeByte('T');
buffer.writeInt(0); // will be set at the end
buffer.writeShort(columns.size());
int idx = 0;
for (Field column : columns) {
byte[] nameBytes = column.path().outputName().getBytes(StandardCharsets.UTF_8);
length += nameBytes.length + 1;
length += columnSize;
writeCString(buffer, nameBytes);
buffer.writeInt(0); // table_oid
buffer.writeShort(0); // attr_num
PGType pgType = PGTypes.get(column.valueType());
buffer.writeInt(pgType.oid());
buffer.writeShort(pgType.typeLen());
buffer.writeInt(pgType.typeMod());
buffer.writeShort(FormatCodes.getFormatCode(formatCodes, idx).ordinal());
idx++;
}
buffer.setInt(1, length);
ChannelFuture channelFuture = channel.write(buffer);
if (LOGGER.isTraceEnabled()) {
channelFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
LOGGER.trace("sentRowDescription");
}
});
}
}
/**
* ParseComplete
* | '1' | int32 len |
*/
static void sendParseComplete(Channel channel) {
sendShortMsg(channel, '1', "sentParseComplete");
}
/**
* BindComplete
* | '2' | int32 len |
*/
static void sendBindComplete(Channel channel) {
sendShortMsg(channel, '2', "sentBindComplete");
}
/**
* EmptyQueryResponse
* | 'I' | int32 len |
*/
static void sendEmptyQueryResponse(Channel channel) {
sendShortMsg(channel, 'I', "sentEmptyQueryResponse");
}
/**
* NoData
* | 'n' | int32 len |
*/
static void sendNoData(Channel channel) {
sendShortMsg(channel, 'n', "sentNoData");
}
/**
* Send a message that just contains the msgType and the msg length
*/
private static void sendShortMsg(Channel channel, char msgType, final String traceLogMsg) {
ChannelBuffer buffer = ChannelBuffers.buffer(5);
buffer.writeByte(msgType);
buffer.writeInt(4);
ChannelFuture channelFuture = channel.write(buffer);
if (LOGGER.isTraceEnabled()) {
channelFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
LOGGER.trace(traceLogMsg);
}
});
}
}
static void sendPortalSuspended(Channel channel) {
sendShortMsg(channel, 's', "sentPortalSuspended");
}
/**
* CloseComplete
* | '3' | int32 len |
*/
static void sendCloseComplete(Channel channel) {
sendShortMsg(channel, '3', "sentCloseComplete");
}
}