package org.limewire.nio.ssl;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.limewire.nio.ByteBufferCache;
import org.limewire.nio.channel.ChannelReader;
import org.limewire.nio.channel.ChannelWriter;
import org.limewire.nio.channel.InterestReadableByteChannel;
import org.limewire.nio.channel.InterestWritableByteChannel;
import org.limewire.nio.observer.Shutdownable;
import org.limewire.nio.observer.WriteObserver;
import org.limewire.util.BufferUtils;
import org.limewire.util.FileUtils;
/**
* An SSL-capable layer that can transform incoming and outgoing
* data according to the specified <code>SSLContext</code> and cipher suite.
*/
class SSLReadWriteChannel implements InterestReadableByteChannel, InterestWritableByteChannel,
ChannelReader, ChannelWriter {
private static final Log LOG = LogFactory.getLog(SSLReadWriteChannel.class);
/** The context from which to retrieve a new SSLEngine. */
private final SSLContext context;
/** An executor to perform blocking tasks. */
private final Executor sslBlockingExecutor;
/** The engine managing this SSL session. */
private SSLEngine engine;
/** A temporary buffer which data is unwrapped to. */
private ByteBuffer readIncoming;
/** The buffer which the underlying readSink is read into. */
private ByteBuffer readOutgoing;
/** The buffer which we wrap writes to. */
private ByteBuffer writeOutgoing;
/** The underlying channel to read from. */
private volatile InterestReadableByteChannel readSink;
/** The underlying channel to write to. */
private volatile InterestWritableByteChannel writeSink;
/** The last WriteObserver that indicated write interested. */
private volatile WriteObserver writeWanter;
/** True if handshaking indicated we need to immediately perform a wrap. */
private volatile boolean needsHandshakeWrap = false;
/** True if handshaking indicated we need to immediately perform an unwrap. */
private volatile boolean needsHandshakeUnwrap = false;
/** True if a read finished and data was still buffered. */
private volatile boolean readDataLeft = false;
/** True only after a single read has been performed. */
private final AtomicBoolean firstReadDone = new AtomicBoolean(false);
/* Statistic gathering variables. */
private volatile long readConsumed;
private volatile long readProduced;
private volatile long writeConsumed;
private volatile long writeProduced;
/**
* Whether or not this has been shutdown.
* Shutting down must be atomic with regard to initializing, so that
* we can guarantee all allocated buffers are released
* properly.
*
* Shutdown is volatile so read/write/handleWrite can quickly
* get it w/o locking.
*/
private volatile boolean shutdown = false;
private final Object initLock = new Object();
/**
* A lock used when changing to 'need task', to prevent
* further writes/reads from changing interest.
* Due to java bug: 6492872
* (deadlock within SSLEngine if task/read/write called at same time,
* fixed in 1.6u2 & 1.7),
* Internally we only do reads/writes on a single thread, so will
* never encounter a read/write collision, but we always do
* tasks on another thread, and it's possible that interest may
* be turned on for reading|writing, which would cause a read|write,
* and possibly an engine.wrap or engine.unwrap.
* To workaround this, we prevent any interest from being turned on
* once a task is performed. When the task is finished, interest
* will be turned on as necessary.
*/
private final Object taskLock = new Object();
private volatile boolean taskScheduled = false;
/**
* The last state of who interested us in reading must be kept,
* so that after handshaking finishes, we can put reading into
* the correct interest state. otherwise, our options are:
* 1) leave interest on, which could potentially loop forever
* if the connected socket closes.
* 2) turn interest off, which could confuse any callers that
* had wanted to read data.
*
* Note that we don't have to do this for writing because writing
* can successfully turn itself off.
*/
private boolean readInterest = false;
private final Object readInterestLock = new Object();
private final ByteBufferCache byteBufferCache;
private final Executor networkExecutor;
public SSLReadWriteChannel(SSLContext context, Executor sslBlockingExecutor,
ByteBufferCache byteBufferCache, Executor networkExecutor) {
this.sslBlockingExecutor = sslBlockingExecutor;
this.context = context;
this.byteBufferCache = byteBufferCache;
this.networkExecutor = networkExecutor;
}
/**
* Initializes this TLSLayer, using the given address and
* enabling the given cipherSuites.
*
* If clientMode is disabled, client authentication can be turned on/off.
*
* @param addr
* @param cipherSuites
*/
void initialize(SocketAddress addr, String[] cipherSuites, boolean clientMode, boolean needClientAuth) {
synchronized(initLock) {
if(shutdown) {
LOG.debug("Not initializing because already shutdown.");
return;
}
if(addr != null) {
if(!(addr instanceof InetSocketAddress))
throw new IllegalArgumentException("unsupported SocketAddress");
InetSocketAddress iaddr = (InetSocketAddress)addr;
String host = iaddr.getAddress().getHostAddress();
int port = iaddr.getPort();
engine = context.createSSLEngine(host, port);
} else {
engine = context.createSSLEngine();
}
if(cipherSuites != null) {
engine.setEnabledCipherSuites(cipherSuites);
}
engine.setUseClientMode(clientMode);
if(!clientMode) {
engine.setWantClientAuth(needClientAuth);
engine.setNeedClientAuth(needClientAuth);
}
SSLSession session = engine.getSession();
readIncoming = byteBufferCache.getHeap(session.getPacketBufferSize());
writeOutgoing = byteBufferCache.getHeap(session.getPacketBufferSize());
if(LOG.isTraceEnabled())
LOG.trace("Initialized engine: " + engine + ", session: " + session);
}
}
public int read(ByteBuffer dst) throws IOException {
if(shutdown)
throw new ClosedChannelException();
// If a task is scheduled, don't read anything!
if(taskScheduled)
return 0;
int transferred = 0;
// If data was left in readOutgoing, pre-transfer it.
if(readOutgoing != null && readOutgoing.position() > 0) {
transferred += BufferUtils.transfer(readOutgoing, dst);
if(readOutgoing.position() > 0) {
LOG.debug("Transferred less than we have left!");
return transferred;
}
}
while(true) {
// If we're not handshaking and there's no space to read into, exit early.
// Must check separately for 'first read' and 'not handshaking', because
// the engine isn't put into handshaking mode until a single read is done.
if(firstReadDone.get() && !dst.hasRemaining() && engine.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING) {
LOG.debug("No room left to transfer data, exiting");
return transferred;
}
int read = -1;
while(readIncoming.hasRemaining() && (read = readSink.read(readIncoming)) > 0);
// if we last read EOF & nothing was put in sourceBuffer, EOF
if(read == -1 && readIncoming.position() == 0) {
// TODO: Is this a better fix for EOF during handshaking?
//if(!firstReadDone.get() || (engine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING && engine.getHandshakeStatus() != HandshakeStatus.FINISHED))
// throw new ClosedChannelException();
LOG.debug("Read EOF, no data to transfer. Connection finished");
return -1;
}
// If we couldn't read anything, there's nothing to unwrap.
if(readIncoming.position() == 0) {
LOG.debug("Unable to read anything, exiting read loop");
return 0;
}
readIncoming.flip();
// Try unwrapping directly into dst first.
SSLEngineResult result = unwrap(engine, readIncoming, dst);
readProduced += result.bytesProduced();
readConsumed += result.bytesConsumed();
transferred += result.bytesProduced();
SSLEngineResult.Status status = result.getStatus();
// If dst didn't have enough space, use an intermediate buffer.
if(status == Status.BUFFER_OVERFLOW) {
// Initialize readOutgoing only if not shutdown,
// but grab the lock after we've checked to make sure
// it's non-null, to avoid lock every read.
// Lock is only necessary for 'shutdown'.
if(readOutgoing == null) {
synchronized(initLock) {
if(shutdown)
throw new IOException("Shutdown while sizing");
readOutgoing = byteBufferCache.getHeap(engine.getSession().getApplicationBufferSize());
}
}
result = unwrap(engine, readIncoming, readOutgoing);
readProduced += result.bytesProduced();
readConsumed += result.bytesConsumed();
status = result.getStatus();
if(status == Status.BUFFER_OVERFLOW) {
if(readIncoming.position() == 0 && readIncoming.capacity() == 16665 && engine.getSession().getPacketBufferSize() == 33049) {
// See: http://download.java.net/jdk7/docs/technotes/guides/security/jsse/JSSERefGuide.html#SSLSession
// The SSL/TLS protocols specify 16KB as the buffer size, but some implementations
// do it wrong and make it 32KB. The Sun impl will dynamically resize up, so we
// need to handle that.
// (We grab the lock to make sure that the buffers aren't recycled as we do this.)
synchronized(initLock) {
if(shutdown)
throw new IOException("Shutdown while resizing.");
// Transfer data from old readIncoming to newIncoming.
ByteBuffer newIncoming = byteBufferCache.getHeap(engine.getSession().getPacketBufferSize());
BufferUtils.transfer(readIncoming, newIncoming, false);
newIncoming.flip();
assert newIncoming.limit() == readIncoming.position();
assert newIncoming.position() == 0;
byteBufferCache.release(readIncoming);
readIncoming = newIncoming;
// Replace outgoing with upgraded version.
assert readOutgoing.position() == 0;
byteBufferCache.release(readOutgoing);
readOutgoing = byteBufferCache.getHeap(engine.getSession().getApplicationBufferSize());
// ... and try again!
result = unwrap(engine, readIncoming, readOutgoing);
readProduced += result.bytesProduced();
readConsumed += result.bytesConsumed();
status = result.getStatus();
if(status == Status.BUFFER_OVERFLOW)
throw new IllegalStateException("tried resizing, but still not enough room in fallback TLS buffer! readOutgoing: " + readOutgoing + ", readIncoming: " + readIncoming + ", packet size: " + engine.getSession().getPacketBufferSize() + ", appl size: " + engine.getSession().getApplicationBufferSize());
}
} else {
throw new IllegalStateException("cannot resize, and not enough room in fallback TLS buffer! readOutgoing: " + readOutgoing + ", readIncoming: " + readIncoming + ", packet size: " + engine.getSession().getPacketBufferSize() + ", appl size: " + engine.getSession().getApplicationBufferSize());
}
}
transferred += BufferUtils.transfer(readOutgoing, dst);
}
firstReadDone.set(true);
if(readIncoming.hasRemaining()) {
readDataLeft = true;
readIncoming.compact();
} else {
readDataLeft = false;
readIncoming.clear();
}
if(LOG.isDebugEnabled())
LOG.debug("Read unwrap result: " + result + ", transferred: " + transferred);
// If we were unable to interpret this packet because not enough was
// read, then we must exit and wait for more to be read later.
if(status == Status.BUFFER_UNDERFLOW) {
if(transferred == 0 && read == -1) {
LOG.debug("Read EOF & underflow when unwrapping. Connection finished");
return -1;
} else {
return transferred;
}
}
// If the engine is closed and nothing was transferred,
// return -1 to show the stream ended. Otherwise return
// however much we were able to already transfer.
if(status == Status.CLOSED) {
if(transferred == 0)
return -1;
else
return transferred;
}
// We may be handshaking, which requires processing of data...
if(!processHandshakeResult(true, false, result.getHandshakeStatus()))
return transferred;
}
}
/** See {@link SSLEngine#wrap(ByteBuffer, ByteBuffer) */
private SSLEngineResult wrap(SSLEngine engine, ByteBuffer src, ByteBuffer dst) throws SSLException, IOException {
assert src != null;
assert dst != null;
try {
return engine.wrap(src, dst);
} catch(RuntimeException re) {
throw new IOException(re);
} catch(Error e) {
throw new IOException(e);
}
}
/** See {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer) */
private SSLEngineResult unwrap(SSLEngine engine, ByteBuffer src, ByteBuffer dst) throws SSLException, IOException {
assert dst != null;
assert src != null;
try {
return engine.unwrap(src, dst);
} catch(RuntimeException re) {
throw new IOException(re);
} catch(Error e) {
throw new IOException(e);
}
}
/**
* Processes a single handshake result.
* If a delegated task is needed, returns false & schedules the task(s).
* If writing is needed, returns false only if currently reading.
* If reading is needed, returns false only if currently writing.
* Otherwise, returns true.
*/
private boolean processHandshakeResult(boolean reading, boolean writing, HandshakeStatus hs) {
if(LOG.isTraceEnabled())
LOG.trace("Processing result from: " + engine + ", result: " + hs);
needsHandshakeWrap = false;
needsHandshakeUnwrap = false;
switch(hs) {
case NEED_TASK:
needTask();
return false;
case NEED_WRAP:
needsHandshakeWrap = true;
// IMPORTANT: read interest must be turned off before write
// interest is turned on. This is necessary because
// if write interest is turned on from the TASK thread,
// it can get immediate notification from NIODispatcher,
// which can continue processing more results, and will
// set read interest on. Then, when the context switches
// back to this thread, it would turn read interest off.
// So: TURN READ INTEREST OFF FIRST!
readSink.interestRead(false);
writeSink.interestWrite(this, true);
return writing;
case NEED_UNWRAP:
// IMPORTANT: write interest must be turned off before write
// interest is turned on. This is necessary because
// if read interest is turned on from the TASK thread,
// it can get immediate notification from NIODispatcher,
// which can continue processing more results, and will
// set read interest on. Then, when the context switches
// back to this thread, it would turn write interest off.
// So: TURN WRITE INTEREST OFF FIRST!
writeSink.interestWrite(null, false);
synchronized(readInterestLock) {
needsHandshakeUnwrap = true;
readSink.interestRead(true);
}
// If we had previously buffered read data, force a read.
if(readDataLeft && !reading)
networkExecutor.execute(new Runnable() {
public void run() {
try {
read(BufferUtils.getEmptyBuffer());
} catch(IOException iox) {
FileUtils.close(SSLReadWriteChannel.this);
}
}
});
return reading;
case FINISHED:
synchronized(readInterestLock) {
// set interest to what our observer wanted.
readSink.interestRead(readInterest);
}
writeSink.interestWrite(this, true);
case NOT_HANDSHAKING:
default:
// no change.
return true;
}
}
/** The engine needs to run some tasks before proceeding... */
private void needTask() {
synchronized(taskLock) {
taskScheduled = true;
readSink.interestRead(false);
writeSink.interestWrite(null, false);
}
// Run as many tasks as possible, and then add another
// that will process the next state.
while(true) {
final Runnable runner = engine.getDelegatedTask();
if(runner == null) {
sslBlockingExecutor.execute(new Runnable() {
public void run() {
synchronized(taskLock) {
taskScheduled = false;
}
HandshakeStatus status = engine.getHandshakeStatus();
if(LOG.isDebugEnabled())
LOG.debug("Task(s) finished, status: " + status);
processHandshakeResult(false, false, status);
}
});
break;
} else {
sslBlockingExecutor.execute(runner);
}
}
}
public int write(ByteBuffer src) throws IOException {
if(shutdown)
throw new ClosedChannelException();
// If a task is scheduled, don't write anything!
if(taskScheduled)
return 0;
int consumed = 0;
// do...while because we want to force one write even with empty buffers
do {
boolean wasEmpty = writeOutgoing.position() == 0;
SSLEngineResult result = wrap(engine, src, writeOutgoing);
writeProduced += result.bytesProduced();
writeConsumed += result.bytesConsumed();
if(LOG.isDebugEnabled())
LOG.debug("Wrap result: " + result);
consumed += result.bytesConsumed();
SSLEngineResult.Status status = result.getStatus();
if(status == Status.CLOSED && !isOpen())
throw new ClosedChannelException();
if(!processHandshakeResult(false, true, result.getHandshakeStatus()))
break;
if(status == Status.BUFFER_OVERFLOW) {
if(wasEmpty)
throw new IllegalStateException("outgoing TLS buffer not large enough!");
else
break;
}
} while(src.hasRemaining());
return consumed;
}
public boolean handleWrite() throws IOException {
if(shutdown)
throw new ClosedChannelException();
InterestWritableByteChannel source = writeSink;
if(source == null)
throw new IllegalStateException("writing with no source.");
while(true) {
if(writeOutgoing.position() > 0) {
// Step 1: See if there is any pending data to be written.
writeOutgoing.flip();
writeSink.write(writeOutgoing);
if(writeOutgoing.hasRemaining()) {
writeOutgoing.compact();
return true; // there is still data that is pending a write.
}
writeOutgoing.clear();
}
// Step 2: If we need to do a handshake wrap, do that.
if(needsHandshakeWrap) {
LOG.debug("Forcing a handshake wrap");
write(BufferUtils.getEmptyBuffer());
if(writeOutgoing.position() > 0)
continue;
}
// Step 3: Tell any interested parties to write data.
WriteObserver interested = writeWanter;
if(interested != null) {
if(LOG.isDebugEnabled())
LOG.debug("Telling interested parties to write. (a " + interested + ")");
interested.handleWrite();
}
// If no data after that, we've written everything we want -- exit.
if (writeOutgoing.position() == 0) {
// We have nothing left to write, however, it is possible
// that between the above check for interested.handleWrite & here,
// we got pre-empted and another thread turned on interest.
synchronized (this) {
if (writeWanter == null) // no observer? good, we can turn interest off
source.interestWrite(this, false);
// else, we've got nothing to write, but our observer might.
}
return false;
}
}
}
/**
* Releases any resources that were acquired by the channel.
* If the underlying channels are still open, this method only propogates
* the shutdown call, instead of shutting down this channel, as it can
* still be used by other channels.
*/
public void shutdown() {
synchronized(initLock) {
if(shutdown)
return;
if(!isOpen()) {
LOG.debug("Shutting down SSL channel");
shutdown = true;
}
}
if(shutdown) {
networkExecutor.execute(new Runnable() {
public void run() {
if(readIncoming != null)
byteBufferCache.release(readIncoming);
if(readOutgoing != null)
byteBufferCache.release(readOutgoing);
if(writeOutgoing != null)
byteBufferCache.release(writeOutgoing);
}
});
}
Shutdownable observer = writeWanter;
if(observer != null)
observer.shutdown();
}
public InterestReadableByteChannel getReadChannel() {
return readSink;
}
public void setReadChannel(InterestReadableByteChannel newChannel) {
this.readSink = newChannel;
}
public InterestWritableByteChannel getWriteChannel() {
return writeSink;
}
public void setWriteChannel(InterestWritableByteChannel newChannel) {
this.writeSink = newChannel;
}
public void close() throws IOException {
readSink.close();
writeSink.close();
}
public boolean isOpen() {
return readSink != null && readSink.isOpen() && writeSink != null && writeSink.isOpen();
}
public void handleIOException(IOException iox) {
shutdown();
}
public void interestRead(boolean status) {
synchronized(taskLock) {
synchronized(readInterestLock) {
readInterest = status;
boolean interest = !taskScheduled && (needsHandshakeUnwrap || status);
readSink.interestRead(interest);
}
}
}
public synchronized void interestWrite(WriteObserver observer, boolean status) {
this.writeWanter = status ? observer : null;
InterestWritableByteChannel source = writeSink;
if(source != null) {
synchronized(taskLock) {
source.interestWrite(this, !taskScheduled);
}
}
}
/** Returns the total number of bytes that this has produced from unwrapping reads. */
long getReadBytesProduced() {
return readProduced;
}
/** Returns the total number of bytes that this has consumed while unwrapping reads. */
long getReadBytesConsumed() {
return readConsumed;
}
/** Returns the total number of bytes that this has produced from wrapping writes. */
long getWrittenBytesProduced() {
return writeProduced;
}
/** Returns the total number of bytes that this has consumed while wrapping writes. */
long getWrittenBytesConsumed() {
return writeConsumed;
}
/** Returns the SSLSession this channel uses. */
SSLSession getSession() {
return engine != null ? engine.getSession() : null;
}
/** Returns true if we're currently handshaking. */
boolean isHandshaking() {
return !firstReadDone.get() || engine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING;
}
public boolean hasBufferedOutput() {
InterestWritableByteChannel channel = this.writeSink;
return writeOutgoing.position() > 0 || (channel != null && channel.hasBufferedOutput());
}
}