// Copyright 2016 Twitter. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.twitter.heron.common.network; import java.io.IOException; import java.nio.channels.SocketChannel; import java.time.Duration; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.logging.Logger; import com.google.protobuf.Message; import org.junit.After; import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import com.twitter.heron.common.basics.ByteAmount; import com.twitter.heron.common.basics.NIOLooper; import com.twitter.heron.common.basics.SysUtils; import com.twitter.heron.proto.testing.Tests; public class EchoTest { private static int serverPort; private ExecutorService threadsPool; private static final HeronSocketOptions TEST_SOCKET_OPTIONS = new HeronSocketOptions( ByteAmount.fromMegabytes(100), Duration.ofMillis(100), ByteAmount.fromMegabytes(100), Duration.ofMillis(100), ByteAmount.fromMegabytes(5), ByteAmount.fromMegabytes(5)); @BeforeClass public static void beforeClass() throws Exception { } @AfterClass public static void afterClass() throws Exception { } @Before public void before() throws Exception { threadsPool = Executors.newSingleThreadExecutor(); // Get an available port serverPort = SysUtils.getFreePort(); } @After public void after() throws Exception { threadsPool.shutdownNow(); threadsPool = null; } @Test public void testStart() throws Exception { runServer(); // We'll sleep to give the server a chance to bind and start listening Thread.sleep(1000); runClient(); } private void runServer() { Runnable server = new Runnable() { @Override public void run() { NIOLooper looper; try { looper = new NIOLooper(); EchoServer s = new EchoServer(looper, serverPort, 1000); s.start(); } catch (IOException e) { throw new RuntimeException("Some error instantiating server"); } looper.loop(); } }; threadsPool.execute(server); } private void runClient() { NIOLooper looper; try { looper = new NIOLooper(); EchoClient c = new EchoClient(looper, serverPort, 1000); c.start(); } catch (IOException e) { throw new RuntimeException("Some error instantiating client"); } looper.loop(); } private static class EchoServer extends HeronServer { private static final Logger LOG = Logger.getLogger(EchoServer.class.getName()); private int nRequests; private int maxRequests; EchoServer(NIOLooper looper, int port, int maxRequests) { super(looper, "localhost", port, TEST_SOCKET_OPTIONS); nRequests = 0; this.maxRequests = maxRequests; registerOnRequest(Tests.EchoServerRequest.newBuilder()); } @Override public void onConnect(SocketChannel channel) { LOG.info("A new client connected with us"); } @Override public void onClose(SocketChannel channel) { LOG.info("A client closed connection"); } @Override public void onRequest(REQID rid, SocketChannel channel, Message request) { if (request instanceof Tests.EchoServerRequest) { Tests.EchoServerResponse.Builder response = Tests.EchoServerResponse.newBuilder(); Tests.EchoServerRequest req = (Tests.EchoServerRequest) request; response.setEchoResponse(req.getEchoRequest()); sendResponse(rid, channel, response.build()); nRequests++; if (nRequests % 10 == 0) { LOG.info("Processed " + nRequests + " requests"); } if (nRequests >= maxRequests) { // We wait for 1 second to let client to receive the request and then exit registerTimerEvent(Duration.ofSeconds(1), new Runnable() { @Override public void run() { EchoServer.this.stop(); getNIOLooper().exitLoop(); return; } }); } } else { throw new RuntimeException("Unknown type of request received"); } } @Override public void onMessage(SocketChannel channel, Message request) { throw new RuntimeException("Expected message from client"); } } private static class EchoClient extends HeronClient { private static final Logger LOG = Logger.getLogger(EchoClient.class.getName()); private HeronSocketOptions socketOptions; private int nRequests; private int maxRequests; EchoClient(NIOLooper looper, int port, int maxRequests) { super(looper, "localhost", port, TEST_SOCKET_OPTIONS); nRequests = 0; this.maxRequests = maxRequests; } @Override public void onConnect(StatusCode status) { if (status != StatusCode.OK) { Assert.fail("Connection with server failed"); } else { LOG.info("Connected with server"); sendRequest(); } } @Override public void onError() { Assert.fail("Error in client while talking to server"); } @Override public void onClose() { } private void sendRequest() { if (nRequests > maxRequests) { this.stop(); getNIOLooper().exitLoop(); return; } Tests.EchoServerRequest.Builder r = Tests.EchoServerRequest.newBuilder(); r.setEchoRequest("Dummy"); sendRequest(r.build(), Tests.EchoServerResponse.newBuilder()); nRequests++; } @Override public void onResponse(StatusCode status, Object ctx, Message response) { if (response instanceof Tests.EchoServerResponse) { Tests.EchoServerResponse r = (Tests.EchoServerResponse) response; Assert.assertEquals(r.getEchoResponse(), "Dummy"); sendRequest(); } else { Assert.fail("Unknown type of response received"); } } @Override public void onIncomingMessage(Message request) { Assert.fail("Expected message from client"); } } }