/**
* Copyright (C) 2009-2013 FoundationDB, LLC
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package com.foundationdb.sql.pg;
import com.foundationdb.server.error.InvalidParameterValueException;
import com.foundationdb.util.tap.InOutTap;
import com.foundationdb.util.tap.Tap;
import java.net.*;
import java.io.*;
import java.nio.charset.Charset;
/**
* Basic implementation of Postgres wire protocol for SQL integration.
*
* See http://developer.postgresql.org/pgdocs/postgres/protocol.html
*/
public class PostgresMessenger implements DataInput, DataOutput
{
/*** Message Formats ***/
public static final int VERSION_CANCEL = 80877102; // 12345678
public static final int VERSION_SSL = 80877103; // 12345679
public static final int AUTHENTICATION_OK = 0;
public static final int AUTHENTICATION_KERBEROS_V5 = 2;
public static final int AUTHENTICATION_CLEAR_TEXT = 3;
public static final int AUTHENTICATION_MD5 = 5;
public static final int AUTHENTICATION_SCM = 6;
public static final int AUTHENTICATION_GSS = 7;
public static final int AUTHENTICATION_SSPI = 9;
public static final int AUTHENTICATION_GSS_CONTINUE = 8;
private final static InOutTap waitTap = Tap.createTimer("sql: msg: wait");
private final static InOutTap recvTap = Tap.createTimer("sql: msg: recv");
private final static InOutTap xmitTap = Tap.createTimer("sql: msg: xmit");
private static final int IDLE_INTERVAL = 100;
private final Socket socket;
private final InputStream inputStream;
private final OutputStream outputStream;
private final DataInputStream dataInput;
private byte[] rawMessageInput;
private DataInputStream messageInput;
private ByteArrayOutputStream byteOutput;
private DataOutputStream messageOutput;
private String encoding = "UTF-8";
public PostgresMessenger(Socket socket) throws SocketException, IOException {
this.socket = socket;
// We flush() when we mean it.
// So, turn off kernel delay, but wrap a buffer so every
// message isn't its own packet.
socket.setTcpNoDelay(true);
inputStream = socket.getInputStream();
dataInput = new DataInputStream(inputStream);
outputStream = new BufferedOutputStream(socket.getOutputStream());
}
InputStream getInputStream() {
return inputStream;
}
OutputStream getOutputStream() {
return outputStream;
}
/** The encoding used for strings. */
public String getEncoding() {
return encoding;
}
public void setEncoding(String encoding) {
String newEncoding = encoding;
if ((newEncoding == null) || newEncoding.equalsIgnoreCase("UNICODE"))
newEncoding = "UTF-8";
else if (newEncoding.startsWith("WIN") && newEncoding.matches("WIN\\d+"))
newEncoding = "Cp" + newEncoding.substring(3);
else if (newEncoding.startsWith("'") && newEncoding.endsWith("'"))
newEncoding = newEncoding.substring(1, newEncoding.length()-1);
try {
Charset.forName(newEncoding);
}
catch (IllegalArgumentException ex) {
throw new InvalidParameterValueException("unknown client_encoding '" +
encoding + "'");
}
this.encoding = newEncoding;
}
/** Read the next message from the stream, without any type opcode. */
protected PostgresMessages readMessage() throws IOException {
return readMessage(true);
}
/** Read the next message from the stream, starting with the message type opcode. */
protected PostgresMessages readMessage(boolean hasType) throws IOException {
PostgresMessages type;
int code = -1;
if (hasType) {
try {
beforeIdle();
while (true) {
try {
code = dataInput.read();
}
catch (SocketTimeoutException ex) {
idle();
continue;
}
if (!PostgresMessages.readTypeCorrect(code)) {
throw new IOException ("Bad protocol read message: " + (char)code);
}
type = PostgresMessages.messageType(code);
break;
}
}
finally {
afterIdle();
}
}
else {
type = PostgresMessages.STARTUP_MESSAGE_TYPE;
code = 0;
}
if (code < 0)
return PostgresMessages.EOF_TYPE; // EOF
recvTap.in();
try {
int count = 0;
if (code > 0) count++;
int len = dataInput.readInt();
if ((len < 0) || (len > type.maxSize()))
throw new IOException(String.format("Implausible message length (%d) received.", len));
count += len;
len -= 4;
try {
rawMessageInput = new byte[len];
dataInput.readFully(rawMessageInput, 0, len);
messageInput = new DataInputStream(new ByteArrayInputStream(rawMessageInput));
} catch (OutOfMemoryError ex) {
throw new IOException (String.format("Unable to allocate read buffer of length (%d)", len));
}
bytesRead(count);
return type;
}
finally {
recvTap.out();
}
}
/** Begin outgoing message of given type. */
protected void beginMessage(int type) throws IOException {
byteOutput = new ByteArrayOutputStream();
messageOutput = new DataOutputStream(byteOutput);
messageOutput.write(type);
messageOutput.writeInt(0);
}
/** Send outgoing message. */
protected void sendMessage() throws IOException {
messageOutput.flush();
byte[] msg = byteOutput.toByteArray();
// check we're writing an allowed message.
assert PostgresMessages.writeTypeCorrect((int)msg[0]) : "Invalid write message: " + (char)msg[0];
int len = msg.length - 1;
msg[1] = (byte)(len >> 24);
msg[2] = (byte)(len >> 16);
msg[3] = (byte)(len >> 8);
msg[4] = (byte)len;
outputStream.write(msg);
bytesWritten(len + 1);
}
/** Send outgoing message and optionally flush stream. */
protected void sendMessage(boolean flush) throws IOException {
sendMessage();
if (flush)
flush();
}
protected void flush() throws IOException {
try {
xmitTap.in();
outputStream.flush();
}
finally {
xmitTap.out();
}
}
/** Save whatever portion of the current message there is so that
* something asynchronous can be sent. */
protected Object suspendMessage() throws IOException {
messageOutput.flush();
return byteOutput;
}
/** Restore the state from {@link #suspendMessage}. */
protected void resumeMessage(Object state) throws IOException {
byteOutput = (ByteArrayOutputStream)state;
messageOutput = new DataOutputStream(byteOutput);
}
/** Read null-terminated string. */
public String readString() throws IOException {
ByteArrayOutputStream bs = new ByteArrayOutputStream();
while (true) {
int b = messageInput.read();
if (b < 0) throw new IOException("EOF in the middle of a string");
if (b == 0) break;
bs.write(b);
}
return bs.toString(encoding);
}
/** Return entire message body. */
public byte[] getRawMessage() {
return rawMessageInput;
}
/** Get the raw stream for current message. */
public OutputStream getRawOutput() {
return messageOutput;
}
/** Write null-terminated string. */
public void writeString(String s) throws IOException {
byte[] ba = s.getBytes(encoding);
messageOutput.write(ba);
messageOutput.write(0);
}
/** Write the raw contents of the given byte stream's buffer. */
public void writeByteStream(ByteArrayOutputStream s) throws IOException {
s.writeTo(messageOutput);
}
/*** DataInput ***/
public boolean readBoolean() throws IOException {
return messageInput.readBoolean();
}
public byte readByte() throws IOException {
return messageInput.readByte();
}
public char readChar() throws IOException {
return messageInput.readChar();
}
public double readDouble() throws IOException {
return messageInput.readDouble();
}
public float readFloat() throws IOException {
return messageInput.readFloat();
}
public void readFully(byte[] b) throws IOException {
messageInput.readFully(b);
}
public void readFully(byte[] b, int off, int len) throws IOException {
messageInput.readFully(b, off, len);
}
public int readInt() throws IOException {
return messageInput.readInt();
}
@SuppressWarnings("deprecation")
public String readLine() throws IOException {
return messageInput.readLine();
}
public long readLong() throws IOException {
return messageInput.readLong();
}
public short readShort() throws IOException {
return messageInput.readShort();
}
public String readUTF() throws IOException {
return messageInput.readUTF();
}
public int readUnsignedByte() throws IOException {
return messageInput.readUnsignedByte();
}
public int readUnsignedShort() throws IOException {
return messageInput.readUnsignedShort();
}
public int skipBytes(int n) throws IOException {
return messageInput.skipBytes(n);
}
/*** DataOutput ***/
public void write(byte[] data) throws IOException {
messageOutput.write(data);
}
public void write(byte[] data, int ofs, int len) throws IOException {
messageOutput.write(data, ofs, len);
}
public void write(int v) throws IOException {
messageOutput.write(v);
}
public void writeBoolean(boolean v) throws IOException {
messageOutput.writeBoolean(v);
}
public void writeByte(int v) throws IOException {
messageOutput.writeByte(v);
}
public void writeBytes(String s) throws IOException {
messageOutput.writeBytes(s);
}
public void writeChar(int v) throws IOException {
messageOutput.writeChar(v);
}
public void writeChars(String s) throws IOException {
messageOutput.writeChars(s);
}
public void writeDouble(double v) throws IOException {
messageOutput.writeDouble(v);
}
public void writeFloat(float v) throws IOException {
messageOutput.writeFloat(v);
}
public void writeInt(int v) throws IOException {
messageOutput.writeInt(v);
}
public void writeLong(long v) throws IOException {
messageOutput.writeLong(v);
}
public void writeShort(int v) throws IOException {
messageOutput.writeShort(v);
}
public void writeUTF(String s) throws IOException {
messageOutput.writeUTF(s);
}
public void beforeIdle() throws IOException {
waitTap.in();
socket.setSoTimeout(IDLE_INTERVAL);
}
public void afterIdle() throws IOException {
socket.setSoTimeout(0);
waitTap.out();
}
/** Called every <code>IDLE_INTERVAL</code> ms. while waiting for a message.
* Overridden to allow insertion of asynch notifications.
*/
public void idle() {
}
public void bytesRead(int count) {
}
public void bytesWritten(int count) {
}
}