/**
* Copyright 2016 LinkedIn Corp. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
package com.github.ambry.network;
import com.codahale.metrics.Counter;
import com.codahale.metrics.Gauge;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import com.github.ambry.commons.SSLFactory;
import com.github.ambry.config.ClusterMapConfig;
import com.github.ambry.config.ConnectionPoolConfig;
import com.github.ambry.config.SSLConfig;
import java.io.IOException;
import java.net.SocketException;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class BlockingChannelInfo {
private final ArrayBlockingQueue<BlockingChannel> blockingChannelAvailableConnections;
private final ArrayBlockingQueue<BlockingChannel> blockingChannelActiveConnections;
private final AtomicInteger numberOfConnections;
private final ConnectionPoolConfig config;
private final ReadWriteLock rwlock;
private final Object lock;
private final String host;
private final Port port;
private final Logger logger = LoggerFactory.getLogger(getClass());
protected Gauge<Integer> availableConnections;
private Gauge<Integer> activeConnections;
private Gauge<Integer> totalNumberOfConnections;
private int maxConnectionsPerHostPerPort;
private final SSLSocketFactory sslSocketFactory;
private final SSLConfig sslConfig;
private final MetricRegistry registry;
public BlockingChannelInfo(ConnectionPoolConfig config, String host, Port port, MetricRegistry registry,
SSLSocketFactory sslSocketFactory, SSLConfig sslConfig) {
this.config = config;
this.port = port;
this.registry = registry;
if (port.getPortType() == PortType.SSL) {
maxConnectionsPerHostPerPort = config.connectionPoolMaxConnectionsPerPortSSL;
} else {
maxConnectionsPerHostPerPort = config.connectionPoolMaxConnectionsPerPortPlainText;
}
this.blockingChannelAvailableConnections = new ArrayBlockingQueue<BlockingChannel>(maxConnectionsPerHostPerPort);
this.blockingChannelActiveConnections = new ArrayBlockingQueue<BlockingChannel>(maxConnectionsPerHostPerPort);
this.numberOfConnections = new AtomicInteger(0);
this.rwlock = new ReentrantReadWriteLock();
this.lock = new Object();
this.host = host;
this.sslSocketFactory = sslSocketFactory;
this.sslConfig = sslConfig;
availableConnections = new Gauge<Integer>() {
@Override
public Integer getValue() {
return blockingChannelAvailableConnections.size();
}
};
registry.register(
MetricRegistry.name(BlockingChannelInfo.class, host + "-" + port.getPort() + "-availableConnections"),
availableConnections);
activeConnections = new Gauge<Integer>() {
@Override
public Integer getValue() {
return blockingChannelActiveConnections.size();
}
};
registry.register(
MetricRegistry.name(BlockingChannelInfo.class, host + "-" + port.getPort() + "-activeConnections"),
activeConnections);
totalNumberOfConnections = new Gauge<Integer>() {
@Override
public Integer getValue() {
return numberOfConnections.intValue();
}
};
registry.register(
MetricRegistry.name(BlockingChannelInfo.class, host + "-" + port.getPort() + "-totalNumberOfConnections"),
totalNumberOfConnections);
logger.info("Starting blocking channel info for host {} and port {}", host, port.getPort());
}
public void releaseBlockingChannel(BlockingChannel blockingChannel) {
rwlock.readLock().lock();
try {
if (blockingChannelActiveConnections.remove(blockingChannel)) {
blockingChannelAvailableConnections.add(blockingChannel);
logger.trace(
"Adding connection to {}:{} back to pool. Current available connections {} Current active connections {}",
blockingChannel.getRemoteHost(), blockingChannel.getRemotePort(),
blockingChannelAvailableConnections.size(), blockingChannelActiveConnections.size());
} else {
logger.error("Tried to add invalid connection. Channel does not belong in the active queue. Host {} port {}"
+ " channel host {} channel port {}", host, port.getPort(), blockingChannel.getRemoteHost(),
blockingChannel.getRemotePort());
}
} finally {
rwlock.readLock().unlock();
}
}
public BlockingChannel getBlockingChannel(long timeoutInMs)
throws InterruptedException, ConnectionPoolTimeoutException {
rwlock.readLock().lock();
try {
// check if the max connections for this queue has reached or if there are any connections available
// in the available queue. The check in available queue is approximate and it could not have any
// connections when polled. In this case we just depend on an existing connection being placed back in
// the available pool
if (numberOfConnections.get() == maxConnectionsPerHostPerPort || blockingChannelAvailableConnections.size() > 0) {
BlockingChannel channel = blockingChannelAvailableConnections.poll(timeoutInMs, TimeUnit.MILLISECONDS);
if (channel != null) {
blockingChannelActiveConnections.add(channel);
logger.trace("Returning connection to " + channel.getRemoteHost() + ":" + channel.getRemotePort());
return channel;
} else if (numberOfConnections.get() == maxConnectionsPerHostPerPort) {
logger.error("Timed out trying to get a connection for host {} and port {}", host, port.getPort());
throw new ConnectionPoolTimeoutException(
"Could not get a connection to host " + host + " and port " + port.getPort());
}
}
synchronized (lock) {
// if the number of connections created for this host and port is less than the max allowed
// connections, we create a new one and add it to the available queue
if (numberOfConnections.get() < maxConnectionsPerHostPerPort) {
logger.trace("Planning to create a new connection for host {} and port {} ", host, port.getPort());
BlockingChannel channel = getBlockingChannelBasedOnPortType(host, port.getPort());
channel.connect();
numberOfConnections.incrementAndGet();
logger.trace("Created a new connection for host {} and port {}. Number of connections {}", host, port,
numberOfConnections.get());
blockingChannelActiveConnections.add(channel);
return channel;
}
}
BlockingChannel channel = blockingChannelAvailableConnections.poll(timeoutInMs, TimeUnit.MILLISECONDS);
if (channel == null) {
logger.error("Timed out trying to get a connection for host {} and port {}", host, port);
throw new ConnectionPoolTimeoutException(
"Could not get a connection to host " + host + " and port " + port.getPort());
}
blockingChannelActiveConnections.add(channel);
return channel;
} catch (SocketException e) {
logger.error("Socket exception when trying to connect to remote host {} and port {}", host, port.getPort());
throw new ConnectionPoolTimeoutException(
"Socket exception when trying to connect to remote host " + host + " port " + port.getPort(), e);
} catch (IOException e) {
logger.error("IOException when trying to connect to the remote host {} and port {}", host, port.getPort());
throw new ConnectionPoolTimeoutException(
"IOException when trying to connect to remote host " + host + " port " + port.getPort(), e);
} finally {
rwlock.readLock().unlock();
}
}
/**
* Returns BlockingChannel or SSLBlockingChannel depending on whether the port type is PlainText or SSL
* @param host upon which connection has to be established
* @param port upon which connection has to be established
* @return BlockingChannel
*/
private BlockingChannel getBlockingChannelBasedOnPortType(String host, int port) {
BlockingChannel channel = null;
if (this.port.getPortType() == PortType.PLAINTEXT) {
channel = new BlockingChannel(host, port, config.connectionPoolReadBufferSizeBytes,
config.connectionPoolWriteBufferSizeBytes, config.connectionPoolReadTimeoutMs,
config.connectionPoolConnectTimeoutMs);
} else if (this.port.getPortType() == PortType.SSL) {
channel = new SSLBlockingChannel(host, port, registry, config.connectionPoolReadBufferSizeBytes,
config.connectionPoolWriteBufferSizeBytes, config.connectionPoolReadTimeoutMs,
config.connectionPoolConnectTimeoutMs, sslSocketFactory, sslConfig);
}
return channel;
}
public void destroyBlockingChannel(BlockingChannel blockingChannel) {
rwlock.readLock().lock();
try {
boolean changed = blockingChannelActiveConnections.remove(blockingChannel);
if (!changed) {
logger.error("Invalid connection being destroyed. "
+ "Channel does not belong to this queue. queue host {} port {} channel host {} port {}", host,
port.getPort(), blockingChannel.getRemoteHost(), blockingChannel.getRemotePort());
throw new IllegalArgumentException("Invalid connection. Channel does not belong to this queue");
}
blockingChannel.disconnect();
// we ensure we maintain the current count of connections to the host to avoid synchronization across threads
// to create the connection
BlockingChannel channel =
getBlockingChannelBasedOnPortType(blockingChannel.getRemoteHost(), blockingChannel.getRemotePort());
channel.connect();
logger.trace("Destroying connection and adding new connection for host {} port {}", host, port.getPort());
blockingChannelAvailableConnections.add(channel);
} catch (Exception e) {
logger.error("Connection failure to remote host {} and port {} when destroying and recreating the connection",
host, port.getPort());
synchronized (lock) {
// decrement the number of connections to the host and port. we were not able to maintain the count
numberOfConnections.decrementAndGet();
// at this point we are good to clean up the available connections since re-creation failed
do {
BlockingChannel channel = blockingChannelAvailableConnections.poll();
if (channel == null) {
break;
}
channel.disconnect();
numberOfConnections.decrementAndGet();
} while (true);
}
} finally {
rwlock.readLock().unlock();
}
}
/**
* Returns the number of connections with this BlockingChannelInfo
* @return
*/
public int getNumberOfConnections() {
return this.numberOfConnections.intValue();
}
public void cleanup() {
rwlock.writeLock().lock();
logger.info("Cleaning all active and available connections for host {} and port {}", host, port.getPort());
try {
for (BlockingChannel channel : blockingChannelActiveConnections) {
channel.disconnect();
}
blockingChannelActiveConnections.clear();
for (BlockingChannel channel : blockingChannelAvailableConnections) {
channel.disconnect();
}
blockingChannelAvailableConnections.clear();
numberOfConnections.set(0);
logger.info("Cleaning completed for all active and available connections for host {} and port {}", host,
port.getPort());
} finally {
rwlock.writeLock().unlock();
}
}
}
/**
* A connection pool that uses BlockingChannel as the underlying connection.
* It is responsible for all the connection management. It helps to
* checkout a new connection, checkin an existing connection that has been
* checked out and destroy a connection in the case of an error
*/
public final class BlockingChannelConnectionPool implements ConnectionPool {
private final Map<String, BlockingChannelInfo> connections;
private final ConnectionPoolConfig config;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final MetricRegistry registry;
private final Timer connectionCheckOutTime;
private final Timer connectionCheckInTime;
private final Timer connectionDestroyTime;
private final AtomicInteger requestsWaitingToCheckoutConnectionCount;
private SSLSocketFactory sslSocketFactory;
private final SSLConfig sslConfig;
// Represents the total number to nodes connectedTo, i.e. if the blockingchannel has atleast 1 connection
private Gauge<Integer> totalNumberOfNodesConnectedTo;
// Represents the total number of connections, in other words, aggregate of the connections from all nodes
public Gauge<Integer> totalNumberOfConnections;
// Represents the number of requests waiting to checkout a connection
public Gauge<Integer> requestsWaitingToCheckoutConnection;
// Represents the number of sslSocketFactory Initializations by client
public Counter sslSocketFactoryClientInitializationCount;
// Represents the number of sslSocketFactory Initialization Error by client
public Counter sslSocketFactoryClientInitializationErrorCount;
public BlockingChannelConnectionPool(ConnectionPoolConfig config, SSLConfig sslConfig,
ClusterMapConfig clusterMapConfig, MetricRegistry registry) throws Exception {
connections = new ConcurrentHashMap<String, BlockingChannelInfo>();
this.config = config;
this.registry = registry;
this.sslConfig = sslConfig;
connectionCheckOutTime =
registry.timer(MetricRegistry.name(BlockingChannelConnectionPool.class, "connectionCheckOutTime"));
connectionCheckInTime =
registry.timer(MetricRegistry.name(BlockingChannelConnectionPool.class, "connectionCheckInTime"));
connectionDestroyTime =
registry.timer(MetricRegistry.name(BlockingChannelConnectionPool.class, "connectionDestroyTime"));
totalNumberOfNodesConnectedTo = new Gauge<Integer>() {
@Override
public Integer getValue() {
int noOfNodesConnectedTo = 0;
for (BlockingChannelInfo blockingChannelInfo : connections.values()) {
if (blockingChannelInfo.getNumberOfConnections() > 0) {
noOfNodesConnectedTo++;
}
}
return noOfNodesConnectedTo;
}
};
registry.register(MetricRegistry.name(BlockingChannelConnectionPool.class, "totalNumberOfNodesConnectedTo"),
totalNumberOfNodesConnectedTo);
totalNumberOfConnections = new Gauge<Integer>() {
@Override
public Integer getValue() {
int noOfConnections = 0;
for (BlockingChannelInfo blockingChannelInfo : connections.values()) {
noOfConnections += blockingChannelInfo.getNumberOfConnections();
}
return noOfConnections;
}
};
registry.register(MetricRegistry.name(BlockingChannelConnectionPool.class, "totalNumberOfConnections"),
totalNumberOfConnections);
requestsWaitingToCheckoutConnectionCount = new AtomicInteger(0);
requestsWaitingToCheckoutConnection = new Gauge<Integer>() {
@Override
public Integer getValue() {
return requestsWaitingToCheckoutConnectionCount.get();
}
};
registry.register(MetricRegistry.name(BlockingChannelConnectionPool.class, "requestsWaitingToCheckoutConnection"),
requestsWaitingToCheckoutConnection);
sslSocketFactoryClientInitializationCount = registry.counter(
MetricRegistry.name(BlockingChannelConnectionPool.class, "SslSocketFactoryClientInitializationCount"));
sslSocketFactoryClientInitializationErrorCount = registry.counter(
MetricRegistry.name(BlockingChannelConnectionPool.class, "SslSocketFactoryClientInitializationErrorCount"));
if (clusterMapConfig.clusterMapSslEnabledDatacenters.length() > 0) {
initializeSSLSocketFactory();
} else {
this.sslSocketFactory = null;
}
}
@Override
public void start() {
logger.info("BlockingChannelConnectionPool started");
}
@Override
public void shutdown() {
logger.info("Shutting down the BlockingChannelConnectionPool");
for (Map.Entry<String, BlockingChannelInfo> channels : connections.entrySet()) {
channels.getValue().cleanup();
}
}
private void initializeSSLSocketFactory() throws Exception {
try {
SSLFactory sslFactory = new SSLFactory(sslConfig);
SSLContext sslContext = sslFactory.getSSLContext();
this.sslSocketFactory = sslContext.getSocketFactory();
this.sslSocketFactoryClientInitializationCount.inc();
} catch (Exception e) {
this.sslSocketFactoryClientInitializationErrorCount.inc();
logger.error("SSLSocketFactory Client Initialization Error ", e);
throw e;
}
}
@Override
public ConnectedChannel checkOutConnection(String host, Port port, long timeoutInMs)
throws IOException, InterruptedException, ConnectionPoolTimeoutException {
final Timer.Context context = connectionCheckOutTime.time();
try {
requestsWaitingToCheckoutConnectionCount.incrementAndGet();
BlockingChannelInfo blockingChannelInfo = connections.get(host + port.getPort());
if (blockingChannelInfo == null) {
synchronized (this) {
blockingChannelInfo = connections.get(host + port.getPort());
if (blockingChannelInfo == null) {
logger.trace("Creating new blocking channel info for host {} and port {}", host, port.getPort());
blockingChannelInfo = new BlockingChannelInfo(config, host, port, registry, sslSocketFactory, sslConfig);
connections.put(host + port.getPort(), blockingChannelInfo);
} else {
logger.trace("Using already existing BlockingChannelInfo for " + host + ":" + port.getPort()
+ " in synchronized block");
}
}
} else {
logger.trace("Using already existing BlockingChannelInfo for " + host + ":" + port.getPort());
}
return blockingChannelInfo.getBlockingChannel(timeoutInMs);
} finally {
requestsWaitingToCheckoutConnectionCount.decrementAndGet();
context.stop();
}
}
@Override
public void checkInConnection(ConnectedChannel connectedChannel) {
final Timer.Context context = connectionCheckInTime.time();
try {
BlockingChannelInfo blockingChannelInfo =
connections.get(connectedChannel.getRemoteHost() + connectedChannel.getRemotePort());
if (blockingChannelInfo == null) {
logger.error("Unexpected state in connection pool. Host {} and port {} not found to checkin connection",
connectedChannel.getRemoteHost(), connectedChannel.getRemotePort());
throw new IllegalArgumentException("Connection does not belong to the pool");
}
blockingChannelInfo.releaseBlockingChannel((BlockingChannel) connectedChannel);
logger.trace("Checking in connection for host {} and port {}", connectedChannel.getRemoteHost(),
connectedChannel.getRemotePort());
} finally {
context.stop();
}
}
@Override
public void destroyConnection(ConnectedChannel connectedChannel) {
final Timer.Context context = connectionDestroyTime.time();
try {
BlockingChannelInfo blockingChannelInfo =
connections.get(connectedChannel.getRemoteHost() + connectedChannel.getRemotePort());
if (blockingChannelInfo == null) {
logger.error("Unexpected state in connection pool. Host {} and port {} not found to checkin connection",
connectedChannel.getRemoteHost(), connectedChannel.getRemotePort());
throw new IllegalArgumentException("Connection does not belong to the pool");
}
blockingChannelInfo.destroyBlockingChannel((BlockingChannel) connectedChannel);
logger.trace("Destroying connection for host {} and port {}", connectedChannel.getRemoteHost(),
connectedChannel.getRemotePort());
} finally {
context.stop();
}
}
}