/** * Copyright 2016 LinkedIn Corp. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ package com.github.ambry.network; import com.codahale.metrics.MetricRegistry; import com.github.ambry.commons.SSLFactory; import com.github.ambry.config.NetworkConfig; import com.github.ambry.config.SSLConfig; import com.github.ambry.utils.ByteBufferInputStream; import com.github.ambry.utils.SystemTime; import com.github.ambry.utils.Time; import com.github.ambry.utils.Utils; import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketException; import java.nio.channels.ClosedChannelException; import java.nio.channels.SelectionKey; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * A NIO socket server. The threading model is * 1 Acceptor thread that handles new connections * N Processor threads that each have their own selector and read requests from sockets * M Handler threads that handle requests and produce responses back to the processor threads for writing. */ public class SocketServer implements NetworkServer { private final String host; private final int port; private final int numProcessorThreads; private final int maxQueuedRequests; private final int sendBufferSize; private final int recvBufferSize; private final int maxRequestSize; private final ArrayList<Processor> processors; private volatile ArrayList<Acceptor> acceptors; private final SocketRequestResponseChannel requestResponseChannel; private Logger logger = LoggerFactory.getLogger(getClass()); private final ServerNetworkMetrics metrics; private final HashMap<PortType, Port> ports; private SSLFactory sslFactory; public SocketServer(NetworkConfig config, SSLConfig sslConfig, MetricRegistry registry, ArrayList<Port> portList) { this.host = config.hostName; this.port = config.port; this.numProcessorThreads = config.numIoThreads; this.maxQueuedRequests = config.queuedMaxRequests; this.sendBufferSize = config.socketSendBufferBytes; this.recvBufferSize = config.socketReceiveBufferBytes; this.maxRequestSize = config.socketRequestMaxBytes; processors = new ArrayList<Processor>(numProcessorThreads); requestResponseChannel = new SocketRequestResponseChannel(numProcessorThreads, maxQueuedRequests); metrics = new ServerNetworkMetrics(requestResponseChannel, registry, processors); this.acceptors = new ArrayList<Acceptor>(); this.ports = new HashMap<PortType, Port>(); this.validatePorts(portList); this.initializeSSLFactory(sslConfig); } public String getHost() { return host; } public int getPort() { return port; } public int getSSLPort() { Port sslPort = ports.get(PortType.SSL); if (sslPort != null) { return sslPort.getPort(); } throw new IllegalStateException("No SSL Port Exists for Server " + host + ":" + port); } private void initializeSSLFactory(SSLConfig sslConfig) { if (ports.get(PortType.SSL) != null) { try { this.sslFactory = new SSLFactory(sslConfig); metrics.sslFactoryInitializationCount.inc(); } catch (Exception e) { metrics.sslFactoryInitializationErrorCount.inc(); throw new IllegalStateException("Exception thrown during initialization of SSLFactory ", e); } } } public int getNumProcessorThreads() { return numProcessorThreads; } public int getMaxQueuedRequests() { return maxQueuedRequests; } public int getSendBufferSize() { return sendBufferSize; } public int getRecvBufferSize() { return recvBufferSize; } public int getMaxRequestSize() { return maxRequestSize; } @Override public RequestResponseChannel getRequestResponseChannel() { return requestResponseChannel; } private void validatePorts(ArrayList<Port> portList) { HashSet<PortType> portTypeSet = new HashSet<PortType>(); for (Port port : portList) { if (portTypeSet.contains(port.getPortType())) { throw new IllegalArgumentException("Not more than one port of same type is allowed : " + port.getPortType()); } else { portTypeSet.add(port.getPortType()); this.ports.put(port.getPortType(), port); } } } public void start() throws IOException, InterruptedException { logger.info("Starting {} processor threads", numProcessorThreads); for (int i = 0; i < numProcessorThreads; i++) { processors.add(i, new Processor(i, maxRequestSize, requestResponseChannel, metrics, sslFactory)); Utils.newThread("ambry-processor-" + port + " " + i, processors.get(i), false).start(); } requestResponseChannel.addResponseListener(new ResponseListener() { @Override public void onResponse(int processorId) { processors.get(processorId).wakeup(); } }); // start accepting connections logger.info("Starting acceptor threads"); Acceptor plainTextAcceptor = new Acceptor(port, processors, sendBufferSize, recvBufferSize, metrics); this.acceptors.add(plainTextAcceptor); Utils.newThread("ambry-acceptor", plainTextAcceptor, false).start(); Port sslPort = ports.get(PortType.SSL); if (sslPort != null) { SSLAcceptor sslAcceptor = new SSLAcceptor(sslPort.getPort(), processors, sendBufferSize, recvBufferSize, metrics); acceptors.add(sslAcceptor); Utils.newThread("ambry-sslacceptor", sslAcceptor, false).start(); } for (Acceptor acceptor : acceptors) { acceptor.awaitStartup(); } logger.info("Started server"); } public void shutdown() { try { logger.info("Shutting down server"); for (Acceptor acceptor : acceptors) { if (acceptor != null) { acceptor.shutdown(); } } for (Processor processor : processors) { processor.shutdown(); } logger.info("Shutdown completed"); } catch (Exception e) { logger.error("Error shutting down socket server {}", e); } } } /** * A base class with some helper variables and methods */ abstract class AbstractServerThread implements Runnable { private final CountDownLatch startupLatch; private final CountDownLatch shutdownLatch; private final AtomicBoolean alive; protected Logger logger = LoggerFactory.getLogger(getClass()); public AbstractServerThread() throws IOException { startupLatch = new CountDownLatch(1); shutdownLatch = new CountDownLatch(1); alive = new AtomicBoolean(false); } /** * Initiates a graceful shutdown by signaling to stop and waiting for the shutdown to complete */ public void shutdown() throws InterruptedException { alive.set(false); shutdownLatch.await(); } /** * Wait for the thread to completely start up */ public void awaitStartup() throws InterruptedException { startupLatch.await(); } /** * Record that the thread startup is complete */ protected void startupComplete() { alive.set(true); startupLatch.countDown(); } /** * Record that the thread shutdown is complete */ protected void shutdownComplete() { shutdownLatch.countDown(); } /** * Is the server still running? */ protected boolean isRunning() { return alive.get(); } } /** * Thread that accepts and configures new connections. */ class Acceptor extends AbstractServerThread { private final ArrayList<Processor> processors; private final int sendBufferSize; private final int recvBufferSize; private final ServerSocketChannel serverChannel; private final java.nio.channels.Selector nioSelector; private static final long selectTimeOutMs = 500; private final ServerNetworkMetrics metrics; protected Logger logger = LoggerFactory.getLogger(getClass()); public Acceptor(int port, ArrayList<Processor> processors, int sendBufferSize, int recvBufferSize, ServerNetworkMetrics metrics) throws IOException { this.processors = processors; this.sendBufferSize = sendBufferSize; this.recvBufferSize = recvBufferSize; this.serverChannel = openServerSocket(port); this.nioSelector = java.nio.channels.Selector.open(); this.metrics = metrics; } /** * Accept loop that checks for new connection attempts for a plain text port */ public void run() { try { serverChannel.register(nioSelector, SelectionKey.OP_ACCEPT); startupComplete(); int currentProcessor = 0; while (isRunning()) { int ready = nioSelector.select(selectTimeOutMs); if (ready > 0) { Set<SelectionKey> keys = nioSelector.selectedKeys(); Iterator<SelectionKey> iter = keys.iterator(); while (iter.hasNext() && isRunning()) { SelectionKey key = null; try { key = iter.next(); iter.remove(); if (key.isAcceptable()) { accept(key, processors.get(currentProcessor)); } else { throw new IllegalStateException("Unrecognized key state for acceptor thread."); } // round robin to the next processor thread currentProcessor = (currentProcessor + 1) % processors.size(); } catch (Exception e) { key.cancel(); metrics.acceptConnectionErrorCount.inc(); logger.debug("Error in accepting new connection", e); } } } } logger.debug("Closing server socket and selector."); serverChannel.close(); nioSelector.close(); shutdownComplete(); super.shutdown(); } catch (Exception e) { metrics.acceptorShutDownErrorCount.inc(); logger.error("Error during shutdown of acceptor thread", e); } } /* * Create a server socket to listen for connections on. */ private ServerSocketChannel openServerSocket(int port) throws IOException { InetSocketAddress address = new InetSocketAddress(port); ServerSocketChannel serverChannel = ServerSocketChannel.open(); serverChannel.configureBlocking(false); serverChannel.socket().bind(address); logger.info("Awaiting socket connections on {}:{}", address.getHostName(), port); return serverChannel; } /* * Accept a new connection */ protected void accept(SelectionKey key, Processor processor) throws SocketException, IOException { SocketChannel socketChannel = acceptConnection(key); processor.accept(socketChannel, PortType.PLAINTEXT); } protected SocketChannel acceptConnection(SelectionKey key) throws SocketException, IOException { ServerSocketChannel serverSocketChannel = (ServerSocketChannel) key.channel(); serverSocketChannel.socket().setReceiveBufferSize(recvBufferSize); SocketChannel socketChannel = serverSocketChannel.accept(); socketChannel.configureBlocking(false); socketChannel.socket().setTcpNoDelay(true); socketChannel.socket().setSendBufferSize(sendBufferSize); logger.trace("Accepted connection from {} on {}. sendBufferSize " + "[actual|requested]: [{}|{}] recvBufferSize [actual|requested]: [{}|{}]", socketChannel.socket().getInetAddress(), socketChannel.socket().getLocalSocketAddress(), socketChannel.socket().getSendBufferSize(), sendBufferSize, socketChannel.socket().getReceiveBufferSize(), recvBufferSize); return socketChannel; } public void shutdown() throws InterruptedException { nioSelector.wakeup(); super.shutdown(); } } /** * Thread that accepts and configures new connections for an SSL Port */ class SSLAcceptor extends Acceptor { public SSLAcceptor(int port, ArrayList<Processor> processors, int sendBufferSize, int recvBufferSize, ServerNetworkMetrics metrics) throws IOException { super(port, processors, sendBufferSize, recvBufferSize, metrics); } /* * Accept a new connection */ @Override protected void accept(SelectionKey key, Processor processor) throws SocketException, IOException { SocketChannel socketChannel = acceptConnection(key); processor.accept(socketChannel, PortType.SSL); } } /** * Thread that processes all requests from a single connection. There are N of these running in parallel * each of which has its own selectors */ class Processor extends AbstractServerThread { private final int maxRequestSize; private final SocketRequestResponseChannel channel; private final int id; private final Time time; private final ConcurrentLinkedQueue<SocketChannelPortTypePair> newConnections = new ConcurrentLinkedQueue<SocketChannelPortTypePair>(); private final Selector selector; private final ServerNetworkMetrics metrics; private static final long pollTimeoutMs = 300; Processor(int id, int maxRequestSize, RequestResponseChannel channel, ServerNetworkMetrics metrics, SSLFactory sslFactory) throws IOException { this.maxRequestSize = maxRequestSize; this.channel = (SocketRequestResponseChannel) channel; this.id = id; this.time = SystemTime.getInstance(); selector = new Selector(metrics, time, sslFactory); this.metrics = metrics; } public void run() { try { startupComplete(); while (isRunning()) { // setup any new connections that have been queued up configureNewConnections(); // register any new responses for writing processNewResponses(); selector.poll(pollTimeoutMs); // handle completed receives List<NetworkReceive> completedReceives = selector.completedReceives(); for (NetworkReceive networkReceive : completedReceives) { String connectionId = networkReceive.getConnectionId(); SocketServerRequest req = new SocketServerRequest(id, connectionId, new ByteBufferInputStream(networkReceive.getReceivedBytes().getPayload())); channel.sendRequest(req); } } } catch (Exception e) { logger.error("Error in processor thread", e); } finally { logger.debug("Closing server socket and selector."); try { closeAll(); shutdownComplete(); super.shutdown(); } catch (InterruptedException ie) { metrics.processorShutDownErrorCount.inc(); logger.error("InterruptedException on processor shutdown ", ie); } } } private void processNewResponses() throws InterruptedException, IOException { SocketServerResponse curr = (SocketServerResponse) channel.receiveResponse(id); while (curr != null) { curr.onDequeueFromResponseQueue(); SocketServerRequest request = (SocketServerRequest) curr.getRequest(); String connectionId = request.getConnectionId(); try { if (curr.getPayload() == null) { // We should never need to send an empty response. If the payload is empty, we will assume error // and close the connection logger.trace("Socket server received no response and hence closing the connection"); selector.close(connectionId); } else { logger.trace("Socket server received response to send, registering for write: {}", curr); NetworkSend networkSend = new NetworkSend(connectionId, curr.getPayload(), curr.getMetrics(), time); selector.send(networkSend); } } catch (IllegalStateException e) { metrics.processNewResponseErrorCount.inc(); logger.debug("Error in processing new responses", e); } finally { curr = (SocketServerResponse) channel.receiveResponse(id); } } } /** * Queue up a new connection for reading */ public void accept(SocketChannel socketChannel, PortType portType) { newConnections.add(new SocketChannelPortTypePair(socketChannel, portType)); wakeup(); } /** * Close all open connections */ private void closeAll() { selector.close(); } /** * Register any new connections that have been queued up */ private void configureNewConnections() throws ClosedChannelException, IOException { while (newConnections.size() > 0) { SocketChannelPortTypePair socketChannelPortTypePair = newConnections.poll(); logger.debug("Processor {} listening to new connection from {}", id, socketChannelPortTypePair.getSocketChannel().socket().getRemoteSocketAddress()); try { selector.register(socketChannelPortTypePair.getSocketChannel(), socketChannelPortTypePair.getPortType()); } catch (IOException e) { logger.error("Error on registering new connection ", e); } } } /** * Initiates a graceful shutdown by signaling to stop and waiting for the shutdown to complete */ public void shutdown() throws InterruptedException { selector.wakeup(); super.shutdown(); } /** * Wakes up the thread for selection. */ public void wakeup() { selector.wakeup(); } class SocketChannelPortTypePair { private SocketChannel socketChannel; private PortType portType; public SocketChannelPortTypePair(SocketChannel socketChannel, PortType portType) { this.socketChannel = socketChannel; this.portType = portType; } public PortType getPortType() { return portType; } public SocketChannel getSocketChannel() { return this.socketChannel; } } }