package org.threadly.litesockets.utils; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_TASK; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP; import java.nio.ByteBuffer; import java.util.concurrent.atomic.AtomicBoolean; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; import javax.net.ssl.SSLEngineResult.Status; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLSession; import org.threadly.concurrent.future.ListenableFuture; import org.threadly.concurrent.future.SettableListenableFuture; import org.threadly.litesockets.Client; import org.threadly.litesockets.buffers.MergedByteBuffers; import org.threadly.litesockets.buffers.ReuseableMergedByteBuffers; import org.threadly.util.ExceptionUtils; /** * This is a generic SSLClient that can be used to create an encrypted connection to a server. * By default it will not check the servers certs and will only do TLS connections. * * @author lwahlmeier */ public class SSLProcessor { //Not sure why but android needs this extra buffer //as the getApp/Packet buffers are not right. public static final int EXTRA_BUFFER_AMOUNT = 50; /** * This is how much extra buffer we allocate for read events so we do quite * as much allocating. We always have to have at least 1 buffer the size SSLEngine says * we need, but might use very little of it. This way we can allocate once and get multiple * read events with it before we have to throw away allocated bytes an allocate more. * */ public static final int PREALLOCATE_BUFFER_MULTIPLIER = 3; private static final ByteBuffer EMPTY_BYTEBUFFER = ByteBuffer.allocate(0); private final AtomicBoolean finishedHandshake = new AtomicBoolean(false); private final AtomicBoolean startedHandshake = new AtomicBoolean(false); private final SettableListenableFuture<SSLSession> handshakeFuture = new SettableListenableFuture<SSLSession>(false); private final MergedByteBuffers encryptedReadBuffers = new ReuseableMergedByteBuffers(false); private final MergedByteBuffers tempBuffers = new ReuseableMergedByteBuffers(false); private final SSLEngine ssle; private final Client client; private ByteBuffer writeBuffer; private ByteBuffer decryptedReadBuffer; public SSLProcessor(final Client client, final SSLEngine ssle) { this.client = client; this.ssle = ssle; } public boolean handShakeStarted() { return startedHandshake.get(); } /** * This lets you know if the connection is currently encrypted or not. * @return true if the connection is encrypted false if not. */ public boolean isEncrypted() { return (startedHandshake.get() && !ssle.getSession().getProtocol().equals("NONE")); } /** * <p>You can start the handshake by calling this method. If connect() has not been called * on the client this will happen as soon as it is. The future allows you to know * when the handshake has finished if if there was an error. While the handshake is processing all writes to the * socket will queue.</p> * * @return A ListenableFuture. If a result was given it succeeded, if there is an error it failed. The connection is closed on failures. */ public ListenableFuture<SSLSession> doHandShake() { if(startedHandshake.compareAndSet(false, true)) { try { ssle.beginHandshake(); if(ssle.getHandshakeStatus() == NEED_WRAP) { client.write(EMPTY_BYTEBUFFER); } client.getClientsSocketExecuter().watchFuture(handshakeFuture, client.getTimeout()); } catch (SSLException e) { this.handshakeFuture.setFailure(e); } } return handshakeFuture; } private void runTasks() { SSLEngineResult.HandshakeStatus hs = ssle.getHandshakeStatus(); while(hs == NEED_TASK) { final Runnable task = ssle.getDelegatedTask(); if(task != null) { ExceptionUtils.runRunnable(task); } hs = ssle.getHandshakeStatus(); } } private ByteBuffer getAppWriteBuffer() { if(this.writeBuffer == null || this.writeBuffer.remaining() < ssle.getSession().getPacketBufferSize()+EXTRA_BUFFER_AMOUNT) { this.writeBuffer = ByteBuffer.allocate(ssle.getSession().getPacketBufferSize()+EXTRA_BUFFER_AMOUNT); } return writeBuffer; } private ByteBuffer getDecryptedByteBuffer() { if(decryptedReadBuffer == null || decryptedReadBuffer.remaining() < ssle.getSession().getApplicationBufferSize()+EXTRA_BUFFER_AMOUNT) { decryptedReadBuffer = ByteBuffer.allocate(ssle.getSession().getApplicationBufferSize()+EXTRA_BUFFER_AMOUNT); } return decryptedReadBuffer; } public MergedByteBuffers encrypt(final ByteBuffer buffer) { final MergedByteBuffers mbb = new ReuseableMergedByteBuffers(false); if(!startedHandshake.get()){ mbb.add(buffer); return mbb; } ByteBuffer oldBB = buffer.duplicate(); if(finishedHandshake.get() && this.tempBuffers.remaining() > 0) { tempBuffers.add(buffer); oldBB = tempBuffers.pullBuffer(tempBuffers.remaining()); } ByteBuffer newBB; ByteBuffer tmpBB; boolean gotFinished = false; while (ssle.getHandshakeStatus() == NEED_WRAP || oldBB.remaining() > 0) { newBB = getAppWriteBuffer(); tmpBB = newBB.duplicate(); try { final SSLEngineResult res = ssle.wrap(oldBB, newBB); if(!finishedHandshake.get() && oldBB.remaining() > 0) { tempBuffers.add(oldBB); oldBB.position(oldBB.limit()); } if(!finishedHandshake.get() && res.getHandshakeStatus() == FINISHED) { gotFinished = true; } else { while (ssle.getHandshakeStatus() == NEED_TASK) { runTasks(); } } } catch (SSLHandshakeException e) { this.handshakeFuture.setFailure(e); client.close(); break; } catch (SSLException e) { throw new EncryptionException(e); } if(tmpBB.hasRemaining()) { tmpBB.limit(newBB.position()); mbb.add(tmpBB); } if(client.isClosed()) { break; } } writeBuffer = null; if(gotFinished && finishedHandshake.compareAndSet(false, true)) { handshakeFuture.setResult(ssle.getSession()); if(tempBuffers.remaining() > 0) { mbb.add(encrypt(EMPTY_BYTEBUFFER)); } } return mbb; } public MergedByteBuffers decrypt(final ByteBuffer bb) { MergedByteBuffers mbb = new ReuseableMergedByteBuffers(false); mbb.add(bb); return decrypt(mbb); } public ReuseableMergedByteBuffers decrypt(final MergedByteBuffers bb) { final ReuseableMergedByteBuffers mbb = new ReuseableMergedByteBuffers(false); if(!this.startedHandshake.get()) { mbb.add(bb); return mbb; } encryptedReadBuffers.add(bb); final ByteBuffer encBB = encryptedReadBuffers.pullBuffer(encryptedReadBuffers.remaining()); while(encBB.remaining() > 0) { final ByteBuffer dbb = getDecryptedByteBuffer(); final ByteBuffer newBB = dbb.duplicate(); SSLEngineResult res; try { res = ssle.unwrap(encBB, dbb); //We have to check both each time till complete if(! handshakeFuture.isDone()) { processHandshake(res.getHandshakeStatus()); processHandshake(ssle.getHandshakeStatus()); } } catch (SSLException e) { throw new EncryptionException(e); } newBB.limit(dbb.position()); if(newBB.hasRemaining()) { mbb.add(newBB); } else if (res.getStatus() == Status.BUFFER_UNDERFLOW) { if(encBB.hasRemaining()) { encryptedReadBuffers.add(encBB); } break; } } return mbb; } private void finishHandshake() { if(this.finishedHandshake.compareAndSet(false, true)){ handshakeFuture.setResult(ssle.getSession()); if(tempBuffers.remaining() > 0) { client.write(EMPTY_BYTEBUFFER); //make the client write to flush tempBuffers } } } private void processHandshake(final HandshakeStatus status) throws SSLException { switch(status) { case NOT_HANDSHAKING: //Fix for older android versions, they dont send a finished case FINISHED: { if(handShakeStarted()) { finishHandshake(); } } break; case NEED_TASK: { runTasks(); } break; case NEED_WRAP: { client.write(EMPTY_BYTEBUFFER); } break; default: { } break; } } /** * Generic Exception to throw when we get an Encryption error. * * @author lwahlmeier * */ public static class EncryptionException extends RuntimeException { private static final long serialVersionUID = -2713992763314654069L; public EncryptionException(final Throwable t) { super(t); } } }