/**
* Copyright 2016 LinkedIn Corp. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
package com.github.ambry.network;
import com.codahale.metrics.MetricRegistry;
import com.github.ambry.utils.SystemTime;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static java.util.Arrays.*;
import static org.junit.Assert.*;
/**
* A set of tests for the selector. These use a test harness that runs a simple socket server that echos back responses.
*/
public class SelectorTest {
private static final int BUFFER_SIZE = 4 * 1024;
private EchoServer server;
private Selector selector;
@Before
public void setup() throws Exception {
this.server = new EchoServer(18283);
this.server.start();
this.selector = new Selector(new NetworkMetrics(new MetricRegistry()), SystemTime.getInstance(), null);
}
@After
public void teardown() throws Exception {
this.selector.close();
this.server.close();
}
/**
* Validate that when the server disconnects, a client send ends up with that node in the disconnected list.
*/
@Test
public void testServerDisconnect() throws Exception {
// connect and do a simple request
String connectionId = blockingConnect();
assertEquals("hello", blockingRequest(connectionId, "hello"));
// disconnect
this.server.closeConnections();
while (!selector.disconnected().contains(connectionId)) {
selector.poll(1000L);
}
// reconnect and do another request
connectionId = blockingConnect();
assertEquals("hello", blockingRequest(connectionId, "hello"));
}
/**
* Validate that the client can intentionally disconnect and reconnect
*/
@Test
public void testClientDisconnect() throws Exception {
String connectionId = blockingConnect();
selector.disconnect(connectionId);
selector.poll(10, asList(createSend(connectionId, "hello1")));
assertEquals("Request should not have succeeded", 0, selector.completedSends().size());
assertEquals("There should be a disconnect", 1, selector.disconnected().size());
assertTrue("The disconnect should be from our node", selector.disconnected().contains(connectionId));
connectionId = blockingConnect();
assertEquals("hello2", blockingRequest(connectionId, "hello2"));
}
/**
* Validate that a closed connectionId is returned via disconnected list after close
*/
@Test
public void testDisconnectedListOnClose() throws Exception {
String connectionId = blockingConnect();
assertEquals("Disconnect list should be empty", 0, selector.disconnected().size());
selector.close(connectionId);
selector.poll(0);
assertEquals("There should be a disconnect", 1, selector.disconnected().size());
assertTrue("Expected connectionId " + connectionId + " missing from selector's disconnected list ",
selector.disconnected().contains(connectionId));
// make sure that the connection id is not returned via disconnected list after another poll()
selector.poll(0);
assertEquals("Disconnect list should be empty", 0, selector.disconnected().size());
}
/**
* Sending a request with one already in flight should result in an exception
*/
@Test(expected = IllegalStateException.class)
public void testCantSendWithInProgress() throws Exception {
String connectionId = blockingConnect();
selector.poll(1000L, asList(createSend(connectionId, "test1"), createSend(connectionId, "test2")));
}
/**
* Sending a request to a node without an existing connection should result in an exception
*/
@Test(expected = IllegalStateException.class)
public void testCantSendWithoutConnecting() throws Exception {
selector.poll(1000L, asList(createSend("testCantSendWithoutConnecting_test", "test")));
}
/**
* Sending a request to a node with a bad hostname should result in an exception during connect
*/
@Test(expected = IOException.class)
public void testNoRouteToHost() throws Exception {
selector.connect(new InetSocketAddress("asdf.asdf.dsc", server.port), BUFFER_SIZE, BUFFER_SIZE, PortType.PLAINTEXT);
}
/**
* Sending a request to a node not listening on that port should result in disconnection
*/
@Test
public void testConnectionRefused() throws Exception {
String connectionId =
selector.connect(new InetSocketAddress("localhost", 6668), BUFFER_SIZE, BUFFER_SIZE, PortType.PLAINTEXT);
while (selector.disconnected().contains(connectionId)) {
selector.poll(1000L);
}
}
/**
* Send multiple requests to several connections in parallel. Validate that responses are received in the order that
* requests were sent.
*/
@Test
public void testNormalOperation() throws Exception {
int conns = 5;
int reqs = 500;
// create connections
InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
ArrayList<String> connectionIds = new ArrayList<String>();
for (int i = 0; i < conns; i++) {
String connectionId = selector.connect(addr, BUFFER_SIZE, BUFFER_SIZE, PortType.PLAINTEXT);
connectionIds.add(connectionId);
}
// send echo requests and receive responses
int[] requests = new int[conns];
int[] responses = new int[conns];
int responseCount = 0;
List<NetworkSend> sends = new ArrayList<NetworkSend>();
for (int i = 0; i < conns; i++) {
String connectionId = connectionIds.get(i);
sends.add(createSend(connectionId, connectionId + "&" + 0));
}
// loop until we complete all requests
while (responseCount < conns * reqs) {
// do the i/o
selector.poll(0L, sends);
assertEquals("No disconnects should have occurred.", 0, selector.disconnected().size());
// handle any responses we may have gotten
for (NetworkReceive receive : selector.completedReceives()) {
String[] pieces = asString(receive).split("&");
assertEquals("Should be in the form 'conn-counter'", 2, pieces.length);
assertEquals("Check the source", receive.getConnectionId(), pieces[0]);
assertEquals("Check that the receive has kindly been rewound", 0,
receive.getReceivedBytes().getPayload().position());
int index = Integer.parseInt(receive.getConnectionId().split("_")[1]);
assertEquals("Check the request counter", responses[index], Integer.parseInt(pieces[1]));
responses[index]++; // increment the expected counter
responseCount++;
}
// prepare new sends for the next round
sends.clear();
for (NetworkSend send : selector.completedSends()) {
String dest = send.getConnectionId();
String[] pieces = dest.split("_");
int index = Integer.parseInt(pieces[1]);
requests[index]++;
if (requests[index] < reqs) {
sends.add(createSend(dest, dest + "&" + requests[index]));
}
}
}
}
/**
* Validate that we can send and receive a message larger than the receive and send buffer size
*/
@Test
public void testSendLargeRequest() throws Exception {
String connectionId = blockingConnect();
String big = randomString(10 * BUFFER_SIZE, new Random());
assertEquals(big, blockingRequest(connectionId, big));
}
/**
* Test sending an empty string
*/
@Test
public void testEmptyRequest() throws Exception {
String connectionId = blockingConnect();
assertEquals("", blockingRequest(connectionId, ""));
}
private String blockingRequest(String connectionId, String s) throws Exception {
selector.poll(1000L, asList(createSend(connectionId, s)));
while (true) {
selector.poll(1000L);
for (NetworkReceive receive : selector.completedReceives()) {
if (receive.getConnectionId() == connectionId) {
return asString(receive);
}
}
}
}
/* connect and wait for the connection to complete */
private String blockingConnect() throws IOException {
String connectionId =
selector.connect(new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE, PortType.PLAINTEXT);
while (!selector.connected().contains(connectionId)) {
selector.poll(10000L);
}
return connectionId;
}
static NetworkSend createSend(String connectionId, String s) {
ByteBuffer buf = ByteBuffer.allocate(8 + s.getBytes().length);
buf.putLong(s.getBytes().length + 8);
buf.put(s.getBytes());
buf.flip();
return new NetworkSend(connectionId, new BoundedByteBufferSend(buf), null, SystemTime.getInstance());
}
static String asString(NetworkReceive receive) {
return new String(receive.getReceivedBytes().getPayload().array());
}
/**
* Generate a random string of letters and digits of the given length
*
* @param len The length of the string
* @return The random string
*/
static String randomString(int len, Random random) {
String LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
String DIGITS = "0123456789";
String LETTERS_AND_DIGITS = LETTERS + DIGITS;
StringBuilder b = new StringBuilder();
for (int i = 0; i < len; i++) {
b.append(LETTERS_AND_DIGITS.charAt(random.nextInt(LETTERS_AND_DIGITS.length())));
}
return b.toString();
}
}