/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport;
import com.google.common.collect.ImmutableMap;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.ThreadPool;
import org.hamcrest.collection.IsEmptyCollection;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
/**
*
*/
public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
protected ThreadPool threadPool;
protected static final Version version0 = Version.CURRENT.minimumCompatibilityVersion();
protected DiscoveryNode nodeA;
protected MockTransportService serviceA;
protected static final Version version1 = Version.fromId(Version.CURRENT.id+1);
protected DiscoveryNode nodeB;
protected MockTransportService serviceB;
protected abstract MockTransportService build(Settings settings, Version version, NamedWriteableRegistry namedWriteableRegistry);
@Override
@Before
public void setUp() throws Exception {
super.setUp();
threadPool = new ThreadPool(getClass().getName());
serviceA = build(
Settings.builder().put("name", "TS_A", TransportService.SETTING_TRACE_LOG_INCLUDE, "", TransportService.SETTING_TRACE_LOG_EXCLUDE, "NOTHING").build(),
version0, new NamedWriteableRegistry()
);
serviceA.acceptIncomingRequests();
nodeA = new DiscoveryNode("TS_A", "TS_A", serviceA.boundAddress().publishAddress(), ImmutableMap.<String, String>of(), version0);
serviceB = build(
Settings.builder().put("name", "TS_B", TransportService.SETTING_TRACE_LOG_INCLUDE, "", TransportService.SETTING_TRACE_LOG_EXCLUDE, "NOTHING").build(),
version1, new NamedWriteableRegistry()
);
nodeB = new DiscoveryNode("TS_B", "TS_B", serviceB.boundAddress().publishAddress(), ImmutableMap.<String, String>of(), version1);
serviceB.acceptIncomingRequests();
// wait till all nodes are properly connected and the event has been sent, so tests in this class
// will not get this callback called on the connections done in this setup
final boolean useLocalNode = randomBoolean();
final CountDownLatch latch = new CountDownLatch(useLocalNode ? 2 : 4);
TransportConnectionListener waitForConnection = new TransportConnectionListener() {
@Override
public void onNodeConnected(DiscoveryNode node) {
latch.countDown();
}
@Override
public void onNodeDisconnected(DiscoveryNode node) {
fail("disconnect should not be called " + node);
}
};
serviceA.addConnectionListener(waitForConnection);
serviceB.addConnectionListener(waitForConnection);
if (useLocalNode) {
logger.info("--> using local node optimization");
serviceA.setLocalNode(nodeA);
serviceB.setLocalNode(nodeB);
} else {
logger.info("--> actively connecting to local node");
serviceA.connectToNode(nodeA);
serviceB.connectToNode(nodeB);
}
serviceA.connectToNode(nodeB);
serviceB.connectToNode(nodeA);
assertThat("failed to wait for all nodes to connect", latch.await(5, TimeUnit.SECONDS), equalTo(true));
serviceA.removeConnectionListener(waitForConnection);
serviceB.removeConnectionListener(waitForConnection);
}
@Override
@After
public void tearDown() throws Exception {
super.tearDown();
serviceA.close();
serviceB.close();
terminate(threadPool);
}
@Test
public void testHelloWorld() {
serviceA.registerRequestHandler("sayHello", StringMessageRequest.class, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) {
assertThat("moshe", equalTo(request.message));
try {
channel.sendResponse(new StringMessageResponse("hello " + request.message));
} catch (IOException e) {
e.printStackTrace();
assertThat(e.getMessage(), false, equalTo(true));
}
}
});
TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHello",
new StringMessageRequest("moshe"), new BaseTransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public String executor() {
return ThreadPool.Names.GENERIC;
}
@Override
public void handleResponse(StringMessageResponse response) {
assertThat("hello moshe", equalTo(response.message));
}
@Override
public void handleException(TransportException exp) {
exp.printStackTrace();
assertThat("got exception instead of a response: " + exp.getMessage(), false, equalTo(true));
}
});
try {
StringMessageResponse message = res.get();
assertThat("hello moshe", equalTo(message.message));
} catch (Exception e) {
assertThat(e.getMessage(), false, equalTo(true));
}
res = serviceB.submitRequest(nodeA, "sayHello",
new StringMessageRequest("moshe"), TransportRequestOptions.builder().withCompress(true).build(), new BaseTransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public String executor() {
return ThreadPool.Names.GENERIC;
}
@Override
public void handleResponse(StringMessageResponse response) {
assertThat("hello moshe", equalTo(response.message));
}
@Override
public void handleException(TransportException exp) {
exp.printStackTrace();
assertThat("got exception instead of a response: " + exp.getMessage(), false, equalTo(true));
}
});
try {
StringMessageResponse message = res.get();
assertThat("hello moshe", equalTo(message.message));
} catch (Exception e) {
assertThat(e.getMessage(), false, equalTo(true));
}
serviceA.removeHandler("sayHello");
}
@Test
public void testLocalNodeConnection() throws InterruptedException {
assertTrue("serviceA is not connected to nodeA", serviceA.nodeConnected(nodeA));
if (((TransportService) serviceA).getLocalNode() != null) {
// this should be a noop
serviceA.disconnectFromNode(nodeA);
}
final AtomicReference<Exception> exception = new AtomicReference<>();
serviceA.registerRequestHandler("localNode", StringMessageRequest.class, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) {
try {
channel.sendResponse(new StringMessageResponse(request.message));
} catch (IOException e) {
exception.set(e);
}
}
});
final AtomicReference<String> responseString = new AtomicReference<>();
final CountDownLatch responseLatch = new CountDownLatch(1);
serviceA.sendRequest(nodeA, "localNode", new StringMessageRequest("test"), new TransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public void handleResponse(StringMessageResponse response) {
responseString.set(response.message);
responseLatch.countDown();
}
@Override
public void handleException(TransportException exp) {
exception.set(exp);
responseLatch.countDown();
}
@Override
public String executor() {
return ThreadPool.Names.GENERIC;
}
});
responseLatch.await();
assertNull(exception.get());
assertThat(responseString.get(), equalTo("test"));
}
@Test
public void testVoidMessageCompressed() {
serviceA.registerRequestHandler("sayHello", TransportRequest.Empty.class, ThreadPool.Names.GENERIC, new TransportRequestHandler<TransportRequest.Empty>() {
@Override
public void messageReceived(TransportRequest.Empty request, TransportChannel channel) {
try {
channel.sendResponse(TransportResponse.Empty.INSTANCE, TransportResponseOptions.builder().withCompress(true).build());
} catch (IOException e) {
e.printStackTrace();
assertThat(e.getMessage(), false, equalTo(true));
}
}
});
TransportFuture<TransportResponse.Empty> res = serviceB.submitRequest(nodeA, "sayHello",
TransportRequest.Empty.INSTANCE, TransportRequestOptions.builder().withCompress(true).build(), new BaseTransportResponseHandler<TransportResponse.Empty>() {
@Override
public TransportResponse.Empty newInstance() {
return TransportResponse.Empty.INSTANCE;
}
@Override
public String executor() {
return ThreadPool.Names.GENERIC;
}
@Override
public void handleResponse(TransportResponse.Empty response) {
}
@Override
public void handleException(TransportException exp) {
exp.printStackTrace();
assertThat("got exception instead of a response: " + exp.getMessage(), false, equalTo(true));
}
});
try {
TransportResponse.Empty message = res.get();
assertThat(message, notNullValue());
} catch (Exception e) {
assertThat(e.getMessage(), false, equalTo(true));
}
serviceA.removeHandler("sayHello");
}
@Test
public void testHelloWorldCompressed() {
serviceA.registerRequestHandler("sayHello", StringMessageRequest.class, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) {
assertThat("moshe", equalTo(request.message));
try {
channel.sendResponse(new StringMessageResponse("hello " + request.message), TransportResponseOptions.builder().withCompress(true).build());
} catch (IOException e) {
e.printStackTrace();
assertThat(e.getMessage(), false, equalTo(true));
}
}
});
TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHello",
new StringMessageRequest("moshe"), TransportRequestOptions.builder().withCompress(true).build(), new BaseTransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public String executor() {
return ThreadPool.Names.GENERIC;
}
@Override
public void handleResponse(StringMessageResponse response) {
assertThat("hello moshe", equalTo(response.message));
}
@Override
public void handleException(TransportException exp) {
exp.printStackTrace();
assertThat("got exception instead of a response: " + exp.getMessage(), false, equalTo(true));
}
});
try {
StringMessageResponse message = res.get();
assertThat("hello moshe", equalTo(message.message));
} catch (Exception e) {
assertThat(e.getMessage(), false, equalTo(true));
}
serviceA.removeHandler("sayHello");
}
@Test
public void testErrorMessage() {
serviceA.registerRequestHandler("sayHelloException", StringMessageRequest.class, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
assertThat("moshe", equalTo(request.message));
throw new RuntimeException("bad message !!!");
}
});
TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHelloException",
new StringMessageRequest("moshe"), new BaseTransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public String executor() {
return ThreadPool.Names.GENERIC;
}
@Override
public void handleResponse(StringMessageResponse response) {
fail("got response instead of exception");
}
@Override
public void handleException(TransportException exp) {
assertThat("runtime_exception: bad message !!!", equalTo(exp.getCause().getMessage()));
}
});
try {
res.txGet();
fail("exception should be thrown");
} catch (Exception e) {
assertThat(e.getCause().getMessage(), equalTo("runtime_exception: bad message !!!"));
}
serviceA.removeHandler("sayHelloException");
}
@Test
public void testDisconnectListener() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
TransportConnectionListener disconnectListener = new TransportConnectionListener() {
@Override
public void onNodeConnected(DiscoveryNode node) {
fail("node connected should not be called, all connection have been done previously, node: " + node);
}
@Override
public void onNodeDisconnected(DiscoveryNode node) {
latch.countDown();
}
};
serviceA.addConnectionListener(disconnectListener);
serviceB.close();
assertThat(latch.await(5, TimeUnit.SECONDS), equalTo(true));
}
@Test
public void testNotifyOnShutdown() throws Exception {
final CountDownLatch latch2 = new CountDownLatch(1);
serviceA.registerRequestHandler("foobar", StringMessageRequest.class, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) {
try {
latch2.await();
logger.info("Stop ServiceB now");
serviceB.stop();
} catch (Exception e) {
fail(e.getMessage());
}
}
});
TransportFuture<TransportResponse.Empty> foobar = serviceB.submitRequest(nodeA, "foobar",
new StringMessageRequest(""), TransportRequestOptions.EMPTY, EmptyTransportResponseHandler.INSTANCE_SAME);
latch2.countDown();
try {
foobar.txGet();
fail("TransportException expected");
} catch (TransportException ex) {
}
serviceA.removeHandler("sayHelloTimeoutDelayedResponse");
}
@Test
public void testTimeoutSendExceptionWithNeverSendingBackResponse() throws Exception {
serviceA.registerRequestHandler("sayHelloTimeoutNoResponse", StringMessageRequest.class, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) {
assertThat("moshe", equalTo(request.message));
// don't send back a response
// try {
// channel.sendResponse(new StringMessage("hello " + request.message));
// } catch (IOException e) {
// e.printStackTrace();
// assertThat(e.getMessage(), false, equalTo(true));
// }
}
});
TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHelloTimeoutNoResponse",
new StringMessageRequest("moshe"), TransportRequestOptions.builder().withTimeout(100).build(), new BaseTransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public String executor() {
return ThreadPool.Names.GENERIC;
}
@Override
public void handleResponse(StringMessageResponse response) {
fail("got response instead of exception");
}
@Override
public void handleException(TransportException exp) {
assertThat(exp, instanceOf(ReceiveTimeoutTransportException.class));
}
});
try {
StringMessageResponse message = res.txGet();
fail("exception should be thrown");
} catch (Exception e) {
assertThat(e, instanceOf(ReceiveTimeoutTransportException.class));
}
serviceA.removeHandler("sayHelloTimeoutNoResponse");
}
public void testTimeoutSendExceptionWithDelayedResponse() throws Exception {
final CountDownLatch doneLatch = new CountDownLatch(1);
serviceA.registerRequestHandler("sayHelloTimeoutDelayedResponse", StringMessageRequest.class, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) {
TimeValue sleep = TimeValue.parseTimeValue(request.message, null, "sleep");
try {
doneLatch.await(sleep.millis(), TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
// ignore
}
try {
channel.sendResponse(new StringMessageResponse("hello " + request.message));
} catch (IOException e) {
e.printStackTrace();
assertThat(e.getMessage(), false, equalTo(true));
}
}
});
final CountDownLatch latch = new CountDownLatch(1);
TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHelloTimeoutDelayedResponse",
new StringMessageRequest("2m"), TransportRequestOptions.builder().withTimeout(100).build(), new BaseTransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public String executor() {
return ThreadPool.Names.GENERIC;
}
@Override
public void handleResponse(StringMessageResponse response) {
latch.countDown();
fail("got response instead of exception");
}
@Override
public void handleException(TransportException exp) {
latch.countDown();
assertThat(exp, instanceOf(ReceiveTimeoutTransportException.class));
}
});
try {
StringMessageResponse message = res.txGet();
fail("exception should be thrown");
} catch (Exception e) {
assertThat(e, instanceOf(ReceiveTimeoutTransportException.class));
}
latch.await();
for (int i = 0; i < 10; i++) {
final int counter = i;
// now, try and send another request, this times, with a short timeout
res = serviceB.submitRequest(nodeA, "sayHelloTimeoutDelayedResponse",
new StringMessageRequest(counter + "ms"), TransportRequestOptions.builder().withTimeout(3000).build(), new BaseTransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public String executor() {
return ThreadPool.Names.GENERIC;
}
@Override
public void handleResponse(StringMessageResponse response) {
assertThat("hello " + counter + "ms", equalTo(response.message));
}
@Override
public void handleException(TransportException exp) {
exp.printStackTrace();
fail("got exception instead of a response for " + counter + ": " + exp.getDetailedMessage());
}
});
StringMessageResponse message = res.txGet();
assertThat(message.message, equalTo("hello " + counter + "ms"));
}
serviceA.removeHandler("sayHelloTimeoutDelayedResponse");
doneLatch.countDown();
}
@Test
public void testNoUnresolvedResponses() throws InterruptedException {
TransportRequestHandler<StringMessageRequest> handler = new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
channel.sendResponse(new StringMessageResponse(""));
}
};
TransportRequestHandler<StringMessageRequest> handlerWithError = new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
channel.sendResponse(new RuntimeException(""));
}
};
final Semaphore requestCompleted = new Semaphore(0);
TransportResponseHandler<StringMessageResponse> noopResponseHandler = new BaseTransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public void handleResponse(StringMessageResponse response) {
requestCompleted.release();
}
@Override
public void handleException(TransportException exp) {
requestCompleted.release();
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
};
final AtomicReference<Set<Long>> unresolvedResponses = new AtomicReference<>();
unresolvedResponses.set(new HashSet<Long>());
MockTransportService.Tracer tracer = new MockTransportService.Tracer() {
final Set<Long> requests = new HashSet<>();
@Override
public void receivedResponse(long requestId, DiscoveryNode sourceNode, String action) {
assertTrue(requests.add(requestId));
super.receivedResponse(requestId, sourceNode, action);
}
@Override
public void unresolvedResponse(long requestId) {
if (requests.contains(requestId)) {
unresolvedResponses.get().add(requestId);
}
}
};
serviceA.addTracer(tracer);
serviceB.registerRequestHandler("test", StringMessageRequest.class, ThreadPool.Names.SAME, handler);
serviceB.registerRequestHandler("testError", StringMessageRequest.class, ThreadPool.Names.SAME, handlerWithError);
serviceA.sendRequest(nodeB, "test", new StringMessageRequest("", 10), TransportRequestOptions.EMPTY, noopResponseHandler);
serviceA.sendRequest(nodeB, "testError", new StringMessageRequest("", 10), TransportRequestOptions.EMPTY, noopResponseHandler);
requestCompleted.acquire();
assertThat(unresolvedResponses.get(), IsEmptyCollection.emptyCollectionOf(Long.class));
}
@TestLogging(value = "test. transport.tracer:TRACE")
public void testTracerLog() throws InterruptedException {
TransportRequestHandler handler = new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
channel.sendResponse(new StringMessageResponse(""));
}
};
TransportRequestHandler handlerWithError = new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
if (request.timeout() > 0) {
Thread.sleep(request.timeout);
}
channel.sendResponse(new RuntimeException(""));
}
};
final Semaphore requestCompleted = new Semaphore(0);
TransportResponseHandler noopResponseHandler = new BaseTransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public void handleResponse(StringMessageResponse response) {
requestCompleted.release();
}
@Override
public void handleException(TransportException exp) {
requestCompleted.release();
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
};
serviceA.registerRequestHandler("test", StringMessageRequest.class, ThreadPool.Names.SAME, handler);
serviceA.registerRequestHandler("testError", StringMessageRequest.class, ThreadPool.Names.SAME, handlerWithError);
serviceB.registerRequestHandler("test", StringMessageRequest.class, ThreadPool.Names.SAME, handler);
serviceB.registerRequestHandler("testError", StringMessageRequest.class, ThreadPool.Names.SAME, handlerWithError);
final Tracer tracer = new Tracer();
serviceA.addTracer(tracer);
serviceB.addTracer(tracer);
tracer.reset(4);
boolean timeout = randomBoolean();
TransportRequestOptions options = timeout ? TransportRequestOptions.builder().withTimeout(1).build() : TransportRequestOptions.EMPTY;
serviceA.sendRequest(nodeB, "test", new StringMessageRequest("", 10), options, noopResponseHandler);
requestCompleted.acquire();
tracer.expectedEvents.get().await();
assertThat("didn't see request sent", tracer.sawRequestSent, equalTo(true));
assertThat("didn't see request received", tracer.sawRequestReceived, equalTo(true));
assertThat("didn't see response sent", tracer.sawResponseSent, equalTo(true));
assertThat("didn't see response received", tracer.sawResponseReceived, equalTo(true));
assertThat("saw error sent", tracer.sawErrorSent, equalTo(false));
tracer.reset(4);
serviceA.sendRequest(nodeB, "testError", new StringMessageRequest(""), noopResponseHandler);
requestCompleted.acquire();
tracer.expectedEvents.get().await();
assertThat("didn't see request sent", tracer.sawRequestSent, equalTo(true));
assertThat("didn't see request received", tracer.sawRequestReceived, equalTo(true));
assertThat("saw response sent", tracer.sawResponseSent, equalTo(false));
assertThat("didn't see response received", tracer.sawResponseReceived, equalTo(true));
assertThat("didn't see error sent", tracer.sawErrorSent, equalTo(true));
String includeSettings;
String excludeSettings;
if (randomBoolean()) {
// sometimes leave include empty (default)
includeSettings = randomBoolean() ? "*" : "";
excludeSettings = "*Error";
} else {
includeSettings = "test";
excludeSettings = "DOESN'T_MATCH";
}
serviceA.applySettings(Settings.builder()
.put(TransportService.SETTING_TRACE_LOG_INCLUDE, includeSettings, TransportService.SETTING_TRACE_LOG_EXCLUDE, excludeSettings)
.build());
tracer.reset(4);
serviceA.sendRequest(nodeB, "test", new StringMessageRequest(""), noopResponseHandler);
requestCompleted.acquire();
tracer.expectedEvents.get().await();
assertThat("didn't see request sent", tracer.sawRequestSent, equalTo(true));
assertThat("didn't see request received", tracer.sawRequestReceived, equalTo(true));
assertThat("didn't see response sent", tracer.sawResponseSent, equalTo(true));
assertThat("didn't see response received", tracer.sawResponseReceived, equalTo(true));
assertThat("saw error sent", tracer.sawErrorSent, equalTo(false));
tracer.reset(2);
serviceA.sendRequest(nodeB, "testError", new StringMessageRequest(""), noopResponseHandler);
requestCompleted.acquire();
tracer.expectedEvents.get().await();
assertThat("saw request sent", tracer.sawRequestSent, equalTo(false));
assertThat("didn't see request received", tracer.sawRequestReceived, equalTo(true));
assertThat("saw response sent", tracer.sawResponseSent, equalTo(false));
assertThat("saw response received", tracer.sawResponseReceived, equalTo(false));
assertThat("didn't see error sent", tracer.sawErrorSent, equalTo(true));
}
private static class Tracer extends MockTransportService.Tracer {
public volatile boolean sawRequestSent;
public volatile boolean sawRequestReceived;
public volatile boolean sawResponseSent;
public volatile boolean sawErrorSent;
public volatile boolean sawResponseReceived;
public AtomicReference<CountDownLatch> expectedEvents = new AtomicReference<>();
@Override
public void receivedRequest(long requestId, String action) {
super.receivedRequest(requestId, action);
sawRequestReceived = true;
expectedEvents.get().countDown();
}
@Override
public void requestSent(DiscoveryNode node, long requestId, String action, TransportRequestOptions options) {
super.requestSent(node, requestId, action, options);
sawRequestSent = true;
expectedEvents.get().countDown();
}
@Override
public void responseSent(long requestId, String action) {
super.responseSent(requestId, action);
sawResponseSent = true;
expectedEvents.get().countDown();
}
@Override
public void responseSent(long requestId, String action, Throwable t) {
super.responseSent(requestId, action, t);
sawErrorSent = true;
expectedEvents.get().countDown();
}
@Override
public void receivedResponse(long requestId, DiscoveryNode sourceNode, String action) {
super.receivedResponse(requestId, sourceNode, action);
sawResponseReceived = true;
expectedEvents.get().countDown();
}
public void reset(int expectedCount) {
sawRequestSent = false;
sawRequestReceived = false;
sawResponseSent = false;
sawErrorSent = false;
sawResponseReceived = false;
expectedEvents.set(new CountDownLatch(expectedCount));
}
}
public static class StringMessageRequest extends TransportRequest {
private String message;
private long timeout;
StringMessageRequest(String message, long timeout) {
this.message = message;
this.timeout = timeout;
}
public StringMessageRequest() {
}
public StringMessageRequest(String message) {
this(message, -1);
}
public long timeout() {
return timeout;
}
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
message = in.readString();
timeout = in.readLong();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(message);
out.writeLong(timeout);
}
}
static class StringMessageResponse extends TransportResponse {
private String message;
StringMessageResponse(String message) {
this.message = message;
}
StringMessageResponse() {
}
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
message = in.readString();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(message);
}
}
public static class Version0Request extends TransportRequest {
int value1;
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
value1 = in.readInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeInt(value1);
}
}
public static class Version1Request extends Version0Request {
int value2;
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
if (in.getVersion().onOrAfter(version1)) {
value2 = in.readInt();
}
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getVersion().onOrAfter(version1)) {
out.writeInt(value2);
}
}
}
static class Version0Response extends TransportResponse {
int value1;
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
value1 = in.readInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeInt(value1);
}
}
static class Version1Response extends Version0Response {
int value2;
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
if (in.getVersion().onOrAfter(version1)) {
value2 = in.readInt();
}
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getVersion().onOrAfter(version1)) {
out.writeInt(value2);
}
}
}
@Test
public void testVersion_from0to1() throws Exception {
serviceB.registerRequestHandler("/version", Version1Request.class, ThreadPool.Names.SAME, new TransportRequestHandler<Version1Request>() {
@Override
public void messageReceived(Version1Request request, TransportChannel channel) throws Exception {
assertThat(request.value1, equalTo(1));
assertThat(request.value2, equalTo(0)); // not set, coming from service A
Version1Response response = new Version1Response();
response.value1 = 1;
response.value2 = 2;
channel.sendResponse(response);
}
});
Version0Request version0Request = new Version0Request();
version0Request.value1 = 1;
Version0Response version0Response = serviceA.submitRequest(nodeB, "/version", version0Request, new BaseTransportResponseHandler<Version0Response>() {
@Override
public Version0Response newInstance() {
return new Version0Response();
}
@Override
public void handleResponse(Version0Response response) {
assertThat(response.value1, equalTo(1));
}
@Override
public void handleException(TransportException exp) {
exp.printStackTrace();
fail();
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
}).txGet();
assertThat(version0Response.value1, equalTo(1));
}
@Test
public void testVersion_from1to0() throws Exception {
serviceA.registerRequestHandler("/version", Version0Request.class, ThreadPool.Names.SAME, new TransportRequestHandler<Version0Request>() {
@Override
public void messageReceived(Version0Request request, TransportChannel channel) throws Exception {
assertThat(request.value1, equalTo(1));
Version0Response response = new Version0Response();
response.value1 = 1;
channel.sendResponse(response);
}
});
Version1Request version1Request = new Version1Request();
version1Request.value1 = 1;
version1Request.value2 = 2;
Version1Response version1Response = serviceB.submitRequest(nodeA, "/version", version1Request, new BaseTransportResponseHandler<Version1Response>() {
@Override
public Version1Response newInstance() {
return new Version1Response();
}
@Override
public void handleResponse(Version1Response response) {
assertThat(response.value1, equalTo(1));
assertThat(response.value2, equalTo(0)); // initial values, cause its serialized from version 0
}
@Override
public void handleException(TransportException exp) {
exp.printStackTrace();
fail();
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
}).txGet();
assertThat(version1Response.value1, equalTo(1));
assertThat(version1Response.value2, equalTo(0));
}
@Test
public void testVersion_from1to1() throws Exception {
serviceB.registerRequestHandler("/version", Version1Request.class, ThreadPool.Names.SAME, new TransportRequestHandler<Version1Request>() {
@Override
public void messageReceived(Version1Request request, TransportChannel channel) throws Exception {
assertThat(request.value1, equalTo(1));
assertThat(request.value2, equalTo(2));
Version1Response response = new Version1Response();
response.value1 = 1;
response.value2 = 2;
channel.sendResponse(response);
}
});
Version1Request version1Request = new Version1Request();
version1Request.value1 = 1;
version1Request.value2 = 2;
Version1Response version1Response = serviceB.submitRequest(nodeB, "/version", version1Request, new BaseTransportResponseHandler<Version1Response>() {
@Override
public Version1Response newInstance() {
return new Version1Response();
}
@Override
public void handleResponse(Version1Response response) {
assertThat(response.value1, equalTo(1));
assertThat(response.value2, equalTo(2));
}
@Override
public void handleException(TransportException exp) {
exp.printStackTrace();
fail();
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
}).txGet();
assertThat(version1Response.value1, equalTo(1));
assertThat(version1Response.value2, equalTo(2));
}
@Test
public void testVersion_from0to0() throws Exception {
serviceA.registerRequestHandler("/version", Version0Request.class, ThreadPool.Names.SAME, new TransportRequestHandler<Version0Request>() {
@Override
public void messageReceived(Version0Request request, TransportChannel channel) throws Exception {
assertThat(request.value1, equalTo(1));
Version0Response response = new Version0Response();
response.value1 = 1;
channel.sendResponse(response);
}
});
Version0Request version0Request = new Version0Request();
version0Request.value1 = 1;
Version0Response version0Response = serviceA.submitRequest(nodeA, "/version", version0Request, new BaseTransportResponseHandler<Version0Response>() {
@Override
public Version0Response newInstance() {
return new Version0Response();
}
@Override
public void handleResponse(Version0Response response) {
assertThat(response.value1, equalTo(1));
}
@Override
public void handleException(TransportException exp) {
exp.printStackTrace();
fail();
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
}).txGet();
assertThat(version0Response.value1, equalTo(1));
}
@Test
public void testMockFailToSendNoConnectRule() {
serviceA.registerRequestHandler("sayHello", StringMessageRequest.class, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
assertThat("moshe", equalTo(request.message));
throw new RuntimeException("bad message !!!");
}
});
serviceB.addFailToSendNoConnectRule(serviceA);
TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHello",
new StringMessageRequest("moshe"), new BaseTransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public String executor() {
return ThreadPool.Names.GENERIC;
}
@Override
public void handleResponse(StringMessageResponse response) {
fail("got response instead of exception");
}
@Override
public void handleException(TransportException exp) {
assertThat(exp.getCause().getMessage(), endsWith("DISCONNECT: simulated"));
}
});
try {
res.txGet();
fail("exception should be thrown");
} catch (Exception e) {
assertThat(e.getCause().getMessage(), endsWith("DISCONNECT: simulated"));
}
try {
serviceB.connectToNode(nodeA);
fail("exception should be thrown");
} catch (ConnectTransportException e) {
// all is well
}
try {
serviceB.connectToNodeLight(nodeA);
fail("exception should be thrown");
} catch (ConnectTransportException e) {
// all is well
}
serviceA.removeHandler("sayHello");
}
@Test
public void testMockUnresponsiveRule() {
serviceA.registerRequestHandler("sayHello", StringMessageRequest.class, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() {
@Override
public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception {
assertThat("moshe", equalTo(request.message));
throw new RuntimeException("bad message !!!");
}
});
serviceB.addUnresponsiveRule(serviceA);
TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHello",
new StringMessageRequest("moshe"), TransportRequestOptions.builder().withTimeout(100).build(), new BaseTransportResponseHandler<StringMessageResponse>() {
@Override
public StringMessageResponse newInstance() {
return new StringMessageResponse();
}
@Override
public String executor() {
return ThreadPool.Names.GENERIC;
}
@Override
public void handleResponse(StringMessageResponse response) {
fail("got response instead of exception");
}
@Override
public void handleException(TransportException exp) {
assertThat(exp, instanceOf(ReceiveTimeoutTransportException.class));
}
});
try {
res.txGet();
fail("exception should be thrown");
} catch (Exception e) {
assertThat(e, instanceOf(ReceiveTimeoutTransportException.class));
}
try {
serviceB.connectToNode(nodeA);
fail("exception should be thrown");
} catch (ConnectTransportException e) {
// all is well
}
try {
serviceB.connectToNodeLight(nodeA);
fail("exception should be thrown");
} catch (ConnectTransportException e) {
// all is well
}
serviceA.removeHandler("sayHello");
}
@Test
public void testHostOnMessages() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(2);
final AtomicReference<TransportAddress> addressA = new AtomicReference<>();
final AtomicReference<TransportAddress> addressB = new AtomicReference<>();
serviceB.registerRequestHandler("action1", TestRequest.class, ThreadPool.Names.SAME, new TransportRequestHandler<TestRequest>() {
@Override
public void messageReceived(TestRequest request, TransportChannel channel) throws Exception {
addressA.set(request.remoteAddress());
channel.sendResponse(new TestResponse());
latch.countDown();
}
});
serviceA.sendRequest(nodeB, "action1", new TestRequest(), new TransportResponseHandler<TestResponse>() {
@Override
public TestResponse newInstance() {
return new TestResponse();
}
@Override
public void handleResponse(TestResponse response) {
addressB.set(response.remoteAddress());
latch.countDown();
}
@Override
public void handleException(TransportException exp) {
latch.countDown();
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
});
if (!latch.await(10, TimeUnit.SECONDS)) {
fail("message round trip did not complete within a sensible time frame");
}
assertTrue(nodeA.address().sameHost(addressA.get()));
assertTrue(nodeB.address().sameHost(addressB.get()));
}
public void testBlockingIncomingRequests() throws Exception {
TransportService service = build(
Settings.builder().put("name", "TS_TEST", TransportService.SETTING_TRACE_LOG_INCLUDE, "", TransportService.SETTING_TRACE_LOG_EXCLUDE, "NOTHING").build(),
version0, new NamedWriteableRegistry()
);
final AtomicBoolean requestProcessed = new AtomicBoolean();
service.registerRequestHandler("action", TestRequest.class, ThreadPool.Names.SAME,
new TransportRequestHandler<TestRequest>() {
@Override
public void messageReceived(TestRequest request, TransportChannel channel) throws Exception {
requestProcessed.set(true);
channel.sendResponse(TransportResponse.Empty.INSTANCE);
}
});
DiscoveryNode node = new DiscoveryNode("TS_TEST", "TS_TEST", service.boundAddress().publishAddress(), ImmutableMap.<String, String>of(), version0);
serviceA.connectToNode(node);
final CountDownLatch latch = new CountDownLatch(1);
serviceA.sendRequest(node, "action", new TestRequest(), new TransportResponseHandler<TestResponse>() {
@Override
public TestResponse newInstance() {
return new TestResponse();
}
@Override
public void handleResponse(TestResponse response) {
latch.countDown();
}
@Override
public void handleException(TransportException exp) {
latch.countDown();
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
});
assertFalse(requestProcessed.get());
service.acceptIncomingRequests();
assertBusy(new Runnable() {
@Override
public void run() {
assertTrue(requestProcessed.get());
}
});
latch.await();
service.close();
}
public static class TestRequest extends TransportRequest {
}
private static class TestResponse extends TransportResponse {
}
}