package org.limewire.nio.ssl; import java.io.IOException; import java.net.InetAddress; import java.net.Socket; import java.net.SocketAddress; import java.net.SocketException; import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.util.concurrent.Executor; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult.HandshakeStatus; import org.limewire.logging.Log; import org.limewire.logging.LogFactory; import org.limewire.nio.ByteBufferCache; import org.limewire.nio.NIODispatcher; import org.limewire.nio.NIOSocket; import org.limewire.nio.channel.ChannelReader; import org.limewire.nio.channel.InterestReadableByteChannel; import org.limewire.nio.channel.InterestWritableByteChannel; import org.limewire.nio.channel.ThrottleReader; import org.limewire.nio.observer.ConnectObserver; /** * An {@link NIOSocket} that uses SSL/TLS for transfer encoding. * <p> * {@link AbstractSSLSocket} can be configured to support any cipher suite * and {@link SSLContext}. */ public abstract class AbstractSSLSocket extends NIOSocket { private final static Log LOG = LogFactory.getLog(TLSNIOSocket.class); private volatile SSLReadWriteChannel sslLayer; private volatile InterestReadableByteChannel baseReader; private volatile InterestWritableByteChannel baseWriter; public AbstractSSLSocket(InetAddress addr, int port, InetAddress localAddr, int localPort) throws IOException { super(addr, port, localAddr, localPort); } public AbstractSSLSocket(InetAddress addr, int port) throws IOException { super(addr, port); } public AbstractSSLSocket(String addr, int port, InetAddress localAddr, int localPort) throws IOException { super(addr, port, localAddr, localPort); } public AbstractSSLSocket(String addr, int port) throws UnknownHostException, IOException { super(addr, port); } public AbstractSSLSocket() throws IOException { super(); } AbstractSSLSocket(Socket socket) { super(socket); } /** * Returns the {@link SSLContext} that should be used to generate new * {@link SSLEngine SSLEngines}. */ protected abstract SSLContext getSSLContext(); /** * Returns the {@link Executor} that should be used to process long-lived tasks * generated from an {@link SSLEngine}. * * @see {@link HandshakeStatus#NEED_TASK} */ protected Executor getSSLExecutor() { return SSLUtils.getExecutor(); } /** * Returns the {@link ByteBufferCache} that should be used to retrieve & * return {@link ByteBuffer ByteBuffers} for use reading & writing data. */ protected ByteBufferCache getByteBufferCache() { return NIODispatcher.instance().getBufferCache(); } /** * Returns the {@link Executor} that should be used to write & read data to * & from the network. */ protected Executor getNetworkExecutor() { return NIODispatcher.instance().getScheduledExecutorService(); } /** * Returns the cipher suites that are allowed to be used by the {@link SSLEngine}. * A return value of null means that the default cipher suites are enabled. * * @see SSLEngine#getEnabledCipherSuites() * @see SSLEngine#getSupportedCipherSuites() */ protected abstract String[] getCipherSuites(); @Override public boolean connect(SocketAddress addr, int timeout, ConnectObserver observer) { return super.connect(addr, timeout, new SSLConnectInitializer(addr, observer)); } @Override protected InterestReadableByteChannel getBaseReadChannel() { if(baseReader == null) { sslLayer.setReadChannel(super.getBaseReadChannel()); baseReader = sslLayer; } return baseReader; } @Override protected InterestWritableByteChannel getBaseWriteChannel() { if(baseWriter == null) { sslLayer.setWriteChannel(super.getBaseWriteChannel()); baseWriter = sslLayer; } return baseWriter; } @Override protected void installThrottle(ThrottleReader throttle, ChannelReader reader) { // The goal is to insert the throttle such that // READER -> READER -> SSL LAYER -> THROTTLE -> SOCKET // so... do everything the same as the super, except when connecting // the throttle to the socket we don't connect it to the SSL layer, // instead we connect it to the real socket. ChannelReader lastChannel = reader; // go down the chain of ChannelReaders and find the last one to set our source while(lastChannel.getReadChannel() instanceof ChannelReader) { lastChannel = (ChannelReader)lastChannel.getReadChannel(); } if(throttle != lastChannel) { lastChannel.setReadChannel(throttle); throttle.setReadChannel(super.getBaseReadChannel()); } } @Override protected void initIncomingSocket() { super.initIncomingSocket(); sslLayer = new SSLReadWriteChannel(getSSLContext(), getSSLExecutor(), getByteBufferCache(), getNetworkExecutor()); sslLayer.initialize(getRemoteSocketAddress(), SSLUtils.getTLSCipherSuites(), false, false); } @Override protected void initOutgoingSocket() throws IOException { super.initOutgoingSocket(); sslLayer = new SSLReadWriteChannel(getSSLContext(), getSSLExecutor(), getByteBufferCache(), getNetworkExecutor()); } @Override protected void shutdownObservers() { if(sslLayer != null) sslLayer.shutdown(); super.shutdownObservers(); } /* package */ SSLReadWriteChannel getSSLChannel() { return sslLayer; } @Override /* Overridden to retrieve the soTimeout from the socket if we're still handshaking. */ public long getReadTimeout() { if(sslLayer != null && sslLayer.isHandshaking()) { try { return getSoTimeout(); } catch(SocketException se) { return 0; } } else { return super.getReadTimeout(); } } /** * A delegating connector that forces the TLS Layer to be initialized * prior to informing the real <code>ConnectObserver</code> about the connection. */ private class SSLConnectInitializer implements ConnectObserver { private final ConnectObserver delegate; private final SocketAddress addr; public SSLConnectInitializer(SocketAddress addr, ConnectObserver delegate) { this.delegate = delegate; this.addr = addr; } public void handleConnect(Socket socket) throws IOException { if(LOG.isDebugEnabled()) { LOG.debug("Initializing SSL/TLS connection to " + getInetAddress().getHostAddress() + ":" + getPort() + ", open " + sslLayer.isOpen() + ", handshaking " + sslLayer.isHandshaking()); } sslLayer.initialize(addr, getCipherSuites(), true, false); if(LOG.isDebugEnabled()) { LOG.debug("Initialized SSL/TLS connection to " + getInetAddress().getHostAddress() + ":" + getPort() + ", open " + sslLayer.isOpen() + ", handshaking " + sslLayer.isHandshaking()); } delegate.handleConnect(socket); } public void handleIOException(IOException iox) { if(LOG.isDebugEnabled()) { LOG.debug(iox + ", " + getInetAddress().getHostAddress() + ":" + getPort() + ", open " + sslLayer.isOpen() + ", handshaking " + sslLayer.isHandshaking()); } delegate.handleIOException(iox); } public void shutdown() { if(LOG.isDebugEnabled()) { LOG.debug("Shutting down SSL/TLS connection to " + getInetAddress().getHostAddress() + ":" + getPort() + ", open " + sslLayer.isOpen() + ", handshaking " + sslLayer.isHandshaking()); } delegate.shutdown(); } } }