/* * Copyright 2015 Netflix, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package io.reactivex.netty.protocol.tcp.server; import io.netty.buffer.ByteBufAllocator; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.ServerChannel; import io.netty.handler.logging.LogLevel; import io.netty.util.concurrent.EventExecutorGroup; import io.reactivex.netty.protocol.tcp.server.events.TcpServerEventListener; import io.reactivex.netty.protocol.tcp.server.events.TcpServerEventPublisher; import io.reactivex.netty.server.ServerState; import io.reactivex.netty.ssl.SslCodec; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import rx.Subscription; import rx.functions.Action1; import rx.functions.Func0; import rx.functions.Func1; import javax.net.ssl.SSLEngine; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; public class TcpServerImpl<R, W> extends TcpServer<R, W> { private static final Logger logger = LoggerFactory.getLogger(TcpServerImpl.class); protected enum ServerStatus {Created, Starting, Started, Shutdown} private final ServerState<R, W> state; private ChannelFuture bindFuture; protected final AtomicReference<ServerStatus> serverStateRef; public TcpServerImpl(SocketAddress socketAddress) { state = TcpServerState.create(socketAddress); serverStateRef = new AtomicReference<>(ServerStatus.Created); } public TcpServerImpl(SocketAddress socketAddress, EventLoopGroup parent, EventLoopGroup child, Class<? extends ServerChannel> channelClass) { state = TcpServerState.create(socketAddress, parent, child, channelClass); serverStateRef = new AtomicReference<>(ServerStatus.Created); } private TcpServerImpl(ServerState<R, W> state) { this.state = state; serverStateRef = new AtomicReference<>(ServerStatus.Created); } @Override public <T> TcpServer<R, W> channelOption(ChannelOption<T> option, T value) { return copy(state.channelOption(option, value)); } @Override public <T> TcpServer<R, W> clientChannelOption(ChannelOption<T> option, T value) { return copy(state.clientChannelOption(option, value)); } @Override public <RR, WW> TcpServer<RR, WW> addChannelHandlerFirst(String name, Func0<ChannelHandler> handlerFactory) { return copy(state.<RR, WW>addChannelHandlerFirst(name, handlerFactory)); } @Override public <RR, WW> TcpServer<RR, WW> addChannelHandlerFirst(EventExecutorGroup group, String name, Func0<ChannelHandler> handlerFactory) { return copy(state.<RR, WW>addChannelHandlerFirst(group, name, handlerFactory)); } @Override public <RR, WW> TcpServer<RR, WW> addChannelHandlerLast(String name, Func0<ChannelHandler> handlerFactory) { return copy(state.<RR, WW>addChannelHandlerLast(name, handlerFactory)); } @Override public <RR, WW> TcpServer<RR, WW> addChannelHandlerLast(EventExecutorGroup group, String name, Func0<ChannelHandler> handlerFactory) { return copy(state.<RR, WW>addChannelHandlerLast(group, name, handlerFactory)); } @Override public <RR, WW> TcpServer<RR, WW> addChannelHandlerBefore(String baseName, String name, Func0<ChannelHandler> handlerFactory) { return copy(state.<RR, WW>addChannelHandlerBefore(baseName, name, handlerFactory)); } @Override public <RR, WW> TcpServer<RR, WW> addChannelHandlerBefore(EventExecutorGroup group, String baseName, String name, Func0<ChannelHandler> handlerFactory) { return copy(state.<RR, WW>addChannelHandlerBefore(group, baseName, name, handlerFactory)); } @Override public <RR, WW> TcpServer<RR, WW> addChannelHandlerAfter(String baseName, String name, Func0<ChannelHandler> handlerFactory) { return copy(state.<RR, WW>addChannelHandlerAfter(baseName, name, handlerFactory)); } @Override public <RR, WW> TcpServer<RR, WW> addChannelHandlerAfter(EventExecutorGroup group, String baseName, String name, Func0<ChannelHandler> handlerFactory) { return copy(state.<RR, WW>addChannelHandlerAfter(group, baseName, name, handlerFactory)); } @Override public <RR, WW> TcpServer<RR, WW> pipelineConfigurator(Action1<ChannelPipeline> pipelineConfigurator) { return copy(state.<RR, WW>pipelineConfigurator(pipelineConfigurator)); } @Override public TcpServer<R, W> secure(Func1<ByteBufAllocator, SSLEngine> sslEngineFactory) { return copy(((TcpServerState<R, W>)state).secure(sslEngineFactory)); } @Override public TcpServer<R, W> secure(SSLEngine sslEngine) { return copy(((TcpServerState<R, W>)state).secure(sslEngine)); } @Override public TcpServer<R, W> secure(SslCodec sslCodec) { return copy(((TcpServerState<R, W>)state).secure(sslCodec)); } @Override public TcpServer<R, W> unsafeSecure() { return copy(((TcpServerState<R, W>)state).unsafeSecure()); } @Override @Deprecated public TcpServer<R, W> enableWireLogging(LogLevel wireLoggingLevel) { return copy(state.<W, R>enableWireLogging(wireLoggingLevel)); } @Override public TcpServer<R, W> enableWireLogging(String name, LogLevel wireLoggingLevel) { return copy(state.<W, R>enableWireLogging(name, wireLoggingLevel)); } @Override public int getServerPort() { final SocketAddress localAddress = getServerAddress(); if (localAddress instanceof InetSocketAddress) { return ((InetSocketAddress) localAddress).getPort(); } else { return 0; } } @Override public SocketAddress getServerAddress() { SocketAddress localAddress; if (null != bindFuture && bindFuture.isDone()) { localAddress = bindFuture.channel().localAddress(); } else { localAddress = state.getServerAddress(); } return localAddress; } @Override public TcpServer<R, W> start(final ConnectionHandler<R, W> connectionHandler) { if (!serverStateRef.compareAndSet(ServerStatus.Created, ServerStatus.Starting)) { throw new IllegalStateException("Server already started"); } try { Action1<ChannelPipeline> handlerFactory = new Action1<ChannelPipeline>() { @Override public void call(ChannelPipeline pipeline) { TcpServerState<R, W> tcpState = (TcpServerState<R, W>) state; TcpServerConnectionToChannelBridge.addToPipeline(pipeline, connectionHandler, tcpState.getEventPublisher(), tcpState.isSecure()); } }; final TcpServerState<R, W> newState = (TcpServerState<R, W>) state.pipelineConfigurator(handlerFactory); bindFuture = newState.getBootstrap().bind(newState.getServerAddress()).sync(); if (!bindFuture.isSuccess()) { throw new RuntimeException(bindFuture.cause()); } } catch (InterruptedException e) { throw new RuntimeException(e); } serverStateRef.set(ServerStatus.Started); // It will come here only if this was the thread that transitioned to Starting logger.info("Rx server started at port: " + getServerPort()); return this; } @Override public void shutdown() { if (!serverStateRef.compareAndSet(ServerStatus.Started, ServerStatus.Shutdown)) { throw new IllegalStateException("The server is already shutdown."); } else { try { bindFuture.channel().close().sync(); } catch (InterruptedException e) { logger.error("Interrupted while waiting for the server socket to close.", e); } } } @Override public void awaitShutdown() { ServerStatus status = serverStateRef.get(); switch (status) { case Created: case Starting: throw new IllegalStateException("Server not started yet."); case Started: try { bindFuture.channel().closeFuture().await(); } catch (InterruptedException e) { Thread.interrupted(); // Reset the interrupted status logger.error("Interrupted while waiting for the server socket to close.", e); } break; case Shutdown: // Nothing to do as it is already shutdown. break; } } @Override public void awaitShutdown(long duration, TimeUnit timeUnit) { ServerStatus status = serverStateRef.get(); switch (status) { case Created: case Starting: throw new IllegalStateException("Server not started yet."); case Started: try { bindFuture.channel().closeFuture().await(duration, timeUnit); } catch (InterruptedException e) { Thread.interrupted(); // Reset the interrupted status logger.error("Interrupted while waiting for the server socket to close.", e); } break; case Shutdown: // Nothing to do as it is already shutdown. break; } } @Override public TcpServerEventPublisher getEventPublisher() { return ((TcpServerState<R, W>)state).getEventPublisher(); } @Override public Subscription subscribe(TcpServerEventListener listener) { return ((TcpServerState<R, W>)state).getEventPublisher().subscribe(listener); } private static <RR, WW> TcpServer<RR, WW> copy(ServerState<RR, WW> newState) { return new TcpServerImpl<>(newState); } }