package com.koushikdutta.async.http; import com.koushikdutta.async.AsyncSocket; import com.koushikdutta.async.ByteBufferList; import com.koushikdutta.async.DataEmitter; import com.koushikdutta.async.NullDataCallback; import com.koushikdutta.async.callback.CompletedCallback; import com.koushikdutta.async.callback.ConnectCallback; import com.koushikdutta.async.callback.ContinuationCallback; import com.koushikdutta.async.future.Cancellable; import com.koushikdutta.async.future.Continuation; import com.koushikdutta.async.future.SimpleCancellable; import com.koushikdutta.async.future.TransformFuture; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.URI; import java.util.HashSet; import java.util.Hashtable; public class AsyncSocketMiddleware extends SimpleMiddleware { String scheme; int port; public AsyncSocketMiddleware(AsyncHttpClient client, String scheme, int port) { mClient = client; this.scheme = scheme; this.port = port; } public int getSchemePort(URI uri) { if (!uri.getScheme().equals(scheme)) return -1; if (uri.getPort() == -1) { return port; } else { return uri.getPort(); } } public AsyncSocketMiddleware(AsyncHttpClient client) { this(client, "http", 80); } AsyncHttpClient mClient; private Hashtable<String, HashSet<AsyncSocket>> mSockets = new Hashtable<String, HashSet<AsyncSocket>>(); protected ConnectCallback wrapCallback(ConnectCallback callback, URI uri, int port) { return callback; } boolean connectAllAddresses; public boolean getConnectAllAddresses() { return connectAllAddresses; } public void setConnectAllAddresses(boolean connectAllAddresses) { this.connectAllAddresses = connectAllAddresses; } String proxyHost; int proxyPort; InetSocketAddress proxyAddress; public void disableProxy() { proxyPort = -1; proxyHost = null; proxyAddress = null; } public void enableProxy(String host, int port) { proxyHost = host; proxyPort = port; proxyAddress = null; } String computeLookup(URI uri, int port, AsyncHttpRequest request) { String proxy; if (proxyHost != null) proxy = proxyHost + ":" + proxyPort; else proxy = ""; if (request.proxyHost != null) proxy = request.getProxyHost() + ":" + request.proxyPort; return uri.getScheme() + "//" + uri.getHost() + ":" + port + "?proxy=" + proxy; } @Override public Cancellable getSocket(final GetSocketData data) { final URI uri = data.request.getUri(); final int port = getSchemePort(data.request.getUri()); if (port == -1) { return null; } final String lookup = computeLookup(uri, port, data.request); data.state.putBoolean(getClass().getCanonicalName() + ".owned", true); synchronized (this) { final HashSet<AsyncSocket> sockets = mSockets.get(lookup); if (sockets != null) { for (final AsyncSocket socket: sockets) { if (socket.isOpen()) { sockets.remove(socket); socket.setClosedCallback(null); mClient.getServer().post(new Runnable() { @Override public void run() { data.request.logd("Reusing keep-alive socket"); data.connectCallback.onConnectCompleted(null, socket); } }); // just a noop/dummy, as this can't actually be cancelled. return new SimpleCancellable(); } } } } if (!connectAllAddresses || proxyHost != null || data.request.getProxyHost() != null) { // just default to connecting to a single address data.request.logd("Connecting socket"); String unresolvedHost; int unresolvedPort; if (data.request.getProxyHost() != null) { unresolvedHost = data.request.getProxyHost(); unresolvedPort = data.request.getProxyPort(); // set the host and port explicitly for proxied connections data.request.getHeaders().getHeaders().setStatusLine(data.request.getProxyRequestLine().toString()); } else if (proxyHost != null) { unresolvedHost = proxyHost; unresolvedPort = proxyPort; // set the host and port explicitly for proxied connections data.request.getHeaders().getHeaders().setStatusLine(data.request.getProxyRequestLine().toString()); } else { unresolvedHost = uri.getHost(); unresolvedPort = port; } return mClient.getServer().connectSocket(unresolvedHost, unresolvedPort, wrapCallback(data.connectCallback, uri, port)); } // try to connect to everything... data.request.logv("Resolving domain and connecting to all available addresses"); return new TransformFuture<AsyncSocket, InetAddress[]>() { Exception lastException; @Override protected void error(Exception e) { super.error(e); data.connectCallback.onConnectCompleted(e, null); } @Override protected void transform(final InetAddress[] result) throws Exception { Continuation keepTrying = new Continuation(new CompletedCallback() { @Override public void onCompleted(Exception ex) { // if it completed, that means that the connection failed if (lastException == null) lastException = new Exception("Unable to connect to remote address"); setComplete(lastException); } }); for (final InetAddress address: result) { keepTrying.add(new ContinuationCallback() { @Override public void onContinue(Continuation continuation, final CompletedCallback next) throws Exception { mClient.getServer().connectSocket(new InetSocketAddress(address, port), wrapCallback(new ConnectCallback() { @Override public void onConnectCompleted(Exception ex, AsyncSocket socket) { assert !isDone(); // try the next address if (ex != null) { lastException = ex; next.onCompleted(null); return; } // if the socket is no longer needed, just hang onto it... if (isDone() || isCancelled()) { data.request.logd("Recycling extra socket leftover from cancelled operation"); idleSocket(socket); recycleSocket(socket, data.request); return; } if (setComplete(null, socket)) { data.connectCallback.onConnectCompleted(ex, socket); } } }, uri, port)); } }); } keepTrying.start(); } } .from(mClient.getServer().getAllByName(uri.getHost())); } public int getConnectionPoolCount() { int ret = 0; synchronized (this) { for (HashSet<AsyncSocket> sockets: mSockets.values()) { ret += sockets.size(); } } return ret; } private void recycleSocket(final AsyncSocket socket, AsyncHttpRequest request) { if (socket == null) return; URI uri = request.getUri(); int port = getSchemePort(uri); String lookup = computeLookup(uri, port, request); // nothing here will block... synchronized (this) { HashSet<AsyncSocket> sockets = mSockets.get(lookup); if (sockets == null) { sockets = new HashSet<AsyncSocket>(); mSockets.put(lookup, sockets); } final HashSet<AsyncSocket> ss = sockets; sockets.add(socket); // should not get any data after this point... // if so, eat it and disconnect. socket.setClosedCallback(new CompletedCallback() { @Override public void onCompleted(Exception ex) { synchronized (AsyncSocketMiddleware.this) { ss.remove(socket); } socket.setClosedCallback(null); } }); } } private void idleSocket(final AsyncSocket socket) { socket.setEndCallback(null); socket.setWriteableCallback(null); socket.setDataCallback(new NullDataCallback() { @Override public void onDataAvailable(DataEmitter emitter, ByteBufferList bb) { super.onDataAvailable(emitter, bb); bb.recycle(); socket.close(); } }); } @Override public void onRequestComplete(final OnRequestCompleteData data) { if (!data.state.getBoolean(getClass().getCanonicalName() + ".owned", false)) { return; } idleSocket(data.socket); if (data.exception != null || !data.socket.isOpen()) { data.socket.close(); return; } String kas = data.headers.getConnection(); if (kas == null || !"keep-alive".toLowerCase().equals(kas.toLowerCase())) { data.socket.close(); return; } data.request.logd("Recycling keep-alive socket"); recycleSocket(data.socket, data.request); } }