/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.transport; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; import org.apache.lucene.util.CollectionUtil; import org.apache.lucene.util.Constants; import org.apache.lucene.util.IOUtils; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.mocksocket.MockServerSocket; import org.elasticsearch.node.Node; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.VersionUtils; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; import org.junit.Before; import java.io.IOException; import java.io.UncheckedIOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.startsWith; public abstract class AbstractSimpleTransportTestCase extends ESTestCase { protected ThreadPool threadPool; // we use always a non-alpha or beta version here otherwise minimumCompatibilityVersion will be different for the two used versions private static final Version CURRENT_VERSION = Version.fromString(String.valueOf(Version.CURRENT.major) + ".0.0"); protected static final Version version0 = CURRENT_VERSION.minimumCompatibilityVersion(); private ClusterSettings clusterSettings; protected volatile DiscoveryNode nodeA; protected volatile MockTransportService serviceA; protected static final Version version1 = Version.fromId(CURRENT_VERSION.id + 1); protected volatile DiscoveryNode nodeB; protected volatile MockTransportService serviceB; protected abstract MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake); @Override @Before public void setUp() throws Exception { super.setUp(); threadPool = new TestThreadPool(getClass().getName()); clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); serviceA = buildService("TS_A", version0, clusterSettings); // this one supports dynamic tracer updates nodeA = serviceA.getLocalNode(); serviceB = buildService("TS_B", version1, null); // this one doesn't support dynamic tracer updates nodeB = serviceB.getLocalNode(); // wait till all nodes are properly connected and the event has been sent, so tests in this class // will not get this callback called on the connections done in this setup final CountDownLatch latch = new CountDownLatch(2); TransportConnectionListener waitForConnection = new TransportConnectionListener() { @Override public void onNodeConnected(DiscoveryNode node) { latch.countDown(); } @Override public void onNodeDisconnected(DiscoveryNode node) { fail("disconnect should not be called " + node); } }; serviceA.addConnectionListener(waitForConnection); serviceB.addConnectionListener(waitForConnection); int numHandshakes = 1; serviceA.connectToNode(nodeB); serviceB.connectToNode(nodeA); assertNumHandshakes(numHandshakes, serviceA.getOriginalTransport()); assertNumHandshakes(numHandshakes, serviceB.getOriginalTransport()); assertThat("failed to wait for all nodes to connect", latch.await(5, TimeUnit.SECONDS), equalTo(true)); serviceA.removeConnectionListener(waitForConnection); serviceB.removeConnectionListener(waitForConnection); } private MockTransportService buildService(final String name, final Version version, ClusterSettings clusterSettings, Settings settings, boolean acceptRequests, boolean doHandshake) { MockTransportService service = build( Settings.builder() .put(settings) .put(Node.NODE_NAME_SETTING.getKey(), name) .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "") .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") .build(), version, clusterSettings, doHandshake); if (acceptRequests) { service.acceptIncomingRequests(); } return service; } private MockTransportService buildService(final String name, final Version version, ClusterSettings clusterSettings) { return buildService(name, version, clusterSettings, Settings.EMPTY, true, true); } @Override @After public void tearDown() throws Exception { super.tearDown(); try { assertNoPendingHandshakes(serviceA.getOriginalTransport()); assertNoPendingHandshakes(serviceB.getOriginalTransport()); } finally { IOUtils.close(serviceA, serviceB, () -> { try { terminate(threadPool); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } }); } } public void assertNumHandshakes(long expected, Transport transport) { if (transport instanceof TcpTransport) { assertEquals(expected, ((TcpTransport) transport).getNumHandshakes()); } } public void assertNoPendingHandshakes(Transport transport) { if (transport instanceof TcpTransport) { assertEquals(0, ((TcpTransport) transport).getNumPendingHandshakes()); } } public void testHelloWorld() { serviceA.registerRequestHandler("sayHello", StringMessageRequest::new, ThreadPool.Names.GENERIC, (request, channel) -> { assertThat("moshe", equalTo(request.message)); try { channel.sendResponse(new StringMessageResponse("hello " + request.message)); } catch (IOException e) { logger.error("Unexpected failure", e); fail(e.getMessage()); } }); TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHello", new StringMessageRequest("moshe"), new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public String executor() { return ThreadPool.Names.GENERIC; } @Override public void handleResponse(StringMessageResponse response) { assertThat("hello moshe", equalTo(response.message)); } @Override public void handleException(TransportException exp) { logger.error("Unexpected failure", exp); fail("got exception instead of a response: " + exp.getMessage()); } }); try { StringMessageResponse message = res.get(); assertThat("hello moshe", equalTo(message.message)); } catch (Exception e) { assertThat(e.getMessage(), false, equalTo(true)); } res = serviceB.submitRequest(nodeA, "sayHello", new StringMessageRequest("moshe"), TransportRequestOptions.builder().withCompress(true).build(), new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public String executor() { return ThreadPool.Names.GENERIC; } @Override public void handleResponse(StringMessageResponse response) { assertThat("hello moshe", equalTo(response.message)); } @Override public void handleException(TransportException exp) { logger.error("Unexpected failure", exp); fail("got exception instead of a response: " + exp.getMessage()); } }); try { StringMessageResponse message = res.get(); assertThat("hello moshe", equalTo(message.message)); } catch (Exception e) { assertThat(e.getMessage(), false, equalTo(true)); } } public void testThreadContext() throws ExecutionException, InterruptedException { serviceA.registerRequestHandler("ping_pong", StringMessageRequest::new, ThreadPool.Names.GENERIC, (request, channel) -> { assertEquals("ping_user", threadPool.getThreadContext().getHeader("test.ping.user")); assertNull(threadPool.getThreadContext().getTransient("my_private_context")); try { StringMessageResponse response = new StringMessageResponse("pong"); threadPool.getThreadContext().putHeader("test.pong.user", "pong_user"); channel.sendResponse(response); } catch (IOException e) { logger.error("Unexpected failure", e); fail(e.getMessage()); } }); final Object context = new Object(); final String executor = randomFrom(ThreadPool.THREAD_POOL_TYPES.keySet().toArray(new String[0])); TransportResponseHandler<StringMessageResponse> responseHandler = new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public String executor() { return executor; } @Override public void handleResponse(StringMessageResponse response) { assertThat("pong", equalTo(response.message)); assertEquals("ping_user", threadPool.getThreadContext().getHeader("test.ping.user")); assertNull(threadPool.getThreadContext().getHeader("test.pong.user")); assertSame(context, threadPool.getThreadContext().getTransient("my_private_context")); threadPool.getThreadContext().putHeader("some.temp.header", "booooom"); } @Override public void handleException(TransportException exp) { logger.error("Unexpected failure", exp); fail("got exception instead of a response: " + exp.getMessage()); } }; StringMessageRequest ping = new StringMessageRequest("ping"); threadPool.getThreadContext().putHeader("test.ping.user", "ping_user"); threadPool.getThreadContext().putTransient("my_private_context", context); TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "ping_pong", ping, responseHandler); StringMessageResponse message = res.get(); assertThat("pong", equalTo(message.message)); assertEquals("ping_user", threadPool.getThreadContext().getHeader("test.ping.user")); assertSame(context, threadPool.getThreadContext().getTransient("my_private_context")); assertNull("this header is only visible in the handler context", threadPool.getThreadContext().getHeader("some.temp.header")); } public void testLocalNodeConnection() throws InterruptedException { assertTrue("serviceA is not connected to nodeA", serviceA.nodeConnected(nodeA)); // this should be a noop serviceA.disconnectFromNode(nodeA); final AtomicReference<Exception> exception = new AtomicReference<>(); serviceA.registerRequestHandler("localNode", StringMessageRequest::new, ThreadPool.Names.GENERIC, (request, channel) -> { try { channel.sendResponse(new StringMessageResponse(request.message)); } catch (IOException e) { exception.set(e); } }); final AtomicReference<String> responseString = new AtomicReference<>(); final CountDownLatch responseLatch = new CountDownLatch(1); serviceA.sendRequest(nodeA, "localNode", new StringMessageRequest("test"), new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public void handleResponse(StringMessageResponse response) { responseString.set(response.message); responseLatch.countDown(); } @Override public void handleException(TransportException exp) { exception.set(exp); responseLatch.countDown(); } @Override public String executor() { return ThreadPool.Names.GENERIC; } }); responseLatch.await(); assertNull(exception.get()); assertThat(responseString.get(), equalTo("test")); } public void testAdapterSendReceiveCallbacks() throws Exception { final TransportRequestHandler<TransportRequest.Empty> requestHandler = (request, channel) -> { try { if (randomBoolean()) { channel.sendResponse(TransportResponse.Empty.INSTANCE); } else { channel.sendResponse(new ElasticsearchException("simulated")); } } catch (IOException e) { logger.error("Unexpected failure", e); fail(e.getMessage()); } }; final String ACTION = "action"; serviceA.registerRequestHandler(ACTION, TransportRequest.Empty::new, ThreadPool.Names.GENERIC, requestHandler); serviceB.registerRequestHandler(ACTION, TransportRequest.Empty::new, ThreadPool.Names.GENERIC, requestHandler); class CountingTracer extends MockTransportService.Tracer { AtomicInteger requestsReceived = new AtomicInteger(); AtomicInteger requestsSent = new AtomicInteger(); AtomicInteger responseReceived = new AtomicInteger(); AtomicInteger responseSent = new AtomicInteger(); @Override public void receivedRequest(long requestId, String action) { if (action.equals(ACTION)) { requestsReceived.incrementAndGet(); } } @Override public void responseSent(long requestId, String action) { if (action.equals(ACTION)) { responseSent.incrementAndGet(); } } @Override public void responseSent(long requestId, String action, Throwable t) { if (action.equals(ACTION)) { responseSent.incrementAndGet(); } } @Override public void receivedResponse(long requestId, DiscoveryNode sourceNode, String action) { if (action.equals(ACTION)) { responseReceived.incrementAndGet(); } } @Override public void requestSent(DiscoveryNode node, long requestId, String action, TransportRequestOptions options) { if (action.equals(ACTION)) { requestsSent.incrementAndGet(); } } } final CountingTracer tracerA = new CountingTracer(); final CountingTracer tracerB = new CountingTracer(); serviceA.addTracer(tracerA); serviceB.addTracer(tracerB); try { serviceA .submitRequest(nodeB, ACTION, TransportRequest.Empty.INSTANCE, EmptyTransportResponseHandler.INSTANCE_SAME).get(); } catch (ExecutionException e) { assertThat(e.getCause(), instanceOf(ElasticsearchException.class)); assertThat(ExceptionsHelper.unwrapCause(e.getCause()).getMessage(), equalTo("simulated")); } // use assert busy as call backs are sometime called after the response have been sent assertBusy(() -> { assertThat(tracerA.requestsReceived.get(), equalTo(0)); assertThat(tracerA.requestsSent.get(), equalTo(1)); assertThat(tracerA.responseReceived.get(), equalTo(1)); assertThat(tracerA.responseSent.get(), equalTo(0)); assertThat(tracerB.requestsReceived.get(), equalTo(1)); assertThat(tracerB.requestsSent.get(), equalTo(0)); assertThat(tracerB.responseReceived.get(), equalTo(0)); assertThat(tracerB.responseSent.get(), equalTo(1)); }); try { serviceA .submitRequest(nodeA, ACTION, TransportRequest.Empty.INSTANCE, EmptyTransportResponseHandler.INSTANCE_SAME).get(); } catch (ExecutionException e) { assertThat(e.getCause(), instanceOf(ElasticsearchException.class)); assertThat(ExceptionsHelper.unwrapCause(e.getCause()).getMessage(), equalTo("simulated")); } // use assert busy as call backs are sometime called after the response have been sent assertBusy(() -> { assertThat(tracerA.requestsReceived.get(), equalTo(1)); assertThat(tracerA.requestsSent.get(), equalTo(2)); assertThat(tracerA.responseReceived.get(), equalTo(2)); assertThat(tracerA.responseSent.get(), equalTo(1)); assertThat(tracerB.requestsReceived.get(), equalTo(1)); assertThat(tracerB.requestsSent.get(), equalTo(0)); assertThat(tracerB.responseReceived.get(), equalTo(0)); assertThat(tracerB.responseSent.get(), equalTo(1)); }); } public void testVoidMessageCompressed() { serviceA.registerRequestHandler("sayHello", TransportRequest.Empty::new, ThreadPool.Names.GENERIC, (request, channel) -> { try { TransportResponseOptions responseOptions = TransportResponseOptions.builder().withCompress(true).build(); channel.sendResponse(TransportResponse.Empty.INSTANCE, responseOptions); } catch (IOException e) { logger.error("Unexpected failure", e); fail(e.getMessage()); } }); TransportFuture<TransportResponse.Empty> res = serviceB.submitRequest(nodeA, "sayHello", TransportRequest.Empty.INSTANCE, TransportRequestOptions.builder().withCompress(true).build(), new TransportResponseHandler<TransportResponse.Empty>() { @Override public TransportResponse.Empty newInstance() { return TransportResponse.Empty.INSTANCE; } @Override public String executor() { return ThreadPool.Names.GENERIC; } @Override public void handleResponse(TransportResponse.Empty response) { } @Override public void handleException(TransportException exp) { logger.error("Unexpected failure", exp); fail("got exception instead of a response: " + exp.getMessage()); } }); try { TransportResponse.Empty message = res.get(); assertThat(message, notNullValue()); } catch (Exception e) { assertThat(e.getMessage(), false, equalTo(true)); } } public void testHelloWorldCompressed() { serviceA.registerRequestHandler("sayHello", StringMessageRequest::new, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() { @Override public void messageReceived(StringMessageRequest request, TransportChannel channel) { assertThat("moshe", equalTo(request.message)); try { TransportResponseOptions responseOptions = TransportResponseOptions.builder().withCompress(true).build(); channel.sendResponse(new StringMessageResponse("hello " + request.message), responseOptions); } catch (IOException e) { logger.error("Unexpected failure", e); fail(e.getMessage()); } } }); TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHello", new StringMessageRequest("moshe"), TransportRequestOptions.builder().withCompress(true).build(), new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public String executor() { return ThreadPool.Names.GENERIC; } @Override public void handleResponse(StringMessageResponse response) { assertThat("hello moshe", equalTo(response.message)); } @Override public void handleException(TransportException exp) { logger.error("Unexpected failure", exp); fail("got exception instead of a response: " + exp.getMessage()); } }); try { StringMessageResponse message = res.get(); assertThat("hello moshe", equalTo(message.message)); } catch (Exception e) { assertThat(e.getMessage(), false, equalTo(true)); } } public void testErrorMessage() { serviceA.registerRequestHandler("sayHelloException", StringMessageRequest::new, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() { @Override public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception { assertThat("moshe", equalTo(request.message)); throw new RuntimeException("bad message !!!"); } }); TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHelloException", new StringMessageRequest("moshe"), new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public String executor() { return ThreadPool.Names.GENERIC; } @Override public void handleResponse(StringMessageResponse response) { fail("got response instead of exception"); } @Override public void handleException(TransportException exp) { assertThat("runtime_exception: bad message !!!", equalTo(exp.getCause().getMessage())); } }); try { res.txGet(); fail("exception should be thrown"); } catch (Exception e) { assertThat(e.getCause().getMessage(), equalTo("runtime_exception: bad message !!!")); } } public void testDisconnectListener() throws Exception { final CountDownLatch latch = new CountDownLatch(1); TransportConnectionListener disconnectListener = new TransportConnectionListener() { @Override public void onNodeConnected(DiscoveryNode node) { fail("node connected should not be called, all connection have been done previously, node: " + node); } @Override public void onNodeDisconnected(DiscoveryNode node) { latch.countDown(); } }; serviceA.addConnectionListener(disconnectListener); serviceB.close(); assertThat(latch.await(5, TimeUnit.SECONDS), equalTo(true)); } public void testConcurrentSendRespondAndDisconnect() throws BrokenBarrierException, InterruptedException { Set<Exception> sendingErrors = ConcurrentCollections.newConcurrentSet(); Set<Exception> responseErrors = ConcurrentCollections.newConcurrentSet(); serviceA.registerRequestHandler("test", TestRequest::new, randomBoolean() ? ThreadPool.Names.SAME : ThreadPool.Names.GENERIC, (request, channel) -> { try { channel.sendResponse(new TestResponse()); } catch (Exception e) { logger.info("caught exception while responding", e); responseErrors.add(e); } }); final TransportRequestHandler<TestRequest> ignoringRequestHandler = (request, channel) -> { try { channel.sendResponse(new TestResponse()); } catch (Exception e) { // we don't really care what's going on B, we're testing through A logger.trace("caught exception while responding from node B", e); } }; serviceB.registerRequestHandler("test", TestRequest::new, ThreadPool.Names.SAME, ignoringRequestHandler); int halfSenders = scaledRandomIntBetween(3, 10); final CyclicBarrier go = new CyclicBarrier(halfSenders * 2 + 1); final CountDownLatch done = new CountDownLatch(halfSenders * 2); for (int i = 0; i < halfSenders; i++) { // B senders just generated activity so serciveA can respond, we don't test what's going on there final int sender = i; threadPool.executor(ThreadPool.Names.GENERIC).execute(new AbstractRunnable() { @Override public void onFailure(Exception e) { logger.trace("caught exception while sending from B", e); } @Override protected void doRun() throws Exception { go.await(); for (int iter = 0; iter < 10; iter++) { PlainActionFuture<TestResponse> listener = new PlainActionFuture<>(); final String info = sender + "_B_" + iter; serviceB.sendRequest(nodeA, "test", new TestRequest(info), new ActionListenerResponseHandler<>(listener, TestResponse::new)); try { listener.actionGet(); } catch (Exception e) { logger.trace( (Supplier<?>) () -> new ParameterizedMessage("caught exception while sending to node {}", nodeA), e); } } } @Override public void onAfter() { done.countDown(); } }); } for (int i = 0; i < halfSenders; i++) { final int sender = i; threadPool.executor(ThreadPool.Names.GENERIC).execute(new AbstractRunnable() { @Override public void onFailure(Exception e) { logger.error("unexpected error", e); sendingErrors.add(e); } @Override protected void doRun() throws Exception { go.await(); for (int iter = 0; iter < 10; iter++) { PlainActionFuture<TestResponse> listener = new PlainActionFuture<>(); final String info = sender + "_" + iter; final DiscoveryNode node = nodeB; // capture now try { serviceA.sendRequest(node, "test", new TestRequest(info), new ActionListenerResponseHandler<>(listener, TestResponse::new)); try { listener.actionGet(); } catch (ConnectTransportException e) { // ok! } catch (Exception e) { logger.error( (Supplier<?>) () -> new ParameterizedMessage("caught exception while sending to node {}", node), e); sendingErrors.add(e); } } catch (NodeNotConnectedException ex) { // ok } } } @Override public void onAfter() { done.countDown(); } }); } go.await(); for (int i = 0; i <= 10; i++) { if (i % 3 == 0) { // simulate restart of nodeB serviceB.close(); MockTransportService newService = buildService("TS_B_" + i, version1, null); newService.registerRequestHandler("test", TestRequest::new, ThreadPool.Names.SAME, ignoringRequestHandler); serviceB = newService; nodeB = newService.getLocalDiscoNode(); serviceB.connectToNode(nodeA); serviceA.connectToNode(nodeB); } else if (serviceA.nodeConnected(nodeB)) { serviceA.disconnectFromNode(nodeB); } else { serviceA.connectToNode(nodeB); } } done.await(); assertThat("found non connection errors while sending", sendingErrors, empty()); assertThat("found non connection errors while responding", responseErrors, empty()); } public void testNotifyOnShutdown() throws Exception { final CountDownLatch latch2 = new CountDownLatch(1); serviceA.registerRequestHandler("foobar", StringMessageRequest::new, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() { @Override public void messageReceived(StringMessageRequest request, TransportChannel channel) { try { latch2.await(); logger.info("Stop ServiceB now"); serviceB.stop(); } catch (Exception e) { fail(e.getMessage()); } } }); TransportFuture<TransportResponse.Empty> foobar = serviceB.submitRequest(nodeA, "foobar", new StringMessageRequest(""), TransportRequestOptions.EMPTY, EmptyTransportResponseHandler.INSTANCE_SAME); latch2.countDown(); try { foobar.txGet(); fail("TransportException expected"); } catch (TransportException ex) { } } public void testTimeoutSendExceptionWithNeverSendingBackResponse() throws Exception { serviceA.registerRequestHandler("sayHelloTimeoutNoResponse", StringMessageRequest::new, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() { @Override public void messageReceived(StringMessageRequest request, TransportChannel channel) { assertThat("moshe", equalTo(request.message)); // don't send back a response } }); TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHelloTimeoutNoResponse", new StringMessageRequest("moshe"), TransportRequestOptions.builder().withTimeout(100).build(), new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public String executor() { return ThreadPool.Names.GENERIC; } @Override public void handleResponse(StringMessageResponse response) { fail("got response instead of exception"); } @Override public void handleException(TransportException exp) { assertThat(exp, instanceOf(ReceiveTimeoutTransportException.class)); } }); try { StringMessageResponse message = res.txGet(); fail("exception should be thrown"); } catch (Exception e) { assertThat(e, instanceOf(ReceiveTimeoutTransportException.class)); } } public void testTimeoutSendExceptionWithDelayedResponse() throws Exception { CountDownLatch waitForever = new CountDownLatch(1); CountDownLatch doneWaitingForever = new CountDownLatch(1); Semaphore inFlight = new Semaphore(Integer.MAX_VALUE); serviceA.registerRequestHandler("sayHelloTimeoutDelayedResponse", StringMessageRequest::new, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() { @Override public void messageReceived(StringMessageRequest request, TransportChannel channel) throws InterruptedException { String message = request.message; inFlight.acquireUninterruptibly(); try { if ("forever".equals(message)) { waitForever.await(); } else { TimeValue sleep = TimeValue.parseTimeValue(message, null, "sleep"); Thread.sleep(sleep.millis()); } try { channel.sendResponse(new StringMessageResponse("hello " + request.message)); } catch (IOException e) { logger.error("Unexpected failure", e); fail(e.getMessage()); } } finally { inFlight.release(); if ("forever".equals(message)) { doneWaitingForever.countDown(); } } } }); final CountDownLatch latch = new CountDownLatch(1); TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHelloTimeoutDelayedResponse", new StringMessageRequest("forever"), TransportRequestOptions.builder().withTimeout(100).build(), new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public String executor() { return ThreadPool.Names.GENERIC; } @Override public void handleResponse(StringMessageResponse response) { latch.countDown(); fail("got response instead of exception"); } @Override public void handleException(TransportException exp) { latch.countDown(); assertThat(exp, instanceOf(ReceiveTimeoutTransportException.class)); } }); try { res.txGet(); fail("exception should be thrown"); } catch (Exception e) { assertThat(e, instanceOf(ReceiveTimeoutTransportException.class)); } latch.await(); List<Runnable> assertions = new ArrayList<>(); for (int i = 0; i < 10; i++) { final int counter = i; // now, try and send another request, this times, with a short timeout TransportFuture<StringMessageResponse> result = serviceB.submitRequest(nodeA, "sayHelloTimeoutDelayedResponse", new StringMessageRequest(counter + "ms"), TransportRequestOptions.builder().withTimeout(3000).build(), new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public String executor() { return ThreadPool.Names.GENERIC; } @Override public void handleResponse(StringMessageResponse response) { assertThat("hello " + counter + "ms", equalTo(response.message)); } @Override public void handleException(TransportException exp) { logger.error("Unexpected failure", exp); fail("got exception instead of a response for " + counter + ": " + exp.getDetailedMessage()); } }); assertions.add(() -> { StringMessageResponse message = result.txGet(); assertThat(message.message, equalTo("hello " + counter + "ms")); }); } for (Runnable runnable : assertions) { runnable.run(); } waitForever.countDown(); doneWaitingForever.await(); assertTrue(inFlight.tryAcquire(Integer.MAX_VALUE, 10, TimeUnit.SECONDS)); } public void testTracerLog() throws InterruptedException { TransportRequestHandler handler = (request, channel) -> channel.sendResponse(new StringMessageResponse("")); TransportRequestHandler handlerWithError = new TransportRequestHandler<StringMessageRequest>() { @Override public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception { if (request.timeout() > 0) { Thread.sleep(request.timeout); } channel.sendResponse(new RuntimeException("")); } }; final Semaphore requestCompleted = new Semaphore(0); TransportResponseHandler noopResponseHandler = new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public void handleResponse(StringMessageResponse response) { requestCompleted.release(); } @Override public void handleException(TransportException exp) { requestCompleted.release(); } @Override public String executor() { return ThreadPool.Names.SAME; } }; serviceA.registerRequestHandler("test", StringMessageRequest::new, ThreadPool.Names.SAME, handler); serviceA.registerRequestHandler("testError", StringMessageRequest::new, ThreadPool.Names.SAME, handlerWithError); serviceB.registerRequestHandler("test", StringMessageRequest::new, ThreadPool.Names.SAME, handler); serviceB.registerRequestHandler("testError", StringMessageRequest::new, ThreadPool.Names.SAME, handlerWithError); final Tracer tracer = new Tracer(new HashSet<>(Arrays.asList("test", "testError"))); // the tracer is invoked concurrently after the actual action is executed. that means a Tracer#requestSent can still be in-flight // from a handshake executed on connect in the setup method. this might confuse this test since it expects exact number of // invocations. To prevent any unrelated events messing with this test we filter on the actions we execute in this test. serviceA.addTracer(tracer); serviceB.addTracer(tracer); tracer.reset(4); boolean timeout = randomBoolean(); TransportRequestOptions options = timeout ? TransportRequestOptions.builder().withTimeout(1).build() : TransportRequestOptions.EMPTY; serviceA.sendRequest(nodeB, "test", new StringMessageRequest("", 10), options, noopResponseHandler); requestCompleted.acquire(); tracer.expectedEvents.get().await(); assertThat("didn't see request sent", tracer.sawRequestSent, equalTo(true)); assertThat("didn't see request received", tracer.sawRequestReceived, equalTo(true)); assertThat("didn't see response sent", tracer.sawResponseSent, equalTo(true)); assertThat("didn't see response received", tracer.sawResponseReceived, equalTo(true)); assertThat("saw error sent", tracer.sawErrorSent, equalTo(false)); tracer.reset(4); serviceA.sendRequest(nodeB, "testError", new StringMessageRequest(""), noopResponseHandler); requestCompleted.acquire(); tracer.expectedEvents.get().await(); assertThat("didn't see request sent", tracer.sawRequestSent, equalTo(true)); assertThat("didn't see request received", tracer.sawRequestReceived, equalTo(true)); assertThat("saw response sent", tracer.sawResponseSent, equalTo(false)); assertThat("didn't see response received", tracer.sawResponseReceived, equalTo(true)); assertThat("didn't see error sent", tracer.sawErrorSent, equalTo(true)); String includeSettings; String excludeSettings; if (randomBoolean()) { // sometimes leave include empty (default) includeSettings = randomBoolean() ? "*" : ""; excludeSettings = "*Error"; } else { includeSettings = "test"; excludeSettings = "DOESN'T_MATCH"; } clusterSettings.applySettings(Settings.builder() .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), includeSettings) .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), excludeSettings) .build()); tracer.reset(4); serviceA.sendRequest(nodeB, "test", new StringMessageRequest(""), noopResponseHandler); requestCompleted.acquire(); tracer.expectedEvents.get().await(); assertThat("didn't see request sent", tracer.sawRequestSent, equalTo(true)); assertThat("didn't see request received", tracer.sawRequestReceived, equalTo(true)); assertThat("didn't see response sent", tracer.sawResponseSent, equalTo(true)); assertThat("didn't see response received", tracer.sawResponseReceived, equalTo(true)); assertThat("saw error sent", tracer.sawErrorSent, equalTo(false)); tracer.reset(2); serviceA.sendRequest(nodeB, "testError", new StringMessageRequest(""), noopResponseHandler); requestCompleted.acquire(); tracer.expectedEvents.get().await(); assertThat("saw request sent", tracer.sawRequestSent, equalTo(false)); assertThat("didn't see request received", tracer.sawRequestReceived, equalTo(true)); assertThat("saw response sent", tracer.sawResponseSent, equalTo(false)); assertThat("saw response received", tracer.sawResponseReceived, equalTo(false)); assertThat("didn't see error sent", tracer.sawErrorSent, equalTo(true)); } private static class Tracer extends MockTransportService.Tracer { private final Set<String> actions; public volatile boolean sawRequestSent; public volatile boolean sawRequestReceived; public volatile boolean sawResponseSent; public volatile boolean sawErrorSent; public volatile boolean sawResponseReceived; public AtomicReference<CountDownLatch> expectedEvents = new AtomicReference<>(); Tracer(Set<String> actions) { this.actions = actions; } @Override public void receivedRequest(long requestId, String action) { super.receivedRequest(requestId, action); if (actions.contains(action)) { sawRequestReceived = true; expectedEvents.get().countDown(); } } @Override public void requestSent(DiscoveryNode node, long requestId, String action, TransportRequestOptions options) { super.requestSent(node, requestId, action, options); if (actions.contains(action)) { sawRequestSent = true; expectedEvents.get().countDown(); } } @Override public void responseSent(long requestId, String action) { super.responseSent(requestId, action); if (actions.contains(action)) { sawResponseSent = true; expectedEvents.get().countDown(); } } @Override public void responseSent(long requestId, String action, Throwable t) { super.responseSent(requestId, action, t); if (actions.contains(action)) { sawErrorSent = true; expectedEvents.get().countDown(); } } @Override public void receivedResponse(long requestId, DiscoveryNode sourceNode, String action) { super.receivedResponse(requestId, sourceNode, action); if (actions.contains(action)) { sawResponseReceived = true; expectedEvents.get().countDown(); } } public void reset(int expectedCount) { sawRequestSent = false; sawRequestReceived = false; sawResponseSent = false; sawErrorSent = false; sawResponseReceived = false; expectedEvents.set(new CountDownLatch(expectedCount)); } } public static class StringMessageRequest extends TransportRequest { private String message; private long timeout; StringMessageRequest(String message, long timeout) { this.message = message; this.timeout = timeout; } public StringMessageRequest() { } public StringMessageRequest(String message) { this(message, -1); } public long timeout() { return timeout; } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); message = in.readString(); timeout = in.readLong(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(message); out.writeLong(timeout); } } static class StringMessageResponse extends TransportResponse { private String message; StringMessageResponse(String message) { this.message = message; } StringMessageResponse() { } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); message = in.readString(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(message); } } public static class Version0Request extends TransportRequest { int value1; @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); value1 = in.readInt(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeInt(value1); } } public static class Version1Request extends Version0Request { int value2; @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); if (in.getVersion().onOrAfter(version1)) { value2 = in.readInt(); } } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); if (out.getVersion().onOrAfter(version1)) { out.writeInt(value2); } } } static class Version0Response extends TransportResponse { int value1; @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); value1 = in.readInt(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeInt(value1); } } static class Version1Response extends Version0Response { int value2; @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); if (in.getVersion().onOrAfter(version1)) { value2 = in.readInt(); } } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); if (out.getVersion().onOrAfter(version1)) { out.writeInt(value2); } } } public void testVersionFrom0to1() throws Exception { serviceB.registerRequestHandler("/version", Version1Request::new, ThreadPool.Names.SAME, new TransportRequestHandler<Version1Request>() { @Override public void messageReceived(Version1Request request, TransportChannel channel) throws Exception { assertThat(request.value1, equalTo(1)); assertThat(request.value2, equalTo(0)); // not set, coming from service A Version1Response response = new Version1Response(); response.value1 = 1; response.value2 = 2; channel.sendResponse(response); assertEquals(version0, channel.getVersion()); } }); Version0Request version0Request = new Version0Request(); version0Request.value1 = 1; Version0Response version0Response = serviceA.submitRequest(nodeB, "/version", version0Request, new TransportResponseHandler<Version0Response>() { @Override public Version0Response newInstance() { return new Version0Response(); } @Override public void handleResponse(Version0Response response) { assertThat(response.value1, equalTo(1)); } @Override public void handleException(TransportException exp) { logger.error("Unexpected failure", exp); fail("got exception instead of a response: " + exp.getMessage()); } @Override public String executor() { return ThreadPool.Names.SAME; } }).txGet(); assertThat(version0Response.value1, equalTo(1)); } public void testVersionFrom1to0() throws Exception { serviceA.registerRequestHandler("/version", Version0Request::new, ThreadPool.Names.SAME, new TransportRequestHandler<Version0Request>() { @Override public void messageReceived(Version0Request request, TransportChannel channel) throws Exception { assertThat(request.value1, equalTo(1)); Version0Response response = new Version0Response(); response.value1 = 1; channel.sendResponse(response); assertEquals(version0, channel.getVersion()); } }); Version1Request version1Request = new Version1Request(); version1Request.value1 = 1; version1Request.value2 = 2; Version1Response version1Response = serviceB.submitRequest(nodeA, "/version", version1Request, new TransportResponseHandler<Version1Response>() { @Override public Version1Response newInstance() { return new Version1Response(); } @Override public void handleResponse(Version1Response response) { assertThat(response.value1, equalTo(1)); assertThat(response.value2, equalTo(0)); // initial values, cause its serialized from version 0 } @Override public void handleException(TransportException exp) { logger.error("Unexpected failure", exp); fail("got exception instead of a response: " + exp.getMessage()); } @Override public String executor() { return ThreadPool.Names.SAME; } }).txGet(); assertThat(version1Response.value1, equalTo(1)); assertThat(version1Response.value2, equalTo(0)); } public void testVersionFrom1to1() throws Exception { serviceB.registerRequestHandler("/version", Version1Request::new, ThreadPool.Names.SAME, (request, channel) -> { assertThat(request.value1, equalTo(1)); assertThat(request.value2, equalTo(2)); Version1Response response = new Version1Response(); response.value1 = 1; response.value2 = 2; channel.sendResponse(response); assertEquals(version1, channel.getVersion()); }); Version1Request version1Request = new Version1Request(); version1Request.value1 = 1; version1Request.value2 = 2; Version1Response version1Response = serviceB.submitRequest(nodeB, "/version", version1Request, new TransportResponseHandler<Version1Response>() { @Override public Version1Response newInstance() { return new Version1Response(); } @Override public void handleResponse(Version1Response response) { assertThat(response.value1, equalTo(1)); assertThat(response.value2, equalTo(2)); } @Override public void handleException(TransportException exp) { logger.error("Unexpected failure", exp); fail("got exception instead of a response: " + exp.getMessage()); } @Override public String executor() { return ThreadPool.Names.SAME; } }).txGet(); assertThat(version1Response.value1, equalTo(1)); assertThat(version1Response.value2, equalTo(2)); } public void testVersionFrom0to0() throws Exception { serviceA.registerRequestHandler("/version", Version0Request::new, ThreadPool.Names.SAME, (request, channel) -> { assertThat(request.value1, equalTo(1)); Version0Response response = new Version0Response(); response.value1 = 1; channel.sendResponse(response); assertEquals(version0, channel.getVersion()); }); Version0Request version0Request = new Version0Request(); version0Request.value1 = 1; Version0Response version0Response = serviceA.submitRequest(nodeA, "/version", version0Request, new TransportResponseHandler<Version0Response>() { @Override public Version0Response newInstance() { return new Version0Response(); } @Override public void handleResponse(Version0Response response) { assertThat(response.value1, equalTo(1)); } @Override public void handleException(TransportException exp) { logger.error("Unexpected failure", exp); fail("got exception instead of a response: " + exp.getMessage()); } @Override public String executor() { return ThreadPool.Names.SAME; } }).txGet(); assertThat(version0Response.value1, equalTo(1)); } public void testMockFailToSendNoConnectRule() throws IOException { serviceA.registerRequestHandler("sayHello", StringMessageRequest::new, ThreadPool.Names.GENERIC, (request, channel) -> { assertThat("moshe", equalTo(request.message)); throw new RuntimeException("bad message !!!"); }); serviceB.addFailToSendNoConnectRule(serviceA); TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHello", new StringMessageRequest("moshe"), new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public String executor() { return ThreadPool.Names.GENERIC; } @Override public void handleResponse(StringMessageResponse response) { fail("got response instead of exception"); } @Override public void handleException(TransportException exp) { Throwable cause = ExceptionsHelper.unwrapCause(exp); assertThat(cause, instanceOf(ConnectTransportException.class)); assertThat(((ConnectTransportException)cause).node(), equalTo(nodeA)); } }); try { res.txGet(); fail("exception should be thrown"); } catch (Exception e) { Throwable cause = ExceptionsHelper.unwrapCause(e); assertThat(cause, instanceOf(ConnectTransportException.class)); assertThat(((ConnectTransportException)cause).node(), equalTo(nodeA)); } try { serviceB.connectToNode(nodeA); fail("exception should be thrown"); } catch (ConnectTransportException e) { // all is well } try (Transport.Connection connection = serviceB.openConnection(nodeA, MockTcpTransport.LIGHT_PROFILE)) { serviceB.handshake(connection, 100); fail("exception should be thrown"); } catch (IllegalStateException e) { // all is well } } public void testMockUnresponsiveRule() throws IOException { serviceA.registerRequestHandler("sayHello", StringMessageRequest::new, ThreadPool.Names.GENERIC, new TransportRequestHandler<StringMessageRequest>() { @Override public void messageReceived(StringMessageRequest request, TransportChannel channel) throws Exception { assertThat("moshe", equalTo(request.message)); throw new RuntimeException("bad message !!!"); } }); serviceB.addUnresponsiveRule(serviceA); TransportFuture<StringMessageResponse> res = serviceB.submitRequest(nodeA, "sayHello", new StringMessageRequest("moshe"), TransportRequestOptions.builder().withTimeout(100).build(), new TransportResponseHandler<StringMessageResponse>() { @Override public StringMessageResponse newInstance() { return new StringMessageResponse(); } @Override public String executor() { return ThreadPool.Names.GENERIC; } @Override public void handleResponse(StringMessageResponse response) { fail("got response instead of exception"); } @Override public void handleException(TransportException exp) { assertThat(exp, instanceOf(ReceiveTimeoutTransportException.class)); } }); try { res.txGet(); fail("exception should be thrown"); } catch (Exception e) { assertThat(e, instanceOf(ReceiveTimeoutTransportException.class)); } try { serviceB.disconnectFromNode(nodeA); serviceB.connectToNode(nodeA); fail("exception should be thrown"); } catch (ConnectTransportException e) { // all is well } try (Transport.Connection connection = serviceB.openConnection(nodeA, MockTcpTransport.LIGHT_PROFILE)) { serviceB.handshake(connection, 100); fail("exception should be thrown"); } catch (IllegalStateException e) { // all is well } } public void testHostOnMessages() throws InterruptedException { final CountDownLatch latch = new CountDownLatch(2); final AtomicReference<TransportAddress> addressA = new AtomicReference<>(); final AtomicReference<TransportAddress> addressB = new AtomicReference<>(); serviceB.registerRequestHandler("action1", TestRequest::new, ThreadPool.Names.SAME, new TransportRequestHandler<TestRequest>() { @Override public void messageReceived(TestRequest request, TransportChannel channel) throws Exception { addressA.set(request.remoteAddress()); channel.sendResponse(new TestResponse()); latch.countDown(); } }); serviceA.sendRequest(nodeB, "action1", new TestRequest(), new TransportResponseHandler<TestResponse>() { @Override public TestResponse newInstance() { return new TestResponse(); } @Override public void handleResponse(TestResponse response) { addressB.set(response.remoteAddress()); latch.countDown(); } @Override public void handleException(TransportException exp) { latch.countDown(); } @Override public String executor() { return ThreadPool.Names.SAME; } }); if (!latch.await(10, TimeUnit.SECONDS)) { fail("message round trip did not complete within a sensible time frame"); } assertTrue(nodeA.getAddress().getAddress().equals(addressA.get().getAddress())); assertTrue(nodeB.getAddress().getAddress().equals(addressB.get().getAddress())); } public void testBlockingIncomingRequests() throws Exception { try (TransportService service = buildService("TS_TEST", version0, null, Settings.EMPTY, false, false)) { AtomicBoolean requestProcessed = new AtomicBoolean(false); service.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, (request, channel) -> { requestProcessed.set(true); channel.sendResponse(TransportResponse.Empty.INSTANCE); }); DiscoveryNode node = service.getLocalNode(); serviceA.close(); serviceA = buildService("TS_A", version0, null, Settings.EMPTY, true, false); try (Transport.Connection connection = serviceA.openConnection(node, null)) { CountDownLatch latch = new CountDownLatch(1); serviceA.sendRequest(connection, "action", new TestRequest(), TransportRequestOptions.EMPTY, new TransportResponseHandler<TestResponse>() { @Override public TestResponse newInstance() { return new TestResponse(); } @Override public void handleResponse(TestResponse response) { latch.countDown(); } @Override public void handleException(TransportException exp) { latch.countDown(); } @Override public String executor() { return ThreadPool.Names.SAME; } }); assertFalse(requestProcessed.get()); service.acceptIncomingRequests(); assertBusy(() -> assertTrue(requestProcessed.get())); latch.await(); } } } public static class TestRequest extends TransportRequest { String info; int resendCount; public TestRequest() { } public TestRequest(String info) { this.info = info; } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); info = in.readOptionalString(); resendCount = in.readInt(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeOptionalString(info); out.writeInt(resendCount); } @Override public String toString() { return "TestRequest{" + "info='" + info + '\'' + '}'; } } private static class TestResponse extends TransportResponse { String info; TestResponse() { } TestResponse(String info) { this.info = info; } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); info = in.readOptionalString(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeOptionalString(info); } @Override public String toString() { return "TestResponse{" + "info='" + info + '\'' + '}'; } } public void testSendRandomRequests() throws InterruptedException { TransportService serviceC = build( Settings.builder() .put("name", "TS_TEST") .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "") .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") .build(), version0, null, true); DiscoveryNode nodeC = serviceC.getLocalNode(); serviceC.acceptIncomingRequests(); final CountDownLatch latch = new CountDownLatch(4); TransportConnectionListener waitForConnection = new TransportConnectionListener() { @Override public void onNodeConnected(DiscoveryNode node) { latch.countDown(); } @Override public void onNodeDisconnected(DiscoveryNode node) { fail("disconnect should not be called " + node); } }; serviceA.addConnectionListener(waitForConnection); serviceB.addConnectionListener(waitForConnection); serviceC.addConnectionListener(waitForConnection); serviceC.connectToNode(nodeA); serviceC.connectToNode(nodeB); serviceA.connectToNode(nodeC); serviceB.connectToNode(nodeC); latch.await(); serviceA.removeConnectionListener(waitForConnection); serviceB.removeConnectionListener(waitForConnection); serviceC.removeConnectionListener(waitForConnection); Map<TransportService, DiscoveryNode> toNodeMap = new HashMap<>(); toNodeMap.put(serviceA, nodeA); toNodeMap.put(serviceB, nodeB); toNodeMap.put(serviceC, nodeC); AtomicBoolean fail = new AtomicBoolean(false); class TestRequestHandler implements TransportRequestHandler<TestRequest> { private final TransportService service; TestRequestHandler(TransportService service) { this.service = service; } @Override public void messageReceived(TestRequest request, TransportChannel channel) throws Exception { if (randomBoolean()) { Thread.sleep(randomIntBetween(10, 50)); } if (fail.get()) { throw new IOException("forced failure"); } if (randomBoolean() && request.resendCount++ < 20) { DiscoveryNode node = randomFrom(nodeA, nodeB, nodeC); logger.debug("send secondary request from {} to {} - {}", toNodeMap.get(service), node, request.info); service.sendRequest(node, "action1", new TestRequest("secondary " + request.info), TransportRequestOptions.builder().withCompress(randomBoolean()).build(), new TransportResponseHandler<TestResponse>() { @Override public TestResponse newInstance() { return new TestResponse(); } @Override public void handleResponse(TestResponse response) { try { if (randomBoolean()) { Thread.sleep(randomIntBetween(10, 50)); } logger.debug("send secondary response {}", response.info); channel.sendResponse(response); } catch (Exception e) { throw new RuntimeException(e); } } @Override public void handleException(TransportException exp) { try { logger.debug("send secondary exception response for request {}", request.info); channel.sendResponse(exp); } catch (Exception e) { throw new RuntimeException(e); } } @Override public String executor() { return randomBoolean() ? ThreadPool.Names.SAME : ThreadPool.Names.GENERIC; } }); } else { logger.debug("send response for {}", request.info); channel.sendResponse(new TestResponse("Response for: " + request.info)); } } } serviceB.registerRequestHandler("action1", TestRequest::new, randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC), new TestRequestHandler(serviceB)); serviceC.registerRequestHandler("action1", TestRequest::new, randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC), new TestRequestHandler(serviceC)); serviceA.registerRequestHandler("action1", TestRequest::new, randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC), new TestRequestHandler(serviceA)); int iters = randomIntBetween(30, 60); CountDownLatch allRequestsDone = new CountDownLatch(iters); class TestResponseHandler implements TransportResponseHandler<TestResponse> { private final int id; TestResponseHandler(int id) { this.id = id; } @Override public TestResponse newInstance() { return new TestResponse(); } @Override public void handleResponse(TestResponse response) { logger.debug("---> received response: {}", response.info); allRequestsDone.countDown(); } @Override public void handleException(TransportException exp) { logger.debug((Supplier<?>) () -> new ParameterizedMessage("---> received exception for id {}", id), exp); allRequestsDone.countDown(); Throwable unwrap = ExceptionsHelper.unwrap(exp, IOException.class); assertNotNull(unwrap); assertEquals(IOException.class, unwrap.getClass()); assertEquals("forced failure", unwrap.getMessage()); } @Override public String executor() { return randomBoolean() ? ThreadPool.Names.SAME : ThreadPool.Names.GENERIC; } } for (int i = 0; i < iters; i++) { TransportService service = randomFrom(serviceC, serviceB, serviceA); DiscoveryNode node = randomFrom(nodeC, nodeB, nodeA); logger.debug("send from {} to {}", toNodeMap.get(service), node); service.sendRequest(node, "action1", new TestRequest("REQ[" + i + "]"), TransportRequestOptions.builder().withCompress(randomBoolean()).build(), new TestResponseHandler(i)); } logger.debug("waiting for response"); fail.set(randomBoolean()); boolean await = allRequestsDone.await(5, TimeUnit.SECONDS); if (await == false) { logger.debug("now failing forcefully"); fail.set(true); assertTrue(allRequestsDone.await(5, TimeUnit.SECONDS)); } logger.debug("DONE"); serviceC.close(); } public void testRegisterHandlerTwice() { serviceB.registerRequestHandler("action1", TestRequest::new, randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC), (request, message) -> {throw new AssertionError("boom");}); expectThrows(IllegalArgumentException.class, () -> serviceB.registerRequestHandler("action1", TestRequest::new, randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC), (request, message) -> {throw new AssertionError("boom");}) ); serviceA.registerRequestHandler("action1", TestRequest::new, randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC), (request, message) -> {throw new AssertionError("boom");}); } public void testTimeoutPerConnection() throws IOException { assumeTrue("Works only on BSD network stacks and apparently windows", Constants.MAC_OS_X || Constants.FREE_BSD || Constants.WINDOWS); try (ServerSocket socket = new MockServerSocket()) { // note - this test uses backlog=1 which is implementation specific ie. it might not work on some TCP/IP stacks // on linux (at least newer ones) the listen(addr, backlog=1) should just ignore new connections if the queue is full which // means that once we received an ACK from the client we just drop the packet on the floor (which is what we want) and we run // into a connection timeout quickly. Yet other implementations can for instance can terminate the connection within the 3 way // handshake which I haven't tested yet. socket.bind(new InetSocketAddress(InetAddress.getLocalHost(), 0), 1); socket.setReuseAddress(true); DiscoveryNode first = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(), socket.getLocalPort()), emptyMap(), emptySet(), version0); DiscoveryNode second = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(), socket.getLocalPort()), emptyMap(), emptySet(), version0); ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); builder.addConnections(1, TransportRequestOptions.Type.BULK, TransportRequestOptions.Type.PING, TransportRequestOptions.Type.RECOVERY, TransportRequestOptions.Type.REG, TransportRequestOptions.Type.STATE); // connection with one connection and a large timeout -- should consume the one spot in the backlog queue try (TransportService service = buildService("TS_TPC", Version.CURRENT, null, Settings.EMPTY, true, false)) { IOUtils.close(service.openConnection(first, builder.build())); builder.setConnectTimeout(TimeValue.timeValueMillis(1)); final ConnectionProfile profile = builder.build(); // now with the 1ms timeout we got and test that is it's applied long startTime = System.nanoTime(); ConnectTransportException ex = expectThrows(ConnectTransportException.class, () -> service.openConnection(second, profile)); final long now = System.nanoTime(); final long timeTaken = TimeValue.nsecToMSec(now - startTime); assertTrue("test didn't timeout quick enough, time taken: [" + timeTaken + "]", timeTaken < TimeValue.timeValueSeconds(5).millis()); assertEquals(ex.getMessage(), "[][" + second.getAddress() + "] connect_timeout[1ms]"); } } } public void testHandshakeWithIncompatVersion() { assumeTrue("only tcp transport has a handshake method", serviceA.getOriginalTransport() instanceof TcpTransport); NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); try (MockTcpTransport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Settings.EMPTY, Collections.emptyList()), Version.fromString("2.0.0"))) { transport.transportServiceAdapter(serviceA.new Adapter()); transport.start(); DiscoveryNode node = new DiscoveryNode("TS_TPC", "TS_TPC", transport.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); builder.addConnections(1, TransportRequestOptions.Type.BULK, TransportRequestOptions.Type.PING, TransportRequestOptions.Type.RECOVERY, TransportRequestOptions.Type.REG, TransportRequestOptions.Type.STATE); expectThrows(ConnectTransportException.class, () -> serviceA.openConnection(node, builder.build())); } } public void testHandshakeUpdatesVersion() throws IOException { assumeTrue("only tcp transport has a handshake method", serviceA.getOriginalTransport() instanceof TcpTransport); NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); Version version = VersionUtils.randomVersionBetween(random(), Version.CURRENT.minimumCompatibilityVersion(), Version.CURRENT); try (MockTcpTransport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Settings.EMPTY, Collections.emptyList()),version)) { transport.transportServiceAdapter(serviceA.new Adapter()); transport.start(); DiscoveryNode node = new DiscoveryNode("TS_TPC", "TS_TPC", transport.boundAddress().publishAddress(), emptyMap(), emptySet(), Version.fromString("2.0.0")); ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); builder.addConnections(1, TransportRequestOptions.Type.BULK, TransportRequestOptions.Type.PING, TransportRequestOptions.Type.RECOVERY, TransportRequestOptions.Type.REG, TransportRequestOptions.Type.STATE); try (Transport.Connection connection = serviceA.openConnection(node, builder.build())) { assertEquals(connection.getVersion(), version); } } } public void testTcpHandshake() throws IOException, InterruptedException { assumeTrue("only tcp transport has a handshake method", serviceA.getOriginalTransport() instanceof TcpTransport); TcpTransport originalTransport = (TcpTransport) serviceA.getOriginalTransport(); NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); try (MockTcpTransport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Settings.EMPTY, Collections.emptyList())) { @Override protected String handleRequest(MockChannel mockChannel, String profileName, StreamInput stream, long requestId, int messageLengthBytes, Version version, InetSocketAddress remoteAddress, byte status) throws IOException { return super.handleRequest(mockChannel, profileName, stream, requestId, messageLengthBytes, version, remoteAddress, (byte)(status & ~(1<<3))); // we flip the isHandshake bit back and act like the handler is not found } }) { transport.transportServiceAdapter(serviceA.new Adapter()); transport.start(); // this acts like a node that doesn't have support for handshakes DiscoveryNode node = new DiscoveryNode("TS_TPC", "TS_TPC", transport.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); ConnectTransportException exception = expectThrows(ConnectTransportException.class, () -> serviceA.connectToNode(node)); assertTrue(exception.getCause() instanceof IllegalStateException); assertEquals("handshake failed", exception.getCause().getMessage()); } try (TransportService service = buildService("TS_TPC", Version.CURRENT, null); TcpTransport.NodeChannels connection = originalTransport.openConnection( new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0), null ) ) { Version version = originalTransport.executeHandshake(connection.getNode(), connection.channel(TransportRequestOptions.Type.PING), TimeValue.timeValueSeconds(10)); assertEquals(version, Version.CURRENT); } } public void testTcpHandshakeTimeout() throws IOException { try (ServerSocket socket = new MockServerSocket()) { socket.bind(new InetSocketAddress(InetAddress.getLocalHost(), 0), 1); socket.setReuseAddress(true); DiscoveryNode dummy = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(), socket.getLocalPort()), emptyMap(), emptySet(), version0); ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); builder.addConnections(1, TransportRequestOptions.Type.BULK, TransportRequestOptions.Type.PING, TransportRequestOptions.Type.RECOVERY, TransportRequestOptions.Type.REG, TransportRequestOptions.Type.STATE); builder.setHandshakeTimeout(TimeValue.timeValueMillis(1)); ConnectTransportException ex = expectThrows(ConnectTransportException.class, () -> serviceA.connectToNode(dummy, builder.build())); assertEquals("[][" + dummy.getAddress() + "] handshake_timeout[1ms]", ex.getMessage()); } } public void testTcpHandshakeConnectionReset() throws IOException, InterruptedException { try (ServerSocket socket = new MockServerSocket()) { socket.bind(new InetSocketAddress(InetAddress.getLocalHost(), 0), 1); socket.setReuseAddress(true); DiscoveryNode dummy = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(), socket.getLocalPort()), emptyMap(), emptySet(), version0); Thread t = new Thread() { @Override public void run() { try (Socket accept = socket.accept()) { if (randomBoolean()) { // sometimes wait until the other side sends the message accept.getInputStream().read(); } } catch (IOException e) { throw new UncheckedIOException(e); } } }; t.start(); ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); builder.addConnections(1, TransportRequestOptions.Type.BULK, TransportRequestOptions.Type.PING, TransportRequestOptions.Type.RECOVERY, TransportRequestOptions.Type.REG, TransportRequestOptions.Type.STATE); builder.setHandshakeTimeout(TimeValue.timeValueHours(1)); ConnectTransportException ex = expectThrows(ConnectTransportException.class, () -> serviceA.connectToNode(dummy, builder.build())); assertEquals(ex.getMessage(), "[][" + dummy.getAddress() + "] general node connection failure"); assertThat(ex.getCause().getMessage(), startsWith("handshake failed")); t.join(); } } public void testResponseHeadersArePreserved() throws InterruptedException { List<String> executors = new ArrayList<>(ThreadPool.THREAD_POOL_TYPES.keySet()); CollectionUtil.timSort(executors); // makes sure it's reproducible serviceA.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, (request, channel) -> { threadPool.getThreadContext().putTransient("boom", new Object()); threadPool.getThreadContext().addResponseHeader("foo.bar", "baz"); if ("fail".equals(request.info)) { throw new RuntimeException("boom"); } else { channel.sendResponse(TransportResponse.Empty.INSTANCE); } }); CountDownLatch latch = new CountDownLatch(2); TransportResponseHandler<TransportResponse> transportResponseHandler = new TransportResponseHandler<TransportResponse>() { @Override public TransportResponse newInstance() { return TransportResponse.Empty.INSTANCE; } @Override public void handleResponse(TransportResponse response) { try { assertSame(response, TransportResponse.Empty.INSTANCE); assertTrue(threadPool.getThreadContext().getResponseHeaders().containsKey("foo.bar")); assertEquals(1, threadPool.getThreadContext().getResponseHeaders().get("foo.bar").size()); assertEquals("baz", threadPool.getThreadContext().getResponseHeaders().get("foo.bar").get(0)); assertNull(threadPool.getThreadContext().getTransient("boom")); } finally { latch.countDown(); } } @Override public void handleException(TransportException exp) { try { assertTrue(threadPool.getThreadContext().getResponseHeaders().containsKey("foo.bar")); assertEquals(1, threadPool.getThreadContext().getResponseHeaders().get("foo.bar").size()); assertEquals("baz", threadPool.getThreadContext().getResponseHeaders().get("foo.bar").get(0)); assertNull(threadPool.getThreadContext().getTransient("boom")); } finally { latch.countDown(); } } @Override public String executor() { return randomFrom(executors); } }; serviceB.sendRequest(nodeA, "action", new TestRequest(randomFrom("fail", "pass")), transportResponseHandler); serviceA.sendRequest(nodeA, "action", new TestRequest(randomFrom("fail", "pass")), transportResponseHandler); latch.await(); } public void testHandlerIsInvokedOnConnectionClose() throws IOException, InterruptedException { List<String> executors = new ArrayList<>(ThreadPool.THREAD_POOL_TYPES.keySet()); CollectionUtil.timSort(executors); // makes sure it's reproducible TransportService serviceC = build(Settings.builder().put("name", "TS_TEST").build(), version0, null, true); serviceC.registerRequestHandler("action", TestRequest::new, ThreadPool.Names.SAME, (request, channel) -> { // do nothing }); serviceC.start(); serviceC.acceptIncomingRequests(); CountDownLatch latch = new CountDownLatch(1); TransportResponseHandler<TransportResponse> transportResponseHandler = new TransportResponseHandler<TransportResponse>() { @Override public TransportResponse newInstance() { return TransportResponse.Empty.INSTANCE; } @Override public void handleResponse(TransportResponse response) { try { fail("no response expected"); } finally { latch.countDown(); } } @Override public void handleException(TransportException exp) { try { assertTrue(exp.getClass().toString(), exp instanceof NodeDisconnectedException); } finally { latch.countDown(); } } @Override public String executor() { return randomFrom(executors); } }; ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); builder.addConnections(1, TransportRequestOptions.Type.BULK, TransportRequestOptions.Type.PING, TransportRequestOptions.Type.RECOVERY, TransportRequestOptions.Type.REG, TransportRequestOptions.Type.STATE); Transport.Connection connection = serviceB.openConnection(serviceC.getLocalNode(), builder.build()); serviceB.sendRequest(connection, "action", new TestRequest(randomFrom("fail", "pass")), TransportRequestOptions.EMPTY, transportResponseHandler); connection.close(); latch.await(); serviceC.close(); } }