// Copyright 2016 Twitter. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.twitter.heron.common.network;
import java.nio.channels.SelectableChannel;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Logger;
import com.google.protobuf.Message;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import com.twitter.heron.common.basics.ByteAmount;
import com.twitter.heron.common.basics.NIOLooper;
import com.twitter.heron.common.basics.SysUtils;
import com.twitter.heron.proto.testing.Tests;
/**
* HeronServer Tester.
*/
public class HeronServerTest {
private static final int N = 10;
private static final int WAIT_TIME_MS = 2 * 1000;
private static final String MESSAGE = "message";
private static final Logger LOG = Logger.getLogger(HeronServerTest.class.getName());
private static final String SERVER_HOST = "127.0.0.1";
private static final HeronSocketOptions TEST_SOCKET_OPTIONS = new HeronSocketOptions(
ByteAmount.fromMegabytes(100), Duration.ofMillis(100),
ByteAmount.fromMegabytes(100), Duration.ofMillis(100),
ByteAmount.fromMegabytes(5),
ByteAmount.fromMegabytes(5));
// Following are state variable to test correctness
private volatile boolean isOnConnectedInvoked = false;
private volatile boolean isOnRequestInvoked = false;
private volatile boolean isOnMessageInvoked = false;
private volatile boolean isOnCloseInvoked = false;
private volatile boolean isTimerEventInvoked = false;
private volatile boolean isClientReceivedResponse = false;
private HeronServer heronServer;
private NIOLooper serverLooper;
private HeronClient heronClient;
private NIOLooper clientLooper;
private ExecutorService threadsPool;
// Control whether we need to send request & response
private volatile boolean isRequestNeed = false;
private volatile boolean isMessageNeed = false;
private volatile int messagesReceieved = 0;
/**
* JUnit rule for expected exception
*/
@Rule
public final ExpectedException exception = ExpectedException.none();
@Before
public void before() throws Exception {
// Get an available port
int serverPort = SysUtils.getFreePort();
serverLooper = new NIOLooper();
heronServer = new SimpleHeronServer(serverLooper, SERVER_HOST, serverPort);
clientLooper = new NIOLooper();
heronClient = new SimpleHeronClient(clientLooper, SERVER_HOST, serverPort);
threadsPool = Executors.newFixedThreadPool(2);
}
@After
public void after() throws Exception {
threadsPool.shutdownNow();
heronServer.stop();
heronServer = null;
heronClient.stop();
heronClient = null;
serverLooper.exitLoop();
serverLooper = null;
clientLooper.exitLoop();
clientLooper = null;
threadsPool = null;
// Reset the state
isOnConnectedInvoked = false;
isOnRequestInvoked = false;
isOnMessageInvoked = false;
isOnCloseInvoked = false;
isTimerEventInvoked = false;
isClientReceivedResponse = false;
}
/**
* Method: registerOnMessage(Message.Builder builder)
*/
@Test
public void testRegisterOnMessage() throws Exception {
Message.Builder m = Tests.EchoServerResponse.newBuilder();
heronServer.registerOnMessage(m);
for (Map.Entry<String, Message.Builder> message : heronServer.getMessageMap().entrySet()) {
Assert.assertEquals("heron.proto.testing.EchoServerResponse", message.getKey());
Assert.assertEquals(m, message.getValue());
}
}
/**
* Method: registerOnRequest(Message.Builder builder)
*/
@Test
public void testRegisterOnRequest() throws Exception {
Message.Builder r = Tests.EchoServerRequest.newBuilder();
heronServer.registerOnRequest(r);
for (Map.Entry<String, Message.Builder> request : heronServer.getRequestMap().entrySet()) {
Assert.assertEquals("heron.proto.testing.EchoServerRequest", request.getKey());
Assert.assertEquals(r, request.getValue());
}
}
/**
* Method: start()
*/
@Test
public void testStart() throws Exception {
Assert.assertTrue(heronServer.start());
}
/**
* Method: stop()
*/
@Test
public void testClose() throws Exception {
runBase();
heronServer.stop();
Map<SocketChannel, SocketChannelHelper> activeConnections = heronServer.getActiveConnections();
ServerSocketChannel acceptChannel = heronServer.getAcceptChannel();
Assert.assertNotNull(acceptChannel);
Assert.assertTrue(!acceptChannel.isOpen());
Assert.assertNotNull(activeConnections);
Assert.assertEquals(0, activeConnections.size());
}
/**
* Method: handleAccept(SelectableChannel channel)
*/
@Test
public void testHandleAccept() throws Exception {
runBase();
Map<SocketChannel, SocketChannelHelper> activeConnections = heronServer.getActiveConnections();
ServerSocketChannel acceptChannel = heronServer.getAcceptChannel();
Assert.assertNotNull(acceptChannel);
Assert.assertTrue(acceptChannel.isOpen());
Assert.assertNotNull(activeConnections);
Assert.assertEquals(1, activeConnections.size());
}
/**
* Method: handleRead(SelectableChannel channel)
*/
@Test
public void testHandleRead() throws Exception {
runBase();
Map<SocketChannel, SocketChannelHelper> activeConnections = heronServer.getActiveConnections();
// Exceptions should not be thrown in this call
if (activeConnections.size() != 0) {
// No errors happened
heronServer.handleRead(activeConnections.keySet().iterator().next());
}
}
/**
* Method: handleWrite(SelectableChannel channel)
*/
@Test
public void testHandleWrite() throws Exception {
runBase();
Map<SocketChannel, SocketChannelHelper> activeConnections = heronServer.getActiveConnections();
// Exceptions should not be thrown in this call
heronServer.handleWrite(activeConnections.keySet().iterator().next());
}
/**
* Method: handleConnect(SelectableChannel channel)
*/
@Test
public void testHandleConnect() throws Exception {
exception.expect(RuntimeException.class);
heronServer.handleConnect(null);
}
/**
* Method: handleError(SelectableChannel channel)
*/
@Test
public void testHandleError() throws Exception {
runBase();
Map<SocketChannel, SocketChannelHelper> activeConnections = heronServer.getActiveConnections();
SelectableChannel channel = activeConnections.keySet().iterator().next();
heronServer.handleError(channel);
Assert.assertEquals(0, activeConnections.size());
}
/**
* Method: getNIOLooper()
*/
@Test
public void testGetNIOLooper() throws Exception {
Assert.assertNotNull(heronServer.getNIOLooper());
}
/**
* Method: registerTimerEventInSeconds(long timerInSeconds, Runnable task)
*/
@Test
public void testRegisterTimerEventInSeconds() throws Exception {
Runnable r = new Runnable() {
@Override
public void run() {
isTimerEventInvoked = true;
}
};
heronServer.registerTimerEvent(Duration.ofSeconds(1), r);
runBase();
Assert.assertTrue(isTimerEventInvoked);
}
/**
* Method: sendResponse(REQID rid, SocketChannel channel, Message response)
*/
@Test
public void testSendResponse() throws Exception {
isRequestNeed = true;
runBase();
Assert.assertTrue(isOnRequestInvoked);
Assert.assertTrue(isClientReceivedResponse);
isRequestNeed = false;
}
/**
* Method: sendMessage(SocketChannel channel, Message message)
*/
@Test
public void testSendMessage() throws Exception {
isRequestNeed = true;
isMessageNeed = true;
runBase();
Assert.assertTrue(isOnMessageInvoked);
Assert.assertEquals(N, messagesReceieved);
isRequestNeed = false;
isMessageNeed = false;
messagesReceieved = 0;
}
/**
* Method: registerTimerEventInNanoSeconds(long timerInNanoSecnods, Runnable task)
*/
@Test
public void testRegisterTimerEventInNanoSeconds() throws Exception {
Runnable r = new Runnable() {
@Override
public void run() {
isTimerEventInvoked = true;
}
};
heronServer.registerTimerEvent(Duration.ZERO, r);
runBase();
Assert.assertTrue(isTimerEventInvoked);
}
/**
* Method: onConnect(SocketChannel channel)
*/
@Test
public void testOnConnect() throws Exception {
runBase();
Map<SocketChannel, SocketChannelHelper> activeConnections = heronServer.getActiveConnections();
heronServer.onConnect(activeConnections.keySet().iterator().next());
Assert.assertTrue(isOnConnectedInvoked);
}
/**
* Method: onRequest(REQID rid, SocketChannel channel, Message request)
*/
@Test
public void testOnRequest() throws Exception {
runBase();
Map<SocketChannel, SocketChannelHelper> activeConnections = heronServer.getActiveConnections();
heronServer.onRequest(REQID.generate(), activeConnections.keySet().iterator().next(), null);
Assert.assertTrue(isOnRequestInvoked);
}
/**
* Method: onMessage(SocketChannel channel, Message message)
*/
@Test
public void testOnMessage() throws Exception {
runBase();
Map<SocketChannel, SocketChannelHelper> activeConnections = heronServer.getActiveConnections();
heronServer.onMessage(activeConnections.keySet().iterator().next(), null);
Assert.assertTrue(isOnMessageInvoked);
}
/**
* Method: onClose(SocketChannel channel)
*/
@Test
public void testOnClose() throws Exception {
runBase();
Map<SocketChannel, SocketChannelHelper> activeConnections = heronServer.getActiveConnections();
heronServer.onClose(activeConnections.keySet().iterator().next());
Assert.assertTrue(isOnCloseInvoked);
}
private void runServer() {
Runnable runServer = new Runnable() {
@Override
public void run() {
heronServer.start();
heronServer.getNIOLooper().loop();
}
};
threadsPool.execute(runServer);
}
private void runClient() {
Runnable runClient = new Runnable() {
@Override
public void run() {
heronClient.start();
heronClient.getNIOLooper().loop();
}
};
threadsPool.execute(runClient);
}
private void runBase() throws Exception {
// First run Server
runServer();
// Wait a while for server fully starting
Thread.sleep(WAIT_TIME_MS);
// Then run Client
runClient();
// Should be connected
Thread.sleep(WAIT_TIME_MS);
}
private class SimpleHeronServer extends HeronServer {
SimpleHeronServer(NIOLooper s, String host, int port) {
super(s, host, port, TEST_SOCKET_OPTIONS);
}
@Override
public void onConnect(SocketChannel socketChannel) {
LOG.info("Server got a new connection from host:port:"
+ socketChannel.socket().getRemoteSocketAddress());
isOnConnectedInvoked = true;
// We only register request when we need to test on sendResponse or sendMessage
if (isRequestNeed) {
registerOnRequest(Tests.EchoServerRequest.newBuilder());
}
// If We need to test sendMessage, we would registerOnMessage EchoServerResponse
if (isMessageNeed) {
registerOnMessage(Tests.EchoServerResponse.newBuilder());
}
}
@Override
public void onRequest(REQID rid, SocketChannel channel, Message request) {
isOnRequestInvoked = true;
if (request == null) {
// We just want to see whether we could invoke onRequest() normally
return;
}
if (request instanceof Tests.EchoServerRequest) {
Tests.EchoServerResponse.Builder response = Tests.EchoServerResponse.newBuilder();
Tests.EchoServerRequest req = (Tests.EchoServerRequest) request;
response.setEchoResponse(req.getEchoRequest());
// We only send back response when we need to test on sendResponse or sendMessage
if (isRequestNeed) {
sendResponse(rid, channel, response.build());
}
} else {
LOG.info("Type of request: " + request);
throw new RuntimeException("Unknown type of request received");
}
}
@Override
public void onMessage(SocketChannel socketChannel, Message message) {
isOnMessageInvoked = true;
if (message == null) {
// We just want to see whether we could invoke onMessage() normally
return;
}
if (message instanceof Tests.EchoServerResponse) {
messagesReceieved++;
Assert.assertEquals(MESSAGE, ((Tests.EchoServerResponse) message).getEchoResponse());
} else {
Assert.fail("Unknown message received");
}
}
@Override
public void onClose(SocketChannel socketChannel) {
isOnCloseInvoked = true;
}
}
private class SimpleHeronClient extends HeronClient {
SimpleHeronClient(NIOLooper looper, String host, int port) {
super(looper, host, port, TEST_SOCKET_OPTIONS);
}
@Override
public void onError() {
}
@Override
public void onClose() {
}
@Override
public void onConnect(StatusCode statusCode) {
if (statusCode != StatusCode.OK) {
Assert.fail("Connection with server failed");
} else {
LOG.info("Connected with server");
// We only send request when we need to test on sendResponse or sendMessage
if (isRequestNeed) {
sendRequest();
}
}
}
private void sendRequest() {
Tests.EchoServerRequest.Builder r = Tests.EchoServerRequest.newBuilder();
r.setEchoRequest("Dummy");
sendRequest(r.build(), Tests.EchoServerResponse.newBuilder());
}
@Override
public void onResponse(StatusCode statusCode, Object o, Message response) {
if (response instanceof Tests.EchoServerResponse) {
Tests.EchoServerResponse r = (Tests.EchoServerResponse) response;
isClientReceivedResponse = true;
Assert.assertEquals(r.getEchoResponse(), "Dummy");
// If we want to test sendMessage, we would send Message
if (isMessageNeed) {
Tests.EchoServerResponse.Builder message =
Tests.EchoServerResponse.newBuilder().setEchoResponse(MESSAGE);
for (int i = 0; i < N; i++) {
sendMessage(message.build());
}
}
} else {
Assert.fail("Unknown type of response received");
}
}
@Override
public void onIncomingMessage(Message message) {
LOG.info("OnIncoming Message: " + message);
}
}
}