/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.transport;
import org.elasticsearch.Version;
import org.elasticsearch.action.GenericAction;
import org.elasticsearch.action.admin.cluster.node.liveness.LivenessResponse;
import org.elasticsearch.action.admin.cluster.node.liveness.TransportLivenessAction;
import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
import org.elasticsearch.client.AbstractClientHeadersTestCase;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment;
import org.elasticsearch.plugins.NetworkPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.MockTransportClient;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportInterceptor;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
public class TransportClientHeadersTests extends AbstractClientHeadersTestCase {
private MockTransportService transportService;
@Override
public void tearDown() throws Exception {
try {
// stop this first before we bubble up since
// transportService uses the threadpool that super.tearDown will close
transportService.stop();
transportService.close();
} finally {
super.tearDown();
}
}
@Override
protected Client buildClient(Settings headersSettings, GenericAction[] testedActions) {
transportService = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null);
transportService.start();
transportService.acceptIncomingRequests();
TransportClient client = new MockTransportClient(Settings.builder()
.put("client.transport.sniff", false)
.put("cluster.name", "cluster1")
.put("node.name", "transport_client_" + this.getTestName())
.put(headersSettings)
.build(), InternalTransportServiceInterceptor.TestPlugin.class);
InternalTransportServiceInterceptor.TestPlugin plugin = client.injector.getInstance(PluginsService.class)
.filterPlugins(InternalTransportServiceInterceptor.TestPlugin.class).stream().findFirst().get();
plugin.instance.threadPool = client.threadPool();
plugin.instance.address = transportService.boundAddress().publishAddress();
client.addTransportAddress(transportService.boundAddress().publishAddress());
return client;
}
public void testWithSniffing() throws Exception {
try (TransportClient client = new MockTransportClient(
Settings.builder()
.put("client.transport.sniff", true)
.put("cluster.name", "cluster1")
.put("node.name", "transport_client_" + this.getTestName() + "_1")
.put("client.transport.nodes_sampler_interval", "1s")
.put(HEADER_SETTINGS)
.put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build(),
InternalTransportServiceInterceptor.TestPlugin.class)) {
InternalTransportServiceInterceptor.TestPlugin plugin = client.injector.getInstance(PluginsService.class)
.filterPlugins(InternalTransportServiceInterceptor.TestPlugin.class).stream().findFirst().get();
plugin.instance.threadPool = client.threadPool();
plugin.instance.address = transportService.boundAddress().publishAddress();
client.addTransportAddress(transportService.boundAddress().publishAddress());
if (!plugin.instance.clusterStateLatch.await(5, TimeUnit.SECONDS)) {
fail("takes way too long to get the cluster state");
}
assertEquals(1, client.connectedNodes().size());
assertEquals(client.connectedNodes().get(0).getAddress(), transportService.boundAddress().publishAddress());
}
}
public static class InternalTransportServiceInterceptor implements TransportInterceptor {
ThreadPool threadPool;
TransportAddress address;
public static class TestPlugin extends Plugin implements NetworkPlugin {
private InternalTransportServiceInterceptor instance = new InternalTransportServiceInterceptor();
@Override
public List<TransportInterceptor> getTransportInterceptors(NamedWriteableRegistry namedWriteableRegistry,
ThreadContext threadContext) {
return Collections.singletonList(new TransportInterceptor() {
@Override
public <T extends TransportRequest> TransportRequestHandler<T> interceptHandler(String action, String executor,
boolean forceExecution,
TransportRequestHandler<T> actualHandler) {
return instance.interceptHandler(action, executor, forceExecution, actualHandler);
}
@Override
public AsyncSender interceptSender(AsyncSender sender) {
return instance.interceptSender(sender);
}
});
}
}
final CountDownLatch clusterStateLatch = new CountDownLatch(1);
@Override
public AsyncSender interceptSender(AsyncSender sender) {
return new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(Transport.Connection connection, String action,
TransportRequest request,
TransportRequestOptions options,
TransportResponseHandler<T> handler) {
final ClusterName clusterName = new ClusterName("cluster1");
if (TransportLivenessAction.NAME.equals(action)) {
assertHeaders(threadPool);
((TransportResponseHandler<LivenessResponse>) handler).handleResponse(
new LivenessResponse(clusterName, connection.getNode()));
} else if (ClusterStateAction.NAME.equals(action)) {
assertHeaders(threadPool);
ClusterName cluster1 = clusterName;
ClusterState.Builder builder = ClusterState.builder(cluster1);
//the sniffer detects only data nodes
builder.nodes(DiscoveryNodes.builder().add(new DiscoveryNode("node_id", "someId", "some_ephemeralId_id",
address.address().getHostString(), address.getAddress(), address, Collections.emptyMap(),
Collections.singleton(DiscoveryNode.Role.DATA), Version.CURRENT)));
((TransportResponseHandler<ClusterStateResponse>) handler)
.handleResponse(new ClusterStateResponse(cluster1, builder.build(), 0L));
clusterStateLatch.countDown();
} else if (TransportService.HANDSHAKE_ACTION_NAME .equals(action)) {
((TransportResponseHandler<TransportService.HandshakeResponse>) handler).handleResponse(
new TransportService.HandshakeResponse(connection.getNode(), clusterName, connection.getNode().getVersion()));
} else {
handler.handleException(new TransportException("", new InternalException(action)));
}
}
};
}
}
}