// Copyright 2017 Twitter. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.twitter.heron.common.network;
import java.io.IOException;
import java.time.Duration;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
import com.google.protobuf.GeneratedMessage;
import com.google.protobuf.Message;
import com.twitter.heron.common.basics.ByteAmount;
import com.twitter.heron.common.basics.NIOLooper;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
* Class to help simplify testing HeronServer implementations by consolidating common test
* functionality. A new instance of this class should be used for each test. After initializing the
* class, tests should call start() which guarantees the following:
* <ol>
* <li>the server has started</li>
* <li>the client has been launched</li>
* </ol>
* Upon test completion, the stop() method should be closed to properly release resources. This
* method is safe to be called on an instance that was never properly started.
*/
public class HeronServerTester {
public static final String SERVER_HOST = "127.0.0.1";
public static final HeronSocketOptions TEST_SOCKET_OPTIONS = new HeronSocketOptions(
ByteAmount.fromMegabytes(100), Duration.ofMillis(100),
ByteAmount.fromMegabytes(100), Duration.ofMillis(100),
ByteAmount.fromMegabytes(5),
ByteAmount.fromMegabytes(5));
private static final Duration DEFAULT_LATCH_TIMEOUT = Duration.ofSeconds(2);
private static final Duration SERVER_START_TIMEOUT = Duration.ofSeconds(2);
public static final Duration RESPONSE_RECEIVED_TIMEOUT = Duration.ofSeconds(10);
private final HeronServer server;
private final ExecutorService threadsPool;
private HeronClient client;
private final CountDownLatch serverStartedSignal;
private CountDownLatch responseReceivedSignal;
private Duration responseReceivedTimeout;
/**
* Constructor to use for the common use case of sending a single request and taking some action
* upon a response. The start() method does the following:
* <ol>
* <li>Initializes the server and a default client</li>
* <li>Calls sendMessage on the client using the message from TestRequestHandler</li>
* <li>Waits up to responseReceivedTimeout for the client to receive a response</li>
* <li>Invokes the TestResponseHandler.handleResponse method</li>
* </ol>
*
* @param server the server to test
* @param requestHandler the request handler to use to build the request and response builder
* @param responseHandler the handler to handle the received response
* @param responseReceivedTimeout how long to wait for the response
* @throws IOException
*/
public HeronServerTester(HeronServer server,
TestRequestHandler requestHandler,
TestResponseHandler responseHandler,
Duration responseReceivedTimeout) throws IOException {
this(server);
this.responseReceivedSignal = new CountDownLatch(1);
this.responseReceivedTimeout = responseReceivedTimeout;
this.client = new TestClient(new NIOLooper(),
server.getEndpoint().getHostName(), server.getEndpoint().getPort(),
responseReceivedSignal, requestHandler, responseHandler);
}
/**
* Constructor where a custom client is to be used. The start() method starts both the server and
* the client.
* @param server server to test
* @param client client to test
*/
public HeronServerTester(HeronServer server, HeronClient client) {
this(server);
this.client = client;
}
private HeronServerTester(HeronServer server) {
this.server = server;
this.threadsPool = Executors.newFixedThreadPool(2);
this.serverStartedSignal = new CountDownLatch(1);
}
public void start() throws InterruptedException {
// First run Server
runServer();
// Then run Client
runClient();
// Wait to make sure message was sent and response was received
if (responseReceivedTimeout != null) {
await(responseReceivedSignal, responseReceivedTimeout);
}
}
public void stop() {
threadsPool.shutdownNow();
server.stop();
if (client != null) {
client.stop();
client.getNIOLooper().exitLoop();
}
server.getNIOLooper().exitLoop();
}
/**
* Convenience method to wait on a latch with a default timeout. If the timeout is reached, the
* fail() method will be invoked.
*/
public static void await(CountDownLatch latch) {
await(latch, DEFAULT_LATCH_TIMEOUT);
}
/**
* Convenience method to wait on a latch with a default timeout. If the timeout is reached, the
* fail() method will be invoked.
*/
public static void await(CountDownLatch latch, Duration timeout) {
try {
latch.await(timeout.toMillis(), TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
fail(String.format(
"Await latch failed to release until timeout of %s was reached. Check latch logic.",
timeout));
}
}
private void runServer() {
Runnable runServer = new Runnable() {
@Override
public void run() {
server.start();
serverStartedSignal.countDown();
server.getNIOLooper().loop();
}
};
threadsPool.execute(runServer);
}
private void runClient() {
Runnable runClient = new Runnable() {
@Override
public void run() {
try {
await(serverStartedSignal, SERVER_START_TIMEOUT);
client.start();
client.getNIOLooper().loop();
} finally {
client.stop();
}
}
};
threadsPool.execute(runClient);
}
/**
* Abstract class to simplify writing test clients by providing default method impls.
*/
public abstract static class AbstractTestClient extends HeronClient {
protected AbstractTestClient(NIOLooper looper,
String host, int port,
HeronSocketOptions options) {
super(looper, host, port, options);
}
@Override
public void onError() {
fail("Error in client while talking to server");
}
@Override
public void onConnect(StatusCode status) {
}
@Override
public void onResponse(StatusCode status, Object ctx, Message response) {
}
@Override
public void onIncomingMessage(Message request) {
fail("Incoming message not expected on the client");
}
@Override
public void onClose() {
}
}
private static class TestClient extends AbstractTestClient {
private static final Logger LOG = Logger.getLogger(TestClient.class.getName());
private final TestRequestHandler requestHandler;
private final TestResponseHandler responseHandler;
private final CountDownLatch responseReceivedSignal;
TestClient(NIOLooper looper, String host, int port, CountDownLatch responseReceivedSignal,
TestRequestHandler requestHandler, TestResponseHandler responseHandler) {
super(looper, host, port, TEST_SOCKET_OPTIONS);
this.requestHandler = requestHandler;
this.responseHandler = responseHandler;
this.responseReceivedSignal = responseReceivedSignal;
}
@Override
public void onConnect(StatusCode status) {
if (status != StatusCode.OK) {
fail("Connection with server failed, onConnect status: " + status);
} else {
LOG.info("Connected with Metrics Manager Server");
sendRequest(requestHandler.getRequestMessage(), requestHandler.getResponseBuilder());
}
}
@Override
public void onResponse(StatusCode status, Object ctx, Message response) {
responseReceivedSignal.countDown();
try {
responseHandler.handleResponse(this, status, ctx, response);
// SUPPRESS CHECKSTYLE IllegalCatch
} catch (Exception e) {
fail("Exception while handling response:" + e);
}
}
}
/**
* Interface to provide the Message to be sent upon onConnect and the expected Message.Builder
* to be used for the response.
*/
public interface TestRequestHandler {
Message getRequestMessage();
Message.Builder getResponseBuilder();
}
/**
* Interface to handle a response received by the server.
*/
public interface TestResponseHandler {
void handleResponse(HeronClient client,
StatusCode status,
Object ctx,
Message response) throws Exception;
}
/**
* Generic SuccessResponseHandler that asserts that the response status code is OK and that the
* message is of the expected type. After that assertion, delegates to delegate for additional
* assertions if set.
*/
public static final class SuccessResponseHandler
implements TestResponseHandler {
private final Class<? extends GeneratedMessage> expectedMessageClass;
private final TestResponseHandler delegate;
public SuccessResponseHandler(Class<? extends GeneratedMessage> expectedMessageClass) {
this(expectedMessageClass, null);
}
public SuccessResponseHandler(Class<? extends GeneratedMessage> expectedMessageClass,
TestResponseHandler delegate) {
this.expectedMessageClass = expectedMessageClass;
this.delegate = delegate;
}
@Override
public void handleResponse(HeronClient client,
StatusCode status,
Object ctx, Message response) throws Exception {
assertTrue(String.format(
"Unexpected response message class received from the server. Expected: %s, Found: %s",
expectedMessageClass.getName(), response.getClass().getName()),
expectedMessageClass.isAssignableFrom(response.getClass()));
assertEquals("Unexpected response code received from the server.", StatusCode.OK, status);
if (delegate != null) {
delegate.handleResponse(client, status, ctx, response);
}
}
}
}