package org.properssl.sslcertx.mariadb.jdbc.internal.common.packet; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; public class PacketOutputStream extends OutputStream{ private static final int MAX_PACKET_LENGTH = 0x00ffffff; private static final int SEQNO_OFFSET = 3; private static final int HEADER_LENGTH = 4; private static final int MAX_SEQNO = 0xff; OutputStream baseStream; byte[] byteBuffer; int position; int seqNo; boolean compress; public PacketOutputStream(OutputStream baseStream) { this.baseStream = baseStream; byteBuffer = new byte[1024]; seqNo = -1; } public void setCompress(boolean value) { if (seqNo != -1) throw new AssertionError("setCompress on already started packet is illegal"); compress = value; } public void startPacket(int seqNo) throws IOException { if (this.seqNo != -1) { throw new IOException("Last packet not finished"); } this.seqNo = seqNo; position = HEADER_LENGTH; } public int getSeqNo() { return seqNo; } private void writeEmptyPacket(int seqNo) throws IOException { byteBuffer[0] = 0; byteBuffer[1] = 0; byteBuffer[2] = 0; byteBuffer[SEQNO_OFFSET] = (byte)seqNo; baseStream.write(byteBuffer, 0, 4); position = HEADER_LENGTH; } /* Used by LOAD DATA INFILE. End of data is indicated by packet of length 0. */ public void sendFile(InputStream is, int seq) throws IOException{ byte[] buffer = new byte[8192]; int len; while((len = is.read(buffer)) > 0) { startPacket(seq++); write(buffer, 0, len); finishPacket(); } writeEmptyPacket(seq); } public void finishPacket() throws IOException{ if (seqNo == -1) { throw new AssertionError("Packet not started"); } internalFlush(); baseStream.flush(); seqNo = -1; } @Override public void write(byte[] bytes, int off, int len) throws IOException{ if (seqNo == -1) { throw new AssertionError("Use PacketOutputStream.startPacket() before write()"); } if (seqNo == MAX_SEQNO) { throw new IOException("MySQL protocol limit reached, you cannot send more than 4GB of data"); } for (;;) { if (len == 0) break; int bytesToWrite= Math.min(len, MAX_PACKET_LENGTH + HEADER_LENGTH - position); // Grow buffer if required if (byteBuffer.length - position < bytesToWrite) { byte[] tmp = new byte[Math.min(MAX_PACKET_LENGTH + HEADER_LENGTH, 2*(byteBuffer.length + bytesToWrite))]; System.arraycopy(byteBuffer, 0, tmp, 0, position); byteBuffer = tmp; } System.arraycopy(bytes, off, byteBuffer, position, bytesToWrite); position += bytesToWrite; off += bytesToWrite; len -= bytesToWrite; if (position == MAX_PACKET_LENGTH + HEADER_LENGTH) { internalFlush(); } } } @Override public void flush() throws IOException { throw new AssertionError("Do not call flush() on PacketOutputStream. use finishPacket() instead."); } private void internalFlush() throws IOException { int dataLen = position - HEADER_LENGTH; byteBuffer[0] = (byte)(dataLen & 0xff); byteBuffer[1] = (byte)((dataLen >> 8) & 0xff); byteBuffer[2] = (byte)((dataLen >> 16) & 0xff); byteBuffer[SEQNO_OFFSET] = (byte)seqNo; baseStream.write(byteBuffer, 0, position); position = HEADER_LENGTH; seqNo++; } @Override public void write(byte[] bytes) throws IOException{ write(bytes, 0, bytes.length); } @Override public void write(int b) throws IOException { byte[] a={(byte)b}; write(a); } @Override public void close() throws IOException { baseStream.close(); byteBuffer = null; } }