package io.mycat.backend.postgresql;
import io.mycat.backend.postgresql.packet.AuthenticationPacket;
import io.mycat.backend.postgresql.packet.AuthenticationPacket.AuthType;
import io.mycat.backend.postgresql.packet.BackendKeyData;
import io.mycat.backend.postgresql.packet.DataRow;
import io.mycat.backend.postgresql.packet.DataRow.DataColumn;
import io.mycat.backend.postgresql.packet.Parse;
import io.mycat.backend.postgresql.packet.PasswordMessage;
import io.mycat.backend.postgresql.packet.PostgreSQLPacket;
import io.mycat.backend.postgresql.packet.PostgreSQLPacket.DateType;
import io.mycat.backend.postgresql.packet.Query;
import io.mycat.backend.postgresql.packet.Terminate;
import io.mycat.backend.postgresql.utils.PIOUtils;
import io.mycat.backend.postgresql.utils.PacketUtils;
import java.io.IOException;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.List;
import java.util.TimeZone;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.alibaba.fastjson.JSON;
/*************
* 提交代码..
*
* @author Coollf
*
*/
public class PostgresqlKnightriders {
private static Logger logger = LoggerFactory
.getLogger(PostgresqlKnightriders.class);
public static void main(String[] args) {
List<String[]> paramList = new ArrayList<String[]>();
String user = "postgres";
String password = "coollf";
String database = "mycat";
String appName = "MyCat-Server";
paramList.add(new String[] { "user", user });
paramList.add(new String[] { "database", database });
paramList.add(new String[] { "client_encoding", "UTF8" });
paramList.add(new String[] { "DateStyle", "ISO" });
paramList.add(new String[] { "TimeZone", createPostgresTimeZone() });
paramList.add(new String[] { "extra_float_digits", "3" });
paramList.add(new String[] { "application_name", appName });
boolean nio = false;
try {
Socket socket = new Socket("localhost", 5432);
if (nio) {
SocketChannel channel = SocketChannel
.open(new InetSocketAddress("localhost", 5210));
channel.configureBlocking(false);
// 打开并注册选择器到信道
Selector selector = Selector.open();
channel.register(selector, SelectionKey.OP_READ
| SelectionKey.OP_WRITE);
// 启动读取线程
new TCPClientReadThread(selector);
// sendStartupPacket(channel, paramList.toArray(new
// String[0][]));
ByteBuffer in = ByteBuffer.allocate(10);
channel.read(in);
// System.out.println(in);
} else {
sendStartupPacket(socket, paramList.toArray(new String[0][]));
PostgreSQLPacket packet = readParsePacket(socket).get(0);
if (packet instanceof AuthenticationPacket) {
AuthType aut = ((AuthenticationPacket) packet)
.getAuthType();
if (aut != AuthType.Ok) {
PasswordMessage pak = new PasswordMessage(user,
password, aut,
((AuthenticationPacket) packet).getSalt());
ByteBuffer buffer = ByteBuffer
.allocate(pak.getLength() + 1);
pak.write(buffer);
socket.getOutputStream().write(buffer.array());
List<PostgreSQLPacket> sqlPacket = readParsePacket(socket);
System.out.println(JSON.toJSONString(sqlPacket));
int pid = 0;
int secretKey = 0;
for (PostgreSQLPacket p : sqlPacket) {
if (p instanceof BackendKeyData) {
pid = ((BackendKeyData) p).getPid();
secretKey = ((BackendKeyData) p).getSecretKey();
}
}
Query query = new Query(
"SELECT text_,timestamp_ from ump_types");
// Query query = new Query("SELECT 1"+"\0");
ByteBuffer oby = ByteBuffer
.allocate(query.getLength() + 1);
query.write(oby);
socket.getOutputStream().write(oby.array());
sqlPacket = readParsePacket(socket);
for (PostgreSQLPacket p : sqlPacket) {
if (p instanceof DataRow) {
for (DataColumn c : ((DataRow) p).getColumns()) {
System.out.println(new String(c.getData(),
"utf-8"));
}
}
}
System.out.println(JSON.toJSONString(sqlPacket));
query = new Query("");
oby = ByteBuffer.allocate(query.getLength() + 1);
query.write(oby);
socket.getOutputStream().write(oby.array());
sqlPacket = readParsePacket(socket);
System.out.println(JSON.toJSONString(sqlPacket));
// CancelRequest cancelRequest = new CancelRequest(pid,
// secretKey);
// oby = ByteBuffer.allocate(cancelRequest.getLength());
// cancelRequest.write(oby);
// socket.getOutputStream().write(oby.array());
// List<PostgreSQLPacket> pgs = readParsePacket(socket);
// System.out.println(JSON.toJSONString(pgs));
// 解析sql
String uuid = UUID.randomUUID().toString();
String sql = "INSERT into ump_coupon(id_,name_,time) VALUES (4 , ? , now());";
Parse parse = new Parse(null, sql,DateType.UNKNOWN);
oby = ByteBuffer.allocate(parse.getPacketSize());
parse.write(oby);
socket.getOutputStream().write(oby.array());
socket.getOutputStream().write(new byte[]{0});
List<PostgreSQLPacket> tre = readParsePacket(socket);
System.out.println(JSON.toJSONString(tre));
// Terminate terminate = new Terminate();
// oby = ByteBuffer.allocate(terminate.getLength() + 1);
// terminate.write(oby);
// socket.getOutputStream().write(oby.array());
tre = readParsePacket(socket);
System.out.println(tre);
}
}
}
System.in.read();
System.in.read();
} catch (Exception e) {
e.printStackTrace();
}
}
private static List<PostgreSQLPacket> readParsePacket(Socket socket)
throws IOException, IllegalAccessException {
byte[] bytes = new byte[1024 * 10];
int leg = socket.getInputStream().read(bytes, 0, bytes.length);
int offset = 0;
return PacketUtils.parsePacket(bytes, offset, leg);
}
/**
* Convert Java time zone to postgres time zone. All others stay the same
* except that GMT+nn changes to GMT-nn and vise versa.
*
* @return The current JVM time zone in postgresql format.
*/
private static String createPostgresTimeZone() {
String tz = TimeZone.getDefault().getID();
if (tz.length() <= 3 || !tz.startsWith("GMT")) {
return tz;
}
char sign = tz.charAt(3);
String start;
if (sign == '+') {
start = "GMT-";
} else if (sign == '-') {
start = "GMT+";
} else {
// unknown type
return tz;
}
return start + tz.substring(4);
}
private static void sendStartupPacket(Socket socket, String[][] params)
throws IOException {
OutputStream sout = socket.getOutputStream();
if (logger.isDebugEnabled()) {
StringBuilder details = new StringBuilder();
for (int i = 0; i < params.length; ++i) {
if (i != 0)
details.append(", ");
details.append(params[i][0]);
details.append("=");
details.append(params[i][1]);
}
logger.debug(" FE=> StartupPacket(" + details + ")");
}
/*
* Precalculate message length and encode params.
*/
int length = 4 + 4;
byte[][] encodedParams = new byte[params.length * 2][];
for (int i = 0; i < params.length; ++i) {
encodedParams[i * 2] = params[i][0].getBytes("UTF-8");
encodedParams[i * 2 + 1] = params[i][1].getBytes("UTF-8");
length += encodedParams[i * 2].length + 1
+ encodedParams[i * 2 + 1].length + 1;
}
length += 1; // Terminating \0
ByteBuffer buffer = ByteBuffer.allocate(length);
/*
* Send the startup message.
*/
PIOUtils.SendInteger4(length, buffer);
PIOUtils.SendInteger2(3, buffer); // protocol major
PIOUtils.SendInteger2(0, buffer); // protocol minor
for (byte[] encodedParam : encodedParams) {
PIOUtils.Send(encodedParam, buffer);
PIOUtils.SendChar(0, buffer);
}
sout.write(buffer.array());
}
}
class TCPClientReadThread implements Runnable {
private static Logger logger = LoggerFactory
.getLogger(TCPClientReadThread.class);
private Selector selector;
private ByteBuffer bs ;
public TCPClientReadThread(Selector selector) {
this.selector = selector;
new Thread(this).start();
}
public void run() {
boolean a = false;
try {
while (selector.select() > 0) {
System.out.println(".....");
// 遍历每个有可用IO操作Channel对应的SelectionKey
for (SelectionKey sk : selector.selectedKeys()) {
if (sk.isWritable()) {
SocketChannel sc = (SocketChannel) sk.channel();
if (!a) {
sendStartupPacket(sc);
a = true;
}
if(this.bs!= null){
sc.write(bs);
}
// 删除正在处理的SelectionKey
// selector.selectedKeys().remove(sk);
sk.interestOps(SelectionKey.OP_READ);
}
if (sk.isReadable()) {
// 使用NIO读取Channel中的数据
SocketChannel sc = (SocketChannel) sk.channel();
ByteBuffer buffer = ByteBuffer.allocate(1024);
sc.read(buffer);
buffer.flip();
byte[] array = buffer.array();
List<PostgreSQLPacket> ls = PacketUtils.parsePacket(
array, 0, buffer.limit());
if (ls.size() > 0) {
if (ls.get(0) instanceof AuthenticationPacket) {
AuthenticationPacket aut = (AuthenticationPacket) ls
.get(0);
if (aut.getAuthType() != AuthType.Ok) {
PasswordMessage pak = new PasswordMessage(
"postgres", "coollf",
aut.getAuthType(), aut.getSalt());
ByteBuffer _buffer = ByteBuffer
.allocate(pak.getLength() + 2);
pak.write(_buffer);
//_buffer.put((byte)0);
_buffer.flip();
this.bs = _buffer;
// sk.interestOps(SelectionKey.OP_READ);
}else{
logger.error("登陆成功啦啦啦....");
}
}
}
sk.interestOps(SelectionKey.OP_WRITE);
// 控制台打印出来
System.out.println("接收到来自服务器" + JSON.toJSONString(ls));
// 为下一次读取作准备
// sk.interestOps(SelectionKey.OP_WRITE);
}
}
}
System.out.println("熄火了.....");
} catch (IOException ex) {
ex.printStackTrace();
}
}
private static void sendStartupPacket(SocketChannel socketChannel)
throws IOException {
List<String[]> paramList = new ArrayList<String[]>();
String user = "postgres";
String password = "coollf";
String database = "odoo";
String appName = "MyCat-Server";
String assumeMinServerVersion = "9.0.0";
paramList.add(new String[] { "user", user });
paramList.add(new String[] { "database", database });
paramList.add(new String[] { "client_encoding", "UTF8" });
paramList.add(new String[] { "DateStyle", "ISO" });
paramList.add(new String[] { "TimeZone", createPostgresTimeZone() });
paramList.add(new String[] { "extra_float_digits", "3" });
paramList.add(new String[] { "application_name", appName });
String[][] params = paramList.toArray(new String[0][]);
if (logger.isDebugEnabled()) {
StringBuilder details = new StringBuilder();
for (int i = 0; i < params.length; ++i) {
if (i != 0)
details.append(", ");
details.append(params[i][0]);
details.append("=");
details.append(params[i][1]);
}
logger.debug(" FE=> StartupPacket(" + details + ")");
}
/*
* Precalculate message length and encode params.
*/
int length = 4 + 4;
byte[][] encodedParams = new byte[params.length * 2][];
for (int i = 0; i < params.length; ++i) {
encodedParams[i * 2] = params[i][0].getBytes("UTF-8");
encodedParams[i * 2 + 1] = params[i][1].getBytes("UTF-8");
length += encodedParams[i * 2].length + 1
+ encodedParams[i * 2 + 1].length + 1;
}
length += 1; // Terminating \0
ByteBuffer buffer = ByteBuffer.allocate(length);
/*
* Send the startup message.
*/
PIOUtils.SendInteger4(length, buffer);
PIOUtils.SendInteger2(3, buffer); // protocol major
PIOUtils.SendInteger2(0, buffer); // protocol minor
for (byte[] encodedParam : encodedParams) {
PIOUtils.Send(encodedParam, buffer);
PIOUtils.SendChar(0, buffer);
}
PIOUtils.Send(new byte[] { 0 }, buffer);
buffer.flip();
socketChannel.write(buffer);
}
/**
* Convert Java time zone to postgres time zone. All others stay the same
* except that GMT+nn changes to GMT-nn and vise versa.
*
* @return The current JVM time zone in postgresql format.
*/
private static String createPostgresTimeZone() {
String tz = TimeZone.getDefault().getID();
if (tz.length() <= 3 || !tz.startsWith("GMT")) {
return tz;
}
char sign = tz.charAt(3);
String start;
if (sign == '+') {
start = "GMT-";
} else if (sign == '-') {
start = "GMT+";
} else {
// unknown type
return tz;
}
return start + tz.substring(4);
}
}