package com.koushikdutta.async;
import android.os.Build;
import com.koushikdutta.async.callback.CompletedCallback;
import com.koushikdutta.async.callback.DataCallback;
import com.koushikdutta.async.callback.WritableCallback;
import com.koushikdutta.async.util.Allocator;
import com.koushikdutta.async.wrapper.AsyncSocketWrapper;
import org.apache.http.conn.ssl.StrictHostnameVerifier;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.cert.X509Certificate;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;
public class AsyncSSLSocketWrapper implements AsyncSocketWrapper, AsyncSSLSocket {
public interface HandshakeCallback {
public void onHandshakeCompleted(Exception e, AsyncSSLSocket socket);
}
static SSLContext defaultSSLContext;
AsyncSocket mSocket;
BufferedDataSink mSink;
boolean mUnwrapping;
SSLEngine engine;
boolean finishedHandshake;
private int mPort;
private String mHost;
private boolean mWrapping;
HostnameVerifier hostnameVerifier;
HandshakeCallback handshakeCallback;
X509Certificate[] peerCertificates;
WritableCallback mWriteableCallback;
DataCallback mDataCallback;
TrustManager[] trustManagers;
boolean clientMode;
static {
// following is the "trust the system" certs setup
try {
// critical extension 2.5.29.15 is implemented improperly prior to 4.0.3.
// https://code.google.com/p/android/issues/detail?id=9307
// https://groups.google.com/forum/?fromgroups=#!topic/netty/UCfqPPk5O4s
// certs that use this extension will throw in Cipher.java.
// fallback is to use a custom SSLContext, and hack around the x509 extension.
if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.ICE_CREAM_SANDWICH_MR1)
throw new Exception();
defaultSSLContext = SSLContext.getInstance("Default");
}
catch (Exception ex) {
try {
defaultSSLContext = SSLContext.getInstance("TLS");
TrustManager[] trustAllCerts = new TrustManager[] { new X509TrustManager() {
public java.security.cert.X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
public void checkClientTrusted(java.security.cert.X509Certificate[] certs, String authType) {
}
public void checkServerTrusted(java.security.cert.X509Certificate[] certs, String authType) {
for (X509Certificate cert : certs) {
if (cert != null && cert.getCriticalExtensionOIDs() != null)
cert.getCriticalExtensionOIDs().remove("2.5.29.15");
}
}
} };
defaultSSLContext.init(null, trustAllCerts, null);
}
catch (Exception ex2) {
ex.printStackTrace();
ex2.printStackTrace();
}
}
}
public static SSLContext getDefaultSSLContext() {
return defaultSSLContext;
}
public static void handshake(AsyncSocket socket,
String host, int port,
SSLEngine sslEngine,
TrustManager[] trustManagers, HostnameVerifier verifier, boolean clientMode,
final HandshakeCallback callback) {
AsyncSSLSocketWrapper wrapper = new AsyncSSLSocketWrapper(socket, host, port, sslEngine, trustManagers, verifier, clientMode);
wrapper.handshakeCallback = callback;
socket.setClosedCallback(new CompletedCallback() {
@Override
public void onCompleted(Exception ex) {
if (ex != null)
callback.onHandshakeCompleted(ex, null);
else
callback.onHandshakeCompleted(new SSLException("socket closed during handshake"), null);
}
});
try {
wrapper.engine.beginHandshake();
wrapper.handleHandshakeStatus(wrapper.engine.getHandshakeStatus());
} catch (SSLException e) {
wrapper.report(e);
}
}
boolean mEnded;
Exception mEndException;
final ByteBufferList pending = new ByteBufferList();
private AsyncSSLSocketWrapper(AsyncSocket socket,
String host, int port,
SSLEngine sslEngine,
TrustManager[] trustManagers, HostnameVerifier verifier, boolean clientMode) {
mSocket = socket;
hostnameVerifier = verifier;
this.clientMode = clientMode;
this.trustManagers = trustManagers;
this.engine = sslEngine;
mHost = host;
mPort = port;
engine.setUseClientMode(clientMode);
mSink = new BufferedDataSink(socket);
mSink.setWriteableCallback(new WritableCallback() {
@Override
public void onWriteable() {
if (mWriteableCallback != null)
mWriteableCallback.onWriteable();
}
});
// on pause, the emitter is paused to prevent the buffered
// socket and itself from firing.
// on resume, emitter is resumed, ssl buffer is flushed as well
mSocket.setEndCallback(new CompletedCallback() {
@Override
public void onCompleted(Exception ex) {
if (mEnded)
return;
mEnded = true;
mEndException = ex;
if (!pending.hasRemaining() && mEndCallback != null)
mEndCallback.onCompleted(ex);
}
});
mSocket.setDataCallback(dataCallback);
}
final DataCallback dataCallback = new DataCallback() {
final Allocator allocator = new Allocator().setMinAlloc(8192);
final ByteBufferList buffered = new ByteBufferList();
@Override
public void onDataAvailable(DataEmitter emitter, ByteBufferList bb) {
if (mUnwrapping)
return;
try {
mUnwrapping = true;
bb.get(buffered);
if (buffered.hasRemaining()) {
ByteBuffer all = buffered.getAll();
buffered.add(all);
}
ByteBuffer b = ByteBufferList.EMPTY_BYTEBUFFER;
while (true) {
if (b.remaining() == 0 && buffered.size() > 0) {
b = buffered.remove();
}
int remaining = b.remaining();
int before = pending.remaining();
SSLEngineResult res;
{
// wrap to prevent access to the readBuf
ByteBuffer readBuf = allocator.allocate();
res = engine.unwrap(b, readBuf);
addToPending(pending, readBuf);
allocator.track(pending.remaining() - before);
}
if (res.getStatus() == Status.BUFFER_OVERFLOW) {
allocator.setMinAlloc(allocator.getMinAlloc() * 2);
remaining = -1;
}
else if (res.getStatus() == Status.BUFFER_UNDERFLOW) {
buffered.addFirst(b);
if (buffered.size() <= 1) {
break;
}
// pack it
remaining = -1;
b = buffered.getAll();
buffered.addFirst(b);
b = ByteBufferList.EMPTY_BYTEBUFFER;
}
handleHandshakeStatus(res.getHandshakeStatus());
if (b.remaining() == remaining && before == pending.remaining()) {
buffered.addFirst(b);
break;
}
}
AsyncSSLSocketWrapper.this.onDataAvailable();
}
catch (SSLException ex) {
ex.printStackTrace();
report(ex);
}
finally {
mUnwrapping = false;
}
}
};
public void onDataAvailable() {
Util.emitAllData(this, pending);
if (mEnded && !pending.hasRemaining() && mEndCallback != null)
mEndCallback.onCompleted(mEndException);
}
@Override
public SSLEngine getSSLEngine() {
return engine;
}
void addToPending(ByteBufferList out, ByteBuffer mReadTmp) {
mReadTmp.flip();
if (mReadTmp.hasRemaining()) {
out.add(mReadTmp);
}
else {
ByteBufferList.reclaim(mReadTmp);
}
}
@Override
public void end() {
mSocket.end();
}
public String getHost() {
return mHost;
}
public int getPort() {
return mPort;
}
private void handleHandshakeStatus(HandshakeStatus status) {
if (status == HandshakeStatus.NEED_TASK) {
final Runnable task = engine.getDelegatedTask();
task.run();
}
if (status == HandshakeStatus.NEED_WRAP) {
write(writeList);
}
if (status == HandshakeStatus.NEED_UNWRAP) {
dataCallback.onDataAvailable(this, new ByteBufferList());
}
try {
if (!finishedHandshake && (engine.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING || engine.getHandshakeStatus() == HandshakeStatus.FINISHED)) {
if (clientMode) {
TrustManager[] trustManagers = this.trustManagers;
if (trustManagers == null) {
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init((KeyStore) null);
trustManagers = tmf.getTrustManagers();
}
boolean trusted = false;
Exception peerUnverifiedCause = null;
for (TrustManager tm : trustManagers) {
try {
X509TrustManager xtm = (X509TrustManager) tm;
peerCertificates = (X509Certificate[]) engine.getSession().getPeerCertificates();
xtm.checkServerTrusted(peerCertificates, "SSL");
if (mHost != null) {
if (hostnameVerifier == null) {
StrictHostnameVerifier verifier = new StrictHostnameVerifier();
verifier.verify(mHost, StrictHostnameVerifier.getCNs(peerCertificates[0]), StrictHostnameVerifier.getDNSSubjectAlts(peerCertificates[0]));
}
else {
if (!hostnameVerifier.verify(mHost, engine.getSession())) {
throw new SSLException("hostname <" + mHost + "> has been denied");
}
}
}
trusted = true;
break;
}
catch (GeneralSecurityException ex) {
peerUnverifiedCause = ex;
}
catch (SSLException ex) {
peerUnverifiedCause = ex;
}
}
finishedHandshake = true;
if (!trusted) {
AsyncSSLException e = new AsyncSSLException(peerUnverifiedCause);
report(e);
if (!e.getIgnore())
throw e;
}
}
else {
finishedHandshake = true;
}
handshakeCallback.onHandshakeCompleted(null, this);
handshakeCallback = null;
mSocket.setClosedCallback(null);
// handshake can complete during a wrap, so make sure that the call
// stack and wrap flag is cleared before invoking writable
getServer().post(new Runnable() {
@Override
public void run() {
if (mWriteableCallback != null)
mWriteableCallback.onWriteable();
}
});
onDataAvailable();
}
}
catch (NoSuchAlgorithmException ex) {
throw new RuntimeException(ex);
}
catch (GeneralSecurityException ex) {
report(ex);
}
catch (AsyncSSLException ex) {
report(ex);
}
}
int calculateAlloc(int remaining) {
// alloc 50% more than we need for writing
int alloc = remaining * 3 / 2;
if (alloc == 0)
alloc = 8192;
return alloc;
}
ByteBufferList writeList = new ByteBufferList();
@Override
public void write(ByteBufferList bb) {
if (mWrapping)
return;
if (mSink.remaining() > 0)
return;
mWrapping = true;
int remaining;
SSLEngineResult res = null;
ByteBuffer writeBuf = ByteBufferList.obtain(calculateAlloc(bb.remaining()));
do {
// if the handshake is finished, don't send
// 0 bytes of data, since that makes the ssl connection die.
// it wraps a 0 byte package, and craps out.
if (finishedHandshake && bb.remaining() == 0)
break;
remaining = bb.remaining();
try {
ByteBuffer[] arr = bb.getAllArray();
res = engine.wrap(arr, writeBuf);
bb.addAll(arr);
writeBuf.flip();
writeList.add(writeBuf);
assert !writeList.hasRemaining();
if (writeList.remaining() > 0)
mSink.write(writeList);
int previousCapacity = writeBuf.capacity();
writeBuf = null;
if (res.getStatus() == Status.BUFFER_OVERFLOW) {
writeBuf = ByteBufferList.obtain(previousCapacity * 2);
remaining = -1;
}
else {
writeBuf = ByteBufferList.obtain(calculateAlloc(bb.remaining()));
handleHandshakeStatus(res.getHandshakeStatus());
}
}
catch (SSLException e) {
report(e);
}
}
while ((remaining != bb.remaining() || (res != null && res.getHandshakeStatus() == HandshakeStatus.NEED_WRAP)) && mSink.remaining() == 0);
mWrapping = false;
ByteBufferList.reclaim(writeBuf);
}
@Override
public void setWriteableCallback(WritableCallback handler) {
mWriteableCallback = handler;
}
@Override
public WritableCallback getWriteableCallback() {
return mWriteableCallback;
}
private void report(Exception e) {
final HandshakeCallback hs = handshakeCallback;
if (hs != null) {
handshakeCallback = null;
mSocket.setDataCallback(new DataCallback.NullDataCallback());
mSocket.end();
// handshake sets this callback. unset it.
mSocket.setClosedCallback(null);
mSocket.close();
hs.onHandshakeCompleted(e, null);
return;
}
CompletedCallback cb = getEndCallback();
if (cb != null)
cb.onCompleted(e);
}
@Override
public void setDataCallback(DataCallback callback) {
mDataCallback = callback;
}
@Override
public DataCallback getDataCallback() {
return mDataCallback;
}
@Override
public boolean isChunked() {
return mSocket.isChunked();
}
@Override
public boolean isOpen() {
return mSocket.isOpen();
}
@Override
public void close() {
mSocket.close();
}
@Override
public void setClosedCallback(CompletedCallback handler) {
mSocket.setClosedCallback(handler);
}
@Override
public CompletedCallback getClosedCallback() {
return mSocket.getClosedCallback();
}
CompletedCallback mEndCallback;
@Override
public void setEndCallback(CompletedCallback callback) {
mEndCallback = callback;
}
@Override
public CompletedCallback getEndCallback() {
return mEndCallback;
}
@Override
public void pause() {
mSocket.pause();
}
@Override
public void resume() {
mSocket.resume();
onDataAvailable();
}
@Override
public boolean isPaused() {
return mSocket.isPaused();
}
@Override
public AsyncServer getServer() {
return mSocket.getServer();
}
@Override
public AsyncSocket getSocket() {
return mSocket;
}
@Override
public DataEmitter getDataEmitter() {
return mSocket;
}
@Override
public X509Certificate[] getPeerCertificates() {
return peerCertificates;
}
@Override
public String charset() {
return null;
}
}