/* * Copyright 2016 Netflix, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package io.reactivex.netty.client; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoopGroup; import io.netty.channel.epoll.EpollSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.logging.LogLevel; import io.netty.handler.logging.LoggingHandler; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.util.concurrent.EventExecutorGroup; import io.reactivex.netty.RxNetty; import io.reactivex.netty.channel.ChannelSubscriberEvent; import io.reactivex.netty.channel.ConnectionCreationFailedEvent; import io.reactivex.netty.channel.DetachedChannelPipeline; import io.reactivex.netty.channel.WriteTransformer; import io.reactivex.netty.client.events.ClientEventListener; import io.reactivex.netty.events.Clock; import io.reactivex.netty.events.EventPublisher; import io.reactivex.netty.events.EventSource; import io.reactivex.netty.ssl.DefaultSslCodec; import io.reactivex.netty.ssl.SslCodec; import io.reactivex.netty.util.LoggingHandlerFactory; import rx.Observable; import rx.exceptions.Exceptions; import rx.functions.Action1; import rx.functions.Func0; import rx.functions.Func1; import javax.net.ssl.SSLEngine; import java.net.SocketAddress; import java.util.LinkedHashMap; import java.util.Map; import java.util.Map.Entry; import static io.reactivex.netty.HandlerNames.*; import static java.util.concurrent.TimeUnit.NANOSECONDS; /** * A collection of state that a client holds. This supports the copy-on-write semantics of clients. * * @param <W> The type of objects written to the client owning this state. * @param <R> The type of objects read from the client owning this state. */ public class ClientState<W, R> { private final Observable<Host> hostStream; private final ConnectionProviderFactory<W, R> factory; private final DetachedChannelPipeline detachedPipeline; private final Map<ChannelOption<?>, Object> options; private final boolean isSecure; private final EventLoopGroup eventLoopGroup; private final Class<? extends Channel> channelClass; private final ChannelProviderFactory channelProviderFactory; protected ClientState(Observable<Host> hostStream, ConnectionProviderFactory<W, R> factory, DetachedChannelPipeline detachedPipeline, EventLoopGroup eventLoopGroup, Class<? extends Channel> channelClass) { this.eventLoopGroup = eventLoopGroup; this.channelClass = channelClass; options = new LinkedHashMap<>(); /// Same as netty bootstrap, order matters. this.hostStream = hostStream; this.factory = factory; this.detachedPipeline = detachedPipeline; isSecure = false; channelProviderFactory = new ChannelProviderFactory() { @Override public ChannelProvider newProvider(Host host, EventSource<? super ClientEventListener> eventSource, EventPublisher publisher, ClientEventListener clientPublisher) { return new ChannelProvider() { @Override public Observable<Channel> newChannel(Observable<Channel> input) { return input; } }; } }; } protected ClientState(ClientState<W, R> toCopy, ChannelOption<?> option, Object value) { options = new LinkedHashMap<>(toCopy.options); // Since, we are adding an option, copy it. options.put(option, value); detachedPipeline = toCopy.detachedPipeline; hostStream = toCopy.hostStream; factory = toCopy.factory; eventLoopGroup = toCopy.eventLoopGroup; channelClass = toCopy.channelClass; isSecure = toCopy.isSecure; channelProviderFactory = toCopy.channelProviderFactory; } protected ClientState(ClientState<?, ?> toCopy, DetachedChannelPipeline newPipeline, boolean secure) { final ClientState<W, R> toCopyCast = toCopy.cast(); options = toCopy.options; hostStream = toCopy.hostStream; factory = toCopyCast.factory; eventLoopGroup = toCopy.eventLoopGroup; channelClass = toCopy.channelClass; detachedPipeline = newPipeline; isSecure = secure; channelProviderFactory = toCopyCast.channelProviderFactory; } protected ClientState(ClientState<?, ?> toCopy, ChannelProviderFactory newFactory) { final ClientState<W, R> toCopyCast = toCopy.cast(); options = toCopy.options; hostStream = toCopy.hostStream; factory = toCopyCast.factory; eventLoopGroup = toCopy.eventLoopGroup; channelClass = toCopy.channelClass; detachedPipeline = toCopy.detachedPipeline; channelProviderFactory = newFactory; isSecure = toCopy.isSecure; } protected ClientState(ClientState<?, ?> toCopy, SslCodec sslCodec) { this(toCopy, toCopy.detachedPipeline.copy(new TailHandlerFactory(true)).configure(sslCodec), true); } public <T> ClientState<W, R> channelOption(ChannelOption<T> option, T value) { return new ClientState<>(this, option, value); } public <WW, RR> ClientState<WW, RR> addChannelHandlerFirst(String name, Func0<ChannelHandler> handlerFactory) { ClientState<WW, RR> copy = copy(); copy.detachedPipeline.addFirst(name, handlerFactory); return copy; } public <WW, RR> ClientState<WW, RR> addChannelHandlerFirst(EventExecutorGroup group, String name, Func0<ChannelHandler> handlerFactory) { ClientState<WW, RR> copy = copy(); copy.detachedPipeline.addFirst(group, name, handlerFactory); return copy; } public <WW, RR> ClientState<WW, RR> addChannelHandlerLast(String name, Func0<ChannelHandler> handlerFactory) { ClientState<WW, RR> copy = copy(); copy.detachedPipeline.addLast(name, handlerFactory); return copy; } public <WW, RR> ClientState<WW, RR> addChannelHandlerLast(EventExecutorGroup group, String name, Func0<ChannelHandler> handlerFactory) { ClientState<WW, RR> copy = copy(); copy.detachedPipeline.addLast(group, name, handlerFactory); return copy; } public <WW, RR> ClientState<WW, RR> addChannelHandlerBefore(String baseName, String name, Func0<ChannelHandler> handlerFactory) { ClientState<WW, RR> copy = copy(); copy.detachedPipeline.addBefore(baseName, name, handlerFactory); return copy; } public <WW, RR> ClientState<WW, RR> addChannelHandlerBefore(EventExecutorGroup group, String baseName, String name, Func0<ChannelHandler> handlerFactory) { ClientState<WW, RR> copy = copy(); copy.detachedPipeline.addBefore(group, baseName, name, handlerFactory); return copy; } public <WW, RR> ClientState<WW, RR> addChannelHandlerAfter(String baseName, String name, Func0<ChannelHandler> handlerFactory) { ClientState<WW, RR> copy = copy(); copy.detachedPipeline.addAfter(baseName, name, handlerFactory); return copy; } public <WW, RR> ClientState<WW, RR> addChannelHandlerAfter(EventExecutorGroup group, String baseName, String name, Func0<ChannelHandler> handlerFactory) { ClientState<WW, RR> copy = copy(); copy.detachedPipeline.addAfter(group, baseName, name, handlerFactory); return copy; } public <WW, RR> ClientState<WW, RR> pipelineConfigurator(Action1<ChannelPipeline> pipelineConfigurator) { ClientState<WW, RR> copy = copy(); copy.detachedPipeline.configure(pipelineConfigurator); return copy; } public ClientState<W, R> enableWireLogging(final LogLevel wireLoggingLevel) { return enableWireLogging(LoggingHandler.class.getName(), wireLoggingLevel); } public ClientState<W, R> enableWireLogging(String name, final LogLevel wireLoggingLevel) { return addChannelHandlerFirst(WireLogging.getName(), LoggingHandlerFactory.getFactory(name, wireLoggingLevel)); } public static <WW, RR> ClientState<WW, RR> create(ConnectionProviderFactory<WW, RR> factory, Observable<Host> hostStream) { return create(newChannelPipeline(new TailHandlerFactory(false)), factory, hostStream); } public static <WW, RR> ClientState<WW, RR> create(ConnectionProviderFactory<WW, RR> factory, Observable<Host> hostStream, EventLoopGroup eventLoopGroup, Class<? extends Channel> channelClass) { return new ClientState<>(hostStream, factory, newChannelPipeline(new TailHandlerFactory(false)), eventLoopGroup, channelClass); } public static <WW, RR> ClientState<WW, RR> create(DetachedChannelPipeline detachedPipeline, ConnectionProviderFactory<WW, RR> factory, Observable<Host> hostStream) { return create(detachedPipeline, factory, hostStream, defaultEventloopGroup(), defaultSocketChannelClass()); } public static <WW, RR> ClientState<WW, RR> create(DetachedChannelPipeline detachedPipeline, ConnectionProviderFactory<WW, RR> factory, Observable<Host> hostStream, EventLoopGroup eventLoopGroup, Class<? extends Channel> channelClass) { return new ClientState<>(hostStream, factory, detachedPipeline, eventLoopGroup, channelClass); } private static DetachedChannelPipeline newChannelPipeline(TailHandlerFactory thf) { return new DetachedChannelPipeline(thf) .addLast(WriteTransformer.getName(), new Func0<ChannelHandler>() { @Override public ChannelHandler call() { return new WriteTransformer(); } }); } public Bootstrap newBootstrap(final EventPublisher eventPublisher, final ClientEventListener eventListener) { final Bootstrap nettyBootstrap = new Bootstrap().group(eventLoopGroup) .channel(channelClass) .option(ChannelOption.AUTO_READ, false);// by default do not read content unless asked. for (Entry<ChannelOption<?>, Object> optionEntry : options.entrySet()) { // Type is just for safety for user of ClientState, internally in Bootstrap, types are thrown on the floor. @SuppressWarnings("unchecked") ChannelOption<Object> key = (ChannelOption<Object>) optionEntry.getKey(); nettyBootstrap.option(key, optionEntry.getValue()); } nettyBootstrap.handler(new ChannelInitializer<Channel>() { @Override protected void initChannel(Channel ch) throws Exception { ch.pipeline().addLast(ClientChannelActiveBufferingHandler.getName(), new ChannelActivityBufferingHandler(eventPublisher, eventListener)); } }); return nettyBootstrap; } public DetachedChannelPipeline unsafeDetachedPipeline() { return detachedPipeline; } public Map<ChannelOption<?>, Object> unsafeChannelOptions() { return options; } public ClientState<W, R> channelProviderFactory(ChannelProviderFactory factory) { return new ClientState<>(this, factory); } public ClientState<W, R> secure(Func1<ByteBufAllocator, SSLEngine> sslEngineFactory) { return secure(new DefaultSslCodec(sslEngineFactory)); } public ClientState<W, R> secure(SSLEngine sslEngine) { return secure(new DefaultSslCodec(sslEngine)); } public ClientState<W, R> secure(SslCodec sslCodec) { return new ClientState<>(this, sslCodec); } public ClientState<W, R> unsafeSecure() { return secure(new DefaultSslCodec(new Func1<ByteBufAllocator, SSLEngine>() { @Override public SSLEngine call(ByteBufAllocator allocator) { try { return SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE) .build() .newEngine(allocator); } catch (Exception e) { throw Exceptions.propagate(e); } } })); } private <WW, RR> ClientState<WW, RR> copy() { TailHandlerFactory newTail = new TailHandlerFactory(isSecure); return new ClientState<>(this, detachedPipeline.copy(newTail), isSecure); } public ConnectionProviderFactory<W, R> getFactory() { return factory; } public Observable<Host> getHostStream() { return hostStream; } public ChannelProviderFactory getChannelProviderFactory() { return channelProviderFactory; } @SuppressWarnings("unchecked") private <WW, RR> ClientState<WW, RR> cast() { return (ClientState<WW, RR>) this; } protected static class TailHandlerFactory implements Action1<ChannelPipeline> { private final boolean isSecure; public TailHandlerFactory(boolean isSecure) { this.isSecure = isSecure; } @Override public void call(ChannelPipeline pipeline) { ClientConnectionToChannelBridge.addToPipeline(pipeline, isSecure); } } public static EventLoopGroup defaultEventloopGroup() { return RxNetty.getRxEventLoopProvider().globalClientEventLoop(true); } public static Class<? extends Channel> defaultSocketChannelClass() { return RxNetty.isUsingNativeTransport() ? EpollSocketChannel.class : NioSocketChannel.class; } /** * Clients construct the pipeline, outside of the {@link ChannelInitializer} through {@link ChannelProvider}. * Thus channel registration and activation events may be lost due to a race condition when the channel is active * before the pipeline is configured. * This handler buffers, the channel events till the time, a subscriber appears for channel establishment. */ private static class ChannelActivityBufferingHandler extends ChannelDuplexHandler { private enum State { Initialized, Registered, Active, Inactive, ChannelSubscribed } private State state = State.Initialized; /** * Unregistered state will hide the active/inactive state, hence this is a different flag. */ private boolean unregistered; private long connectStartTimeNanos; private final EventPublisher eventPublisher; private final ClientEventListener eventListener; private ChannelActivityBufferingHandler(EventPublisher eventPublisher, ClientEventListener eventListener) { this.eventPublisher = eventPublisher; this.eventListener = eventListener; } @SuppressWarnings("unchecked") @Override public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) throws Exception { connectStartTimeNanos = Clock.newStartTimeNanos(); if (eventPublisher.publishingEnabled()) { eventListener.onConnectStart(); promise.addListener(new ChannelFutureListener() { @SuppressWarnings("unchecked") @Override public void operationComplete(ChannelFuture future) throws Exception { if (eventPublisher.publishingEnabled()) { long endTimeNanos = Clock.onEndNanos(connectStartTimeNanos); if (!future.isSuccess()) { eventListener.onConnectFailed(endTimeNanos, NANOSECONDS, future.cause()); } else { eventListener.onConnectSuccess(endTimeNanos, NANOSECONDS); } } } }); } super.connect(ctx, remoteAddress, localAddress, promise); } @Override public void channelRegistered(ChannelHandlerContext ctx) throws Exception { if (State.ChannelSubscribed == state) { super.channelRegistered(ctx); } else { state = State.Registered; } } @Override public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { if (State.ChannelSubscribed == state) { super.channelUnregistered(ctx); } else { unregistered = true; } } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { if (State.ChannelSubscribed == state) { super.channelActive(ctx); } else { state = State.Active; } } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { if (State.ChannelSubscribed == state) { super.channelInactive(ctx); } else { state = State.Inactive; } } @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof ChannelSubscriberEvent) { final State existingState = state; state = State.ChannelSubscribed; super.userEventTriggered(ctx, evt); final ChannelPipeline pipeline = ctx.channel().pipeline(); switch (existingState) { case Initialized: break; case Registered: pipeline.fireChannelRegistered(); break; case Active: pipeline.fireChannelRegistered(); pipeline.fireChannelActive(); break; case Inactive: pipeline.fireChannelRegistered(); pipeline.fireChannelActive(); pipeline.fireChannelInactive(); break; case ChannelSubscribed: // Duplicate event, ignore. break; } if (unregistered) { pipeline.fireChannelUnregistered(); } } else if (evt instanceof ConnectionCreationFailedEvent) { ConnectionCreationFailedEvent failedEvent = (ConnectionCreationFailedEvent) evt; onConnectFailedEvent(failedEvent); super.userEventTriggered(ctx, evt); } else { super.userEventTriggered(ctx, evt); } } @SuppressWarnings("unchecked") private void onConnectFailedEvent(ConnectionCreationFailedEvent event) { if (eventPublisher.publishingEnabled()) { eventListener.onConnectFailed(connectStartTimeNanos, NANOSECONDS, event.getThrowable()); } } } }