package com.koushikdutta.async; import android.os.Build; import com.koushikdutta.async.callback.CompletedCallback; import com.koushikdutta.async.callback.DataCallback; import com.koushikdutta.async.callback.WritableCallback; import com.koushikdutta.async.wrapper.AsyncSocketWrapper; import org.apache.http.conn.ssl.StrictHostnameVerifier; import java.nio.ByteBuffer; import java.security.KeyStore; import java.security.cert.X509Certificate; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; import javax.net.ssl.SSLEngineResult.Status; import javax.net.ssl.SSLException; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509TrustManager; public class AsyncSSLSocketWrapper implements AsyncSocketWrapper, AsyncSSLSocket { AsyncSocket mSocket; BufferedDataEmitter mEmitter; BufferedDataSink mSink; ByteBuffer mReadTmp = ByteBufferList.obtain(8192); boolean mUnwrapping = false; HostnameVerifier hostnameVerifier; @Override public void end() { mSocket.end(); } public AsyncSSLSocketWrapper(AsyncSocket socket, String host, int port) { this(socket, host, port, sslContext, null, null, true); } TrustManager[] trustManagers; boolean clientMode; public AsyncSSLSocketWrapper(AsyncSocket socket, String host, int port, SSLContext sslContext, TrustManager[] trustManagers, HostnameVerifier verifier, boolean clientMode) { mSocket = socket; hostnameVerifier = verifier; this.clientMode = clientMode; this.trustManagers = trustManagers; if (sslContext == null) sslContext = AsyncSSLSocketWrapper.sslContext; if (host != null) { engine = sslContext.createSSLEngine(host, port); } else { engine = sslContext.createSSLEngine(); } mHost = host; mPort = port; engine.setUseClientMode(clientMode); mSink = new BufferedDataSink(socket); mSink.setMaxBuffer(0); // SSL needs buffering of data written during handshake. // aka exhcange.setDatacallback mEmitter = new BufferedDataEmitter(socket); final ByteBufferList transformed = new ByteBufferList(); mEmitter.setDataCallback(new DataCallback() { @Override public void onDataAvailable(DataEmitter emitter, ByteBufferList bb) { if (mUnwrapping) return; try { mUnwrapping = true; mReadTmp.position(0); mReadTmp.limit(mReadTmp.capacity()); ByteBuffer b = ByteBufferList.EMPTY_BYTEBUFFER; while (true) { if (b.remaining() == 0 && bb.size() > 0) { b = bb.remove(); } int remaining = b.remaining(); SSLEngineResult res = engine.unwrap(b, mReadTmp); if (res.getStatus() == Status.BUFFER_OVERFLOW) { addToPending(transformed); mReadTmp = ByteBufferList.obtain(mReadTmp.remaining() * 2); remaining = -1; } else if (res.getStatus() == Status.BUFFER_UNDERFLOW) { bb.addFirst(b); if (bb.size() <= 1) { break; } remaining = -1; b = bb.getAll(); } handleResult(res); if (b.remaining() == remaining) { bb.addFirst(b); break; } } addToPending(transformed); Util.emitAllData(AsyncSSLSocketWrapper.this, transformed); } catch (Exception ex) { ex.printStackTrace(); report(ex); } finally { mUnwrapping = false; } } }); } void addToPending(ByteBufferList out) { if (mReadTmp.position() > 0) { mReadTmp.limit(mReadTmp.position()); mReadTmp.position(0); out.add(mReadTmp); mReadTmp = ByteBufferList.obtain(mReadTmp.capacity()); } } static SSLContext sslContext; static { // following is the "trust the system" certs setup try { // critical extension 2.5.29.15 is implemented improperly prior to 4.0.3. // https://code.google.com/p/android/issues/detail?id=9307 // https://groups.google.com/forum/?fromgroups=#!topic/netty/UCfqPPk5O4s // certs that use this extension will throw in Cipher.java. // fallback is to use a custom SSLContext, and hack around the x509 extension. if (Build.VERSION.SDK_INT <= 15) throw new Exception(); sslContext = SSLContext.getInstance("Default"); } catch (Exception ex) { try { sslContext = SSLContext.getInstance("TLS"); TrustManager[] trustAllCerts = new TrustManager[] { new X509TrustManager() { public java.security.cert.X509Certificate[] getAcceptedIssuers() { return new X509Certificate[0]; } public void checkClientTrusted(java.security.cert.X509Certificate[] certs, String authType) { } public void checkServerTrusted(java.security.cert.X509Certificate[] certs, String authType) { for (X509Certificate cert : certs) { if (cert != null && cert.getCriticalExtensionOIDs() != null) cert.getCriticalExtensionOIDs().remove("2.5.29.15"); } } } }; sslContext.init(null, trustAllCerts, null); } catch (Exception ex2) { ex.printStackTrace(); ex2.printStackTrace(); } } } SSLEngine engine; boolean finishedHandshake = false; private String mHost; public String getHost() { return mHost; } private int mPort; public int getPort() { return mPort; } private void handleResult(SSLEngineResult res) { if (res.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { final Runnable task = engine.getDelegatedTask(); task.run(); } if (res.getHandshakeStatus() == HandshakeStatus.NEED_WRAP) { write(ByteBufferList.EMPTY_BYTEBUFFER); } if (res.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP) { mEmitter.onDataAvailable(); } try { if (!finishedHandshake && (engine.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING || engine.getHandshakeStatus() == HandshakeStatus.FINISHED)) { if (clientMode) { TrustManager[] trustManagers = this.trustManagers; if (trustManagers == null) { TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); tmf.init((KeyStore) null); trustManagers = tmf.getTrustManagers(); } boolean trusted = false; for (TrustManager tm : trustManagers) { try { X509TrustManager xtm = (X509TrustManager) tm; peerCertificates = (X509Certificate[]) engine.getSession().getPeerCertificates(); xtm.checkServerTrusted(peerCertificates, "SSL"); if (mHost != null) { if (hostnameVerifier == null) { StrictHostnameVerifier verifier = new StrictHostnameVerifier(); verifier.verify(mHost, StrictHostnameVerifier.getCNs(peerCertificates[0]), StrictHostnameVerifier.getDNSSubjectAlts(peerCertificates[0])); } else { hostnameVerifier.verify(mHost, engine.getSession()); } } trusted = true; break; } catch (Exception ex) { ex.printStackTrace(); } } finishedHandshake = true; if (!trusted) { AsyncSSLException e = new AsyncSSLException(); report(e); if (!e.getIgnore()) throw e; } } if (mWriteableCallback != null) mWriteableCallback.onWriteable(); mEmitter.onDataAvailable(); } } catch (Exception ex) { report(ex); } } private void writeTmp() { mWriteTmp.limit(mWriteTmp.position()); mWriteTmp.position(0); if (mWriteTmp.remaining() > 0) mSink.write(mWriteTmp); } boolean checkWrapResult(SSLEngineResult res) { if (res.getStatus() == Status.BUFFER_OVERFLOW) { mWriteTmp = ByteBufferList.obtain(mWriteTmp.remaining() * 2); return false; } return true; } private boolean mWrapping = false; ByteBuffer mWriteTmp = ByteBufferList.obtain(8192); @Override public void write(ByteBuffer bb) { if (mWrapping) return; if (mSink.remaining() > 0) return; mWrapping = true; int remaining; SSLEngineResult res = null; do { // if the handshake is finished, don't send // 0 bytes of data, since that makes the ssl connection die. // it wraps a 0 byte package, and craps out. if (finishedHandshake && bb.remaining() == 0) { mWrapping = false; return; } remaining = bb.remaining(); mWriteTmp.position(0); mWriteTmp.limit(mWriteTmp.capacity()); try { res = engine.wrap(bb, mWriteTmp); if (!checkWrapResult(res)) remaining = -1; writeTmp(); handleResult(res); } catch (SSLException e) { report(e); } } while ((remaining != bb.remaining() || (res != null && res.getHandshakeStatus() == HandshakeStatus.NEED_WRAP)) && mSink.remaining() == 0); mWrapping = false; } @Override public void write(ByteBufferList bb) { if (mWrapping) return; if (mSink.remaining() > 0) return; mWrapping = true; int remaining; SSLEngineResult res = null; do { // if the handshake is finished, don't send // 0 bytes of data, since that makes the ssl connection die. // it wraps a 0 byte package, and craps out. if (finishedHandshake && bb.remaining() == 0) { mWrapping = false; return; } remaining = bb.remaining(); mWriteTmp.position(0); mWriteTmp.limit(mWriteTmp.capacity()); try { ByteBuffer[] arr = bb.getAllArray(); res = engine.wrap(arr, mWriteTmp); bb.addAll(arr); if (!checkWrapResult(res)) remaining = -1; writeTmp(); handleResult(res); } catch (SSLException e) { report(e); } } while ((remaining != bb.remaining() || (res != null && res.getHandshakeStatus() == HandshakeStatus.NEED_WRAP)) && mSink.remaining() == 0); mWrapping = false; } WritableCallback mWriteableCallback; @Override public void setWriteableCallback(WritableCallback handler) { mWriteableCallback = handler; } @Override public WritableCallback getWriteableCallback() { return mWriteableCallback; } private void report(Exception e) { CompletedCallback cb = getEndCallback(); if (cb != null) cb.onCompleted(e); } DataCallback mDataCallback; @Override public void setDataCallback(DataCallback callback) { mDataCallback = callback; } @Override public DataCallback getDataCallback() { return mDataCallback; } @Override public boolean isChunked() { return mSocket.isChunked(); } @Override public boolean isOpen() { return mSocket.isOpen(); } @Override public void close() { mSocket.close(); } @Override public void setClosedCallback(CompletedCallback handler) { mSocket.setClosedCallback(handler); } @Override public CompletedCallback getClosedCallback() { return mSocket.getClosedCallback(); } @Override public void setEndCallback(CompletedCallback callback) { mSocket.setEndCallback(callback); } @Override public CompletedCallback getEndCallback() { return mSocket.getEndCallback(); } @Override public void pause() { mSocket.pause(); } @Override public void resume() { mSocket.resume(); } @Override public boolean isPaused() { return mSocket.isPaused(); } @Override public AsyncServer getServer() { return mSocket.getServer(); } @Override public AsyncSocket getSocket() { return mSocket; } @Override public DataEmitter getDataEmitter() { return mSocket; } X509Certificate[] peerCertificates; @Override public X509Certificate[] getPeerCertificates() { return peerCertificates; } }