/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.thrift.async; import java.io.IOException; import java.io.PrintWriter; import java.io.StringWriter; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import junit.framework.TestCase; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.server.ServerTestBase; import org.apache.thrift.server.THsHaServer; import org.apache.thrift.server.THsHaServer.Args; import org.apache.thrift.transport.TNonblockingServerSocket; import org.apache.thrift.transport.TNonblockingSocket; import thrift.test.CompactProtoTestStruct; import thrift.test.Srv; import thrift.test.Srv.Iface; import thrift.test.Srv.AsyncClient.Janky_call; import thrift.test.Srv.AsyncClient.onewayMethod_call; import thrift.test.Srv.AsyncClient.primitiveMethod_call; import thrift.test.Srv.AsyncClient.voidMethod_call; public class TestTAsyncClientManager extends TestCase { private THsHaServer server_; private Thread serverThread_; private TAsyncClientManager clientManager_; public void setUp() throws Exception { server_ = new THsHaServer(new Args(new TNonblockingServerSocket(ServerTestBase.PORT)).processor(new Srv.Processor(new SrvHandler()))); serverThread_ = new Thread(new Runnable() { public void run() { server_.serve(); } }); serverThread_.start(); clientManager_ = new TAsyncClientManager(); Thread.sleep(500); } public void tearDown() throws Exception { server_.stop(); clientManager_.stop(); serverThread_.join(); } public void testBasicCall() throws Exception { Srv.AsyncClient client = getClient(); basicCall(client); } public void testBasicCallWithTimeout() throws Exception { Srv.AsyncClient client = getClient(); client.setTimeout(5000); basicCall(client); } public void testTimeoutCall() throws Exception { final CountDownLatch latch = new CountDownLatch(1); Srv.AsyncClient client = getClient(); client.setTimeout(100); client.primitiveMethod(new AsyncMethodCallback<primitiveMethod_call>() { @Override public void onError(Exception exception) { try { if (!(exception instanceof TimeoutException)) { StringWriter sink = new StringWriter(); exception.printStackTrace(new PrintWriter(sink, true)); fail("expected TimeoutException but got " + sink.toString()); } } finally { latch.countDown(); } } @Override public void onComplete(primitiveMethod_call response) { try { fail("Should not have finished timed out call."); } finally { latch.countDown(); } } }); latch.await(2, TimeUnit.SECONDS); assertTrue(client.hasError()); assertTrue(client.getError() instanceof TimeoutException); } public void testVoidCall() throws Exception { final CountDownLatch latch = new CountDownLatch(1); final AtomicBoolean returned = new AtomicBoolean(false); Srv.AsyncClient client = getClient(); client.voidMethod(new FailureLessCallback<Srv.AsyncClient.voidMethod_call>() { @Override public void onComplete(voidMethod_call response) { try { response.getResult(); returned.set(true); } catch (TException e) { fail(e); } finally { latch.countDown(); } } }); latch.await(1, TimeUnit.SECONDS); assertTrue(returned.get()); } public void testOnewayCall() throws Exception { final CountDownLatch latch = new CountDownLatch(1); final AtomicBoolean returned = new AtomicBoolean(false); Srv.AsyncClient client = getClient(); client.onewayMethod(new FailureLessCallback<onewayMethod_call>() { @Override public void onComplete(onewayMethod_call response) { try { response.getResult(); returned.set(true); } catch (TException e) { fail(e); } finally { latch.countDown(); } } }); latch.await(1, TimeUnit.SECONDS); assertTrue(returned.get()); } public void testParallelCalls() throws Exception { // make multiple calls with deserialization in the selector thread (repro Eric's issue) int numThreads = 50; int numCallsPerThread = 100; List<JankyRunnable> runnables = new ArrayList<JankyRunnable>(); List<Thread> threads = new ArrayList<Thread>(); for (int i = 0; i < numThreads; i++) { JankyRunnable runnable = new JankyRunnable(numCallsPerThread); Thread thread = new Thread(runnable); thread.start(); threads.add(thread); runnables.add(runnable); } for (Thread thread : threads) { thread.join(); } int numSuccesses = 0; for (JankyRunnable runnable : runnables) { numSuccesses += runnable.getNumSuccesses(); } assertEquals(numThreads * numCallsPerThread, numSuccesses); } private Srv.AsyncClient getClient() throws IOException { TNonblockingSocket clientSocket = new TNonblockingSocket(ServerTestBase.HOST, ServerTestBase.PORT); return new Srv.AsyncClient(new TBinaryProtocol.Factory(), clientManager_, clientSocket); } private void basicCall(Srv.AsyncClient client) throws Exception { final CountDownLatch latch = new CountDownLatch(1); final AtomicBoolean returned = new AtomicBoolean(false); client.Janky(1, new FailureLessCallback<Srv.AsyncClient.Janky_call>() { @Override public void onComplete(Janky_call response) { try { assertEquals(3, response.getResult()); returned.set(true); } catch (TException e) { fail(e); } finally { latch.countDown(); } } @Override public void onError(Exception exception) { try { StringWriter sink = new StringWriter(); exception.printStackTrace(new PrintWriter(sink, true)); fail("unexpected onError with exception " + sink.toString()); } finally { latch.countDown(); } } }); latch.await(100, TimeUnit.SECONDS); assertTrue(returned.get()); } public class SrvHandler implements Iface { // Use this method for a standard call testing @Override public int Janky(int arg) throws TException { assertEquals(1, arg); return 3; } // Using this method for timeout testing - sleeps for 1 second before returning @Override public int primitiveMethod() throws TException { try { Thread.sleep(1000); } catch (InterruptedException e) { e.printStackTrace(); } return 0; } @Override public void methodWithDefaultArgs(int something) throws TException { } @Override public CompactProtoTestStruct structMethod() throws TException { return null; } @Override public void voidMethod() throws TException { } @Override public void onewayMethod() throws TException { } } private static abstract class FailureLessCallback<T extends TAsyncMethodCall> implements AsyncMethodCallback<T> { @Override public void onError(Exception exception) { fail(exception); } } private static void fail(Exception exception) { StringWriter sink = new StringWriter(); exception.printStackTrace(new PrintWriter(sink, true)); fail("unexpected error " + sink.toString()); } private class JankyRunnable implements Runnable { private int numCalls_; private int numSuccesses_ = 0; private Srv.AsyncClient client_; public JankyRunnable(int numCalls) throws Exception { numCalls_ = numCalls; client_ = getClient(); client_.setTimeout(20000); } public int getNumSuccesses() { return numSuccesses_; } public void run() { for (int i = 0; i < numCalls_ && !client_.hasError(); i++) { final int iteration = i; try { // connect an async client final CountDownLatch latch = new CountDownLatch(1); final AtomicBoolean returned = new AtomicBoolean(false); client_.Janky(1, new AsyncMethodCallback<Srv.AsyncClient.Janky_call>() { @Override public void onComplete(Janky_call response) { try { assertEquals(3, response.getResult()); returned.set(true); latch.countDown(); } catch (TException e) { latch.countDown(); fail(e); } } @Override public void onError(Exception exception) { try { StringWriter sink = new StringWriter(); exception.printStackTrace(new PrintWriter(sink, true)); fail("unexpected onError on iteration " + iteration + ": " + sink.toString()); } finally { latch.countDown(); } } }); boolean calledBack = latch.await(30, TimeUnit.SECONDS); assertTrue("wasn't called back in time on iteration " + iteration, calledBack); assertTrue("onComplete not called on iteration " + iteration, returned.get()); this.numSuccesses_++; } catch (Exception e) { fail(e); } } } } }