/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.remote.client.socket;
import static org.apache.nifi.remote.util.EventReportUtil.error;
import static org.apache.nifi.remote.util.EventReportUtil.warn;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.nio.channels.SocketChannel;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLContext;
import org.apache.nifi.events.EventReporter;
import org.apache.nifi.remote.Peer;
import org.apache.nifi.remote.PeerDescription;
import org.apache.nifi.remote.PeerStatus;
import org.apache.nifi.remote.RemoteDestination;
import org.apache.nifi.remote.RemoteResourceInitiator;
import org.apache.nifi.remote.TransferDirection;
import org.apache.nifi.remote.client.PeerSelector;
import org.apache.nifi.remote.client.PeerStatusProvider;
import org.apache.nifi.remote.client.SiteInfoProvider;
import org.apache.nifi.remote.client.SiteToSiteClientConfig;
import org.apache.nifi.remote.codec.FlowFileCodec;
import org.apache.nifi.remote.exception.HandshakeException;
import org.apache.nifi.remote.exception.PortNotRunningException;
import org.apache.nifi.remote.exception.TransmissionDisabledException;
import org.apache.nifi.remote.exception.UnknownPortException;
import org.apache.nifi.remote.io.socket.SocketChannelCommunicationsSession;
import org.apache.nifi.remote.io.socket.ssl.SSLSocketChannel;
import org.apache.nifi.remote.io.socket.ssl.SSLSocketChannelCommunicationsSession;
import org.apache.nifi.remote.protocol.CommunicationsSession;
import org.apache.nifi.remote.protocol.socket.SocketClientProtocol;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class EndpointConnectionPool implements PeerStatusProvider {
private static final Logger logger = LoggerFactory.getLogger(EndpointConnectionPool.class);
private final ConcurrentMap<PeerDescription, BlockingQueue<EndpointConnection>> connectionQueueMap = new ConcurrentHashMap<>();
private final Set<EndpointConnection> activeConnections = Collections.synchronizedSet(new HashSet<>());
private final EventReporter eventReporter;
private final SSLContext sslContext;
private final ScheduledExecutorService taskExecutor;
private final int idleExpirationMillis;
private final RemoteDestination remoteDestination;
private volatile int commsTimeout;
private volatile boolean shutdown = false;
private final SiteInfoProvider siteInfoProvider;
private final PeerSelector peerSelector;
private final InetAddress localAddress;
public EndpointConnectionPool(final RemoteDestination remoteDestination, final int commsTimeoutMillis, final int idleExpirationMillis,
final SSLContext sslContext, final EventReporter eventReporter, final File persistenceFile, final SiteInfoProvider siteInfoProvider,
final InetAddress localAddress) {
Objects.requireNonNull(remoteDestination, "Remote Destination/Port Identifier cannot be null");
this.remoteDestination = remoteDestination;
this.sslContext = sslContext;
this.eventReporter = eventReporter;
this.commsTimeout = commsTimeoutMillis;
this.idleExpirationMillis = idleExpirationMillis;
this.localAddress = localAddress;
this.siteInfoProvider = siteInfoProvider;
peerSelector = new PeerSelector(this, persistenceFile);
peerSelector.setEventReporter(eventReporter);
// Initialize a scheduled executor and run some maintenance tasks in the background to kill off old, unused
// connections and keep our list of peers up-to-date.
taskExecutor = Executors.newScheduledThreadPool(1, new ThreadFactory() {
private final ThreadFactory defaultFactory = Executors.defaultThreadFactory();
@Override
public Thread newThread(final Runnable r) {
final Thread thread = defaultFactory.newThread(r);
thread.setName("NiFi Site-to-Site Connection Pool Maintenance");
thread.setDaemon(true);
return thread;
}
});
taskExecutor.scheduleWithFixedDelay(new Runnable() {
@Override
public void run() {
peerSelector.refreshPeers();
}
}, 0, 5, TimeUnit.SECONDS);
taskExecutor.scheduleWithFixedDelay(new Runnable() {
@Override
public void run() {
cleanupExpiredSockets();
}
}, 5, 5, TimeUnit.SECONDS);
}
private String getPortIdentifier(final TransferDirection transferDirection) throws IOException {
if (remoteDestination.getIdentifier() != null) {
return remoteDestination.getIdentifier();
}
return siteInfoProvider.getPortIdentifier(remoteDestination.getName(), transferDirection);
}
public EndpointConnection getEndpointConnection(final TransferDirection direction) throws IOException {
return getEndpointConnection(direction, null);
}
public EndpointConnection getEndpointConnection(final TransferDirection direction, final SiteToSiteClientConfig config) throws IOException {
//
// Attempt to get a connection state that already exists for this URL.
//
FlowFileCodec codec = null;
CommunicationsSession commsSession = null;
SocketClientProtocol protocol = null;
EndpointConnection connection;
Peer peer = null;
URI clusterUrl = siteInfoProvider.getActiveClusterUrl();
do {
final List<EndpointConnection> addBack = new ArrayList<>();
logger.debug("{} getting next peer status", this);
final PeerStatus peerStatus = peerSelector.getNextPeerStatus(direction);
logger.debug("{} next peer status = {}", this, peerStatus);
if (peerStatus == null) {
return null;
}
final PeerDescription peerDescription = peerStatus.getPeerDescription();
BlockingQueue<EndpointConnection> connectionQueue = connectionQueueMap.get(peerDescription);
if (connectionQueue == null) {
connectionQueue = new LinkedBlockingQueue<>();
BlockingQueue<EndpointConnection> existing = connectionQueueMap.putIfAbsent(peerDescription, connectionQueue);
if (existing != null) {
connectionQueue = existing;
}
}
try {
connection = connectionQueue.poll();
logger.debug("{} Connection State for {} = {}", this, clusterUrl, connection);
final String portId = getPortIdentifier(direction);
if (connection == null && !addBack.isEmpty()) {
// all available connections have been penalized.
logger.debug("{} all Connections for {} are penalized; returning no Connection", this, portId);
return null;
}
if (connection != null && connection.getPeer().isPenalized(portId)) {
// we have a connection, but it's penalized. We want to add it back to the queue
// when we've found one to use.
addBack.add(connection);
continue;
}
// if we can't get an existing Connection, create one
if (connection == null) {
logger.debug("{} No Connection available for Port {}; creating new Connection", this, portId);
protocol = new SocketClientProtocol();
protocol.setDestination(new IdEnrichedRemoteDestination(remoteDestination, portId));
protocol.setEventReporter(eventReporter);
final long penalizationMillis = remoteDestination.getYieldPeriod(TimeUnit.MILLISECONDS);
try {
logger.debug("{} Establishing site-to-site connection with {}", this, peerStatus);
commsSession = establishSiteToSiteConnection(peerStatus);
} catch (final IOException ioe) {
peerSelector.penalize(peerStatus.getPeerDescription(), penalizationMillis);
throw ioe;
}
final DataInputStream dis = new DataInputStream(commsSession.getInput().getInputStream());
final DataOutputStream dos = new DataOutputStream(commsSession.getOutput().getOutputStream());
try {
logger.debug("{} Negotiating protocol", this);
RemoteResourceInitiator.initiateResourceNegotiation(protocol, dis, dos);
} catch (final HandshakeException e) {
try {
commsSession.close();
} catch (final IOException ioe) {
throw e;
}
}
final String peerUrl = "nifi://" + peerDescription.getHostname() + ":" + peerDescription.getPort();
peer = new Peer(peerDescription, commsSession, peerUrl, clusterUrl.toString());
// set properties based on config
if (config != null) {
protocol.setTimeout((int) config.getTimeout(TimeUnit.MILLISECONDS));
protocol.setPreferredBatchCount(config.getPreferredBatchCount());
protocol.setPreferredBatchSize(config.getPreferredBatchSize());
protocol.setPreferredBatchDuration(config.getPreferredBatchDuration(TimeUnit.MILLISECONDS));
}
// perform handshake
try {
logger.debug("{} performing handshake", this);
protocol.handshake(peer);
// handle error cases
if (protocol.isDestinationFull()) {
logger.warn("{} {} indicates that port {}'s destination is full; penalizing peer",
this, peer, config.getPortName() == null ? config.getPortIdentifier() : config.getPortName());
peerSelector.penalize(peer, penalizationMillis);
try {
peer.close();
} catch (final IOException ioe) {
}
continue;
} else if (protocol.isPortInvalid()) {
peerSelector.penalize(peer, penalizationMillis);
cleanup(protocol, peer);
throw new PortNotRunningException(peer.toString() + " indicates that port " + portId + " is not running");
} else if (protocol.isPortUnknown()) {
peerSelector.penalize(peer, penalizationMillis);
cleanup(protocol, peer);
throw new UnknownPortException(peer.toString() + " indicates that port " + portId + " is not known");
}
// negotiate the FlowFileCodec to use
logger.debug("{} negotiating codec", this);
codec = protocol.negotiateCodec(peer);
logger.debug("{} negotiated codec is {}", this, codec);
} catch (final PortNotRunningException | UnknownPortException e) {
throw e;
} catch (final Exception e) {
peerSelector.penalize(peer, penalizationMillis);
cleanup(protocol, peer);
final String message = String.format("%s failed to communicate with %s due to %s", this, peer == null ? clusterUrl : peer, e.toString());
error(logger, eventReporter, message);
if (logger.isDebugEnabled()) {
logger.error("", e);
}
throw e;
}
connection = new EndpointConnection(peer, protocol, codec);
} else {
final long lastTimeUsed = connection.getLastTimeUsed();
final long millisSinceLastUse = System.currentTimeMillis() - lastTimeUsed;
if (commsTimeout > 0L && millisSinceLastUse >= commsTimeout) {
cleanup(connection.getSocketClientProtocol(), connection.getPeer());
connection = null;
} else {
codec = connection.getCodec();
peer = connection.getPeer();
commsSession = peer.getCommunicationsSession();
protocol = connection.getSocketClientProtocol();
}
}
} catch (final Throwable t) {
if (commsSession != null) {
try {
commsSession.close();
} catch (final IOException ioe) {
}
}
throw t;
} finally {
if (!addBack.isEmpty()) {
connectionQueue.addAll(addBack);
addBack.clear();
}
}
} while (connection == null || codec == null || commsSession == null || protocol == null);
activeConnections.add(connection);
return connection;
}
public boolean offer(final EndpointConnection endpointConnection) {
final Peer peer = endpointConnection.getPeer();
if (peer == null) {
return false;
}
final BlockingQueue<EndpointConnection> connectionQueue = connectionQueueMap.get(peer.getDescription());
if (connectionQueue == null) {
return false;
}
activeConnections.remove(endpointConnection);
if (shutdown) {
terminate(endpointConnection);
return false;
} else {
endpointConnection.setLastTimeUsed();
return connectionQueue.offer(endpointConnection);
}
}
private void cleanup(final SocketClientProtocol protocol, final Peer peer) {
if (protocol != null && peer != null) {
try {
protocol.shutdown(peer);
} catch (final TransmissionDisabledException e) {
// User disabled transmission.... do nothing.
logger.debug(this + " Transmission Disabled by User");
} catch (IOException e1) {
}
}
if (peer != null) {
try {
peer.close();
} catch (final TransmissionDisabledException e) {
// User disabled transmission.... do nothing.
logger.debug(this + " Transmission Disabled by User");
} catch (IOException e1) {
}
}
}
@Override
public PeerDescription getBootstrapPeerDescription() throws IOException {
final String hostname = siteInfoProvider.getActiveClusterUrl().getHost();
final Integer port = siteInfoProvider.getSiteToSitePort();
if (port == null) {
throw new IOException("Remote instance of NiFi is not configured to allow RAW Socket site-to-site communications");
}
final boolean secure = siteInfoProvider.isSecure();
return new PeerDescription(hostname, port, secure);
}
@Override
public Set<PeerStatus> fetchRemotePeerStatuses(final PeerDescription peerDescription) throws IOException {
final String hostname = peerDescription.getHostname();
final int port = peerDescription.getPort();
final URI clusterUrl = siteInfoProvider.getActiveClusterUrl();
final PeerDescription clusterPeerDescription = new PeerDescription(hostname, port, clusterUrl.toString().startsWith("https://"));
final CommunicationsSession commsSession = establishSiteToSiteConnection(hostname, port);
final Peer peer = new Peer(clusterPeerDescription, commsSession, "nifi://" + hostname + ":" + port, clusterUrl.toString());
final SocketClientProtocol clientProtocol = new SocketClientProtocol();
final DataInputStream dis = new DataInputStream(commsSession.getInput().getInputStream());
final DataOutputStream dos = new DataOutputStream(commsSession.getOutput().getOutputStream());
RemoteResourceInitiator.initiateResourceNegotiation(clientProtocol, dis, dos);
clientProtocol.setTimeout(commsTimeout);
if (clientProtocol.getVersionNegotiator().getVersion() < 5) {
String portId = getPortIdentifier(TransferDirection.RECEIVE);
if (portId == null) {
portId = getPortIdentifier(TransferDirection.SEND);
}
if (portId == null) {
peer.close();
throw new IOException("Failed to determine the identifier of port " + remoteDestination.getName());
}
clientProtocol.handshake(peer, portId);
} else {
clientProtocol.handshake(peer, null);
}
final Set<PeerStatus> peerStatuses = clientProtocol.getPeerStatuses(peer);
try {
clientProtocol.shutdown(peer);
} catch (final IOException e) {
final String message = String.format("%s Failed to shutdown protocol when updating list of peers due to %s", this, e.toString());
warn(logger, eventReporter, message);
if (logger.isDebugEnabled()) {
logger.warn("", e);
}
}
try {
peer.close();
} catch (final IOException e) {
final String message = String.format("%s Failed to close resources when updating list of peers due to %s", this, e.toString());
warn(logger, eventReporter, message);
if (logger.isDebugEnabled()) {
logger.warn("", e);
}
}
return peerStatuses;
}
private CommunicationsSession establishSiteToSiteConnection(final PeerStatus peerStatus) throws IOException {
final PeerDescription description = peerStatus.getPeerDescription();
return establishSiteToSiteConnection(description.getHostname(), description.getPort());
}
private CommunicationsSession establishSiteToSiteConnection(final String hostname, final int port) throws IOException {
final boolean siteToSiteSecure = siteInfoProvider.isSecure();
CommunicationsSession commsSession = null;
try {
if (siteToSiteSecure) {
if (sslContext == null) {
throw new IOException("Unable to communicate with " + hostname + ":" + port
+ " because it requires Secure Site-to-Site communications, but this instance is not configured for secure communications");
}
final SSLSocketChannel socketChannel = new SSLSocketChannel(sslContext, hostname, port, localAddress, true);
socketChannel.connect();
commsSession = new SSLSocketChannelCommunicationsSession(socketChannel);
try {
commsSession.setUserDn(socketChannel.getDn());
} catch (final CertificateException ex) {
throw new IOException(ex);
}
} else {
final SocketChannel socketChannel = SocketChannel.open();
if (localAddress != null) {
final SocketAddress localSocketAddress = new InetSocketAddress(localAddress, 0);
socketChannel.socket().bind(localSocketAddress);
}
socketChannel.socket().connect(new InetSocketAddress(hostname, port), commsTimeout);
socketChannel.socket().setSoTimeout(commsTimeout);
commsSession = new SocketChannelCommunicationsSession(socketChannel);
}
commsSession.getOutput().getOutputStream().write(CommunicationsSession.MAGIC_BYTES);
} catch (final IOException ioe) {
if (commsSession != null) {
commsSession.close();
}
throw ioe;
}
return commsSession;
}
private void cleanupExpiredSockets() {
for (final BlockingQueue<EndpointConnection> connectionQueue : connectionQueueMap.values()) {
final List<EndpointConnection> connections = new ArrayList<>();
EndpointConnection connection;
while ((connection = connectionQueue.poll()) != null) {
// If the socket has not been used in 10 seconds, shut it down.
final long lastUsed = connection.getLastTimeUsed();
if (lastUsed < System.currentTimeMillis() - idleExpirationMillis) {
try {
connection.getSocketClientProtocol().shutdown(connection.getPeer());
} catch (final Exception e) {
logger.debug("Failed to shut down {} using {} due to {}",
connection.getSocketClientProtocol(), connection.getPeer(), e);
}
terminate(connection);
} else {
connections.add(connection);
}
}
connectionQueue.addAll(connections);
}
}
public void shutdown() {
shutdown = true;
taskExecutor.shutdown();
peerSelector.clear();
for (final EndpointConnection conn : activeConnections) {
conn.getPeer().getCommunicationsSession().interrupt();
}
for (final BlockingQueue<EndpointConnection> connectionQueue : connectionQueueMap.values()) {
EndpointConnection state;
while ((state = connectionQueue.poll()) != null) {
terminate(state);
}
}
}
public void terminate(final EndpointConnection connection) {
activeConnections.remove(connection);
cleanup(connection.getSocketClientProtocol(), connection.getPeer());
}
@Override
public String toString() {
return "EndpointConnectionPool[Cluster URL=" + siteInfoProvider.getClusterUrls() + "]";
}
private class IdEnrichedRemoteDestination implements RemoteDestination {
private final RemoteDestination original;
private final String identifier;
public IdEnrichedRemoteDestination(final RemoteDestination original, final String identifier) {
this.original = original;
this.identifier = identifier;
}
@Override
public String getIdentifier() {
return identifier;
}
@Override
public String getName() {
return original.getName();
}
@Override
public long getYieldPeriod(final TimeUnit timeUnit) {
return original.getYieldPeriod(timeUnit);
}
@Override
public boolean isUseCompression() {
return original.isUseCompression();
}
}
}