package water.network; import water.H2O; import water.util.Log; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import java.io.IOException; import java.net.SocketException; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; import java.nio.channels.SocketChannel; /** * This class is based on: * <a href="https://docs.oracle.com/javase/8/docs/technotes/guides/security/jsse/JSSERefGuide.html">Oracle's JSSE guide.</a> * <a href="https://docs.oracle.com/javase/8/docs/technotes/guides/security/jsse/samples/sslengine/SSLEngineSimpleDemo.java">Oracle's SSLEngine demo.</a> * * It's a simple wrapper around SocketChannels which enables SSL/TLS * communication using {@link javax.net.ssl.SSLEngine}. */ class SSLSocketChannel implements ByteChannel { // Empty buffer for handshakes private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); // Buffer holding encrypted outgoing data private ByteBuffer netInBuffer; // Buffer holding encrypted incoming data private ByteBuffer netOutBuffer; // Buffer holding decrypted incoming data private ByteBuffer peerAppData; private SocketChannel channel = null; private SSLEngine sslEngine = null; private boolean closing = false; private boolean closed = false; private boolean handshakeComplete = false; SSLSocketChannel(SocketChannel channel, SSLEngine sslEngine) throws IOException { this.channel = channel; this.sslEngine = sslEngine; sslEngine.setEnableSessionCreation(true); SSLSession session = sslEngine.getSession(); prepareBuffers(session); handshake(); } @Override public boolean isOpen() { return channel.isOpen(); } @Override public void close() throws IOException { closing = true; sslEngine.closeOutbound(); sslEngine.getSession().invalidate(); netOutBuffer.clear(); channel.close(); closed = true; } private void prepareBuffers(SSLSession session) throws SocketException { int appBufferSize = session.getApplicationBufferSize(); // Less is not more. More is more. Bigger than the app buffer size so successful unwraps() don't cause BUFFER_OVERFLOW // Value 64 was based on other frameworks using it and some manual testing. Might require tuning in the future. peerAppData = ByteBuffer.allocate(appBufferSize + 64); int netBufferSize = session.getPacketBufferSize(); netInBuffer = ByteBuffer.allocate(netBufferSize); netOutBuffer = ByteBuffer.allocate(netBufferSize); } // ----------------------------------------------------------- // HANDSHAKE // ----------------------------------------------------------- private SSLEngineResult.HandshakeStatus hs; private void handshake() throws IOException { Log.debug("Starting SSL handshake..."); sslEngine.beginHandshake(); hs = sslEngine.getHandshakeStatus(); SSLEngineResult initHandshakeStatus; while (!handshakeComplete) { switch (hs) { case NOT_HANDSHAKING: { //should never happen throw new IOException("NOT_HANDSHAKING during handshake"); } case FINISHED: handshakeComplete = !netOutBuffer.hasRemaining(); break; case NEED_WRAP: { initHandshakeStatus = handshakeWrap(); if ( initHandshakeStatus.getStatus() == SSLEngineResult.Status.OK ){ if (hs == SSLEngineResult.HandshakeStatus.NEED_TASK) { tasks(); } } break; } case NEED_UNWRAP: { initHandshakeStatus = handshakeUnwrap(); if ( initHandshakeStatus.getStatus() == SSLEngineResult.Status.OK ){ if (hs == SSLEngineResult.HandshakeStatus.NEED_TASK) { tasks(); } } break; } // SSL needs to perform some delegating tasks before it can continue. // Those tasks will be run in the same thread and can be blocking. case NEED_TASK: tasks(); break; } } Log.debug("SSL handshake finished successfully!"); } private synchronized SSLEngineResult handshakeWrap() throws IOException { netOutBuffer.clear(); SSLEngineResult wrapResult = sslEngine.wrap(EMPTY_BUFFER, netOutBuffer); netOutBuffer.flip(); hs = wrapResult.getHandshakeStatus(); channel.write(netOutBuffer); return wrapResult; } private synchronized SSLEngineResult handshakeUnwrap() throws IOException { if (netInBuffer.position() == netInBuffer.limit()) { netInBuffer.clear(); } channel.read(netInBuffer); SSLEngineResult unwrapResult; peerAppData.clear(); do { netInBuffer.flip(); unwrapResult = sslEngine.unwrap(netInBuffer, peerAppData); netInBuffer.compact(); hs = unwrapResult.getHandshakeStatus(); switch (unwrapResult.getStatus()) { case OK: case BUFFER_UNDERFLOW: { if (unwrapResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) { tasks(); } break; } case BUFFER_OVERFLOW: { int applicationBufferSize = sslEngine.getSession().getApplicationBufferSize(); if (applicationBufferSize > peerAppData.capacity()) { ByteBuffer b = ByteBuffer.allocate(applicationBufferSize + peerAppData.position()); peerAppData.flip(); b.put(peerAppData); peerAppData = b; } else { peerAppData.compact(); } break; } default: throw new IOException("Failed to SSL unwrap with status " + unwrapResult.getStatus()); } } while(unwrapResult.getStatus() == SSLEngineResult.Status.OK && hs == SSLEngineResult.HandshakeStatus.NEED_UNWRAP); return unwrapResult; } // ----------------------------------------------------------- // READ AND WRITE // ----------------------------------------------------------- @Override public int read(ByteBuffer dst) throws IOException { if (closing || closed) return -1; return unwrap(dst); } private synchronized int unwrap(ByteBuffer dst) throws IOException { int read = 0; // We have outstanding data in our incoming decrypted buffer, use that data first to fill dst if(!dst.hasRemaining()) { return 0; } if(peerAppData.position() != 0) { read += copy(peerAppData, dst); return read; } if(netInBuffer.position() == 0) { channel.read(netInBuffer); } while(netInBuffer.position() != 0) { netInBuffer.flip(); // We still might have left data here if dst was smaller than the amount of data in peerAppData if(peerAppData.position() != 0) { peerAppData.compact(); } SSLEngineResult unwrapResult = sslEngine.unwrap(netInBuffer, peerAppData); switch (unwrapResult.getStatus()) { case OK: { unwrapResult.bytesProduced(); if (unwrapResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) tasks(); break; } case BUFFER_OVERFLOW: { int applicationBufferSize = sslEngine.getSession().getApplicationBufferSize(); if (applicationBufferSize > peerAppData.capacity()) { int appSize = applicationBufferSize; ByteBuffer b = ByteBuffer.allocate(appSize + peerAppData.position()); peerAppData.flip(); b.put(peerAppData); peerAppData = b; } else { // We tried to unwrap data into peerAppData which means there's leftover in netInBuffer // the upcoming read should read int potential new data after the leftover netInBuffer.position(netInBuffer.limit()); netInBuffer.limit(netInBuffer.capacity()); peerAppData.compact(); if(!dst.hasRemaining()) { return read; } } break; } case BUFFER_UNDERFLOW: { int packetBufferSize = sslEngine.getSession().getPacketBufferSize(); if (packetBufferSize > netInBuffer.capacity()) { int netSize = packetBufferSize; if (netSize > netInBuffer.capacity()) { ByteBuffer b = ByteBuffer.allocate(netSize); netInBuffer.flip(); b.put(netInBuffer); netInBuffer = b; } } else { // We have some leftover data from unwrap but no enough. // We need to read in more data from the socket AFTER the current data. netInBuffer.position(netInBuffer.limit()); netInBuffer.limit(netInBuffer.capacity()); channel.read(netInBuffer); continue; } break; } default: throw new IOException("Failed to SSL unwrap with status " + unwrapResult.getStatus()); } if (peerAppData != dst && dst.hasRemaining()) { peerAppData.flip(); read += copy(peerAppData, dst); if(!dst.hasRemaining()) { netInBuffer.compact(); return read; } } netInBuffer.compact(); } return read; } private int copy(ByteBuffer src, ByteBuffer dst) { int toCopy = Math.min(src.remaining(), dst.remaining()); dst.put(src.array(), src.position(), toCopy); src.position(src.position() + toCopy); if(!src.hasRemaining()) { src.clear(); } return toCopy; } @Override public int write(ByteBuffer src) throws IOException { if(closing || closed) { throw new IOException("Cannot perform socket write, the socket is closed (or being closed)."); } int wrote = 0; // src can be much bigger than what our SSL session allows to send in one go while (src.hasRemaining()) { netOutBuffer.clear(); SSLEngineResult wrapResult = sslEngine.wrap(src, netOutBuffer); netOutBuffer.flip(); if (wrapResult.getStatus() == SSLEngineResult.Status.OK) { if (wrapResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) tasks(); } while (netOutBuffer.hasRemaining()) { wrote += channel.write(netOutBuffer); } } return wrote; } // ----------------------------------------------------------- // MISC // ----------------------------------------------------------- private void tasks() { Runnable r; while ( (r = sslEngine.getDelegatedTask()) != null) { r.run(); } hs = sslEngine.getHandshakeStatus(); } public SocketChannel channel() { return channel; } SSLEngine getEngine() { return sslEngine; } boolean isHandshakeComplete() { return handshakeComplete; } }