package org.infinispan.server.core.transport; import java.net.SocketAddress; import java.util.List; import javax.security.sasl.Sasl; import javax.security.sasl.SaslServer; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandler; import io.netty.channel.ChannelPromise; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.TooLongFrameException; /** * Handles QOP of the SASL protocol. */ public class SaslQopHandler extends ByteToMessageDecoder implements ChannelOutboundHandler { private final SaslServer server; private final int maxBufferSize; private final int maxSendBufferSize; private int packetLength = -1; public SaslQopHandler(SaslServer server) { this.server = server; String maxBuf = (String) server.getNegotiatedProperty(Sasl.MAX_BUFFER); if (maxBuf != null) { maxBufferSize = Integer.parseInt(maxBuf); } else { maxBufferSize = -1; } String maxSendBuf = (String) server.getNegotiatedProperty(Sasl.RAW_SEND_SIZE); if (maxSendBuf != null) { maxSendBufferSize = Integer.parseInt(maxSendBuf); } else { maxSendBufferSize = -1; } } private static byte[] readBytes(ByteBuf buffer) { byte[] bytes = new byte[buffer.readableBytes()]; buffer.readBytes(bytes); buffer.release(); return bytes; } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { ByteBuf buffer = (ByteBuf) msg; byte[] bytes; int offset; int len; if (buffer.hasArray()) { bytes = buffer.array(); offset = buffer.arrayOffset() + buffer.readerIndex(); len = buffer.readableBytes(); } else { bytes = readBytes(buffer); offset = 0; len = bytes.length; } byte[] wrapped = server.wrap(bytes, offset, len); ctx.write(ctx.alloc().buffer(4).writeInt(wrapped.length)); if (maxSendBufferSize != -1 && wrapped.length > maxSendBufferSize) { // The produces data is bigger then the maxSendBufferSize so split it and flush every of them directly. int size = wrapped.length; int off = 0; for (;;) { if (size < maxSendBufferSize) { ctx.writeAndFlush(Unpooled.wrappedBuffer(wrapped, off, size), promise); return; } else { ctx.writeAndFlush(Unpooled.wrappedBuffer(wrapped, off, maxSendBufferSize)); off += maxSendBufferSize; size -= maxSendBufferSize; } } } else { ctx.write(Unpooled.wrappedBuffer(wrapped), promise); } } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { int len = packetLength; if (len == -1) { if (in.readableBytes() < 4) { return; } len = packetLength = (int) in.readUnsignedInt(); if (maxBufferSize != -1 && maxBufferSize < packetLength) { TooLongFrameException ex = new TooLongFrameException( "Frame exceed exceed max buffer size: " + packetLength + " > " + maxBufferSize); ctx.fireExceptionCaught(ex); ctx.close(); return; } } if (len > in.readableBytes()) { return; } // reset packet length packetLength = -1; int offset; byte[] array; if (in.hasArray()) { offset = in.readerIndex() + in.arrayOffset(); array = in.array(); in.skipBytes(len); } else { offset = 0; array = new byte[len]; in.readBytes(array); } out.add(Unpooled.wrappedBuffer(server.unwrap(array, offset, len))); } @Override public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { ctx.bind(localAddress, promise); } @Override public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) throws Exception { ctx.connect(remoteAddress, localAddress, promise); } @Override public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { ctx.disconnect(promise); } @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { ctx.close(promise); } @Override public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { ctx.deregister(promise); } @Override public void read(ChannelHandlerContext ctx) throws Exception { ctx.read(); } @Override public void flush(ChannelHandlerContext ctx) throws Exception { ctx.flush(); } @Override protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { super.handlerRemoved0(ctx); server.dispose(); } }