/** * Copyright 2016 LinkedIn Corp. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ package com.github.ambry.network; import com.codahale.metrics.MetricRegistry; import com.github.ambry.commons.SSLFactory; import com.github.ambry.commons.TestSSLUtils; import com.github.ambry.config.ClusterMapConfig; import com.github.ambry.config.ConnectionPoolConfig; import com.github.ambry.config.NetworkConfig; import com.github.ambry.config.SSLConfig; import com.github.ambry.config.VerifiableProperties; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; import org.junit.After; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; /** * Test for the blocking channel connection pool */ public class BlockingChannelConnectionPoolTest { private SocketServer server1 = null; private SocketServer server2 = null; private SocketServer server3 = null; private static File trustStoreFile = null; private static SSLFactory sslFactory; private static SSLConfig sslConfig; private static ClusterMapConfig plainTextClusterMapConfig; private static ClusterMapConfig sslEnabledClusterMapConfig; private static SSLConfig serverSSLConfig1; private static SSLConfig serverSSLConfig2; private static SSLConfig serverSSLConfig3; private static SSLSocketFactory sslSocketFactory; /** * Run only once for all tests */ @BeforeClass public static void initializeTests() throws Exception { trustStoreFile = File.createTempFile("truststore", ".jks"); serverSSLConfig1 = new SSLConfig(TestSSLUtils.createSslProps("DC2,DC3", SSLFactory.Mode.SERVER, trustStoreFile, "server1")); serverSSLConfig2 = new SSLConfig(TestSSLUtils.createSslProps("DC1,DC3", SSLFactory.Mode.SERVER, trustStoreFile, "server2")); serverSSLConfig3 = new SSLConfig(TestSSLUtils.createSslProps("DC1,DC2", SSLFactory.Mode.SERVER, trustStoreFile, "server3")); VerifiableProperties sslClientProps = TestSSLUtils.createSslProps("DC1,DC2,DC3", SSLFactory.Mode.CLIENT, trustStoreFile, "client"); sslConfig = new SSLConfig(sslClientProps); sslEnabledClusterMapConfig = new ClusterMapConfig(sslClientProps); Properties props = new Properties(); props.setProperty("clustermap.cluster.name", "test"); props.setProperty("clustermap.datacenter.name", "dc1"); props.setProperty("clustermap.host.name", "localhost"); plainTextClusterMapConfig = new ClusterMapConfig(new VerifiableProperties(props)); sslFactory = new SSLFactory(sslConfig); SSLContext sslContext = sslFactory.getSSLContext(); sslSocketFactory = sslContext.getSocketFactory(); } public BlockingChannelConnectionPoolTest() throws Exception { Properties props = new Properties(); props.setProperty("port", "6667"); props.setProperty("clustermap.cluster.name", "test"); props.setProperty("clustermap.datacenter.name", "dc1"); props.setProperty("clustermap.host.name", "localhost"); VerifiableProperties propverify = new VerifiableProperties(props); NetworkConfig config = new NetworkConfig(propverify); ArrayList<Port> ports = new ArrayList<Port>(); ports.add(new Port(6667, PortType.PLAINTEXT)); ports.add(new Port(7667, PortType.SSL)); server1 = new SocketServer(config, serverSSLConfig1, new MetricRegistry(), ports); server1.start(); props.setProperty("port", "6668"); propverify = new VerifiableProperties(props); config = new NetworkConfig(propverify); ports = new ArrayList<Port>(); ports.add(new Port(6668, PortType.PLAINTEXT)); ports.add(new Port(7668, PortType.SSL)); server2 = new SocketServer(config, serverSSLConfig2, new MetricRegistry(), ports); server2.start(); props.setProperty("port", "6669"); propverify = new VerifiableProperties(props); config = new NetworkConfig(propverify); ports = new ArrayList<Port>(); ports.add(new Port(6669, PortType.PLAINTEXT)); ports.add(new Port(7669, PortType.SSL)); server3 = new SocketServer(config, serverSSLConfig3, new MetricRegistry(), ports); server3.start(); } @After public void cleanup() { server1.shutdown(); server2.shutdown(); server3.shutdown(); } class BlockingChannelInfoThread implements Runnable { private final BlockingChannelInfo channelInfo; private final CountDownLatch channelCount; private final CountDownLatch shouldRelease; private final CountDownLatch releaseComplete; private final boolean destroyConnection; private final AtomicReference<Exception> exception; public BlockingChannelInfoThread(BlockingChannelInfo channelInfo, CountDownLatch channelCount, CountDownLatch shouldRelease, CountDownLatch releaseComplete, boolean destroyConnection, AtomicReference<Exception> exception) { this.channelInfo = channelInfo; this.channelCount = channelCount; this.shouldRelease = shouldRelease; this.releaseComplete = releaseComplete; this.destroyConnection = destroyConnection; this.exception = exception; } @Override public void run() { try { BlockingChannel channel = channelInfo.getBlockingChannel(1000); channelCount.countDown(); if (shouldRelease.await(1000, TimeUnit.MILLISECONDS)) { if (destroyConnection) { channelInfo.destroyBlockingChannel(channel); } else { channelInfo.releaseBlockingChannel(channel); } } else if (exception.get() == null) { exception.set(new Exception("Timed out waiting for signal to release connections")); } } catch (Exception e) { exception.set(e); } finally { releaseComplete.countDown(); } } } //@Test public void testBlockingChannelInfoForPlainText() throws Exception { testBlockingChannelInfo("127.0.0.1", new Port(6667, PortType.PLAINTEXT), 5, 5); } @Test public void testBlockingChannelInfoForSSL() throws Exception { testBlockingChannelInfo("127.0.0.1", new Port(7667, PortType.SSL), 5, 5); } /** * Tests how connection failures are handled by BlockingChannelInfo. */ @Test public void testConnectionFailureCases() throws InterruptedException, ConnectionPoolTimeoutException, IOException { int port = 6680; String host = "127.0.0.1"; SocketServer server = startServer(port); Properties props = new Properties(); props.setProperty("clustermap.cluster.name", "test"); props.setProperty("clustermap.datacenter.name", "dc1"); props.setProperty("clustermap.host.name", "localhost"); BlockingChannelInfo channelInfo = new BlockingChannelInfo(new ConnectionPoolConfig(new VerifiableProperties(props)), host, new Port(port, PortType.PLAINTEXT), new MetricRegistry(), sslSocketFactory, sslConfig); // ask for N no of connections Assert.assertEquals(channelInfo.getNumberOfConnections(), 0); BlockingChannel blockingChannel1 = channelInfo.getBlockingChannel(1000); Assert.assertEquals(channelInfo.getNumberOfConnections(), 1); BlockingChannel blockingChannel2 = channelInfo.getBlockingChannel(1000); Assert.assertEquals(channelInfo.getNumberOfConnections(), 2); BlockingChannel blockingChannel3 = channelInfo.getBlockingChannel(1000); Assert.assertEquals(channelInfo.getNumberOfConnections(), 3); // realease 2 of them back to pool channelInfo.releaseBlockingChannel(blockingChannel2); channelInfo.releaseBlockingChannel(blockingChannel3); Assert.assertEquals("Available connections count mismatch ", 2, channelInfo.availableConnections.getValue().intValue()); // shutdown server server.shutdown(); // destroy one of the connections and verify that the available connections cleaned up channelInfo.destroyBlockingChannel(blockingChannel1); Assert.assertEquals("Available connections should have not been cleaned up", 0, channelInfo.availableConnections.getValue().intValue()); // bring up the server startServer(port); // ask for 2 more connections BlockingChannel blockingChannel4 = channelInfo.getBlockingChannel(1000); Assert.assertEquals(channelInfo.getNumberOfConnections(), 1); BlockingChannel blockingChannel5 = channelInfo.getBlockingChannel(1000); Assert.assertEquals(channelInfo.getNumberOfConnections(), 2); // release one of them back to pool channelInfo.releaseBlockingChannel(blockingChannel4); // verify that destroy connection will not trigger clean up of available connections // as connection recreation should have passed channelInfo.destroyBlockingChannel(blockingChannel5); Assert.assertEquals("Available connections should not have been cleaned up", 2, channelInfo.availableConnections.getValue().intValue()); } private void testBlockingChannelInfo(String host, Port port, int maxConnectionsPerPortPlainText, int maxConnectionsPerPortSSL) throws Exception { Properties props = new Properties(); props.put("connectionpool.max.connections.per.port.plain.text", "" + maxConnectionsPerPortPlainText); props.put("connectionpool.max.connections.per.port.ssl", "" + maxConnectionsPerPortSSL); props.put("clustermap.cluster.name", "test"); props.put("clustermap.datacenter.name", "dc1"); props.put("clustermap.host.name", "localhost"); int maxConnectionsPerHost = (port.getPortType() == PortType.PLAINTEXT) ? maxConnectionsPerPortPlainText : maxConnectionsPerPortSSL; createAndReleaseSingleChannelTest(props, host, port); overSubscriptionTest(props, host, port, maxConnectionsPerHost, true); overSubscriptionTest(props, host, port, maxConnectionsPerHost, false); underSubscriptionTest(props, host, port, (maxConnectionsPerHost / 2)); } private void createAndReleaseSingleChannelTest(Properties props, String host, Port port) throws InterruptedException, ConnectionPoolTimeoutException { BlockingChannelInfo channelInfo = new BlockingChannelInfo(new ConnectionPoolConfig(new VerifiableProperties(props)), host, port, new MetricRegistry(), sslSocketFactory, sslConfig); Assert.assertEquals(channelInfo.getNumberOfConnections(), 0); BlockingChannel blockingChannel = channelInfo.getBlockingChannel(1000); Assert.assertEquals(channelInfo.getNumberOfConnections(), 1); channelInfo.releaseBlockingChannel(blockingChannel); Assert.assertEquals(channelInfo.getNumberOfConnections(), 1); } private void overSubscriptionTest(Properties props, String host, Port port, int maxConnectionsPerHost, boolean destroyConnection) throws Exception { AtomicReference<Exception> exception = new AtomicReference<Exception>(); BlockingChannelInfo channelInfo = new BlockingChannelInfo(new ConnectionPoolConfig(new VerifiableProperties(props)), host, port, new MetricRegistry(), sslSocketFactory, sslConfig); CountDownLatch channelCount = new CountDownLatch(maxConnectionsPerHost); CountDownLatch shouldRelease = new CountDownLatch(1); CountDownLatch releaseComplete = new CountDownLatch(2 * maxConnectionsPerHost); for (int i = 0; i < maxConnectionsPerHost; i++) { BlockingChannelInfoThread infoThread = new BlockingChannelInfoThread(channelInfo, channelCount, shouldRelease, releaseComplete, destroyConnection, exception); Thread t = new Thread(infoThread); t.start(); } awaitCountdown(channelCount, 1000, exception, "Timed out while waiting for channel count to reach " + maxConnectionsPerHost); Assert.assertEquals(channelInfo.getNumberOfConnections(), maxConnectionsPerHost); // try "maxConnectionsPerHost" more connections channelCount = new CountDownLatch(maxConnectionsPerHost); for (int i = 0; i < maxConnectionsPerHost; i++) { BlockingChannelInfoThread infoThread = new BlockingChannelInfoThread(channelInfo, channelCount, shouldRelease, releaseComplete, destroyConnection, exception); Thread t = new Thread(infoThread); t.start(); } Assert.assertEquals(channelInfo.getNumberOfConnections(), maxConnectionsPerHost); shouldRelease.countDown(); awaitCountdown(channelCount, 1000, exception, "Timed out while waiting for channel count to reach " + maxConnectionsPerHost); Assert.assertEquals(channelInfo.getNumberOfConnections(), maxConnectionsPerHost); awaitCountdown(releaseComplete, 2000, exception, "Timed out while waiting for channels to be released"); channelInfo.cleanup(); Assert.assertEquals(channelInfo.getNumberOfConnections(), 0); } private void underSubscriptionTest(Properties props, String host, Port port, int underSubscriptionCount) throws Exception { AtomicReference<Exception> exception = new AtomicReference<Exception>(); BlockingChannelInfo channelInfo = new BlockingChannelInfo(new ConnectionPoolConfig(new VerifiableProperties(props)), host, port, new MetricRegistry(), sslSocketFactory, sslConfig); CountDownLatch channelCount = new CountDownLatch(underSubscriptionCount); CountDownLatch shouldRelease = new CountDownLatch(1); CountDownLatch releaseComplete = new CountDownLatch(underSubscriptionCount); for (int i = 0; i < underSubscriptionCount; i++) { BlockingChannelInfoThread infoThread = new BlockingChannelInfoThread(channelInfo, channelCount, shouldRelease, releaseComplete, true, exception); Thread t = new Thread(infoThread); t.start(); } shouldRelease.countDown(); awaitCountdown(releaseComplete, 2000, exception, "Timed out while waiting for channels to be released"); Assert.assertEquals(channelInfo.getNumberOfConnections(), underSubscriptionCount); channelInfo.getBlockingChannel(1000); Assert.assertEquals(channelInfo.getNumberOfConnections(), underSubscriptionCount); channelInfo.cleanup(); Assert.assertEquals(channelInfo.getNumberOfConnections(), 0); } /** * Starts up an ambry server given a port number in localhost * * @param port the port number over which ambry server needs to be started * @return the {@link SocketServer} referring to ambry's instance * @throws IOException * @throws InterruptedException */ private SocketServer startServer(int port) throws IOException, InterruptedException { Properties props = new Properties(); props.setProperty("port", "" + port); props.setProperty("clustermap.cluster.name", "test"); props.setProperty("clustermap.datacenter.name", "dc1"); props.setProperty("clustermap.host.name", "localhost"); VerifiableProperties propverify = new VerifiableProperties(props); NetworkConfig config = new NetworkConfig(propverify); ArrayList<Port> ports = new ArrayList<Port>(); ports.add(new Port(port, PortType.PLAINTEXT)); ports.add(new Port(port + 1000, PortType.SSL)); SocketServer server = new SocketServer(config, serverSSLConfig1, new MetricRegistry(), ports); server.start(); return server; } class ConnectionPoolThread implements Runnable { private final AtomicReference<Exception> exception; private final Map<String, CountDownLatch> channelCount; private final ConnectionPool connectionPool; private final boolean destroyConnection; private final CountDownLatch shouldRelease; private final CountDownLatch releaseComplete; private Map<String, Port> channelToPortMap; public ConnectionPoolThread(Map<String, CountDownLatch> channelCount, Map<String, Port> channelToPortMap, ConnectionPool connectionPool, boolean destroyConnection, CountDownLatch shouldRelease, CountDownLatch releaseComplete, AtomicReference<Exception> e) { this.channelCount = channelCount; this.channelToPortMap = channelToPortMap; this.connectionPool = connectionPool; this.destroyConnection = destroyConnection; this.shouldRelease = shouldRelease; this.releaseComplete = releaseComplete; this.exception = e; } @Override public void run() { try { List<ConnectedChannel> connectedChannels = new ArrayList<ConnectedChannel>(); for (String channelStr : channelCount.keySet()) { Port port = channelToPortMap.get(channelStr); ConnectedChannel channel = connectionPool.checkOutConnection("localhost", new Port(port.getPort(), port.getPortType()), 1000); connectedChannels.add(channel); channelCount.get(channelStr).countDown(); } if (shouldRelease.await(5000, TimeUnit.MILLISECONDS)) { for (ConnectedChannel channel : connectedChannels) { if (destroyConnection) { connectionPool.destroyConnection(channel); } else { connectionPool.checkInConnection(channel); } } } else if (exception.get() == null) { exception.set(new Exception("Timed out waiting for signal to release connections")); } } catch (Exception e) { exception.set(e); } finally { releaseComplete.countDown(); } } } @Test public void testBlockingChannelConnectionPool() throws Exception { Properties props = new Properties(); props.put("connectionpool.max.connections.per.port.plain.text", "5"); props.put("connectionpool.max.connections.per.port.ssl", "5"); props.put("clustermap.cluster.name", "test"); props.put("clustermap.datacenter.name", "dc1"); props.put("clustermap.host.name", "localhost"); ConnectionPool connectionPool = new BlockingChannelConnectionPool(new ConnectionPoolConfig(new VerifiableProperties(props)), sslConfig, plainTextClusterMapConfig, new MetricRegistry()); connectionPool.start(); CountDownLatch shouldRelease = new CountDownLatch(1); CountDownLatch releaseComplete = new CountDownLatch(10); AtomicReference<Exception> exception = new AtomicReference<Exception>(); Map<String, CountDownLatch> channelCount = new HashMap<String, CountDownLatch>(); channelCount.put("localhost" + 6667, new CountDownLatch(5)); channelCount.put("localhost" + 6668, new CountDownLatch(5)); channelCount.put("localhost" + 6669, new CountDownLatch(5)); Map<String, Port> channelToPortMap = new HashMap<String, Port>(); channelToPortMap.put("localhost" + 6667, new Port(6667, PortType.PLAINTEXT)); channelToPortMap.put("localhost" + 6668, new Port(6668, PortType.PLAINTEXT)); channelToPortMap.put("localhost" + 6669, new Port(6669, PortType.PLAINTEXT)); for (int i = 0; i < 10; i++) { ConnectionPoolThread connectionPoolThread = new ConnectionPoolThread(channelCount, channelToPortMap, connectionPool, false, shouldRelease, releaseComplete, exception); Thread t = new Thread(connectionPoolThread); t.start(); } for (String channelStr : channelCount.keySet()) { awaitCountdown(channelCount.get(channelStr), 1000, exception, "Timed out waiting for channel count to reach 5"); } // reset for (String channelStr : channelCount.keySet()) { channelCount.put(channelStr, new CountDownLatch(5)); } shouldRelease.countDown(); for (String channelStr : channelCount.keySet()) { awaitCountdown(channelCount.get(channelStr), 1000, exception, "Timed out waiting for channel count to reach 5"); } awaitCountdown(releaseComplete, 2000, exception, "Timed out while waiting for channels to be released"); connectionPool.shutdown(); } //@Test public void testSSLBlockingChannelConnectionPool() throws Exception { Properties props = new Properties(); props.put("connectionpool.max.connections.per.port.plain.text", "5"); props.put("connectionpool.max.connections.per.port.ssl", "5"); props.put("clustermap.cluster.name", "test"); props.put("clustermap.datacenter.name", "dc1"); props.put("clustermap.host.name", "localhost"); ConnectionPool connectionPool = new BlockingChannelConnectionPool(new ConnectionPoolConfig(new VerifiableProperties(props)), sslConfig, sslEnabledClusterMapConfig, new MetricRegistry()); connectionPool.start(); CountDownLatch shouldRelease = new CountDownLatch(1); CountDownLatch releaseComplete = new CountDownLatch(10); AtomicReference<Exception> exception = new AtomicReference<Exception>(); Map<String, CountDownLatch> channelCount = new HashMap<String, CountDownLatch>(); channelCount.put("localhost" + 7667, new CountDownLatch(5)); channelCount.put("localhost" + 7668, new CountDownLatch(5)); channelCount.put("localhost" + 7669, new CountDownLatch(5)); Map<String, Port> channelToPortMap = new HashMap<String, Port>(); channelToPortMap.put("localhost" + 7667, new Port(7667, PortType.SSL)); channelToPortMap.put("localhost" + 7668, new Port(7668, PortType.SSL)); channelToPortMap.put("localhost" + 7669, new Port(7669, PortType.SSL)); for (int i = 0; i < 10; i++) { ConnectionPoolThread connectionPoolThread = new ConnectionPoolThread(channelCount, channelToPortMap, connectionPool, false, shouldRelease, releaseComplete, exception); Thread t = new Thread(connectionPoolThread); t.start(); } for (String channelStr : channelCount.keySet()) { awaitCountdown(channelCount.get(channelStr), 1000, exception, "Timed out waiting for channel count to reach 5"); } // reset for (String channelStr : channelCount.keySet()) { channelCount.put(channelStr, new CountDownLatch(5)); } shouldRelease.countDown(); for (String channelStr : channelCount.keySet()) { awaitCountdown(channelCount.get(channelStr), 1000, exception, "Timed out waiting for channel count to reach 5"); } awaitCountdown(releaseComplete, 2000, exception, "Timed out while waiting for channels to be released"); connectionPool.shutdown(); } private void awaitCountdown(CountDownLatch countDownLatch, long timeoutMs, AtomicReference<Exception> exception, String errMsg) throws Exception { if (!countDownLatch.await(timeoutMs, TimeUnit.MILLISECONDS)) { if (exception.get() == null) { exception.set(new Exception(errMsg)); } throw exception.get(); } } }