package DBProxy.Plugins.proxy; import DBProxy.Core.CMySQLLoadBalancer; import DBProxy.Core.Engine; import DBProxy.MySQL.Protocol.Com_Initdb; import DBProxy.MySQL.Protocol.Com_Query; import DBProxy.MySQL.Protocol.Flags; import DBProxy.MySQL.Protocol.Handshake; import DBProxy.MySQL.Protocol.HandshakeResponse; import DBProxy.MySQL.Protocol.Packet; import DBProxy.MySQL.Protocol.ResultSet; import java.net.Socket; import java.net.UnknownHostException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.BufferedInputStream; import org.apache.log4j.Logger; import DBProxy.Plugins.Base; import Publisher.CPublisherManager; import java.util.regex.Pattern; public class Proxy extends Base { public Pattern regUpdateIncident = Pattern.compile("UPDATE Incident [ (=),\'a-zA-Z0-9]* incidentStatus = 'On-going'"); public Pattern regInsertIncident = Pattern.compile("INSERT INTO Incident [ (),\'a-zA-Z0-9]* VALUES [ (),\'a-zA-Z0-9]* 'On-going'"); public Logger logger = Logger.getLogger("Plugin.Proxy"); // MySql server stuff public String mysqlHost = ""; public int mysqlPort = 0; public Socket mysqlSocket = null; public InputStream mysqlIn = null; public OutputStream mysqlOut = null; public void init(Engine context) throws IOException, UnknownHostException { this.logger.trace("init"); if (context.query.startsWith("SELECT") || context.query.startsWith("SHOW")) { this.mysqlSocket = CMySQLLoadBalancer.getDataStorage(); } else { this.mysqlSocket = CMySQLLoadBalancer.getMasterStorage(); } // Connect to the mysql server on the other side this.logger.info("Connected to mysql server at " + this.mysqlHost + ":" + this.mysqlPort); this.mysqlIn = new BufferedInputStream(this.mysqlSocket.getInputStream(), 16384); this.mysqlOut = this.mysqlSocket.getOutputStream(); } public void read_handshake(Engine context) throws IOException { this.logger.trace("read_handshake"); byte[] packet = Packet.read_packet(this.mysqlIn); context.handshake = Handshake.loadFromPacket(packet); // Remove some flags from the reply context.handshake.removeCapabilityFlag(Flags.CLIENT_COMPRESS); context.handshake.removeCapabilityFlag(Flags.CLIENT_SSL); context.handshake.removeCapabilityFlag(Flags.CLIENT_LOCAL_FILES); // Set the default result set creation to the server's character set ResultSet.characterSet = context.handshake.characterSet; // Set Replace the packet in the buffer context.buffer.add(context.handshake.toPacket()); } public void send_handshake(Engine context) throws IOException { this.logger.trace("send_handshake"); Packet.write(context.clientOut, context.buffer); context.clear_buffer(); } public void read_auth(Engine context) throws IOException { this.logger.trace("read_auth"); byte[] packet = Packet.read_packet(context.clientIn); context.buffer.add(packet); context.authReply = HandshakeResponse.loadFromPacket(packet); if (!context.authReply.hasCapabilityFlag(Flags.CLIENT_PROTOCOL_41)) { this.logger.fatal("We do not support Protocols under 4.1"); context.halt(); return; } context.authReply.removeCapabilityFlag(Flags.CLIENT_COMPRESS); context.authReply.removeCapabilityFlag(Flags.CLIENT_SSL); context.authReply.removeCapabilityFlag(Flags.CLIENT_LOCAL_FILES); context.schema = context.authReply.schema; } public void send_auth(Engine context) throws IOException { this.logger.trace("send_auth"); Packet.write(this.mysqlOut, context.buffer); context.clear_buffer(); } public void read_auth_result(Engine context) throws IOException { this.logger.trace("read_auth_result"); byte[] packet = Packet.read_packet(this.mysqlIn); context.buffer.add(packet); if (Packet.getType(packet) != Flags.OK) { this.logger.fatal("Auth is not okay!"); } } public void send_auth_result(Engine context) throws IOException { this.logger.trace("read_auth_result"); Packet.write(context.clientOut, context.buffer); context.clear_buffer(); } public void read_query(Engine context) throws IOException { this.logger.trace("read_query"); context.bufferResultSet = false; byte[] packet = Packet.read_packet(context.clientIn); context.buffer.add(packet); context.sequenceId = Packet.getSequenceId(packet); this.logger.trace("Client sequenceId: " + context.sequenceId); switch (Packet.getType(packet)) { case Flags.COM_QUIT: this.logger.trace("COM_QUIT"); context.halt(); break; // Extract out the new default schema case Flags.COM_INIT_DB: this.logger.trace("COM_INIT_DB"); context.schema = Com_Initdb.loadFromPacket(packet).schema; break; // Query case Flags.COM_QUERY: this.logger.trace("COM_QUERY"); context.query = Com_Query.loadFromPacket(packet).query; break; default: break; } } public void send_query(Engine context) throws IOException { this.logger.trace("send_query"); Packet.write(this.mysqlOut, context.buffer); context.clear_buffer(); } public void read_query_result(Engine context) throws IOException { this.logger.trace("read_query_result"); byte[] packet = Packet.read_packet(this.mysqlIn); context.buffer.add(packet); context.sequenceId = Packet.getSequenceId(packet); switch (Packet.getType(packet)) { case Flags.OK: case Flags.ERR: break; default: context.buffer = Packet.read_full_result_set(this.mysqlIn, context.clientOut, context.buffer, context.bufferResultSet); break; } if (this.regUpdateIncident.matcher(context.query).find()) { CPublisherManager.publishOngoingIncident(context.query); } else if (this.regInsertIncident.matcher(context.query).find()) { CPublisherManager.publishOngoingIncident(context.query); } } public void send_query_result(Engine context) throws IOException { this.logger.trace("send_query_result"); Packet.write(context.clientOut, context.buffer); context.clear_buffer(); } public void cleanup(Engine context) { this.logger.trace("cleanup"); if (this.mysqlSocket == null) { return; } try { this.mysqlSocket.close(); } catch (IOException e) { } } }