/* * Copyright (c) 2012 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * To change this template, choose Tools | Templates * and open the template in the editor. */ package org.eclipse.jetty.spdy; import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; import org.eclipse.jetty.io.AsyncEndPoint; import org.eclipse.jetty.io.ConnectedEndPoint; import org.eclipse.jetty.io.Connection; import org.eclipse.jetty.io.nio.AsyncConnection; import org.eclipse.jetty.io.nio.SelectChannelEndPoint; import org.eclipse.jetty.io.nio.SelectorManager; import org.eclipse.jetty.io.nio.SslConnection; import org.eclipse.jetty.npn.NextProtoNego; import org.eclipse.jetty.spdy.api.Session; import org.eclipse.jetty.spdy.api.SessionFrameListener; import org.eclipse.jetty.spdy.generator.Generator; import org.eclipse.jetty.spdy.parser.Parser; import org.eclipse.jetty.util.component.AggregateLifeCycle; import org.eclipse.jetty.util.ssl.SslContextFactory; import org.eclipse.jetty.util.thread.QueuedThreadPool; public class SPDYClient { private final Map<String, AsyncConnectionFactory> factories = new ConcurrentHashMap<>(); private final short version; private final Factory factory; private SocketAddress bindAddress; private long maxIdleTime; protected SPDYClient(short version, Factory factory) { this.version = version; this.factory = factory; } /** * @return the address to bind the socket channel to * @see #setBindAddress(SocketAddress) */ public SocketAddress getBindAddress() { return bindAddress; } /** * @param bindAddress the address to bind the socket channel to * @see #getBindAddress() */ public void setBindAddress(SocketAddress bindAddress) { this.bindAddress = bindAddress; } public Future<Session> connect(InetSocketAddress address, SessionFrameListener listener) throws IOException { if (!factory.isStarted()) throw new IllegalStateException(Factory.class.getSimpleName() + " is not started"); SocketChannel channel = SocketChannel.open(); if (bindAddress != null) channel.bind(bindAddress); channel.socket().setTcpNoDelay(true); channel.configureBlocking(false); SessionPromise result = new SessionPromise(this, listener); channel.connect(address); factory.selector.register(channel, result); return result; } public long getMaxIdleTime() { return maxIdleTime; } public void setMaxIdleTime(long maxIdleTime) { this.maxIdleTime = maxIdleTime; } protected String selectProtocol(List<String> serverProtocols) { if (serverProtocols == null) return "spdy/2"; for (String serverProtocol : serverProtocols) { for (String protocol : factories.keySet()) { if (serverProtocol.equals(protocol)) return protocol; } String protocol = factory.selectProtocol(serverProtocols); if (protocol != null) return protocol; } return null; } public AsyncConnectionFactory getAsyncConnectionFactory(String protocol) { for (Map.Entry<String, AsyncConnectionFactory> entry : factories.entrySet()) { if (protocol.equals(entry.getKey())) return entry.getValue(); } for (Map.Entry<String, AsyncConnectionFactory> entry : factory.factories.entrySet()) { if (protocol.equals(entry.getKey())) return entry.getValue(); } return null; } public void putAsyncConnectionFactory(String protocol, AsyncConnectionFactory factory) { factories.put(protocol, factory); } public AsyncConnectionFactory removeAsyncConnectionFactory(String protocol) { return factories.remove(protocol); } protected SSLEngine newSSLEngine(SslContextFactory sslContextFactory, SocketChannel channel) { String peerHost = channel.socket().getInetAddress().getHostAddress(); int peerPort = channel.socket().getPort(); SSLEngine engine = sslContextFactory.newSslEngine(peerHost, peerPort); engine.setUseClientMode(true); return engine; } public static class Factory extends AggregateLifeCycle { private final Map<String, AsyncConnectionFactory> factories = new ConcurrentHashMap<>(); private final Queue<Session> sessions = new ConcurrentLinkedQueue<>(); private final ByteBufferPool bufferPool = new StandardByteBufferPool(); private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(); private final Executor threadPool; private final SslContextFactory sslContextFactory; private final SelectorManager selector; public Factory() { this(null, null); } public Factory(SslContextFactory sslContextFactory) { this(null, sslContextFactory); } public Factory(Executor threadPool) { this(threadPool, null); } public Factory(Executor threadPool, SslContextFactory sslContextFactory) { if (threadPool == null) threadPool = new QueuedThreadPool(); this.threadPool = threadPool; addBean(threadPool); this.sslContextFactory = sslContextFactory; if (sslContextFactory != null) addBean(sslContextFactory); selector = new ClientSelectorManager(); addBean(selector); factories.put("spdy/2", new ClientSPDYAsyncConnectionFactory()); } public SPDYClient newSPDYClient(short version) { return new SPDYClient(version, this); } @Override protected void doStop() throws Exception { closeConnections(); super.doStop(); } protected String selectProtocol(List<String> serverProtocols) { for (String serverProtocol : serverProtocols) { for (String protocol : factories.keySet()) { if (serverProtocol.equals(protocol)) return protocol; } } return null; } private boolean sessionOpened(Session session) { // Add sessions only if the factory is not stopping return isRunning() && sessions.offer(session); } private boolean sessionClosed(Session session) { // Remove sessions only if the factory is not stopping // to avoid concurrent removes during iterations return isRunning() && sessions.remove(session); } private void closeConnections() { for (Session session : sessions) session.goAway(); sessions.clear(); } protected Collection<Session> getSessions() { return Collections.unmodifiableCollection(sessions); } private class ClientSelectorManager extends SelectorManager { @Override public boolean dispatch(Runnable task) { try { threadPool.execute(task); return true; } catch (RejectedExecutionException x) { return false; } } @Override protected SelectChannelEndPoint newEndPoint(SocketChannel channel, SelectSet selectSet, SelectionKey key) throws IOException { SessionPromise attachment = (SessionPromise)key.attachment(); long maxIdleTime = attachment.client.getMaxIdleTime(); if (maxIdleTime < 0) maxIdleTime = getMaxIdleTime(); SelectChannelEndPoint result = new SelectChannelEndPoint(channel, selectSet, key, (int)maxIdleTime); AsyncConnection connection = newConnection(channel, result, attachment); result.setConnection(connection); return result; } @Override protected void endPointOpened(SelectChannelEndPoint endpoint) { } @Override protected void endPointUpgraded(ConnectedEndPoint endpoint, Connection oldConnection) { } @Override protected void endPointClosed(SelectChannelEndPoint endpoint) { endpoint.getConnection().onClose(); } @Override public AsyncConnection newConnection(final SocketChannel channel, AsyncEndPoint endPoint, Object attachment) { SessionPromise sessionPromise = (SessionPromise)attachment; final SPDYClient client = sessionPromise.client; try { if (sslContextFactory != null) { final AtomicReference<AsyncEndPoint> sslEndPointRef = new AtomicReference<>(); final AtomicReference<Object> attachmentRef = new AtomicReference<>(attachment); SSLEngine engine = client.newSSLEngine(sslContextFactory, channel); SslConnection sslConnection = new SslConnection(engine, endPoint) { @Override public void onClose() { sslEndPointRef.set(null); attachmentRef.set(null); super.onClose(); } }; endPoint.setConnection(sslConnection); AsyncEndPoint sslEndPoint = sslConnection.getSslEndPoint(); sslEndPointRef.set(sslEndPoint); // Instances of the ClientProvider inner class strong reference the // SslEndPoint (via lexical scoping), which strong references the SSLEngine. // Since NextProtoNego stores in a WeakHashMap the SSLEngine as key // and this instance as value, we are in the situation where the value // of a WeakHashMap refers indirectly to the key, which is bad because // the entry will never be removed from the WeakHashMap. // We use AtomicReferences to be captured via lexical scoping, // and we null them out above when the connection is closed. NextProtoNego.put(engine, new NextProtoNego.ClientProvider() { @Override public boolean supports() { return true; } @Override public void unsupported() { // Server does not support NPN, but this is a SPDY client, so hardcode SPDY ClientSPDYAsyncConnectionFactory connectionFactory = new ClientSPDYAsyncConnectionFactory(); AsyncEndPoint sslEndPoint = sslEndPointRef.get(); AsyncConnection connection = connectionFactory.newAsyncConnection(channel, sslEndPoint, attachmentRef.get()); sslEndPoint.setConnection(connection); } @Override public String selectProtocol(List<String> protocols) { String protocol = client.selectProtocol(protocols); if (protocol == null) return null; AsyncConnectionFactory connectionFactory = client.getAsyncConnectionFactory(protocol); AsyncEndPoint sslEndPoint = sslEndPointRef.get(); AsyncConnection connection = connectionFactory.newAsyncConnection(channel, sslEndPoint, attachmentRef.get()); sslEndPoint.setConnection(connection); return protocol; } }); AsyncConnection connection = new EmptyAsyncConnection(sslEndPoint); sslEndPoint.setConnection(connection); startHandshake(engine); return sslConnection; } else { AsyncConnectionFactory connectionFactory = new ClientSPDYAsyncConnectionFactory(); AsyncConnection connection = connectionFactory.newAsyncConnection(channel, endPoint, attachment); endPoint.setConnection(connection); return connection; } } catch (RuntimeException x) { sessionPromise.failed(x); throw x; } } private void startHandshake(SSLEngine engine) { try { engine.beginHandshake(); } catch (SSLException x) { throw new RuntimeException(x); } } } } private static class SessionPromise extends Promise<Session> { private final SPDYClient client; private final SessionFrameListener listener; private SessionPromise(SPDYClient client, SessionFrameListener listener) { this.client = client; this.listener = listener; } } private static class ClientSPDYAsyncConnectionFactory implements AsyncConnectionFactory { @Override public AsyncConnection newAsyncConnection(SocketChannel channel, AsyncEndPoint endPoint, Object attachment) { SessionPromise sessionPromise = (SessionPromise)attachment; Factory factory = sessionPromise.client.factory; CompressionFactory compressionFactory = new StandardCompressionFactory(); Parser parser = new Parser(compressionFactory.newDecompressor()); Generator generator = new Generator(factory.bufferPool, compressionFactory.newCompressor()); SPDYAsyncConnection connection = new ClientSPDYAsyncConnection(endPoint, factory.bufferPool, parser, factory); endPoint.setConnection(connection); StandardSession session = new StandardSession(sessionPromise.client.version, factory.bufferPool, factory.threadPool, factory.scheduler, connection, connection, 1, sessionPromise.listener, generator); parser.addListener(session); sessionPromise.completed(session); connection.setSession(session); factory.sessionOpened(session); return connection; } private class ClientSPDYAsyncConnection extends SPDYAsyncConnection { private final Factory factory; public ClientSPDYAsyncConnection(AsyncEndPoint endPoint, ByteBufferPool bufferPool, Parser parser, Factory factory) { super(endPoint, bufferPool, parser); this.factory = factory; } @Override public void onClose() { super.onClose(); factory.sessionClosed(getSession()); } } } }