package org.scribble.net.session; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SocketChannel; import java.security.KeyManagementException; import java.security.NoSuchAlgorithmException; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; public class SSLSocketChannelWrapper extends BinaryChannelWrapper { private SSLEngine engine; private ByteBuffer EMPTY = ByteBuffer.allocate(0); private ByteBuffer myAppData; private ByteBuffer myNetData; //private ByteBuffer peerAppData = ByteBuffer.allocate(16916); // Hacked constant //private ByteBuffer peerNetData; public SSLSocketChannelWrapper() { } private void init(InetSocketAddress addr) throws NoSuchAlgorithmException, KeyManagementException { SSLContext sslContext = SSLContext.getInstance("TLS"); sslContext.init(null, null, null); this.engine = sslContext.createSSLEngine(addr.getHostName(), addr.getPort()); } @Override public void clientHandshake() throws IOException, KeyManagementException, NoSuchAlgorithmException { SocketChannel s = getSelectableChannel(); init((InetSocketAddress) s.getRemoteAddress()); this.engine.setUseClientMode(true); SSLSession session = this.engine.getSession(); this.myAppData = ByteBuffer.allocate(session.getApplicationBufferSize()); // 16916 this.myNetData = ByteBuffer.allocate(session.getPacketBufferSize()); // 16921 //this.peerAppData = ByteBuffer.allocate(16916); // Hacked constant //this.peerNetData = ByteBuffer.allocate(session.getPacketBufferSize()); doHandshake(s, this.engine, this.myNetData); } @Override public void serverHandshake() throws IOException, KeyManagementException, NoSuchAlgorithmException { SocketChannel s = getSelectableChannel(); init((InetSocketAddress) s.getRemoteAddress()); throw new RuntimeException("TODO"); } @Override public SocketChannel getSelectableChannel() { return (SocketChannel) super.getSelectableChannel(); } @Override protected void writeWrappedBytes(byte[] bs) throws IOException { getSelectableChannel().write(ByteBuffer.wrap(bs)); } @Override protected void readWrappedBytesIntoBuffer() throws IOException { ByteBuffer bb = getWrapped(); getSelectableChannel().read(bb); } @Override public byte[] wrap(byte[] bs) throws IOException { this.myAppData.put(bs); this.myAppData.flip(); this.myNetData.clear(); while (this.myAppData.hasRemaining()) { SSLEngineResult res = this.engine.wrap(this.myAppData, this.myNetData); if (res.getStatus() == SSLEngineResult.Status.OK) { this.myAppData.compact(); this.myAppData.flip(); } else { // Handle other status: BUFFER_OVERFLOW (write some first), BUFFER_UNDERFLOW, CLOSED throw new RuntimeException("TODO: " + res.getStatus()); } } this.myAppData.compact(); // Should be same as clear here (i.e. empty) this.myNetData.flip(); byte[] res = new byte[this.myNetData.remaining()]; System.arraycopy(this.myNetData.array(), this.myNetData.position(), res, 0, res.length); return res; } // Decode from this.wrapped (this.getUnwrapped()) into this.bb @Override public void unwrap() throws IOException // Decode from this.wrapped (this.getUnwrapped()) into this.bb { ByteBuffer peerNetData = getWrapped(); ByteBuffer peerAppData = getBuffer(); //for (boolean done = false; !done; ) { //s.read(peerNetData); // Assumes there is something to read (or else blocked) peerNetData.flip(); SSLEngineResult res = this.engine.unwrap(peerNetData, peerAppData); peerNetData.compact(); if (res.getStatus() == SSLEngineResult.Status.OK) { //done = true; } else if (res.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) { //peerNetData.flip(); //return; } else if (res.getStatus() == SSLEngineResult.Status.CLOSED) { // FIXME: check if closed is OK or not (if not, throw exception) } else { // Handle other status: BUFFER_OVERFLOW throw new RuntimeException("TODO: " + res.getStatus()); } } } private void doHandshake(SocketChannel s, SSLEngine engine, ByteBuffer myNetData) throws IOException { ByteBuffer peerNetData = ByteBuffer.allocate(engine.getSession().getPacketBufferSize()); ByteBuffer peerAppData = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()); engine.beginHandshake(); // Not explicitly needed for initial handshake SSLEngineResult.HandshakeStatus hs = engine.getHandshakeStatus(); while (hs != SSLEngineResult.HandshakeStatus.FINISHED && hs != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { switch (hs) { case NEED_UNWRAP: { for (boolean done = false; !done; ) { peerNetData.flip(); SSLEngineResult res = engine.unwrap(peerNetData, peerAppData); peerNetData.compact(); hs = res.getHandshakeStatus(); switch (res.getStatus()) { case OK: { done = true; break; } case BUFFER_UNDERFLOW: { if (s.read(peerNetData) < 0) { throw new RuntimeException("TODO: "); } break; } default: { // Handle other status: BUFFER_OVERFLOW, CLOSED throw new RuntimeException("TODO: " + res.getStatus()); } } } break; } case NEED_WRAP: { myNetData.clear(); SSLEngineResult res = engine.wrap(EMPTY, myNetData); hs = res.getHandshakeStatus(); switch (res.getStatus()) { case OK: { myNetData.flip(); while (myNetData.hasRemaining()) { if (s.write(myNetData) < 0) { throw new RuntimeException("TODO: "); } } myNetData.compact(); break; } default: { // Handle other status: BUFFER_OVERFLOW (write some first to free up space), BUFFER_UNDERFLOW, CLOSED throw new RuntimeException("TODO: " + hs); } } break; } case NEED_TASK: { engine.getDelegatedTask().run(); hs = engine.getHandshakeStatus(); break; } default: { // Handle other status: // FINISHED or NOT_HANDSHAKING throw new RuntimeException("TODO: " + hs); } } } } private void doShutdown() throws SSLException, IOException { SocketChannel s = getSelectableChannel(); this.engine.closeOutbound(); myNetData.clear(); while (!this.engine.isOutboundDone()) { SSLEngineResult res = this.engine.wrap(EMPTY, myNetData); if (res.getStatus() != SSLEngineResult.Status.CLOSED) // FIXME: is this the correct state to check for? or OK? (or both) { throw new RuntimeException("TODO: " + res.getStatus()); } this.myNetData.flip(); while (this.myNetData.hasRemaining()) { int num1 = s.write(this.myNetData); if (num1 == -1) { throw new RuntimeException("TODO: "); } this.myNetData.compact(); this.myNetData.flip(); } } //s1.close(); } @Override public synchronized void close() throws IOException { doShutdown(); super.close(); } }