/* * Copyright (C) 2012-2016 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.nifty.core; import com.facebook.nifty.ssl.SslPlaintextHandler; import com.facebook.nifty.ssl.SslServerConfiguration; import com.google.common.base.Preconditions; import edu.umd.cs.findbugs.annotations.SuppressWarnings; import io.airlift.log.Logger; import org.apache.thrift.protocol.TProtocolFactory; import org.jboss.netty.bootstrap.ServerBootstrap; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelFuture; import org.jboss.netty.channel.ChannelFutureListener; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelPipeline; import org.jboss.netty.channel.ChannelPipelineFactory; import org.jboss.netty.channel.ChannelStateEvent; import org.jboss.netty.channel.Channels; import org.jboss.netty.channel.ServerChannelFactory; import org.jboss.netty.channel.SimpleChannelUpstreamHandler; import org.jboss.netty.channel.group.ChannelGroup; import org.jboss.netty.channel.group.DefaultChannelGroup; import org.jboss.netty.channel.socket.nio.NioServerBossPool; import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory; import org.jboss.netty.channel.socket.nio.NioWorkerPool; import org.jboss.netty.handler.ssl.SslHandler; import org.jboss.netty.handler.timeout.IdleStateHandler; import org.jboss.netty.util.ExternalResourceReleasable; import org.jboss.netty.util.ThreadNameDeterminer; import javax.inject.Inject; import java.net.InetSocketAddress; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; /** * A core channel the decode framed Thrift message, dispatches to the TProcessor given * and then encode message back to Thrift frame. */ public class NettyServerTransport implements ExternalResourceReleasable { private static final Logger log = Logger.get(NettyServerTransport.class); private final int requestedPort; private int actualPort; private final ChannelPipelineFactory pipelineFactory; private static final int NO_WRITER_IDLE_TIMEOUT = 0; private static final int NO_ALL_IDLE_TIMEOUT = 0; private ServerBootstrap bootstrap; private final ChannelGroup allChannels; private ExecutorService bossExecutor; private ExecutorService ioWorkerExecutor; private ServerChannelFactory channelFactory; private Channel serverChannel; private final ThriftServerDef def; private final NettyServerConfig nettyServerConfig; private final ChannelStatistics channelStatistics; private AtomicReference<SslServerConfiguration> sslConfiguration = new AtomicReference<>(); public NettyServerTransport(final ThriftServerDef def) { this(def, NettyServerConfig.newBuilder().build(), new DefaultChannelGroup()); } @Inject public NettyServerTransport( final ThriftServerDef def, final NettyServerConfig nettyServerConfig, final ChannelGroup allChannels) { this.def = def; this.nettyServerConfig = nettyServerConfig; this.requestedPort = def.getServerPort(); this.allChannels = allChannels; // connectionLimiter must be instantiated exactly once (and thus outside the pipeline factory) final ConnectionLimiter connectionLimiter = new ConnectionLimiter(def.getMaxConnections()); this.channelStatistics = new ChannelStatistics(allChannels); this.sslConfiguration.set(this.def.getSslConfiguration()); this.pipelineFactory = new ChannelPipelineFactory() { @Override public ChannelPipeline getPipeline() throws Exception { ChannelPipeline cp = Channels.pipeline(); TProtocolFactory inputProtocolFactory = def.getDuplexProtocolFactory().getInputProtocolFactory(); NiftySecurityHandlers securityHandlers = def.getSecurityFactory().getSecurityHandlers(def, nettyServerConfig); cp.addLast("connectionContext", new ConnectionContextHandler()); cp.addLast("connectionLimiter", connectionLimiter); cp.addLast(ChannelStatistics.NAME, channelStatistics); cp.addLast("encryptionHandler", securityHandlers.getEncryptionHandler()); cp.addLast("ioDispatcher", new NiftyIODispatcher()); cp.addLast("frameCodec", def.getThriftFrameCodecFactory().create(def.getMaxFrameSize(), inputProtocolFactory)); if (def.getClientIdleTimeout() != null) { // Add handlers to detect idle client connections and disconnect them cp.addLast("idleTimeoutHandler", new IdleStateHandler(nettyServerConfig.getTimer(), def.getClientIdleTimeout().toMillis(), NO_WRITER_IDLE_TIMEOUT, NO_ALL_IDLE_TIMEOUT, TimeUnit.MILLISECONDS)); cp.addLast("idleDisconnectHandler", new IdleDisconnectHandler()); } cp.addLast("authHandler", securityHandlers.getAuthenticationHandler()); cp.addLast("dispatcher", new NiftyDispatcher(def, nettyServerConfig.getTimer())); cp.addLast("exceptionLogger", new NiftyExceptionLogger()); SslServerConfiguration serverConfiguration = sslConfiguration.get(); if (serverConfiguration != null) { if (serverConfiguration.allowPlaintext) { cp.addFirst("ssl_plaintext", new SslPlaintextHandler(serverConfiguration, "ssl")); } else { SslHandler handler = serverConfiguration.createHandler(); cp.addFirst("ssl", handler); } } return cp; } }; } public void start() { bossExecutor = nettyServerConfig.getBossExecutor(); int bossThreadCount = nettyServerConfig.getBossThreadCount(); ioWorkerExecutor = nettyServerConfig.getWorkerExecutor(); int ioWorkerThreadCount = nettyServerConfig.getWorkerThreadCount(); channelFactory = new NioServerSocketChannelFactory(new NioServerBossPool(bossExecutor, bossThreadCount, ThreadNameDeterminer.CURRENT), new NioWorkerPool(ioWorkerExecutor, ioWorkerThreadCount, ThreadNameDeterminer.CURRENT)); start(channelFactory); } public void start(ServerChannelFactory serverChannelFactory) { bootstrap = new ServerBootstrap(serverChannelFactory); bootstrap.setOptions(nettyServerConfig.getBootstrapOptions()); bootstrap.setPipelineFactory(pipelineFactory); serverChannel = bootstrap.bind(new InetSocketAddress(requestedPort)); InetSocketAddress actualSocket = (InetSocketAddress) serverChannel.getLocalAddress(); actualPort = actualSocket.getPort(); Preconditions.checkState(actualPort != 0 && (actualPort == requestedPort || requestedPort == 0)); log.info("started transport %s:%s", def.getName(), actualPort); if (def.getTransportAttachObserver() != null) { def.getTransportAttachObserver().attachTransport(this); } } public void stop() throws InterruptedException { if (serverChannel != null) { log.info("stopping transport %s:%s", def.getName(), actualPort); // first stop accepting final CountDownLatch latch = new CountDownLatch(1); serverChannel.close().addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { // stop and process remaining in-flight invocations if (def.getExecutor() instanceof ExecutorService) { ExecutorService exe = (ExecutorService) def.getExecutor(); ShutdownUtil.shutdownExecutor(exe, "dispatcher"); } latch.countDown(); } }); latch.await(); serverChannel = null; } // If the channelFactory was created by us, we should also clean it up. If the // channelFactory was passed in by NiftyBootstrap, then it may be shared so don't clean // it up. if (channelFactory != null) { ShutdownUtil.shutdownChannelFactory(channelFactory, bossExecutor, ioWorkerExecutor, allChannels); } if (def.getTransportAttachObserver() != null) { def.getTransportAttachObserver().detachTransport(); } } public Channel getServerChannel() { return serverChannel; } public int getPort() { if (actualPort != 0) { return actualPort; } else { return requestedPort; // may be 0 if server not yet started } } @Override public void releaseExternalResources() { bootstrap.releaseExternalResources(); } private static class ConnectionLimiter extends SimpleChannelUpstreamHandler { private final AtomicInteger numConnections; private final int maxConnections; public ConnectionLimiter(int maxConnections) { this.maxConnections = maxConnections; this.numConnections = new AtomicInteger(0); } @Override @SuppressWarnings("PMD.CollapsibleIfStatements") public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { if (maxConnections > 0 && numConnections.incrementAndGet() > maxConnections) { ctx.getChannel().close(); // numConnections will be decremented in channelClosed log.info("Accepted connection above limit (%s). Dropping.", maxConnections); } super.channelOpen(ctx, e); } @Override @SuppressWarnings("PMD.CollapsibleIfStatements") public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { if (maxConnections > 0 && numConnections.decrementAndGet() < 0) { log.error("BUG in ConnectionLimiter"); } super.channelClosed(ctx, e); } } public NiftyMetrics getMetrics() { return channelStatistics; } /** * Returns the current {@link SslServerConfiguration}. * * @return the configuration. */ public SslServerConfiguration getSSLConfiguration() { return sslConfiguration.get(); } /** * Atomically replaces the current {@link SslServerConfiguration} with the provided one. * * @param sslServerConfiguration the new configuration. */ public void updateSSLConfiguration(SslServerConfiguration sslServerConfiguration) { sslConfiguration.set(sslServerConfiguration); } /** * Atomically replaces the current {@link SslServerConfiguration} with {@code updated} if and only if the * current configuration is {@code ==} to {@code expected}. * * @param expected the expected current configuration. * @param updated the new configuration. * @return true if the update succeeded, or false otherwise. */ public boolean compareAndSetSSLConfiguration(SslServerConfiguration expected, SslServerConfiguration updated) { return sslConfiguration.compareAndSet(expected, updated); } }