package org.webpieces.ssl.impl; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; import javax.net.ssl.SSLEngineResult.Status; import javax.net.ssl.SSLException; import org.webpieces.util.logging.Logger; import org.webpieces.util.logging.LoggerFactory; import org.webpieces.data.api.BufferPool; import org.webpieces.ssl.api.AsyncSSLEngine; import org.webpieces.ssl.api.AsyncSSLEngineException; import org.webpieces.ssl.api.ConnectionState; import org.webpieces.ssl.api.SslListener; public class AsyncSSLEngine2Impl implements AsyncSSLEngine { private static final Logger log = LoggerFactory.getLogger(AsyncSSLEngine2Impl.class); private static final ByteBuffer EMPTY = ByteBuffer.allocate(0); private BufferPool pool; private SslMementoImpl mem; private SslListener listener; private Object wrapLock = new Object(); //unwrap would be a separate lock but it's better just to //order and not have to be thread safe...responsibility is on the client of this class. //This is because you would have to lock unwrap/listener.packetUnencrypted and client would own this //lock and for all we know another lock comes into the picture and deadlock occurs. Instead we do //not need to lock anyways. Just keep ordered and not call from multiple threads AT same time. //ie. use something like SessionExecutor //private Object unwrapLock = new Object(); private boolean clientInitiated; private AtomicBoolean fireClosed = new AtomicBoolean(false); private AtomicBoolean fireConnected = new AtomicBoolean(false); public AsyncSSLEngine2Impl(String loggingId, SSLEngine engine, BufferPool pool, SslListener listener) { this.pool = pool; this.listener = listener; ByteBuffer cachedOutBuffer = pool.nextBuffer(engine.getSession().getApplicationBufferSize()); this.mem = new SslMementoImpl(loggingId, engine, cachedOutBuffer); } @Override public void beginHandshake() { mem.compareSet(ConnectionState.NOT_STARTED, ConnectionState.CONNECTING); SSLEngine sslEngine = mem.getEngine(); log.trace(()->mem+"start handshake"); try { sslEngine.beginHandshake(); } catch (SSLException e) { throw new AsyncSSLEngineException(e); } sendHandshakeMessage(); } private void createRunnable() { SSLEngine sslEngine = mem.getEngine(); Runnable r = sslEngine.getDelegatedTask(); listener.runTask(new Runnable() { @Override public void run() { r.run(); runnableComplete(); } }); } private void runnableComplete() { SSLEngine sslEngine = mem.getEngine(); HandshakeStatus hsStatus = sslEngine.getHandshakeStatus(); ByteBuffer cached = mem.getCachedToProcess(); if(hsStatus == HandshakeStatus.NEED_UNWRAP) { //unwrap any previously incoming data... if(cached != null) { mem.setCachedEncryptedData(null); //wipe out the data we are now procesing log.trace(()->mem+"[AfterRunnable][socketToEngine] refeeding myself pos="+cached.position()+" lim="+cached.limit()); feedEncryptedPacketImpl(cached); } } else if(hsStatus == HandshakeStatus.NEED_WRAP) { log.trace(()->mem+"[Runnable]continuing handshake"); sendHandshakeMessage(); } else { throw new UnsupportedOperationException("need to support state="+hsStatus); } } private void sendHandshakeMessage() { try { sendHandshakeMessageImpl(); } catch (SSLException e) { throw new AsyncSSLEngineException(e); } } private void sendHandshakeMessageImpl() throws SSLException { SSLEngine sslEngine = mem.getEngine(); log.trace(()->mem+"sending handshake message"); //HELPER.eraseBuffer(empty); HandshakeStatus hsStatus = sslEngine.getHandshakeStatus(); if(hsStatus != HandshakeStatus.NEED_WRAP) throw new IllegalStateException("we should only be calling this method when hsStatus=NEED_WRAP. hsStatus="+hsStatus); while(hsStatus == HandshakeStatus.NEED_WRAP) { ByteBuffer engineToSocketData = pool.nextBuffer(sslEngine.getSession().getPacketBufferSize()); Status lastStatus = null; synchronized (wrapLock ) { //KEEEEEP This very small. wrap and then listener.packetEncrypted SSLEngineResult result = sslEngine.wrap(EMPTY, engineToSocketData); lastStatus = result.getStatus(); hsStatus = result.getHandshakeStatus(); final Status lastStatus2 = lastStatus; final HandshakeStatus hsStatus2 = hsStatus; log.trace(()->mem+"write packet pos="+engineToSocketData.position()+" lim="+ engineToSocketData.limit()+" status="+lastStatus2+" hs="+hsStatus2); if(lastStatus == Status.BUFFER_OVERFLOW || lastStatus == Status.BUFFER_UNDERFLOW) throw new RuntimeException("status not right, status="+lastStatus+" even though we sized the buffer to consume all?"); engineToSocketData.flip(); listener.sendEncryptedHandshakeData(engineToSocketData); } if(lastStatus == Status.CLOSED && !clientInitiated) { fireClose(); } } if(hsStatus == HandshakeStatus.NEED_WRAP || hsStatus == HandshakeStatus.NEED_TASK) throw new RuntimeException(mem+"BUG, need to implement more here status="+hsStatus); final HandshakeStatus hsStatus2 = hsStatus; log.trace(()->mem+"status="+hsStatus2+" isConn="+mem.getConnectionState()); if(hsStatus == HandshakeStatus.FINISHED) { fireLinkEstablished(); } } /** * This is synchronized as the socketToEngineData2 buffer is modified in this method * and modified in other methods that are called on other threads.(ie. the put is called) * */ @Override public void feedEncryptedPacket(ByteBuffer b) { if(mem.getConnectionState() == ConnectionState.DISCONNECTED) throw new IllegalStateException(mem+"SSLEngine is closed"); mem.compareSet(ConnectionState.NOT_STARTED, ConnectionState.CONNECTING); feedEncryptedPacketImpl(b); } private void feedEncryptedPacketImpl(ByteBuffer encryptedInData) { SSLEngine sslEngine = mem.getEngine(); HandshakeStatus hsStatus = sslEngine.getHandshakeStatus(); Status status = null; final HandshakeStatus hsStatus2 = hsStatus; log.trace(()->mem+"[sockToEngine] going to unwrap pos="+encryptedInData.position()+ " lim="+encryptedInData.limit()+" hsStatus="+hsStatus2+" cached="+mem.getCachedToProcess()); ByteBuffer encryptedData = encryptedInData; ByteBuffer cached = mem.getCachedToProcess(); if(cached != null) { encryptedData = combine(cached, encryptedData); mem.setCachedEncryptedData(null); } int i = 0; //stay in loop while we //1. need unwrap or not_handshaking or need_task AND //2. have data in buffer //3. have enough data in buffer(ie. not underflow) while(encryptedData.hasRemaining() && status != Status.BUFFER_UNDERFLOW && status != Status.CLOSED) { i++; SSLEngineResult result; ByteBuffer outBuffer = mem.getCachedOut(); try { result = sslEngine.unwrap(encryptedData, outBuffer); } catch(SSLException e) { AsyncSSLEngineException ee = new AsyncSSLEngineException("status="+status+" hsStatus="+hsStatus+" b="+encryptedData, e); throw ee; } finally { if(outBuffer.position() != 0) { outBuffer.flip(); listener.packetUnencrypted(outBuffer); //frequently the out buffer is not used so we only ask the pool for buffers AFTER it has been consumed/used ByteBuffer newCachedOut = pool.nextBuffer(sslEngine.getSession().getApplicationBufferSize()); mem.setCachedOut(newCachedOut); } } status = result.getStatus(); hsStatus = result.getHandshakeStatus(); final ByteBuffer data = encryptedData; final Status status2 = status; final HandshakeStatus hsStatus3 = hsStatus; log.trace(()->mem+"[sockToEngine] unwrap done pos="+data.position()+" lim="+ data.limit()+" status="+status2+" hs="+hsStatus3); if(i > 1000) throw new RuntimeException(this+"Bug, stuck in loop, bufIn="+encryptedData+" bufOut="+outBuffer+ " hsStatus="+hsStatus+" status="+status); else if(hsStatus == HandshakeStatus.NEED_TASK) { //if status is need task, we need to break to run the task before other handshake //messages? break; } else if(status == Status.BUFFER_UNDERFLOW) { final ByteBuffer data1 = encryptedData; log.trace(()->"buffer underflow. data="+data1.remaining()); } } if(encryptedData.hasRemaining()) { mem.setCachedEncryptedData(encryptedData); } final ByteBuffer data2 = encryptedData; final Status status2 = status; final HandshakeStatus hsStatus3 = hsStatus; log.trace(()->mem+"[sockToEngine] reset pos="+data2.position()+" lim="+data2.limit()+" status="+status2+" hs="+hsStatus3); cleanAndFire(hsStatus, status, encryptedData); } private void cleanAndFire(HandshakeStatus hsStatus, Status status, ByteBuffer encryptedData) { if(!encryptedData.hasRemaining()) pool.releaseBuffer(encryptedData); //First if avoids case where the close handshake is still going on so we are not closed //yet I think(I am writing this from memory)... if(status == Status.CLOSED) { if(hsStatus == HandshakeStatus.NEED_WRAP) { mem.compareSet(ConnectionState.CONNECTED, ConnectionState.DISCONNECTING); sendHandshakeMessage(); } else { fireClose(); } } else if(hsStatus == HandshakeStatus.NEED_TASK) { createRunnable(); } else if(hsStatus == HandshakeStatus.NEED_UNWRAP) { //just need to wait for more data } else if(hsStatus == HandshakeStatus.NEED_WRAP) { sendHandshakeMessage(); } else if(hsStatus ==HandshakeStatus.FINISHED) { fireLinkEstablished(); } else if(hsStatus == HandshakeStatus.NOT_HANDSHAKING) { //nothing to do. packet already fed } else { throw new UnsupportedOperationException("need to support state="+hsStatus+" status="+status); } } private void fireLinkEstablished() { boolean shouldFire = fireConnected.compareAndSet(false, true); if(shouldFire) { mem.compareSet(ConnectionState.CONNECTING, ConnectionState.CONNECTED); listener.encryptedLinkEstablished(); } } private void fireClose() { //fire only ONCE... boolean shouldFire = fireClosed.compareAndSet(false, true); if(shouldFire) { mem.compareSet(ConnectionState.DISCONNECTING, ConnectionState.DISCONNECTED); listener.closed(clientInitiated); } } private ByteBuffer combine(ByteBuffer cachedToProcessLaterData, ByteBuffer encryptedData) { int size = cachedToProcessLaterData.remaining()+encryptedData.remaining(); ByteBuffer nextBuffer = pool.nextBuffer(size); nextBuffer.put(cachedToProcessLaterData); nextBuffer.put(encryptedData); nextBuffer.flip(); pool.releaseBuffer(cachedToProcessLaterData); pool.releaseBuffer(encryptedData); return nextBuffer; } @Override public CompletableFuture<Void> feedPlainPacket(ByteBuffer buffer) { try { return feedPlainPacketImpl(buffer); } catch (SSLException e) { throw new AsyncSSLEngineException(e); } } @SuppressWarnings("rawtypes") public CompletableFuture<Void> feedPlainPacketImpl(ByteBuffer buffer) throws SSLException { if(mem.getConnectionState() != ConnectionState.CONNECTED) throw new IllegalStateException(mem+" SSLEngine is not connected right now"); else if(!buffer.hasRemaining()) throw new IllegalArgumentException("your buffer has no readable data"); SSLEngine sslEngine = mem.getEngine(); log.trace(()->mem+"feedPlainPacket [in-buffer] pos="+buffer.position()+" lim="+buffer.limit()); List<CompletableFuture> futures = new ArrayList<>(); while(buffer.hasRemaining()) { ByteBuffer engineToSocketData = pool.nextBuffer(sslEngine.getSession().getPacketBufferSize()); synchronized(wrapLock) { SSLEngineResult result = sslEngine.wrap(buffer, engineToSocketData); Status status = result.getStatus(); HandshakeStatus hsStatus = result.getHandshakeStatus(); if(status != Status.OK) throw new RuntimeException("Bug, status="+status+" instead of OK. hsStatus="+ hsStatus+" Something went wrong and we could not encrypt the data"); log.trace(()->mem+"SSLListener.packetEncrypted pos="+engineToSocketData.position()+ " lim="+engineToSocketData.limit()+" hsStatus="+hsStatus+" status="+status); engineToSocketData.flip(); CompletableFuture future = listener.packetEncrypted(engineToSocketData); futures.add(future); } } pool.releaseBuffer(buffer); CompletableFuture[] array = futures.toArray(new CompletableFuture[0]); return CompletableFuture.allOf(array); } @Override public void close() { clientInitiated = true; if(mem.getConnectionState() == ConnectionState.NOT_STARTED) { listener.closed(true); return; } mem.compareSet(ConnectionState.CONNECTED, ConnectionState.DISCONNECTING); SSLEngine engine = mem.getEngine(); engine.closeOutbound(); HandshakeStatus status = engine.getHandshakeStatus(); switch (status) { case NEED_WRAP: sendHandshakeMessage(); break; case NOT_HANDSHAKING: if(ConnectionState.DISCONNECTED != mem.getConnectionState()) throw new IllegalStateException("state="+mem.getConnectionState()+" hsStatus="+status+" should not be able to occur"); break; default: //we WILL hit this and need to fix if other end closes...try closing both ends!!! throw new RuntimeException(mem+"bug, status not handled in close="+status); } } @Override public ConnectionState getConnectionState() { return mem.getConnectionState(); } }