/* * Copyright 2004-2011 H2 Group. Multiple-Licensed under the H2 License, * Version 1.0, and under the Eclipse Public License, Version 1.0 * (http://h2database.com/html/license.html). * Initial Developer: H2 Group */ package org.h2.dev.net; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.ServerSocket; import java.net.Socket; /** * This class helps debug the PostgreSQL network protocol. * It listens on one port, and sends the exact same data to another port. */ public class PgTcpRedirect { private static final boolean DEBUG = false; /** * This method is called when executing this application from the command * line. * * @param args the command line parameters */ public static void main(String... args) throws Exception { new PgTcpRedirect().loop(args); } private void loop(String... args) throws Exception { // MySQL protocol: // http://www.redferni.uklinux.net/mysql/MySQL-Protocol.html // PostgreSQL protocol: // http://developer.postgresql.org/pgdocs/postgres/protocol.html // int portServer = 9083, portClient = 9084; // int portServer = 3306, portClient = 3307; // H2 PgServer // int portServer = 5435, portClient = 5433; // PostgreSQL int portServer = 5432, portClient = 5433; for (int i = 0; i < args.length; i++) { if ("-client".equals(args[i])) { portClient = Integer.parseInt(args[++i]); } else if ("-server".equals(args[i])) { portServer = Integer.parseInt(args[++i]); } } ServerSocket listener = new ServerSocket(portClient); while (true) { Socket client = listener.accept(); Socket server = new Socket("localhost", portServer); TcpRedirectThread c = new TcpRedirectThread(client, server, true); TcpRedirectThread s = new TcpRedirectThread(server, client, false); new Thread(c).start(); new Thread(s).start(); } } /** * This is the working thread of the TCP redirector. */ private class TcpRedirectThread implements Runnable { private static final int STATE_INIT_CLIENT = 0, STATE_REGULAR = 1; private Socket read, write; private int state; private boolean client; TcpRedirectThread(Socket read, Socket write, boolean client) { this.read = read; this.write = write; this.client = client; state = client ? STATE_INIT_CLIENT : STATE_REGULAR; } String readStringNull(InputStream in) throws IOException { StringBuilder buff = new StringBuilder(); while (true) { int x = in.read(); if (x <= 0) { break; } buff.append((char) x); } return buff.toString(); } private void println(String s) { if (DEBUG) { System.out.println(s); } } private boolean processClient(InputStream inStream, OutputStream outStream) throws IOException { DataInputStream dataIn = new DataInputStream(inStream); ByteArrayOutputStream buff = new ByteArrayOutputStream(); DataOutputStream dataOut = new DataOutputStream(buff); if (state == STATE_INIT_CLIENT) { state = STATE_REGULAR; int len = dataIn.readInt(); dataOut.writeInt(len); len -= 4; byte[] data = new byte[len]; dataIn.readFully(data, 0, len); dataOut.write(data); dataIn = new DataInputStream(new ByteArrayInputStream(data, 0, len)); int version = dataIn.readInt(); if (version == 80877102) { println("CancelRequest"); println(" pid: " + dataIn.readInt()); println(" key: " + dataIn.readInt()); } else if (version == 80877103) { println("SSLRequest"); } else { println("StartupMessage"); println(" version " + version + " (" + (version >> 16) + "." + (version & 0xff) + ")"); while (true) { String param = readStringNull(dataIn); if (param.length() == 0) { break; } String value = readStringNull(dataIn); println(" param " + param + "=" + value); } } } else { int x = dataIn.read(); if (x < 0) { println("end"); return false; } // System.out.println(" x=" + (char)x+" " +x); dataOut.write(x); int len = dataIn.readInt(); dataOut.writeInt(len); len -= 4; byte[] data = new byte[len]; dataIn.readFully(data, 0, len); dataOut.write(data); dataIn = new DataInputStream(new ByteArrayInputStream(data, 0, len)); switch (x) { case 'B': { println("Bind"); println(" destPortal: " + readStringNull(dataIn)); println(" prepName: " + readStringNull(dataIn)); int formatCodesCount = dataIn.readShort(); for (int i = 0; i < formatCodesCount; i++) { println(" formatCode[" + i + "]=" + dataIn.readShort()); } int paramCount = dataIn.readShort(); for (int i = 0; i < paramCount; i++) { int paramLen = dataIn.readInt(); println(" length[" + i + "]=" + paramLen); byte[] d2 = new byte[paramLen]; dataIn.readFully(d2); } int resultCodeCount = dataIn.readShort(); for (int i = 0; i < resultCodeCount; i++) { println(" resultCodeCount[" + i + "]=" + dataIn.readShort()); } break; } case 'C': { println("Close"); println(" type: (S:prepared statement, P:portal): " + dataIn.read()); break; } case 'd': { println("CopyData"); break; } case 'c': { println("CopyDone"); break; } case 'f': { println("CopyFail"); println(" message: " + readStringNull(dataIn)); break; } case 'D': { println("Describe"); println(" type (S=prepared statement, P=portal): " + (char) dataIn.readByte()); println(" name: " + readStringNull(dataIn)); break; } case 'E': { println("Execute"); println(" name: " + readStringNull(dataIn)); println(" maxRows: " + dataIn.readShort()); break; } case 'H': { println("Flush"); break; } case 'F': { println("FunctionCall"); println(" objectId:" + dataIn.readInt()); int columns = dataIn.readShort(); for (int i = 0; i < columns; i++) { println(" formatCode[" + i + "]: " + dataIn.readShort()); } int count = dataIn.readShort(); for (int i = 0; i < count; i++) { int l = dataIn.readInt(); println(" len[" + i + "]: " + l); if (l >= 0) { for (int j = 0; j < l; j++) { dataIn.readByte(); } } } println(" resultFormat: " + dataIn.readShort()); break; } case 'P': { println("Parse"); println(" name:" + readStringNull(dataIn)); println(" query:" + readStringNull(dataIn)); int count = dataIn.readShort(); for (int i = 0; i < count; i++) { println(" [" + i + "]: " + dataIn.readInt()); } break; } case 'p': { println("PasswordMessage"); println(" password: " + readStringNull(dataIn)); break; } case 'Q': { println("Query"); println(" sql : " + readStringNull(dataIn)); break; } case 'S': { println("Sync"); break; } case 'X': { println("Terminate"); break; } default: println("############## UNSUPPORTED: " + (char) x); } } dataOut.flush(); byte[] buffer = buff.toByteArray(); printData(buffer, buffer.length); try { outStream.write(buffer, 0, buffer.length); outStream.flush(); } catch (IOException e) { e.printStackTrace(); } return true; } private boolean processServer(InputStream inStream, OutputStream outStream) throws IOException { DataInputStream dataIn = new DataInputStream(inStream); ByteArrayOutputStream buff = new ByteArrayOutputStream(); DataOutputStream dataOut = new DataOutputStream(buff); int x = dataIn.read(); if (x < 0) { println("end"); return false; } // System.out.println(" x=" + (char)x+" " +x); dataOut.write(x); int len = dataIn.readInt(); dataOut.writeInt(len); len -= 4; byte[] data = new byte[len]; dataIn.readFully(data, 0, len); dataOut.write(data); dataIn = new DataInputStream(new ByteArrayInputStream(data, 0, len)); switch (x) { case 'R': { println("Authentication"); int value = dataIn.readInt(); if (value == 0) { println(" Ok"); } else if (value == 2) { println(" KerberosV5"); } else if (value == 3) { println(" CleartextPassword"); } else if (value == 4) { println(" CryptPassword"); byte b1 = dataIn.readByte(); byte b2 = dataIn.readByte(); println(" salt1=" + b1 + " salt2=" + b2); } else if (value == 5) { println(" MD5Password"); byte b1 = dataIn.readByte(); byte b2 = dataIn.readByte(); byte b3 = dataIn.readByte(); byte b4 = dataIn.readByte(); println(" salt1=" + b1 + " salt2=" + b2 + " 3=" + b3 + " 4=" + b4); } else if (value == 6) { println(" SCMCredential"); } break; } case 'K': { println("BackendKeyData"); println(" process ID " + dataIn.readInt()); println(" key " + dataIn.readInt()); break; } case '2': { println("BindComplete"); break; } case '3': { println("CloseComplete"); break; } case 'C': { println("CommandComplete"); println(" command tag: " + readStringNull(dataIn)); break; } case 'd': { println("CopyData"); break; } case 'c': { println("CopyDone"); break; } case 'G': { println("CopyInResponse"); println(" format: " + dataIn.readByte()); int columns = dataIn.readShort(); for (int i = 0; i < columns; i++) { println(" formatCode[" + i + "]: " + dataIn.readShort()); } break; } case 'H': { println("CopyOutResponse"); println(" format: " + dataIn.readByte()); int columns = dataIn.readShort(); for (int i = 0; i < columns; i++) { println(" formatCode[" + i + "]: " + dataIn.readShort()); } break; } case 'D': { println("DataRow"); int columns = dataIn.readShort(); println(" columns : " + columns); for (int i = 0; i < columns; i++) { int l = dataIn.readInt(); if (l > 0) { for (int j = 0; j < l; j++) { dataIn.readByte(); } } // println(" ["+i+"] len: " + l); } break; } case 'I': { println("EmptyQueryResponse"); break; } case 'E': { println("ErrorResponse"); while (true) { int fieldType = dataIn.readByte(); if (fieldType == 0) { break; } String msg = readStringNull(dataIn); // http://developer.postgresql.org/pgdocs/postgres/protocol-error-fields.html // S Severity // C Code: the SQLSTATE code // M Message // D Detail // H Hint // P Position // p Internal position // q Internal query // W Where // F File // L Line // R Routine println(" fieldType: " + fieldType + " msg: " + msg); } break; } case 'V': { println("FunctionCallResponse"); int resultLen = dataIn.readInt(); println(" len: " + resultLen); break; } case 'n': { println("NoData"); break; } case 'N': { println("NoticeResponse"); while (true) { int fieldType = dataIn.readByte(); if (fieldType == 0) { break; } String msg = readStringNull(dataIn); // http://developer.postgresql.org/pgdocs/postgres/protocol-error-fields.html // S Severity // C Code: the SQLSTATE code // M Message // D Detail // H Hint // P Position // p Internal position // q Internal query // W Where // F File // L Line // R Routine println(" fieldType: " + fieldType + " msg: " + msg); } break; } case 'A': { println("NotificationResponse"); println(" processID: " + dataIn.readInt()); println(" condition: " + readStringNull(dataIn)); println(" information: " + readStringNull(dataIn)); break; } case 't': { println("ParameterDescription"); println(" processID: " + dataIn.readInt()); int count = dataIn.readShort(); for (int i = 0; i < count; i++) { println(" [" + i + "] objectId: " + dataIn.readInt()); } break; } case 'S': { println("ParameterStatus"); println(" parameter " + readStringNull(dataIn) + " = " + readStringNull(dataIn)); break; } case '1': { println("ParseComplete"); break; } case 's': { println("ParseComplete"); break; } case 'Z': { println("ReadyForQuery"); println(" status (I:idle, T:transaction, E:failed): " + (char) dataIn.readByte()); break; } case 'T': { println("RowDescription"); int columns = dataIn.readShort(); println(" columns : " + columns); for (int i = 0; i < columns; i++) { println(" [" + i + "]"); println(" name:" + readStringNull(dataIn)); println(" tableId:" + dataIn.readInt()); println(" columnId:" + dataIn.readShort()); println(" dataTypeId:" + dataIn.readInt()); println(" dataTypeSize (pg_type.typlen):" + dataIn.readShort()); println(" modifier (pg_attribute.atttypmod):" + dataIn.readInt()); println(" format code:" + dataIn.readShort()); } break; } default: println("############## UNSUPPORTED: " + (char) x); } dataOut.flush(); byte[] buffer = buff.toByteArray(); printData(buffer, buffer.length); try { outStream.write(buffer, 0, buffer.length); outStream.flush(); } catch (IOException e) { e.printStackTrace(); } return true; } public void run() { try { OutputStream out = write.getOutputStream(); InputStream in = read.getInputStream(); while (true) { boolean more; if (client) { more = processClient(in, out); } else { more = processServer(in, out); } if (!more) { break; } } try { read.close(); } catch (IOException e) { // ignore } try { write.close(); } catch (IOException e) { // ignore } } catch (Throwable e) { e.printStackTrace(); } } } /** * Print the uninterpreted byte array. * * @param buffer the byte array * @param len the length */ static synchronized void printData(byte[] buffer, int len) { if (DEBUG) { System.out.print(" "); for (int i = 0; i < len; i++) { int c = buffer[i] & 255; if (c >= ' ' && c <= 127 && c != '[' & c != ']') { System.out.print((char) c); } else { System.out.print("[" + Integer.toHexString(c) + "]"); } } System.out.println(); } } }