/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport;
import com.carrotsearch.hppc.IntHashSet;
import com.carrotsearch.hppc.IntSet;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.Supplier;
import org.apache.lucene.util.IOUtils;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.NotifyOnceListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.CheckedBiConsumer;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.compress.Compressor;
import org.elasticsearch.common.compress.CompressorFactory;
import org.elasticsearch.common.compress.NotCompressedException;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.common.network.NetworkAddress;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.PortsRange;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.AbstractLifecycleRunnable;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.KeyedLock;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.monitor.jvm.JvmInfo;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import java.io.IOException;
import java.io.StreamCorruptedException;
import java.net.BindException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.channels.CancelledKeyException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static java.util.Collections.unmodifiableMap;
import static org.elasticsearch.common.settings.Setting.boolSetting;
import static org.elasticsearch.common.settings.Setting.intSetting;
import static org.elasticsearch.common.settings.Setting.timeSetting;
import static org.elasticsearch.common.transport.NetworkExceptionHelper.isCloseConnectionException;
import static org.elasticsearch.common.transport.NetworkExceptionHelper.isConnectException;
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
public abstract class TcpTransport<Channel> extends AbstractLifecycleComponent implements Transport {
public static final String TRANSPORT_SERVER_WORKER_THREAD_NAME_PREFIX = "transport_server_worker";
public static final String TRANSPORT_CLIENT_BOSS_THREAD_NAME_PREFIX = "transport_client_boss";
// the scheduled internal ping interval setting, defaults to disabled (-1)
public static final Setting<TimeValue> PING_SCHEDULE =
timeSetting("transport.ping_schedule", TimeValue.timeValueSeconds(-1), Setting.Property.NodeScope);
public static final Setting<Integer> CONNECTIONS_PER_NODE_RECOVERY =
intSetting("transport.connections_per_node.recovery", 2, 1, Setting.Property.NodeScope);
public static final Setting<Integer> CONNECTIONS_PER_NODE_BULK =
intSetting("transport.connections_per_node.bulk", 3, 1, Setting.Property.NodeScope);
public static final Setting<Integer> CONNECTIONS_PER_NODE_REG =
intSetting("transport.connections_per_node.reg", 6, 1, Setting.Property.NodeScope);
public static final Setting<Integer> CONNECTIONS_PER_NODE_STATE =
intSetting("transport.connections_per_node.state", 1, 1, Setting.Property.NodeScope);
public static final Setting<Integer> CONNECTIONS_PER_NODE_PING =
intSetting("transport.connections_per_node.ping", 1, 1, Setting.Property.NodeScope);
public static final Setting<TimeValue> TCP_CONNECT_TIMEOUT =
timeSetting("transport.tcp.connect_timeout", NetworkService.TcpSettings.TCP_CONNECT_TIMEOUT, Setting.Property.NodeScope);
public static final Setting<Boolean> TCP_NO_DELAY =
boolSetting("transport.tcp_no_delay", NetworkService.TcpSettings.TCP_NO_DELAY, Setting.Property.NodeScope);
public static final Setting<Boolean> TCP_KEEP_ALIVE =
boolSetting("transport.tcp.keep_alive", NetworkService.TcpSettings.TCP_KEEP_ALIVE, Setting.Property.NodeScope);
public static final Setting<Boolean> TCP_REUSE_ADDRESS =
boolSetting("transport.tcp.reuse_address", NetworkService.TcpSettings.TCP_REUSE_ADDRESS, Setting.Property.NodeScope);
public static final Setting<ByteSizeValue> TCP_SEND_BUFFER_SIZE =
Setting.byteSizeSetting("transport.tcp.send_buffer_size", NetworkService.TcpSettings.TCP_SEND_BUFFER_SIZE,
Setting.Property.NodeScope);
public static final Setting<ByteSizeValue> TCP_RECEIVE_BUFFER_SIZE =
Setting.byteSizeSetting("transport.tcp.receive_buffer_size", NetworkService.TcpSettings.TCP_RECEIVE_BUFFER_SIZE,
Setting.Property.NodeScope);
private static final long NINETY_PER_HEAP_SIZE = (long) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.9);
private static final int PING_DATA_SIZE = -1;
private final CircuitBreakerService circuitBreakerService;
// package visibility for tests
protected final ScheduledPing scheduledPing;
private final TimeValue pingSchedule;
protected final ThreadPool threadPool;
private final BigArrays bigArrays;
protected final NetworkService networkService;
protected volatile TransportServiceAdapter transportServiceAdapter;
// node id to actual channel
protected final ConcurrentMap<DiscoveryNode, NodeChannels> connectedNodes = newConcurrentMap();
protected final Map<String, List<Channel>> serverChannels = newConcurrentMap();
protected final ConcurrentMap<String, BoundTransportAddress> profileBoundAddresses = newConcurrentMap();
protected final KeyedLock<String> connectionLock = new KeyedLock<>();
private final NamedWriteableRegistry namedWriteableRegistry;
// this lock is here to make sure we close this transport and disconnect all the client nodes
// connections while no connect operations is going on... (this might help with 100% CPU when stopping the transport?)
protected final ReadWriteLock globalLock = new ReentrantReadWriteLock();
protected final boolean compress;
protected volatile BoundTransportAddress boundAddress;
private final String transportName;
protected final ConnectionProfile defaultConnectionProfile;
private final ConcurrentMap<Long, HandshakeResponseHandler> pendingHandshakes = new ConcurrentHashMap<>();
private final AtomicLong requestIdGenerator = new AtomicLong();
private final CounterMetric numHandshakes = new CounterMetric();
private static final String HANDSHAKE_ACTION_NAME = "internal:tcp/handshake";
public TcpTransport(String transportName, Settings settings, ThreadPool threadPool, BigArrays bigArrays,
CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry,
NetworkService networkService) {
super(settings);
this.threadPool = threadPool;
this.bigArrays = bigArrays;
this.circuitBreakerService = circuitBreakerService;
this.scheduledPing = new ScheduledPing();
this.pingSchedule = PING_SCHEDULE.get(settings);
this.namedWriteableRegistry = namedWriteableRegistry;
this.compress = Transport.TRANSPORT_TCP_COMPRESS.get(settings);
this.networkService = networkService;
this.transportName = transportName;
defaultConnectionProfile = buildDefaultConnectionProfile(settings);
}
static ConnectionProfile buildDefaultConnectionProfile(Settings settings) {
int connectionsPerNodeRecovery = CONNECTIONS_PER_NODE_RECOVERY.get(settings);
int connectionsPerNodeBulk = CONNECTIONS_PER_NODE_BULK.get(settings);
int connectionsPerNodeReg = CONNECTIONS_PER_NODE_REG.get(settings);
int connectionsPerNodeState = CONNECTIONS_PER_NODE_STATE.get(settings);
int connectionsPerNodePing = CONNECTIONS_PER_NODE_PING.get(settings);
ConnectionProfile.Builder builder = new ConnectionProfile.Builder();
builder.setConnectTimeout(TCP_CONNECT_TIMEOUT.get(settings));
builder.setHandshakeTimeout(TCP_CONNECT_TIMEOUT.get(settings));
builder.addConnections(connectionsPerNodeBulk, TransportRequestOptions.Type.BULK);
builder.addConnections(connectionsPerNodePing, TransportRequestOptions.Type.PING);
// if we are not master eligible we don't need a dedicated channel to publish the state
builder.addConnections(DiscoveryNode.isMasterNode(settings) ? connectionsPerNodeState : 0, TransportRequestOptions.Type.STATE);
// if we are not a data-node we don't need any dedicated channels for recovery
builder.addConnections(DiscoveryNode.isDataNode(settings) ? connectionsPerNodeRecovery : 0, TransportRequestOptions.Type.RECOVERY);
builder.addConnections(connectionsPerNodeReg, TransportRequestOptions.Type.REG);
return builder.build();
}
@Override
protected void doStart() {
if (pingSchedule.millis() > 0) {
threadPool.schedule(pingSchedule, ThreadPool.Names.GENERIC, scheduledPing);
}
}
@Override
public CircuitBreaker getInFlightRequestBreaker() {
// We always obtain a fresh breaker to reflect changes to the breaker configuration.
return circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);
}
@Override
public void transportServiceAdapter(TransportServiceAdapter service) {
if (service.getRequestHandler(HANDSHAKE_ACTION_NAME) != null) {
throw new IllegalStateException(HANDSHAKE_ACTION_NAME + " is a reserved request handler and must not be registered");
}
this.transportServiceAdapter = service;
}
private static class HandshakeResponseHandler<Channel> implements TransportResponseHandler<VersionHandshakeResponse> {
final AtomicReference<Version> versionRef = new AtomicReference<>();
final CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<Exception> exceptionRef = new AtomicReference<>();
final Channel channel;
HandshakeResponseHandler(Channel channel) {
this.channel = channel;
}
@Override
public VersionHandshakeResponse newInstance() {
return new VersionHandshakeResponse();
}
@Override
public void handleResponse(VersionHandshakeResponse response) {
final boolean success = versionRef.compareAndSet(null, response.version);
latch.countDown();
assert success;
}
@Override
public void handleException(TransportException exp) {
final boolean success = exceptionRef.compareAndSet(null, exp);
latch.countDown();
assert success;
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
}
public class ScheduledPing extends AbstractLifecycleRunnable {
/**
* The magic number (must be lower than 0) for a ping message. This is handled
* specifically in {@link TcpTransport#validateMessageHeader}.
*/
private final BytesReference pingHeader;
final CounterMetric successfulPings = new CounterMetric();
final CounterMetric failedPings = new CounterMetric();
public ScheduledPing() {
super(lifecycle, logger);
try (BytesStreamOutput out = new BytesStreamOutput()) {
out.writeByte((byte) 'E');
out.writeByte((byte) 'S');
out.writeInt(PING_DATA_SIZE);
pingHeader = out.bytes();
} catch (IOException e) {
throw new IllegalStateException(e.getMessage(), e); // won't happen
}
}
@Override
protected void doRunInLifecycle() throws Exception {
for (Map.Entry<DiscoveryNode, NodeChannels> entry : connectedNodes.entrySet()) {
DiscoveryNode node = entry.getKey();
NodeChannels channels = entry.getValue();
for (Channel channel : channels.getChannels()) {
internalSendMessage(channel, pingHeader, new NotifyOnceListener<Channel>() {
@Override
public void innerOnResponse(Channel channel) {
successfulPings.inc();
}
@Override
public void innerOnFailure(Exception e) {
if (isOpen(channel)) {
logger.debug(
(Supplier<?>) () -> new ParameterizedMessage("[{}] failed to send ping transport message", node), e);
failedPings.inc();
} else {
logger.trace((Supplier<?>) () ->
new ParameterizedMessage("[{}] failed to send ping transport message (channel closed)", node), e);
}
}
});
}
}
}
public long getSuccessfulPings() {
return successfulPings.count();
}
public long getFailedPings() {
return failedPings.count();
}
@Override
protected void onAfterInLifecycle() {
try {
threadPool.schedule(pingSchedule, ThreadPool.Names.GENERIC, this);
} catch (EsRejectedExecutionException ex) {
if (ex.isExecutorShutdown()) {
logger.debug("couldn't schedule new ping execution, executor is shutting down", ex);
} else {
throw ex;
}
}
}
@Override
public void onFailure(Exception e) {
if (lifecycle.stoppedOrClosed()) {
logger.trace("failed to send ping transport message", e);
} else {
logger.warn("failed to send ping transport message", e);
}
}
}
public final class NodeChannels implements Connection {
private final Map<TransportRequestOptions.Type, ConnectionProfile.ConnectionTypeHandle> typeMapping;
private final Channel[] channels;
private final DiscoveryNode node;
private final AtomicBoolean closed = new AtomicBoolean(false);
private final Version version;
private final Consumer<Connection> onClose;
public NodeChannels(DiscoveryNode node, Channel[] channels, ConnectionProfile connectionProfile, Consumer<Connection> onClose) {
this.node = node;
this.channels = channels;
assert channels.length == connectionProfile.getNumConnections() : "expected channels size to be == "
+ connectionProfile.getNumConnections() + " but was: [" + channels.length + "]";
typeMapping = new EnumMap<>(TransportRequestOptions.Type.class);
for (ConnectionProfile.ConnectionTypeHandle handle : connectionProfile.getHandles()) {
for (TransportRequestOptions.Type type : handle.getTypes())
typeMapping.put(type, handle);
}
version = node.getVersion();
this.onClose = onClose;
}
NodeChannels(NodeChannels channels, Version handshakeVersion) {
this.node = channels.node;
this.channels = channels.channels;
this.typeMapping = channels.typeMapping;
this.version = handshakeVersion;
this.onClose = channels.onClose;
}
@Override
public Version getVersion() {
return version;
}
public boolean hasChannel(Channel channel) {
for (Channel channel1 : channels) {
if (channel.equals(channel1)) {
return true;
}
}
return false;
}
public List<Channel> getChannels() {
return Arrays.asList(channels);
}
public Channel channel(TransportRequestOptions.Type type) {
ConnectionProfile.ConnectionTypeHandle connectionTypeHandle = typeMapping.get(type);
if (connectionTypeHandle == null) {
throw new IllegalArgumentException("no type channel for [" + type + "]");
}
return connectionTypeHandle.getChannel(channels);
}
@Override
public synchronized void close() throws IOException {
if (closed.compareAndSet(false, true)) {
try {
closeChannels(Arrays.stream(channels).filter(Objects::nonNull).collect(Collectors.toList()));
} finally {
onClose.accept(this);
}
}
}
@Override
public DiscoveryNode getNode() {
return this.node;
}
@Override
public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options)
throws IOException, TransportException {
if (closed.get()) {
throw new NodeNotConnectedException(node, "connection already closed");
}
Channel channel = channel(options.type());
sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), (byte) 0);
}
}
@Override
public boolean nodeConnected(DiscoveryNode node) {
return connectedNodes.containsKey(node);
}
@Override
public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile,
CheckedBiConsumer<Connection, ConnectionProfile, IOException> connectionValidator)
throws ConnectTransportException {
connectionProfile = resolveConnectionProfile(connectionProfile, defaultConnectionProfile);
if (node == null) {
throw new ConnectTransportException(null, "can't connect to a null node");
}
globalLock.readLock().lock(); // ensure we don't open connections while we are closing
try {
ensureOpen();
try (Releasable ignored = connectionLock.acquire(node.getId())) {
NodeChannels nodeChannels = connectedNodes.get(node);
if (nodeChannels != null) {
return;
}
try {
try {
nodeChannels = openConnection(node, connectionProfile);
connectionValidator.accept(nodeChannels, connectionProfile);
} catch (Exception e) {
logger.trace(
(Supplier<?>) () -> new ParameterizedMessage(
"failed to connect to [{}], cleaning dangling connections", node), e);
IOUtils.closeWhileHandlingException(nodeChannels);
throw e;
}
// we acquire a connection lock, so no way there is an existing connection
connectedNodes.put(node, nodeChannels);
if (logger.isDebugEnabled()) {
logger.debug("connected to node [{}]", node);
}
transportServiceAdapter.onNodeConnected(node);
} catch (ConnectTransportException e) {
throw e;
} catch (Exception e) {
throw new ConnectTransportException(node, "general node connection failure", e);
}
}
} finally {
globalLock.readLock().unlock();
}
}
/**
* takes a {@link ConnectionProfile} that have been passed as a parameter to the public methods
* and resolves it to a fully specified (i.e., no nulls) profile
*/
static ConnectionProfile resolveConnectionProfile(@Nullable ConnectionProfile connectionProfile,
ConnectionProfile defaultConnectionProfile) {
Objects.requireNonNull(defaultConnectionProfile);
if (connectionProfile == null) {
return defaultConnectionProfile;
} else if (connectionProfile.getConnectTimeout() != null && connectionProfile.getHandshakeTimeout() != null) {
return connectionProfile;
} else {
ConnectionProfile.Builder builder = new ConnectionProfile.Builder(connectionProfile);
if (connectionProfile.getConnectTimeout() == null) {
builder.setConnectTimeout(defaultConnectionProfile.getConnectTimeout());
}
if (connectionProfile.getHandshakeTimeout() == null) {
builder.setHandshakeTimeout(defaultConnectionProfile.getHandshakeTimeout());
}
return builder.build();
}
}
@Override
public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile connectionProfile) throws IOException {
if (node == null) {
throw new ConnectTransportException(null, "can't open connection to a null node");
}
boolean success = false;
NodeChannels nodeChannels = null;
connectionProfile = resolveConnectionProfile(connectionProfile, defaultConnectionProfile);
globalLock.readLock().lock(); // ensure we don't open connections while we are closing
try {
ensureOpen();
try {
nodeChannels = connectToChannels(node, connectionProfile);
final Channel channel = nodeChannels.getChannels().get(0); // one channel is guaranteed by the connection profile
final TimeValue connectTimeout = connectionProfile.getConnectTimeout() == null ?
defaultConnectionProfile.getConnectTimeout() :
connectionProfile.getConnectTimeout();
final TimeValue handshakeTimeout = connectionProfile.getHandshakeTimeout() == null ?
connectTimeout : connectionProfile.getHandshakeTimeout();
final Version version = executeHandshake(node, channel, handshakeTimeout);
transportServiceAdapter.onConnectionOpened(nodeChannels);
nodeChannels = new NodeChannels(nodeChannels, version); // clone the channels - we now have the correct version
success = true;
return nodeChannels;
} catch (ConnectTransportException e) {
throw e;
} catch (Exception e) {
// ConnectTransportExceptions are handled specifically on the caller end - we wrap the actual exception to ensure
// only relevant exceptions are logged on the caller end.. this is the same as in connectToNode
throw new ConnectTransportException(node, "general node connection failure", e);
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(nodeChannels);
}
}
} finally {
globalLock.readLock().unlock();
}
}
/**
* Disconnects from a node, only if the relevant channel is found to be part of the node channels.
*/
protected boolean disconnectFromNode(DiscoveryNode node, Channel channel, String reason) {
// this might be called multiple times from all the node channels, so do a lightweight
// check outside of the lock
NodeChannels nodeChannels = connectedNodes.get(node);
if (nodeChannels != null && nodeChannels.hasChannel(channel)) {
try (Releasable ignored = connectionLock.acquire(node.getId())) {
nodeChannels = connectedNodes.get(node);
// check again within the connection lock, if its still applicable to remove it
if (nodeChannels != null && nodeChannels.hasChannel(channel)) {
connectedNodes.remove(node);
closeAndNotify(node, nodeChannels, reason);
return true;
}
}
}
return false;
}
private void closeAndNotify(DiscoveryNode node, NodeChannels nodeChannels, String reason) {
try {
logger.debug("disconnecting from [{}], {}", node, reason);
IOUtils.closeWhileHandlingException(nodeChannels);
} finally {
logger.trace("disconnected from [{}], {}", node, reason);
transportServiceAdapter.onNodeDisconnected(node);
}
}
/**
* Disconnects from a node if a channel is found as part of that nodes channels.
*/
protected final void disconnectFromNodeChannel(final Channel channel, final Exception failure) {
threadPool.generic().execute(() -> {
try {
try {
closeChannels(Collections.singletonList(channel));
} finally {
for (DiscoveryNode node : connectedNodes.keySet()) {
if (disconnectFromNode(node, channel, ExceptionsHelper.detailedMessage(failure))) {
// if we managed to find this channel and disconnect from it, then break, no need to check on
// the rest of the nodes
break;
}
}
}
} catch (IOException e) {
logger.warn("failed to close channel", e);
} finally {
onChannelClosed(channel);
}
});
}
@Override
public NodeChannels getConnection(DiscoveryNode node) {
NodeChannels nodeChannels = connectedNodes.get(node);
if (nodeChannels == null) {
throw new NodeNotConnectedException(node, "Node not connected");
}
return nodeChannels;
}
@Override
public void disconnectFromNode(DiscoveryNode node) {
try (Releasable ignored = connectionLock.acquire(node.getId())) {
NodeChannels nodeChannels = connectedNodes.remove(node);
if (nodeChannels != null) {
closeAndNotify(node, nodeChannels, "due to explicit disconnect call");
}
}
}
protected Version getCurrentVersion() {
// this is just for tests to mock stuff like the nodes version - tests can override this internally
return Version.CURRENT;
}
@Override
public BoundTransportAddress boundAddress() {
return this.boundAddress;
}
@Override
public Map<String, BoundTransportAddress> profileBoundAddresses() {
return unmodifiableMap(new HashMap<>(profileBoundAddresses));
}
protected Map<String, Settings> buildProfileSettings() {
// extract default profile first and create standard bootstrap
Map<String, Settings> profiles = TransportSettings.TRANSPORT_PROFILES_SETTING.get(settings).getAsGroups(true);
if (!profiles.containsKey(TransportSettings.DEFAULT_PROFILE)) {
profiles = new HashMap<>(profiles);
profiles.put(TransportSettings.DEFAULT_PROFILE, Settings.EMPTY);
}
Settings defaultSettings = profiles.get(TransportSettings.DEFAULT_PROFILE);
Map<String, Settings> result = new HashMap<>();
// loop through all profiles and start them up, special handling for default one
for (Map.Entry<String, Settings> entry : profiles.entrySet()) {
Settings profileSettings = entry.getValue();
String name = entry.getKey();
if (!Strings.hasLength(name)) {
logger.info("transport profile configured without a name. skipping profile with settings [{}]",
profileSettings.toDelimitedString(','));
continue;
} else if (TransportSettings.DEFAULT_PROFILE.equals(name)) {
profileSettings = Settings.builder()
.put(profileSettings)
.put("port", profileSettings.get("port", TransportSettings.PORT.get(this.settings)))
.build();
} else if (profileSettings.get("port") == null) {
// if profile does not have a port, skip it
logger.info("No port configured for profile [{}], not binding", name);
continue;
}
Settings mergedSettings = Settings.builder()
.put(defaultSettings)
.put(profileSettings)
.build();
result.put(name, mergedSettings);
}
return result;
}
@Override
public List<String> getLocalAddresses() {
List<String> local = new ArrayList<>();
local.add("127.0.0.1");
// check if v6 is supported, if so, v4 will also work via mapped addresses.
if (NetworkUtils.SUPPORTS_V6) {
local.add("[::1]"); // may get ports appended!
}
return local;
}
protected void bindServer(final String name, final Settings settings) {
// Bind and start to accept incoming connections.
InetAddress hostAddresses[];
String bindHosts[] = settings.getAsArray("bind_host", null);
try {
hostAddresses = networkService.resolveBindHostAddresses(bindHosts);
} catch (IOException e) {
throw new BindTransportException("Failed to resolve host " + Arrays.toString(bindHosts), e);
}
if (logger.isDebugEnabled()) {
String[] addresses = new String[hostAddresses.length];
for (int i = 0; i < hostAddresses.length; i++) {
addresses[i] = NetworkAddress.format(hostAddresses[i]);
}
logger.debug("binding server bootstrap to: {}", (Object) addresses);
}
assert hostAddresses.length > 0;
List<InetSocketAddress> boundAddresses = new ArrayList<>();
for (InetAddress hostAddress : hostAddresses) {
boundAddresses.add(bindToPort(name, hostAddress, settings.get("port")));
}
final BoundTransportAddress boundTransportAddress = createBoundTransportAddress(name, settings, boundAddresses);
if (TransportSettings.DEFAULT_PROFILE.equals(name)) {
this.boundAddress = boundTransportAddress;
} else {
profileBoundAddresses.put(name, boundTransportAddress);
}
}
protected InetSocketAddress bindToPort(final String name, final InetAddress hostAddress, String port) {
PortsRange portsRange = new PortsRange(port);
final AtomicReference<Exception> lastException = new AtomicReference<>();
final AtomicReference<InetSocketAddress> boundSocket = new AtomicReference<>();
boolean success = portsRange.iterate(portNumber -> {
try {
Channel channel = bind(name, new InetSocketAddress(hostAddress, portNumber));
synchronized (serverChannels) {
List<Channel> list = serverChannels.get(name);
if (list == null) {
list = new ArrayList<>();
serverChannels.put(name, list);
}
list.add(channel);
boundSocket.set(getLocalAddress(channel));
}
} catch (Exception e) {
lastException.set(e);
return false;
}
return true;
});
if (!success) {
throw new BindTransportException("Failed to bind to [" + port + "]", lastException.get());
}
if (logger.isDebugEnabled()) {
logger.debug("Bound profile [{}] to address {{}}", name, NetworkAddress.format(boundSocket.get()));
}
return boundSocket.get();
}
private BoundTransportAddress createBoundTransportAddress(String name, Settings profileSettings,
List<InetSocketAddress> boundAddresses) {
String[] boundAddressesHostStrings = new String[boundAddresses.size()];
TransportAddress[] transportBoundAddresses = new TransportAddress[boundAddresses.size()];
for (int i = 0; i < boundAddresses.size(); i++) {
InetSocketAddress boundAddress = boundAddresses.get(i);
boundAddressesHostStrings[i] = boundAddress.getHostString();
transportBoundAddresses[i] = new TransportAddress(boundAddress);
}
final String[] publishHosts;
if (TransportSettings.DEFAULT_PROFILE.equals(name)) {
publishHosts = TransportSettings.PUBLISH_HOST.get(settings).toArray(Strings.EMPTY_ARRAY);
} else {
publishHosts = profileSettings.getAsArray("publish_host", boundAddressesHostStrings);
}
final InetAddress publishInetAddress;
try {
publishInetAddress = networkService.resolvePublishHostAddresses(publishHosts);
} catch (Exception e) {
throw new BindTransportException("Failed to resolve publish address", e);
}
final int publishPort = resolvePublishPort(name, settings, profileSettings, boundAddresses, publishInetAddress);
final TransportAddress publishAddress = new TransportAddress(new InetSocketAddress(publishInetAddress, publishPort));
return new BoundTransportAddress(transportBoundAddresses, publishAddress);
}
// package private for tests
public static int resolvePublishPort(String profileName, Settings settings, Settings profileSettings,
List<InetSocketAddress> boundAddresses, InetAddress publishInetAddress) {
int publishPort;
if (TransportSettings.DEFAULT_PROFILE.equals(profileName)) {
publishPort = TransportSettings.PUBLISH_PORT.get(settings);
} else {
publishPort = profileSettings.getAsInt("publish_port", -1);
}
// if port not explicitly provided, search for port of address in boundAddresses that matches publishInetAddress
if (publishPort < 0) {
for (InetSocketAddress boundAddress : boundAddresses) {
InetAddress boundInetAddress = boundAddress.getAddress();
if (boundInetAddress.isAnyLocalAddress() || boundInetAddress.equals(publishInetAddress)) {
publishPort = boundAddress.getPort();
break;
}
}
}
// if no matching boundAddress found, check if there is a unique port for all bound addresses
if (publishPort < 0) {
final IntSet ports = new IntHashSet();
for (InetSocketAddress boundAddress : boundAddresses) {
ports.add(boundAddress.getPort());
}
if (ports.size() == 1) {
publishPort = ports.iterator().next().value;
}
}
if (publishPort < 0) {
String profileExplanation = TransportSettings.DEFAULT_PROFILE.equals(profileName) ? "" : " for profile " + profileName;
throw new BindTransportException("Failed to auto-resolve publish port" + profileExplanation + ", multiple bound addresses " +
boundAddresses + " with distinct ports and none of them matched the publish address (" + publishInetAddress + "). " +
"Please specify a unique port by setting " + TransportSettings.PORT.getKey() + " or " +
TransportSettings.PUBLISH_PORT.getKey());
}
return publishPort;
}
@Override
public TransportAddress[] addressesFromString(String address, int perAddressLimit) throws UnknownHostException {
return parse(address, settings.get("transport.profiles.default.port", TransportSettings.PORT.get(settings)), perAddressLimit);
}
// this code is a take on guava's HostAndPort, like a HostAndPortRange
// pattern for validating ipv6 bracket addresses.
// not perfect, but PortsRange should take care of any port range validation, not a regex
private static final Pattern BRACKET_PATTERN = Pattern.compile("^\\[(.*:.*)\\](?::([\\d\\-]*))?$");
/** parse a hostname+port range spec into its equivalent addresses */
static TransportAddress[] parse(String hostPortString, String defaultPortRange, int perAddressLimit) throws UnknownHostException {
Objects.requireNonNull(hostPortString);
String host;
String portString = null;
if (hostPortString.startsWith("[")) {
// Parse a bracketed host, typically an IPv6 literal.
Matcher matcher = BRACKET_PATTERN.matcher(hostPortString);
if (!matcher.matches()) {
throw new IllegalArgumentException("Invalid bracketed host/port range: " + hostPortString);
}
host = matcher.group(1);
portString = matcher.group(2); // could be null
} else {
int colonPos = hostPortString.indexOf(':');
if (colonPos >= 0 && hostPortString.indexOf(':', colonPos + 1) == -1) {
// Exactly 1 colon. Split into host:port.
host = hostPortString.substring(0, colonPos);
portString = hostPortString.substring(colonPos + 1);
} else {
// 0 or 2+ colons. Bare hostname or IPv6 literal.
host = hostPortString;
// 2+ colons and not bracketed: exception
if (colonPos >= 0) {
throw new IllegalArgumentException("IPv6 addresses must be bracketed: " + hostPortString);
}
}
}
// if port isn't specified, fill with the default
if (portString == null || portString.isEmpty()) {
portString = defaultPortRange;
}
// generate address for each port in the range
Set<InetAddress> addresses = new HashSet<>(Arrays.asList(InetAddress.getAllByName(host)));
List<TransportAddress> transportAddresses = new ArrayList<>();
int[] ports = new PortsRange(portString).ports();
int limit = Math.min(ports.length, perAddressLimit);
for (int i = 0; i < limit; i++) {
for (InetAddress address : addresses) {
transportAddresses.add(new TransportAddress(address, ports[i]));
}
}
return transportAddresses.toArray(new TransportAddress[transportAddresses.size()]);
}
@Override
protected final void doClose() {
}
@Override
protected final void doStop() {
final CountDownLatch latch = new CountDownLatch(1);
// make sure we run it on another thread than a possible IO handler thread
threadPool.generic().execute(() -> {
globalLock.writeLock().lock();
try {
// first stop to accept any incoming connections so nobody can connect to this transport
for (Map.Entry<String, List<Channel>> entry : serverChannels.entrySet()) {
try {
closeChannels(entry.getValue());
} catch (Exception e) {
logger.debug(
(Supplier<?>) () -> new ParameterizedMessage(
"Error closing serverChannel for profile [{}]", entry.getKey()), e);
}
}
for (Iterator<NodeChannels> it = connectedNodes.values().iterator(); it.hasNext();) {
NodeChannels nodeChannels = it.next();
it.remove();
IOUtils.closeWhileHandlingException(nodeChannels);
}
stopInternal();
} finally {
globalLock.writeLock().unlock();
latch.countDown();
}
});
try {
latch.await(30, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
// ignore
}
}
protected void onException(Channel channel, Exception e) {
if (!lifecycle.started()) {
// just close and ignore - we are already stopped and just need to make sure we release all resources
disconnectFromNodeChannel(channel, e);
return;
}
if (isCloseConnectionException(e)) {
logger.trace(
(Supplier<?>) () -> new ParameterizedMessage(
"close connection exception caught on transport layer [{}], disconnecting from relevant node",
channel),
e);
// close the channel, which will cause a node to be disconnected if relevant
disconnectFromNodeChannel(channel, e);
} else if (isConnectException(e)) {
logger.trace((Supplier<?>) () -> new ParameterizedMessage("connect exception caught on transport layer [{}]", channel), e);
// close the channel as safe measure, which will cause a node to be disconnected if relevant
disconnectFromNodeChannel(channel, e);
} else if (e instanceof BindException) {
logger.trace((Supplier<?>) () -> new ParameterizedMessage("bind exception caught on transport layer [{}]", channel), e);
// close the channel as safe measure, which will cause a node to be disconnected if relevant
disconnectFromNodeChannel(channel, e);
} else if (e instanceof CancelledKeyException) {
logger.trace(
(Supplier<?>) () -> new ParameterizedMessage(
"cancelled key exception caught on transport layer [{}], disconnecting from relevant node",
channel),
e);
// close the channel as safe measure, which will cause a node to be disconnected if relevant
disconnectFromNodeChannel(channel, e);
} else if (e instanceof TcpTransport.HttpOnTransportException) {
// in case we are able to return data, serialize the exception content and sent it back to the client
if (isOpen(channel)) {
final NotifyOnceListener<Channel> closeChannel = new NotifyOnceListener<Channel>() {
@Override
public void innerOnResponse(Channel channel) {
try {
closeChannels(Collections.singletonList(channel));
} catch (IOException e1) {
logger.debug("failed to close httpOnTransport channel", e1);
}
}
@Override
public void innerOnFailure(Exception e) {
try {
closeChannels(Collections.singletonList(channel));
} catch (IOException e1) {
e.addSuppressed(e1);
logger.debug("failed to close httpOnTransport channel", e1);
}
}
};
internalSendMessage(channel, new BytesArray(e.getMessage().getBytes(StandardCharsets.UTF_8)), closeChannel);
}
} else {
logger.warn(
(Supplier<?>) () -> new ParameterizedMessage("exception caught on transport layer [{}], closing connection", channel), e);
// close the channel, which will cause a node to be disconnected if relevant
disconnectFromNodeChannel(channel, e);
}
}
/**
* Returns the channels local address
*/
protected abstract InetSocketAddress getLocalAddress(Channel channel);
/**
* Binds to the given {@link InetSocketAddress}
*
* @param name the profile name
* @param address the address to bind to
*/
protected abstract Channel bind(String name, InetSocketAddress address) throws IOException;
/**
* Closes all channels in this list
*/
protected abstract void closeChannels(List<Channel> channel) throws IOException;
/**
* Sends message to channel. The listener's onResponse method will be called when the send is complete unless an exception
* is thrown during the send. If an exception is thrown, the listener's onException method will be called.
* @param channel the destination channel
* @param reference the byte reference for the message
* @param listener the listener to call when the operation has completed
*/
protected abstract void sendMessage(Channel channel, BytesReference reference, ActionListener<Channel> listener);
protected abstract NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile connectionProfile) throws IOException;
/**
* Called to tear down internal resources
*/
protected void stopInternal() {}
public boolean canCompress(TransportRequest request) {
return compress && (!(request instanceof BytesTransportRequest));
}
private void sendRequestToChannel(final DiscoveryNode node, final Channel targetChannel, final long requestId, final String action,
final TransportRequest request, TransportRequestOptions options, Version channelVersion,
byte status) throws IOException,
TransportException {
if (compress) {
options = TransportRequestOptions.builder(options).withCompress(true).build();
}
status = TransportStatus.setRequest(status);
ReleasableBytesStreamOutput bStream = new ReleasableBytesStreamOutput(bigArrays);
boolean addedReleaseListener = false;
StreamOutput stream = Streams.flushOnCloseStream(bStream);
try {
// only compress if asked, and, the request is not bytes, since then only
// the header part is compressed, and the "body" can't be extracted as compressed
if (options.compress() && canCompress(request)) {
status = TransportStatus.setCompress(status);
stream = CompressorFactory.COMPRESSOR.streamOutput(stream);
}
// we pick the smallest of the 2, to support both backward and forward compatibility
// note, this is the only place we need to do this, since from here on, we use the serialized version
// as the version to use also when the node receiving this request will send the response with
Version version = Version.min(getCurrentVersion(), channelVersion);
stream.setVersion(version);
threadPool.getThreadContext().writeTo(stream);
stream.writeString(action);
BytesReference message = buildMessage(requestId, status, node.getVersion(), request, stream, bStream);
final TransportRequestOptions finalOptions = options;
final StreamOutput finalStream = stream;
// this might be called in a different thread
SendListener onRequestSent = new SendListener(
() -> IOUtils.closeWhileHandlingException(finalStream, bStream),
() -> transportServiceAdapter.onRequestSent(node, requestId, action, request, finalOptions));
internalSendMessage(targetChannel, message, onRequestSent);
addedReleaseListener = true;
} finally {
if (!addedReleaseListener) {
IOUtils.close(stream, bStream);
}
}
}
/**
* sends a message to the given channel, using the given callbacks.
*/
private void internalSendMessage(Channel targetChannel, BytesReference message, NotifyOnceListener<Channel> listener) {
try {
sendMessage(targetChannel, message, listener);
} catch (Exception ex) {
// call listener to ensure that any resources are released
listener.onFailure(ex);
onException(targetChannel, ex);
}
}
/**
* Sends back an error response to the caller via the given channel
*
* @param nodeVersion the caller node version
* @param channel the channel to send the response to
* @param error the error to return
* @param requestId the request ID this response replies to
* @param action the action this response replies to
*/
public void sendErrorResponse(Version nodeVersion, Channel channel, final Exception error, final long requestId,
final String action) throws IOException {
try (BytesStreamOutput stream = new BytesStreamOutput()) {
stream.setVersion(nodeVersion);
RemoteTransportException tx = new RemoteTransportException(
nodeName(), new TransportAddress(getLocalAddress(channel)), action, error);
threadPool.getThreadContext().writeTo(stream);
stream.writeException(tx);
byte status = 0;
status = TransportStatus.setResponse(status);
status = TransportStatus.setError(status);
final BytesReference bytes = stream.bytes();
final BytesReference header = buildHeader(requestId, status, nodeVersion, bytes.length());
SendListener onResponseSent = new SendListener(null,
() -> transportServiceAdapter.onResponseSent(requestId, action, error));
internalSendMessage(channel, new CompositeBytesReference(header, bytes), onResponseSent);
}
}
/**
* Sends the response to the given channel. This method should be used to send {@link TransportResponse} objects back to the caller.
*
* @see #sendErrorResponse(Version, Object, Exception, long, String) for sending back errors to the caller
*/
public void sendResponse(Version nodeVersion, Channel channel, final TransportResponse response, final long requestId,
final String action, TransportResponseOptions options) throws IOException {
sendResponse(nodeVersion, channel, response, requestId, action, options, (byte) 0);
}
private void sendResponse(Version nodeVersion, Channel channel, final TransportResponse response, final long requestId,
final String action, TransportResponseOptions options, byte status) throws IOException {
if (compress) {
options = TransportResponseOptions.builder(options).withCompress(true).build();
}
status = TransportStatus.setResponse(status); // TODO share some code with sendRequest
ReleasableBytesStreamOutput bStream = new ReleasableBytesStreamOutput(bigArrays);
boolean addedReleaseListener = false;
StreamOutput stream = Streams.flushOnCloseStream(bStream);
try {
if (options.compress()) {
status = TransportStatus.setCompress(status);
stream = CompressorFactory.COMPRESSOR.streamOutput(stream);
}
threadPool.getThreadContext().writeTo(stream);
stream.setVersion(nodeVersion);
BytesReference reference = buildMessage(requestId, status, nodeVersion, response, stream, bStream);
final TransportResponseOptions finalOptions = options;
final StreamOutput finalStream = stream;
// this might be called in a different thread
SendListener listener = new SendListener(() -> IOUtils.closeWhileHandlingException(finalStream, bStream),
() -> transportServiceAdapter.onResponseSent(requestId, action, response, finalOptions));
internalSendMessage(channel, reference, listener);
addedReleaseListener = true;
} finally {
if (!addedReleaseListener) {
IOUtils.close(stream, bStream);
}
}
}
/**
* Writes the Tcp message header into a bytes reference.
*
* @param requestId the request ID
* @param status the request status
* @param protocolVersion the protocol version used to serialize the data in the message
* @param length the payload length in bytes
* @see TcpHeader
*/
final BytesReference buildHeader(long requestId, byte status, Version protocolVersion, int length) throws IOException {
try (BytesStreamOutput headerOutput = new BytesStreamOutput(TcpHeader.HEADER_SIZE)) {
headerOutput.setVersion(protocolVersion);
TcpHeader.writeHeader(headerOutput, requestId, status, protocolVersion, length);
final BytesReference bytes = headerOutput.bytes();
assert bytes.length() == TcpHeader.HEADER_SIZE : "header size mismatch expected: " + TcpHeader.HEADER_SIZE + " but was: "
+ bytes.length();
return bytes;
}
}
/**
* Serializes the given message into a bytes representation
*/
private BytesReference buildMessage(long requestId, byte status, Version nodeVersion, TransportMessage message, StreamOutput stream,
ReleasableBytesStreamOutput writtenBytes) throws IOException {
final BytesReference zeroCopyBuffer;
if (message instanceof BytesTransportRequest) { // what a shitty optimization - we should use a direct send method instead
BytesTransportRequest bRequest = (BytesTransportRequest) message;
assert nodeVersion.equals(bRequest.version());
bRequest.writeThin(stream);
zeroCopyBuffer = bRequest.bytes;
} else {
message.writeTo(stream);
zeroCopyBuffer = BytesArray.EMPTY;
}
// we have to close the stream here - flush is not enough since we might be compressing the content
// and if we do that the close method will write some marker bytes (EOS marker) and otherwise
// we barf on the decompressing end when we read past EOF on purpose in the #validateRequest method.
// this might be a problem in deflate after all but it's important to close it for now.
stream.close();
final BytesReference messageBody = writtenBytes.bytes();
final BytesReference header = buildHeader(requestId, status, stream.getVersion(), messageBody.length() + zeroCopyBuffer.length());
return new CompositeBytesReference(header, messageBody, zeroCopyBuffer);
}
/**
* Validates the first N bytes of the message header and returns <code>false</code> if the message is
* a ping message and has no payload ie. isn't a real user level message.
*
* @throws IllegalStateException if the message is too short, less than the header or less that the header plus the message size
* @throws HttpOnTransportException if the message has no valid header and appears to be a HTTP message
* @throws IllegalArgumentException if the message is greater that the maximum allowed frame size. This is dependent on the available
* memory.
*/
public static boolean validateMessageHeader(BytesReference buffer) throws IOException {
final int sizeHeaderLength = TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
if (buffer.length() < sizeHeaderLength) {
throw new IllegalStateException("message size must be >= to the header size");
}
int offset = 0;
if (buffer.get(offset) != 'E' || buffer.get(offset + 1) != 'S') {
// special handling for what is probably HTTP
if (bufferStartsWith(buffer, offset, "GET ") ||
bufferStartsWith(buffer, offset, "POST ") ||
bufferStartsWith(buffer, offset, "PUT ") ||
bufferStartsWith(buffer, offset, "HEAD ") ||
bufferStartsWith(buffer, offset, "DELETE ") ||
bufferStartsWith(buffer, offset, "OPTIONS ") ||
bufferStartsWith(buffer, offset, "PATCH ") ||
bufferStartsWith(buffer, offset, "TRACE ")) {
throw new HttpOnTransportException("This is not a HTTP port");
}
// we have 6 readable bytes, show 4 (should be enough)
throw new StreamCorruptedException("invalid internal transport message format, got ("
+ Integer.toHexString(buffer.get(offset) & 0xFF) + ","
+ Integer.toHexString(buffer.get(offset + 1) & 0xFF) + ","
+ Integer.toHexString(buffer.get(offset + 2) & 0xFF) + ","
+ Integer.toHexString(buffer.get(offset + 3) & 0xFF) + ")");
}
final int dataLen;
try (StreamInput input = buffer.streamInput()) {
input.skip(TcpHeader.MARKER_BYTES_SIZE);
dataLen = input.readInt();
if (dataLen == PING_DATA_SIZE) {
// discard the messages we read and continue, this is achieved by skipping the bytes
// and returning null
return false;
}
}
if (dataLen <= 0) {
throw new StreamCorruptedException("invalid data length: " + dataLen);
}
// safety against too large frames being sent
if (dataLen > NINETY_PER_HEAP_SIZE) {
throw new IllegalArgumentException("transport content length received [" + new ByteSizeValue(dataLen) + "] exceeded ["
+ new ByteSizeValue(NINETY_PER_HEAP_SIZE) + "]");
}
if (buffer.length() < dataLen + sizeHeaderLength) {
throw new IllegalStateException("buffer must be >= to the message size but wasn't");
}
return true;
}
private static boolean bufferStartsWith(BytesReference buffer, int offset, String method) {
char[] chars = method.toCharArray();
for (int i = 0; i < chars.length; i++) {
if (buffer.get(offset + i) != chars[i]) {
return false;
}
}
return true;
}
/**
* A helper exception to mark an incoming connection as potentially being HTTP
* so an appropriate error code can be returned
*/
public static class HttpOnTransportException extends ElasticsearchException {
public HttpOnTransportException(String msg) {
super(msg);
}
@Override
public RestStatus status() {
return RestStatus.BAD_REQUEST;
}
public HttpOnTransportException(StreamInput in) throws IOException {
super(in);
}
}
protected abstract boolean isOpen(Channel channel);
/**
* This method handles the message receive part for both request and responses
*/
public final void messageReceived(BytesReference reference, Channel channel, String profileName,
InetSocketAddress remoteAddress, int messageLengthBytes) throws IOException {
final int totalMessageSize = messageLengthBytes + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
transportServiceAdapter.addBytesReceived(totalMessageSize);
// we have additional bytes to read, outside of the header
boolean hasMessageBytesToRead = (totalMessageSize - TcpHeader.HEADER_SIZE) > 0;
StreamInput streamIn = reference.streamInput();
boolean success = false;
try (ThreadContext.StoredContext tCtx = threadPool.getThreadContext().stashContext()) {
long requestId = streamIn.readLong();
byte status = streamIn.readByte();
Version version = Version.fromId(streamIn.readInt());
if (TransportStatus.isCompress(status) && hasMessageBytesToRead && streamIn.available() > 0) {
Compressor compressor;
try {
final int bytesConsumed = TcpHeader.REQUEST_ID_SIZE + TcpHeader.STATUS_SIZE + TcpHeader.VERSION_ID_SIZE;
compressor = CompressorFactory.compressor(reference.slice(bytesConsumed, reference.length() - bytesConsumed));
} catch (NotCompressedException ex) {
int maxToRead = Math.min(reference.length(), 10);
StringBuilder sb = new StringBuilder("stream marked as compressed, but no compressor found, first [").append(maxToRead)
.append("] content bytes out of [").append(reference.length())
.append("] readable bytes with message size [").append(messageLengthBytes).append("] ").append("] are [");
for (int i = 0; i < maxToRead; i++) {
sb.append(reference.get(i)).append(",");
}
sb.append("]");
throw new IllegalStateException(sb.toString());
}
streamIn = compressor.streamInput(streamIn);
}
if (version.isCompatible(getCurrentVersion()) == false) {
throw new IllegalStateException("Received message from unsupported version: [" + version
+ "] minimal compatible version is: [" + getCurrentVersion().minimumCompatibilityVersion() + "]");
}
streamIn = new NamedWriteableAwareStreamInput(streamIn, namedWriteableRegistry);
streamIn.setVersion(version);
threadPool.getThreadContext().readHeaders(streamIn);
if (TransportStatus.isRequest(status)) {
handleRequest(channel, profileName, streamIn, requestId, messageLengthBytes, version, remoteAddress, status);
} else {
final TransportResponseHandler<?> handler;
if (TransportStatus.isHandshake(status)) {
handler = pendingHandshakes.remove(requestId);
} else {
TransportResponseHandler theHandler = transportServiceAdapter.onResponseReceived(requestId);
if (theHandler == null && TransportStatus.isError(status)) {
handler = pendingHandshakes.remove(requestId);
} else {
handler = theHandler;
}
}
// ignore if its null, the adapter logs it
if (handler != null) {
if (TransportStatus.isError(status)) {
handlerResponseError(streamIn, handler);
} else {
handleResponse(remoteAddress, streamIn, handler);
}
// Check the entire message has been read
final int nextByte = streamIn.read();
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
if (nextByte != -1) {
throw new IllegalStateException("Message not fully read (response) for requestId [" + requestId + "], handler ["
+ handler + "], error [" + TransportStatus.isError(status) + "]; resetting");
}
}
}
success = true;
} finally {
if (success) {
IOUtils.close(streamIn);
} else {
IOUtils.closeWhileHandlingException(streamIn);
}
}
}
private void handleResponse(InetSocketAddress remoteAddress, final StreamInput stream, final TransportResponseHandler handler) {
final TransportResponse response = handler.newInstance();
response.remoteAddress(new TransportAddress(remoteAddress));
try {
response.readFrom(stream);
} catch (Exception e) {
handleException(handler, new TransportSerializationException(
"Failed to deserialize response of type [" + response.getClass().getName() + "]", e));
return;
}
threadPool.executor(handler.executor()).execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
handleException(handler, new ResponseHandlerFailureTransportException(e));
}
@Override
protected void doRun() throws Exception {
handler.handleResponse(response);
}
});
}
/**
* Executed for a received response error
*/
private void handlerResponseError(StreamInput stream, final TransportResponseHandler handler) {
Exception error;
try {
error = stream.readException();
} catch (Exception e) {
error = new TransportSerializationException("Failed to deserialize exception response from stream", e);
}
handleException(handler, error);
}
private void handleException(final TransportResponseHandler handler, Throwable error) {
if (!(error instanceof RemoteTransportException)) {
error = new RemoteTransportException(error.getMessage(), error);
}
final RemoteTransportException rtx = (RemoteTransportException) error;
threadPool.executor(handler.executor()).execute(() -> {
try {
handler.handleException(rtx);
} catch (Exception e) {
logger.error((Supplier<?>) () -> new ParameterizedMessage("failed to handle exception response [{}]", handler), e);
}
});
}
protected String handleRequest(Channel channel, String profileName, final StreamInput stream, long requestId, int messageLengthBytes,
Version version, InetSocketAddress remoteAddress, byte status) throws IOException {
final String action = stream.readString();
transportServiceAdapter.onRequestReceived(requestId, action);
TransportChannel transportChannel = null;
try {
if (TransportStatus.isHandshake(status)) {
final VersionHandshakeResponse response = new VersionHandshakeResponse(getCurrentVersion());
sendResponse(version, channel, response, requestId, HANDSHAKE_ACTION_NAME, TransportResponseOptions.EMPTY,
TransportStatus.setHandshake((byte) 0));
} else {
final RequestHandlerRegistry reg = transportServiceAdapter.getRequestHandler(action);
if (reg == null) {
throw new ActionNotFoundTransportException(action);
}
if (reg.canTripCircuitBreaker()) {
getInFlightRequestBreaker().addEstimateBytesAndMaybeBreak(messageLengthBytes, "<transport_request>");
} else {
getInFlightRequestBreaker().addWithoutBreaking(messageLengthBytes);
}
transportChannel = new TcpTransportChannel<>(this, channel, transportName, action, requestId, version, profileName,
messageLengthBytes);
final TransportRequest request = reg.newRequest();
request.remoteAddress(new TransportAddress(remoteAddress));
request.readFrom(stream);
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
validateRequest(stream, requestId, action);
threadPool.executor(reg.getExecutor()).execute(new RequestHandler(reg, request, transportChannel));
}
} catch (Exception e) {
// the circuit breaker tripped
if (transportChannel == null) {
transportChannel = new TcpTransportChannel<>(this, channel, transportName, action, requestId, version, profileName, 0);
}
try {
transportChannel.sendResponse(e);
} catch (IOException inner) {
inner.addSuppressed(e);
logger.warn(
(Supplier<?>) () -> new ParameterizedMessage(
"Failed to send error message back to client for action [{}]", action), inner);
}
}
return action;
}
// This template method is needed to inject custom error checking logic in tests.
protected void validateRequest(StreamInput stream, long requestId, String action) throws IOException {
final int nextByte = stream.read();
// calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
if (nextByte != -1) {
throw new IllegalStateException("Message not fully read (request) for requestId [" + requestId + "], action [" + action
+ "], available [" + stream.available() + "]; resetting");
}
}
class RequestHandler extends AbstractRunnable {
private final RequestHandlerRegistry reg;
private final TransportRequest request;
private final TransportChannel transportChannel;
RequestHandler(RequestHandlerRegistry reg, TransportRequest request, TransportChannel transportChannel) {
this.reg = reg;
this.request = request;
this.transportChannel = transportChannel;
}
@SuppressWarnings({"unchecked"})
@Override
protected void doRun() throws Exception {
reg.processMessageReceived(request, transportChannel);
}
@Override
public boolean isForceExecution() {
return reg.isForceExecution();
}
@Override
public void onFailure(Exception e) {
if (lifecycleState() == Lifecycle.State.STARTED) {
// we can only send a response transport is started....
try {
transportChannel.sendResponse(e);
} catch (Exception inner) {
inner.addSuppressed(e);
logger.warn(
(Supplier<?>) () -> new ParameterizedMessage(
"Failed to send error message back to client for action [{}]", reg.getAction()), inner);
}
}
}
}
private static final class VersionHandshakeResponse extends TransportResponse {
private Version version;
private VersionHandshakeResponse(Version version) {
this.version = version;
}
private VersionHandshakeResponse() {}
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
version = Version.readVersion(in);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
assert version != null;
Version.writeVersion(version, out);
}
}
protected Version executeHandshake(DiscoveryNode node, Channel channel, TimeValue timeout) throws IOException, InterruptedException {
numHandshakes.inc();
final long requestId = newRequestId();
final HandshakeResponseHandler handler = new HandshakeResponseHandler(channel);
AtomicReference<Version> versionRef = handler.versionRef;
AtomicReference<Exception> exceptionRef = handler.exceptionRef;
pendingHandshakes.put(requestId, handler);
boolean success = false;
try {
if (isOpen(channel) == false) {
// we have to protect us here since sendRequestToChannel won't barf if the channel is closed.
// it's weird but to change it will cause a lot of impact on the exception handling code all over the codebase.
// yet, if we don't check the state here we might have registered a pending handshake handler but the close
// listener calling #onChannelClosed might have already run and we are waiting on the latch below unitl we time out.
throw new IllegalStateException("handshake failed, channel already closed");
}
// for the request we use the minCompatVersion since we don't know what's the version of the node we talk to
// we also have no payload on the request but the response will contain the actual version of the node we talk
// to as the payload.
final Version minCompatVersion = getCurrentVersion().minimumCompatibilityVersion();
sendRequestToChannel(node, channel, requestId, HANDSHAKE_ACTION_NAME, TransportRequest.Empty.INSTANCE,
TransportRequestOptions.EMPTY, minCompatVersion, TransportStatus.setHandshake((byte) 0));
if (handler.latch.await(timeout.millis(), TimeUnit.MILLISECONDS) == false) {
throw new ConnectTransportException(node, "handshake_timeout[" + timeout + "]");
}
success = true;
if (exceptionRef.get() != null) {
throw new IllegalStateException("handshake failed", exceptionRef.get());
} else {
Version version = versionRef.get();
if (getCurrentVersion().isCompatible(version) == false) {
throw new IllegalStateException("Received message from unsupported version: [" + version
+ "] minimal compatible version is: [" + getCurrentVersion().minimumCompatibilityVersion() + "]");
}
return version;
}
} finally {
final TransportResponseHandler<?> removedHandler = pendingHandshakes.remove(requestId);
// in the case of a timeout or an exception on the send part the handshake has not been removed yet.
// but the timeout is tricky since it's basically a race condition so we only assert on the success case.
assert success && removedHandler == null || success == false : "handler for requestId [" + requestId + "] is not been removed";
}
}
final int getNumPendingHandshakes() { // for testing
return pendingHandshakes.size();
}
final long getNumHandshakes() {
return numHandshakes.count(); // for testing
}
@Override
public long newRequestId() {
return requestIdGenerator.incrementAndGet();
}
/**
* Called once the channel is closed for instance due to a disconnect or a closed socket etc.
*/
protected final void onChannelClosed(Channel channel) {
final Optional<Long> first = pendingHandshakes.entrySet().stream()
.filter((entry) -> entry.getValue().channel == channel).map((e) -> e.getKey()).findFirst();
if (first.isPresent()) {
final Long requestId = first.get();
final HandshakeResponseHandler handler = pendingHandshakes.remove(requestId);
if (handler != null) {
// there might be a race removing this or this method might be called twice concurrently depending on how
// the channel is closed ie. due to connection reset or broken pipes
handler.handleException(new TransportException("connection reset"));
}
}
}
/**
* Ensures this transport is still started / open
*
* @throws IllegalStateException if the transport is not started / open
*/
protected final void ensureOpen() {
if (lifecycle.started() == false) {
throw new IllegalStateException("transport has been stopped");
}
}
private final class SendListener extends NotifyOnceListener<Channel> {
private final Releasable optionalReleasable;
private final Runnable transportAdaptorCallback;
private SendListener(Releasable optionalReleasable, Runnable transportAdaptorCallback) {
this.optionalReleasable = optionalReleasable;
this.transportAdaptorCallback = transportAdaptorCallback;
}
@Override
public void innerOnResponse(Channel channel) {
release();
}
@Override
public void innerOnFailure(Exception e) {
release();
}
private void release() {
Releasables.close(optionalReleasable, transportAdaptorCallback::run);
}
}
}