/**
* Copyright 2016 LinkedIn Corp. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
package com.github.ambry.network;
import com.github.ambry.commons.SSLFactory;
import com.github.ambry.utils.Time;
import java.io.EOFException;
import java.io.IOException;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.nio.channels.UnresolvedAddressException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A selector doing non-blocking multi-connection network I/O.
* <p>
* This class works with {@link NetworkSend} and {@link NetworkReceive} to transmit network requests and responses.
* <p>
* A connection can be added to the selector by doing
*
* <pre>
* selector.connect("connectionId", new InetSocketAddress("linkedin.com", server.port), 64000, 64000);
* </pre>
*
* The connect call does not block on the creation of the TCP connection, so the connect method only begins initiating
* the connection. The successful invocation of this method does not mean a valid connection has been established. The
* call on return provides a unique id that identifies this connection
*
* Sending requests, receiving responses, processing connection completions, and disconnections on the existing
* connections are all done using the <code>poll()</code> call.
*
* <pre>
* List<NetworkSend> requestsToSend = Arrays.asList(new NetworkSend(0, bytes), new NetworkSend(1, otherBytes));
* selector.poll(TIMEOUT_MS, requestsToSend);
* </pre>
*
* The selector maintains several lists that are reset by each call to <code>poll()</code> which are available via
* various getters. These are reset by each call to <code>poll()</code>.
*
* This class is not thread safe!
*/
public class Selector implements Selectable {
private static final Logger logger = LoggerFactory.getLogger(Selector.class);
private final java.nio.channels.Selector nioSelector;
private final Map<String, SelectionKey> keyMap;
private final List<NetworkSend> completedSends;
private final List<NetworkReceive> completedReceives;
private final List<String> disconnected;
private final List<String> closedConnections;
private final List<String> connected;
private final Set<String> unreadyConnections;
private final Time time;
private final NetworkMetrics metrics;
private final AtomicLong IdGenerator;
private final AtomicLong numActiveConnections;
private final SSLFactory sslFactory;
/**
* Create a new selector
*/
public Selector(NetworkMetrics metrics, Time time, SSLFactory sslFactory) throws IOException {
this.nioSelector = java.nio.channels.Selector.open();
this.time = time;
this.keyMap = new HashMap<String, SelectionKey>();
this.completedSends = new ArrayList<NetworkSend>();
this.completedReceives = new ArrayList<NetworkReceive>();
this.connected = new ArrayList<String>();
this.disconnected = new ArrayList<String>();
this.closedConnections = new ArrayList<>();
this.metrics = metrics;
this.IdGenerator = new AtomicLong(0);
numActiveConnections = new AtomicLong(0);
unreadyConnections = new HashSet<>();
metrics.registerSelectorActiveConnections(numActiveConnections);
this.sslFactory = sslFactory;
}
/**
* Generate an unique connection id
* @param channel The channel between two hosts
* @return The id for the connection that was created
*/
private String generateConnectionId(SocketChannel channel) {
Socket socket = channel.socket();
String localHost = socket.getLocalAddress().getHostAddress();
int localPort = socket.getLocalPort();
String remoteHost = socket.getInetAddress().getHostAddress();
int remotePort = socket.getPort();
long connectionIdSuffix = IdGenerator.getAndIncrement();
StringBuilder connectionIdBuilder = new StringBuilder();
connectionIdBuilder.append(localHost)
.append(":")
.append(localPort)
.append("-")
.append(remoteHost)
.append(":")
.append(remotePort)
.append("_")
.append(connectionIdSuffix);
return connectionIdBuilder.toString();
}
/**
* Begin connecting to the given address and add the connection to this selector and returns an id that identifies
* the connection
* <p>
* Note that this call only initiates the connection, which will be completed on a future {@link #poll(long)}
* call. Check {@link #connected()} to see which (if any) connections have completed after a given poll call.
* @param address The address to connect to
* @param sendBufferSize The networkSend buffer size for the new connection
* @param receiveBufferSize The receive buffer size for the new connection
* @param portType {@PortType} which represents the type of connection to establish
* @return The id for the connection that was created
* @throws IllegalStateException if there is already a connection for that id
* @throws IOException if DNS resolution fails on the hostname or if the server is down
*/
@Override
public String connect(InetSocketAddress address, int sendBufferSize, int receiveBufferSize, PortType portType)
throws IOException {
SocketChannel channel = SocketChannel.open();
channel.configureBlocking(false);
Socket socket = channel.socket();
socket.setKeepAlive(true);
socket.setSendBufferSize(sendBufferSize);
socket.setReceiveBufferSize(receiveBufferSize);
socket.setTcpNoDelay(true);
try {
channel.connect(address);
} catch (UnresolvedAddressException e) {
channel.close();
throw new IOException("Can't resolve address: " + address, e);
} catch (IOException e) {
channel.close();
throw e;
}
String connectionId = generateConnectionId(channel);
SelectionKey key = channel.register(this.nioSelector, SelectionKey.OP_CONNECT);
Transmission transmission = null;
try {
transmission =
TransmissionFactory.getTransmission(connectionId, channel, key, address.getHostName(), address.getPort(),
time, metrics, portType, sslFactory, SSLFactory.Mode.CLIENT);
} catch (IOException e) {
logger.error("IOException on transmission creation " + e);
channel.socket().close();
channel.close();
throw e;
}
key.attach(transmission);
this.keyMap.put(connectionId, key);
numActiveConnections.set(this.keyMap.size());
return connectionId;
}
/**
* Register the nioSelector with an existing channel
* Use this on server-side, when a connection is accepted by a different thread but processed by the Selector
* Note that we are not checking if the connection id is valid - since the connection already exists
*/
public String register(SocketChannel channel, PortType portType) throws IOException {
Socket socket = channel.socket();
String connectionId = generateConnectionId(channel);
SelectionKey key = channel.register(nioSelector, SelectionKey.OP_READ);
Transmission transmission = null;
try {
transmission =
TransmissionFactory.getTransmission(connectionId, channel, key, socket.getInetAddress().getHostAddress(),
socket.getPort(), time, metrics, portType, sslFactory, SSLFactory.Mode.SERVER);
} catch (IOException e) {
logger.error("IOException on transmission creation " + e);
socket.close();
channel.close();
throw e;
}
key.attach(transmission);
this.keyMap.put(connectionId, key);
numActiveConnections.set(this.keyMap.size());
return connectionId;
}
/**
* Disconnect any connections for the given id (if there are any). The disconnection is asynchronous and will not be
* processed until the next {@link #poll(long) poll()} call.
*/
@Override
public void disconnect(String connectionId) {
SelectionKey key = this.keyMap.get(connectionId);
if (key != null) {
key.cancel();
}
}
/**
* Interrupt the selector if it is blocked waiting to do I/O.
*/
@Override
public void wakeup() {
nioSelector.wakeup();
}
/**
* Close this selector and all associated connections
*/
@Override
public void close() {
for (SelectionKey key : this.nioSelector.keys()) {
close(key);
}
try {
this.nioSelector.close();
} catch (IOException e) {
metrics.selectorNioCloseErrorCount.inc();
logger.error("Exception closing nioSelector:", e);
}
}
/**
* Tells whether or not this selector is open. </p>
*
* @return <tt>true</tt> if, and only if, this selector is open
*/
@Override
public boolean isOpen() {
return nioSelector.isOpen();
}
/**
* Queue the given request for sending in the subsequent {@poll(long)} calls
* @param networkSend The NetworkSend that is ready to be sent
*/
public void send(NetworkSend networkSend) {
SelectionKey key = keyForId(networkSend.getConnectionId());
if (key == null) {
throw new IllegalStateException("Attempt to send data to a null key");
}
Transmission transmission = getTransmission(key);
try {
transmission.setNetworkSend(networkSend);
} catch (CancelledKeyException e) {
logger.debug("Ignoring response for closed socket.");
close(key);
}
}
/**
* Do whatever I/O can be done on each connection without blocking. This includes completing connections, completing
* disconnections, initiating new sends, or making progress on in-progress sends or receives.
* <p>
*
* When this call is completed the user can check for completed sends, receives, connections or disconnects using
* {@link #completedSends()}, {@link #completedReceives()}, {@link #connected()}, {@link #disconnected()}. These
* lists will be cleared at the beginning of each {@link #poll(long)} call and repopulated by the call if any
* completed I/O.
*
* @param timeoutMs The amount of time to wait, in milliseconds. If negative, wait indefinitely.
*
* @throws IOException If a send is given for which we have no existing connection or for which there is
* already an in-progress send
*/
@Override
public void poll(long timeoutMs) throws IOException {
poll(timeoutMs, null);
}
/**
* Firstly initiate the provided sends. Then do whatever I/O can be done on each connection without blocking.
* This includes completing connections, completing disconnections, initiating new sends,
* or making progress on in-progress sends or receives.
* <p>
*
* When this call is completed the user can check for completed sends, receives, connections or disconnects using
* {@link #completedSends()}, {@link #completedReceives()}, {@link #connected()}, {@link #disconnected()}. These
* lists will be cleared at the beginning of each {@link #poll(long, List)} call and repopulated by the call if any
* completed I/O.
*
* @param timeoutMs The amount of time to wait, in milliseconds. If negative, wait indefinitely.
* @param sends The list of new sends to begin
*
* @throws IOException If a send is given for which we have no existing connection or for which there is
* already an in-progress send
*/
@Override
public void poll(long timeoutMs, List<NetworkSend> sends) throws IOException {
clear();
// register for write interest on any new sends
if (sends != null) {
for (NetworkSend networkSend : sends) {
send(networkSend);
}
}
// check ready keys
long startSelect = time.milliseconds();
int readyKeys = select(timeoutMs);
long endSelect = time.milliseconds();
this.metrics.selectorSelectTime.update(endSelect - startSelect);
this.metrics.selectorSelectCount.inc();
if (readyKeys > 0) {
Set<SelectionKey> keys = nioSelector.selectedKeys();
Iterator<SelectionKey> iter = keys.iterator();
while (iter.hasNext()) {
SelectionKey key = iter.next();
iter.remove();
Transmission transmission = getTransmission(key);
try {
if (key.isConnectable()) {
transmission.finishConnect();
if (transmission.ready()) {
connected.add(transmission.getConnectionId());
metrics.selectorConnectionCreated.inc();
} else {
unreadyConnections.add(transmission.getConnectionId());
}
}
/* if channel is not ready, finish prepare */
if (transmission.isConnected() && !transmission.ready()) {
transmission.prepare();
continue;
}
if (key.isReadable() && transmission.ready()) {
read(key, transmission);
} else if (key.isWritable() && transmission.ready()) {
write(key, transmission);
} else if (!key.isValid()) {
close(key);
}
} catch (IOException e) {
String socketDescription = socketDescription(channel(key));
if (e instanceof EOFException || e instanceof ConnectException) {
metrics.selectorDisconnectedErrorCount.inc();
logger.error("Connection {} disconnected", socketDescription, e);
} else {
metrics.selectorIOErrorCount.inc();
logger.warn("Error in I/O with connection to {}", socketDescription, e);
}
close(key);
} catch (Exception e) {
metrics.selectorKeyOperationErrorCount.inc();
logger.error("closing key on exception remote host {}", channel(key).socket().getRemoteSocketAddress(), e);
close(key);
}
}
checkUnreadyConnectionsStatus();
this.metrics.selectorIOCount.inc();
}
disconnected.addAll(closedConnections);
closedConnections.clear();
long endIo = time.milliseconds();
this.metrics.selectorIOTime.update(endIo - endSelect);
}
/**
* Check readiness for unready connections and add to completed list if ready
*/
private void checkUnreadyConnectionsStatus() {
Iterator<String> iterator = unreadyConnections.iterator();
while (iterator.hasNext()) {
String connId = iterator.next();
if (isChannelReady(connId)) {
connected.add(connId);
iterator.remove();
metrics.selectorConnectionCreated.inc();
}
}
}
/**
* Generate the description for a SocketChannel
*/
private String socketDescription(SocketChannel channel) {
Socket socket = channel.socket();
if (socket == null) {
return "[unconnected socket]";
} else if (socket.getInetAddress() != null) {
return socket.getInetAddress().toString();
} else {
return socket.getLocalAddress().toString();
}
}
/**
* Returns {@code true} if channel is ready to send or receive data, {@code false} otherwise
* @param connectionId upon which readiness is checked for
* @return true if channel is ready to accept reads/writes, false otherwise
*/
public boolean isChannelReady(String connectionId) {
Transmission transmission = getTransmission(keyForId(connectionId));
return transmission.ready();
}
@Override
public List<NetworkSend> completedSends() {
return this.completedSends;
}
@Override
public List<NetworkReceive> completedReceives() {
return this.completedReceives;
}
@Override
public List<String> disconnected() {
return this.disconnected;
}
@Override
public List<String> connected() {
return this.connected;
}
public long getNumActiveConnections() {
return numActiveConnections.get();
}
/**
* Clear the results from the prior poll
*/
private void clear() {
completedSends.clear();
completedReceives.clear();
connected.clear();
disconnected.clear();
}
/**
* Check for data, waiting up to the given timeout.
*
* @param ms Length of time to wait, in milliseconds. If negative, wait indefinitely.
* @return The number of keys ready
* @throws IOException
*/
private int select(long ms) throws IOException {
if (ms == 0L) {
return this.nioSelector.selectNow();
} else if (ms < 0L) {
return this.nioSelector.select();
} else {
return this.nioSelector.select(ms);
}
}
/**
* Begin closing this connection by given connection id
*/
@Override
public void close(String connectionId) {
SelectionKey key = keyForId(connectionId);
if (key == null) {
metrics.selectorCloseKeyErrorCount.inc();
logger.error("Attempt to close socket for which there is no open connection. Connection id {}", connectionId);
} else {
close(key);
}
}
/**
* Begin closing this connection by given key
*/
private void close(SelectionKey key) {
Transmission transmission = getTransmission(key);
if (transmission != null) {
logger.debug("Closing connection from {}", transmission.getConnectionId());
closedConnections.add(transmission.getConnectionId());
keyMap.remove(transmission.getConnectionId());
numActiveConnections.set(keyMap.size());
unreadyConnections.remove(transmission.getConnectionId());
try {
transmission.close();
} catch (IOException e) {
logger.error("IOException thrown during closing of transmission with connectionId {} :",
transmission.getConnectionId(), e);
}
} else {
key.attach(null);
key.cancel();
SocketAddress address = null;
try {
SocketChannel socketChannel = channel(key);
address = socketChannel.socket().getRemoteSocketAddress();
socketChannel.socket().close();
socketChannel.close();
} catch (IOException e) {
metrics.selectorCloseSocketErrorCount.inc();
logger.error("Exception closing connection to remote host {} :", address, e);
}
}
this.metrics.selectorConnectionClosed.inc();
}
/**
* Get the selection key associated with this numeric id
*/
private SelectionKey keyForId(String id) {
return this.keyMap.get(id);
}
/**
* Process reads from ready sockets
*/
private void read(SelectionKey key, Transmission transmission) throws IOException {
long startTimeToReadInMs = time.milliseconds();
try {
boolean readComplete = transmission.read();
if (readComplete) {
this.completedReceives.add(transmission.getNetworkReceive());
transmission.onReceiveComplete();
transmission.clearReceive();
}
} finally {
long readTime = time.milliseconds() - startTimeToReadInMs;
logger.trace("SocketServer time spent on read per key {} = {}", transmission.getConnectionId(), readTime);
}
}
/**
* Process writes to ready sockets
*/
private void write(SelectionKey key, Transmission transmission) throws IOException {
long startTimeToWriteInMs = time.milliseconds();
try {
boolean sendComplete = transmission.write();
if (sendComplete) {
logger.trace("Finished writing, registering for read on connection {}", transmission.getRemoteSocketAddress());
transmission.onSendComplete();
this.completedSends.add(transmission.getNetworkSend());
metrics.sendInFlight.dec();
transmission.clearSend();
key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE | SelectionKey.OP_READ);
}
} finally {
long writeTime = time.milliseconds() - startTimeToWriteInMs;
logger.trace("SocketServer time spent on write per key {} = {}", transmission.getConnectionId(), writeTime);
}
}
/**
* Get the Transmission for the given connection
*/
private Transmission getTransmission(SelectionKey key) {
return (Transmission) key.attachment();
}
/**
* Get the socket channel associated with this selection key
*/
private SocketChannel channel(SelectionKey key) {
return (SocketChannel) key.channel();
}
}