/* * Copyright © 2014-2015 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package co.cask.tephra.distributed; import co.cask.tephra.TransactionServiceMain; import co.cask.tephra.TxConstants; import co.cask.tephra.runtime.ConfigModule; import co.cask.tephra.runtime.DiscoveryModules; import co.cask.tephra.runtime.TransactionClientModule; import co.cask.tephra.runtime.TransactionModules; import co.cask.tephra.runtime.ZKModule; import com.google.common.base.Throwables; import com.google.inject.Guice; import com.google.inject.Injector; import org.apache.hadoop.conf.Configuration; import org.apache.twill.discovery.DiscoveryServiceClient; import org.apache.twill.internal.zookeeper.InMemoryZKServer; import org.apache.twill.zookeeper.ZKClientService; import org.junit.Assert; import org.junit.ClassRule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeoutException; public class PooledClientProviderTest { public static final int MAX_CLIENT_COUNT = 3; public static final long CLIENT_OBTAIN_TIMEOUT = 10; @ClassRule public static TemporaryFolder tmpFolder = new TemporaryFolder(); @Test public void testClientConnectionPoolMaximumNumberOfClients() throws Exception { // We need a server for the client to connect to InMemoryZKServer zkServer = InMemoryZKServer.builder().setDataDir(tmpFolder.newFolder()).build(); zkServer.startAndWait(); try { Configuration conf = new Configuration(); conf.set(TxConstants.Service.CFG_DATA_TX_ZOOKEEPER_QUORUM, zkServer.getConnectionStr()); conf.set(TxConstants.Manager.CFG_TX_SNAPSHOT_DIR, tmpFolder.newFolder().getAbsolutePath()); conf.set("data.tx.client.count", Integer.toString(MAX_CLIENT_COUNT)); conf.set("data.tx.client.obtain.timeout", Long.toString(CLIENT_OBTAIN_TIMEOUT)); final TransactionServiceMain main = new TransactionServiceMain(conf); final CountDownLatch latch = new CountDownLatch(1); Thread t = new Thread() { @Override public void run() { try { main.start(); latch.countDown(); } catch (Exception e) { throw Throwables.propagate(e); } } }; try { t.start(); // Wait for service to startup latch.await(); startClientAndTestPool(conf); } finally { main.stop(); t.join(); } } finally { zkServer.stopAndWait(); } } private void startClientAndTestPool(Configuration conf) throws Exception { Injector injector = Guice.createInjector( new ConfigModule(conf), new ZKModule(), new DiscoveryModules().getDistributedModules(), new TransactionModules().getDistributedModules(), new TransactionClientModule() ); ZKClientService zkClient = injector.getInstance(ZKClientService.class); zkClient.startAndWait(); final PooledClientProvider clientProvider = new PooledClientProvider(conf, injector.getInstance(DiscoveryServiceClient.class)); // test simple case of get + return. Note: this also initializes the provider's pool, which // takes about one second (discovery). Doing it before we test the threads makes it so that one // thread doesn't take exceptionally longer than the others. try (CloseableThriftClient closeableThriftClient = clientProvider.getCloseableClient()) { // do nothing with the client } //Now race to get MAX_CLIENT_COUNT+1 clients, exhausting the pool and requesting 1 more. List<Future<Integer>> clientIds = new ArrayList<Future<Integer>>(); CountDownLatch countDownLatch = new CountDownLatch(1); ExecutorService executor = Executors.newFixedThreadPool(MAX_CLIENT_COUNT + 1); for (int i = 0; i < MAX_CLIENT_COUNT + 1; i++) { clientIds.add(executor.submit(new RetrieveClient(clientProvider, CLIENT_OBTAIN_TIMEOUT / 2, countDownLatch))); } countDownLatch.countDown(); Set<Integer> ids = new HashSet<Integer>(); for (Future<Integer> id : clientIds) { ids.add(id.get()); } Assert.assertEquals(MAX_CLIENT_COUNT, ids.size()); // now, try it again with, where each thread holds onto the client for twice the client.obtain.timeout value. // one of the threads should throw a TimeOutException, because the other threads don't release their clients // within the configured timeout. countDownLatch = new CountDownLatch(1); for (int i = 0; i < MAX_CLIENT_COUNT + 1; i++) { clientIds.add(executor.submit(new RetrieveClient(clientProvider, CLIENT_OBTAIN_TIMEOUT * 2, countDownLatch))); } countDownLatch.countDown(); int numTimeoutExceptions = 0; for (Future<Integer> clientId : clientIds) { try { clientId.get(); } catch (ExecutionException expected) { Assert.assertEquals(TimeoutException.class, expected.getCause().getClass()); numTimeoutExceptions++; } } // expect that exactly one of the threads hit the TimeoutException Assert.assertEquals(String.format("Expected one thread to not obtain a client within %s milliseconds.", CLIENT_OBTAIN_TIMEOUT), 1, numTimeoutExceptions); executor.shutdown(); } private static class RetrieveClient implements Callable<Integer> { private final PooledClientProvider pool; private final long holdClientMs; private final CountDownLatch begin; public RetrieveClient(PooledClientProvider pool, long holdClientMs, CountDownLatch begin) { this.pool = pool; this.holdClientMs = holdClientMs; this.begin = begin; } @Override public Integer call() throws Exception { begin.await(); try (CloseableThriftClient client = pool.getCloseableClient()) { int id = System.identityHashCode(client.getThriftClient()); // "use" the client for a configured amount of milliseconds Thread.sleep(holdClientMs); return id; } } } }