package org.jboss.pitbull.internal.nio.socket;
import org.jboss.pitbull.internal.logging.Logger;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLSession;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.concurrent.TimeUnit;
/**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public class SSLChannel extends FreeChannel
{
protected SSLEngine engine;
protected SSLSession sslSession;
protected ByteBuffer inputBuffer;
protected ByteBuffer appBuffer;
protected ByteBuffer outputBuffer;
protected SSLEngineResult.HandshakeStatus handshakeStatus;
protected SSLEngineResult.Status engineStatus = null;
protected ByteBuffer dummy = ByteBuffer.allocate(0);
protected static final Logger log = Logger.getLogger(SSLChannel.class);
public SSLChannel(SocketChannel channel, SSLEngine engine) throws Exception
{
super(channel);
this.engine = engine;
this.sslSession = engine.getSession();
int packetBufferSize = sslSession.getPacketBufferSize();
inputBuffer = ByteBuffer.allocate(packetBufferSize);
outputBuffer = ByteBuffer.allocate(packetBufferSize);
int applicationBufferSize = sslSession.getApplicationBufferSize();
appBuffer = ByteBuffer.allocate(applicationBufferSize);
// Change the position of the buffers so that a
// call to hasRemaining() returns false. A buffer is considered
// empty when the position is set to its limit, that is when
// hasRemaining() returns false.
appBuffer.position(appBuffer.limit());
outputBuffer.position(outputBuffer.limit());
engine.beginHandshake();
handshakeStatus = engine.getHandshakeStatus();
}
@Override
public SSLSession getSslSession()
{
return sslSession;
}
protected void executeEngineTasks()
{
Runnable task;
while ((task = engine.getDelegatedTask()) != null)
{
task.run();
}
handshakeStatus = engine.getHandshakeStatus();
}
/**
* @return true if status == CLOSED
* @throws IOException
*/
protected boolean needUnwrap() throws IOException
{
if (inputBuffer.position() == 0) return false;
log.trace("needUnwrap()");
SSLEngineResult res;
inputBuffer.flip();
do
{
res = engine.unwrap(inputBuffer, appBuffer);
log.trace("Unwrapping:\n" + res);
// During an handshake renegotiation we might need to perform
// several unwraps to consume the handshake data.
} while (res.getStatus() == SSLEngineResult.Status.OK &&
res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP);
engineStatus = res.getStatus();
handshakeStatus = res.getHandshakeStatus();
// Should never happen, the peerAppData must always have enough space
// for an unwrap operation
assert engineStatus != SSLEngineResult.Status.BUFFER_OVERFLOW :
"Buffer should not overflow: " + res.toString();
// The handshake status here can be different than NOT_HANDSHAKING
// if the other peer closed the connection. So only check for it
// after testing for closure.
if (engineStatus == SSLEngineResult.Status.CLOSED)
{
log.debug("Connection is being closed by peer.");
return true;
}
inputBuffer.compact();
return false;
}
protected int processHandshake() throws IOException
{
log.trace("processHandshake()");
SSLEngineResult res;
try
{
for (; ; )
{
handshakeStatus = engine.getHandshakeStatus();
switch (handshakeStatus)
{
case FINISHED:
log.trace("Handshake FINISHED");
return 0;
case NEED_TASK:
log.trace("Handshake NEED_TASK");
executeEngineTasks();
break;
case NEED_UNWRAP:
log.trace("Handshake NEED_UNWRAP");
if (inputBuffer.position() > 0
&& engineStatus != SSLEngineResult.Status.BUFFER_UNDERFLOW)
{
if (needUnwrap())
{
return -1;
}
}
else
{
return 0;
}
case NEED_WRAP:
log.trace("Handshake NEED_WRAP");
outputBuffer.clear();
res = engine.wrap(dummy, outputBuffer);
handshakeStatus = res.getHandshakeStatus();
outputBuffer.flip();
if (super.writeBlocking(outputBuffer) == -1)
{
log.debug("Error writing on NEED_WRAP");
return -1;
}
break;
case NOT_HANDSHAKING:
log.trace("Handshake NOT_HANDSHAKING");
return 0;
}
}
}
finally
{
handshakeStatus = engine.getHandshakeStatus();
log.trace("End processHandshake() : {0}", handshakeStatus);
}
}
private int readBuffer(ByteBuffer buf)
{
if (appBuffer.hasRemaining())
{
int limit = Math.min(appBuffer.remaining(), buf.remaining());
for (int i = 0; i < limit; i++)
{
buf.put(appBuffer.get());
}
return limit;
}
return 0;
}
protected int readSuper(ByteBuffer buf) throws IOException
{
return super.read(buf);
}
protected int readBlockingSuper(ByteBuffer buf) throws IOException
{
return super.readBlocking(buf);
}
protected int readBlockingSuper(ByteBuffer buf, long time, TimeUnit unit) throws IOException
{
return super.readBlocking(buf, time, unit);
}
protected interface ReadExecution
{
int read(ByteBuffer buf) throws IOException;
}
protected int readExecution(ByteBuffer buf, ReadExecution execution) throws IOException
{
log.trace("<--- enter - read");
try
{
int bufBytesRead = readBuffer(buf);
log.trace("Bytes read from buffer: {0}", bufBytesRead);
if (bufBytesRead > 0)
{
return bufBytesRead;
}
// nothing in appBuffer so clear it
appBuffer.clear();
// do everything but reading from channel
int status = processHandshake();
log.trace("finished 1st processEngine");
appBuffer.flip();
if (appBuffer.hasRemaining())
{
return readBuffer(buf);
}
else
{
appBuffer.clear();
}
if (status == -1)
{
log.trace("processEngine resulted in closed channel");
return -1;
}
int bytesRead = execution.read(inputBuffer);
log.trace("Bytes read into inputBuffer: {0}", bytesRead);
if (bytesRead < 1) return bytesRead;
// Now that we have bytes in the buffer, do something with it.
log.trace("Start loop--");
do
{
status = processHandshake();
if (status == -1)
{
log.trace("channel closed after processHandshake");
return -1;
}
log.trace("Unwrapping");
inputBuffer.flip();
SSLEngineResult res = engine.unwrap(inputBuffer, appBuffer);
log.trace("Unwrapped: {0}", res);
handshakeStatus = res.getHandshakeStatus();
engineStatus = res.getStatus();
inputBuffer.compact();
} while (engineStatus == SSLEngineResult.Status.OK && inputBuffer.hasRemaining());
log.trace("--After loop:");
log.trace("HandshakeStatus: {0}", handshakeStatus);
log.trace("Engine Status: {0}", engineStatus);
// handle any need-task, need-wrap
processHandshake();
// Prepare the buffer to be written again.
log.trace("prepare buffers");
appBuffer.flip();
log.trace("remaining inputBuffer: {0}", inputBuffer.position());
return readBuffer(buf);
}
finally
{
log.trace("---> exit - read");
}
}
public int read(ByteBuffer buf) throws IOException
{
log.trace("read()");
return readExecution(buf,
new ReadExecution()
{
@Override
public int read(ByteBuffer buf) throws IOException
{
return readSuper(buf);
}
});
}
@Override
public int readBlocking(ByteBuffer buf) throws IOException
{
log.trace("readBlocking()");
return readExecution(buf,
new ReadExecution()
{
@Override
public int read(ByteBuffer buf) throws IOException
{
return readBlockingSuper(buf);
}
});
}
@Override
public int readBlocking(final ByteBuffer buf, final long time, final TimeUnit unit) throws IOException
{
log.trace("readBlocking() with timeout");
return readExecution(buf,
new ReadExecution()
{
@Override
public int read(ByteBuffer buf) throws IOException
{
return readBlockingSuper(buf, time, unit);
}
});
}
@Override
public int write(ByteBuffer buf) throws IOException
{
return writeBlocking(buf);
}
protected int writeBlockingSuper(ByteBuffer buf) throws IOException
{
return super.writeBlocking(buf);
}
@Override
public int writeBlocking(ByteBuffer buffer) throws IOException
{
int size = buffer.remaining();
while (buffer.hasRemaining())
{
outputBuffer.clear();
SSLEngineResult res = engine.wrap(buffer, outputBuffer);
if (res.getStatus() != SSLEngineResult.Status.OK)
{
throw new IOException("Illegal status for write: " + res.getStatus());
}
// Prepare the buffer for reading
outputBuffer.flip();
int result = super.writeBlocking(outputBuffer);
if (result == -1) return -1;
}
// Return the number of bytes read
// from the source buffer
return size;
}
@Override
public int writeBlocking(ByteBuffer buffer, long time, TimeUnit unit) throws IOException
{
long timeRemaining = unit.toMillis(time);
long now = System.currentTimeMillis();
int numBufWritten = 0;
while (buffer.hasRemaining() && timeRemaining > 0L)
{
outputBuffer.clear();
SSLEngineResult res = engine.wrap(buffer, outputBuffer);
if (res.getStatus() != SSLEngineResult.Status.OK)
{
throw new IOException("Illegal status for write: " + res.getStatus());
}
// Prepare the buffer for reading
outputBuffer.flip();
int result = super.writeBlocking(outputBuffer, timeRemaining, TimeUnit.MILLISECONDS);
if (result == -1) return -1;
numBufWritten += res.bytesConsumed();
timeRemaining -= Math.max(-now + (now = System.currentTimeMillis()), 0L);
}
return numBufWritten;
}
}