package org.infinispan.client.hotrod.impl.transport.tcp; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import javax.net.ssl.SSLContext; import org.apache.commons.pool.KeyedObjectPool; import org.apache.commons.pool.impl.GenericKeyedObjectPool; import org.infinispan.client.hotrod.CacheTopologyInfo; import org.infinispan.client.hotrod.RemoteCacheManager; import org.infinispan.client.hotrod.configuration.Configuration; import org.infinispan.client.hotrod.configuration.ServerConfiguration; import org.infinispan.client.hotrod.configuration.SslConfiguration; import org.infinispan.client.hotrod.event.ClientListenerNotifier; import org.infinispan.client.hotrod.exceptions.TransportException; import org.infinispan.client.hotrod.impl.TopologyInfo; import org.infinispan.client.hotrod.impl.consistenthash.ConsistentHash; import org.infinispan.client.hotrod.impl.consistenthash.ConsistentHashFactory; import org.infinispan.client.hotrod.impl.operations.AddClientListenerOperation; import org.infinispan.client.hotrod.impl.protocol.Codec; import org.infinispan.client.hotrod.impl.protocol.HotRodConstants; import org.infinispan.client.hotrod.impl.transport.Transport; import org.infinispan.client.hotrod.impl.transport.TransportFactory; import org.infinispan.client.hotrod.logging.Log; import org.infinispan.client.hotrod.logging.LogFactory; import org.infinispan.commons.marshall.Marshaller; import org.infinispan.commons.marshall.WrappedByteArray; import org.infinispan.commons.util.SslContextFactory; import org.infinispan.commons.util.Util; import net.jcip.annotations.GuardedBy; import net.jcip.annotations.ThreadSafe; /** * @author Mircea.Markus@jboss.com * @since 4.1 */ @ThreadSafe public class TcpTransportFactory implements TransportFactory { private static final Log log = LogFactory.getLog(TcpTransportFactory.class, Log.class); private static final boolean trace = log.isTraceEnabled(); public static final String DEFAULT_CLUSTER_NAME = "___DEFAULT-CLUSTER___"; /** * We need synchronization as the thread that calls {@link TransportFactory#start(org.infinispan.client.hotrod.impl.protocol.Codec, * org.infinispan.client.hotrod.configuration.Configuration, java.util.concurrent.atomic.AtomicInteger, * org.infinispan.client.hotrod.event.ClientListenerNotifier)} * might(and likely will) be different from the thread(s) that calls {@link TransportFactory#getTransport(Object, * java.util.Set, byte[])} or other methods */ private final Object lock = new Object(); // The connection pool implementation is assumed to be thread-safe, so we need to synchronize just the access to this field and not the method calls private GenericKeyedObjectPool<SocketAddress, TcpTransport> connectionPool; // Per cache request balancing strategy private Map<WrappedByteArray, FailoverRequestBalancingStrategy> balancers; private Configuration configuration; private Collection<SocketAddress> initialServers; // the primitive fields are often accessed separately from the rest so it makes sense not to require synchronization for them private volatile boolean tcpNoDelay; private volatile boolean tcpKeepAlive; private volatile int soTimeout; private volatile int connectTimeout; private volatile int maxRetries; private volatile SSLContext sslContext; private volatile String sniHostName; private volatile ClientListenerNotifier listenerNotifier; @GuardedBy("lock") private volatile TopologyInfo topologyInfo; private volatile String currentClusterName; private List<ClusterInfo> clusters = new ArrayList<>(); // Topology age provides a way to avoid concurrent cluster view changes, // affecting a cluster switch. After a cluster switch, the topology age is // increased and so any old requests that might have received topology // updates won't be allowed to apply since they refer to older views. private final AtomicInteger topologyAge = new AtomicInteger(0); private final BlockingQueue<AddClientListenerOperation> disconnectedListeners = new LinkedBlockingQueue<>(); @Override public void start(Codec codec, Configuration configuration, AtomicInteger defaultCacheTopologyId, ClientListenerNotifier listenerNotifier) { synchronized (lock) { this.listenerNotifier = listenerNotifier; this.configuration = configuration; Collection<SocketAddress> servers = new ArrayList<>(); initialServers = new ArrayList<>(); for (ServerConfiguration server : configuration.servers()) { servers.add(new InetSocketAddress(server.host(), server.port())); } initialServers.addAll(servers); if (!configuration.clusters().isEmpty()) { configuration.clusters().stream().forEach(cluster -> { Collection<SocketAddress> clusterAddresses = cluster.getCluster().stream() .map(server -> new InetSocketAddress(server.host(), server.port())) .collect(Collectors.toList()); ClusterInfo clusterInfo = new ClusterInfo(cluster.getClusterName(), clusterAddresses); log.debugf("Add secondary cluster: %s", clusterInfo); clusters.add(clusterInfo); }); clusters.add(new ClusterInfo(DEFAULT_CLUSTER_NAME, initialServers)); } currentClusterName = DEFAULT_CLUSTER_NAME; topologyInfo = new TopologyInfo(defaultCacheTopologyId, Collections.unmodifiableCollection(servers), configuration); tcpNoDelay = configuration.tcpNoDelay(); tcpKeepAlive = configuration.tcpKeepAlive(); soTimeout = configuration.socketTimeout(); connectTimeout = configuration.connectionTimeout(); maxRetries = configuration.maxRetries(); if (configuration.security().ssl().enabled()) { SslConfiguration ssl = configuration.security().ssl(); if (ssl.sslContext() != null) { sslContext = ssl.sslContext(); } else { sslContext = SslContextFactory.getContext( ssl.keyStoreFileName(), ssl.keyStoreType(), ssl.keyStorePassword(), ssl.keyStoreCertificatePassword(), ssl.keyAlias(), ssl.trustStoreFileName(), ssl.trustStoreType(), ssl.trustStorePassword(), ssl.protocol(), configuration.classLoader()); } sniHostName = ssl.sniHostName(); } if (log.isDebugEnabled()) { log.debugf("Statically configured servers: %s", servers); log.debugf("Load balancer class: %s", configuration.balancingStrategyClass().getName()); log.debugf("Tcp no delay = %b; client socket timeout = %d ms; connect timeout = %d ms", tcpNoDelay, soTimeout, connectTimeout); } TransportObjectFactory connectionFactory; if (configuration.security().authentication().enabled()) { connectionFactory = new SaslTransportObjectFactory(codec, this, defaultCacheTopologyId, configuration); } else { connectionFactory = new TransportObjectFactory(codec, this, defaultCacheTopologyId, configuration); } PropsKeyedObjectPoolFactory<SocketAddress, TcpTransport> poolFactory = new PropsKeyedObjectPoolFactory<SocketAddress, TcpTransport>( connectionFactory, configuration.connectionPool()); createAndPreparePool(poolFactory); balancers = new HashMap<>(); addBalancer(new WrappedByteArray(RemoteCacheManager.cacheNameBytes())); pingServersIgnoreException(); } } private FailoverRequestBalancingStrategy addBalancer(WrappedByteArray cacheName) { FailoverRequestBalancingStrategy balancer; FailoverRequestBalancingStrategy cfgBalancerInstance = configuration.balancingStrategy(); if (cfgBalancerInstance != null) { balancer = cfgBalancerInstance; } else { balancer = Util.getInstance(configuration.balancingStrategyClass()); } balancers.put(cacheName, balancer); balancer.setServers(topologyInfo.getServers(cacheName)); return balancer; } private void pingServersIgnoreException() { GenericKeyedObjectPool<SocketAddress, TcpTransport> pool = getConnectionPool(); Collection<SocketAddress> servers = topologyInfo.getServers(); for (SocketAddress addr : servers) { try { // Go through all statically configured nodes and force a // connection to be established and a ping message to be sent. pool.returnObject(addr, pool.borrowObject(addr)); } catch (Exception e) { // Ping's objective is to retrieve a potentially newer // version of the Hot Rod cluster topology, so ignore // exceptions from nodes that might not be up any more. if (trace) log.tracef(e, "Ignoring exception pinging configured servers %s to establish a connection", servers); } } } /** * This will makes sure that, when the evictor thread kicks in the minIdle is set. We don't want to do this is the * caller's thread, * as this is the user. */ private void createAndPreparePool(PropsKeyedObjectPoolFactory<SocketAddress, TcpTransport> poolFactory) { connectionPool = (GenericKeyedObjectPool<SocketAddress, TcpTransport>) poolFactory.createPool(); Collection<SocketAddress> servers = topologyInfo.getServers(); for (SocketAddress addr : servers) { connectionPool.preparePool(addr, false); } } @Override public void destroy() { synchronized (lock) { connectionPool.clear(); try { connectionPool.close(); } catch (Exception e) { log.warn("Exception while shutting down the connection pool.", e); } } } @Override public CacheTopologyInfo getCacheTopologyInfo(byte[] cacheName) { synchronized (lock) { return topologyInfo.getCacheTopologyInfo(cacheName); } } @Override public void updateHashFunction(Map<SocketAddress, Set<Integer>> servers2Hash, int numKeyOwners, short hashFunctionVersion, int hashSpace, byte[] cacheName, AtomicInteger topologyId) { synchronized (lock) { topologyInfo.updateTopology(servers2Hash, numKeyOwners, hashFunctionVersion, hashSpace, cacheName, topologyId); } } @Override public void updateHashFunction(SocketAddress[][] segmentOwners, int numSegments, short hashFunctionVersion, byte[] cacheName, AtomicInteger topologyId) { synchronized (lock) { topologyInfo.updateTopology(segmentOwners, numSegments, hashFunctionVersion, cacheName, topologyId); } } @Override public Transport getTransport(Set<SocketAddress> failedServers, byte[] cacheName) { SocketAddress server; synchronized (lock) { server = getNextServer(failedServers, cacheName); } return borrowTransportFromPool(server); } @GuardedBy("lock") private SocketAddress getNextServer(Set<SocketAddress> failedServers, byte[] cacheName) { FailoverRequestBalancingStrategy balancer = getOrCreateIfAbsentBalancer(cacheName); SocketAddress server = balancer.nextServer(failedServers); if (trace) log.tracef("Using the balancer for determining the server: %s", server); return server; } private FailoverRequestBalancingStrategy getOrCreateIfAbsentBalancer(byte[] cacheName) { WrappedByteArray key = new WrappedByteArray(cacheName); FailoverRequestBalancingStrategy balancer = balancers.get(key); if (balancer == null) balancer = addBalancer(key); return balancer; } @Override public Transport getAddressTransport(SocketAddress server) { return borrowTransportFromPool(server); } @Override public SocketAddress getSocketAddress(Object key, byte[] cacheName) { return topologyInfo.getHashAwareServer(key, cacheName).orElse(null); } public Transport getTransport(Object key, Set<SocketAddress> failedServers, byte[] cacheName) { SocketAddress server; synchronized (lock) { Optional<SocketAddress> hashAwareServer = topologyInfo.getHashAwareServer(key, cacheName); Optional<SocketAddress> filtered = hashAwareServer.filter(a -> failedServers == null || !failedServers.contains(a)); server = filtered.orElse(getNextServer(failedServers, cacheName)); } return borrowTransportFromPool(server); } @Override public void releaseTransport(Transport transport) { if (transport.isBusy()) { if (trace) { log.tracef("Not releasing transport since it is in use: %s", transport); } return; } // The invalidateObject()/returnObject() calls could take a long time, so we hold the lock only until we get the connection pool reference KeyedObjectPool<SocketAddress, TcpTransport> pool = getConnectionPool(); TcpTransport tcpTransport = (TcpTransport) transport; if (!tcpTransport.isValid()) { try { if (trace) { log.tracef("Dropping connection as it is no longer valid: %s", tcpTransport); } pool.invalidateObject(tcpTransport.getServerAddress(), tcpTransport); } catch (Exception e) { log.couldNoInvalidateConnection(tcpTransport, e); } } else { try { pool.returnObject(tcpTransport.getServerAddress(), tcpTransport); } catch (Exception e) { log.couldNotReleaseConnection(tcpTransport, e); } finally { logConnectionInfo(tcpTransport.getServerAddress()); } } } @Override public void invalidateTransport(SocketAddress serverAddress, Transport transport) { transport.invalidate(); } @Override public void updateServers(Collection<SocketAddress> newServers, byte[] cacheName, boolean quiet) { synchronized (lock) { Collection<SocketAddress> servers = updateTopologyInfo(cacheName, newServers, quiet); if (!servers.isEmpty()) { FailoverRequestBalancingStrategy balancer = getOrCreateIfAbsentBalancer(cacheName); balancer.setServers(servers); } } } private void updateServers(Collection<SocketAddress> newServers, boolean quiet) { synchronized (lock) { Collection<SocketAddress> servers = updateTopologyInfo(null, newServers, quiet); if (!servers.isEmpty()) { for (FailoverRequestBalancingStrategy balancer : balancers.values()) balancer.setServers(servers); } } } @GuardedBy("lock") private Collection<SocketAddress> updateTopologyInfo(byte[] cacheName, Collection<SocketAddress> newServers, boolean quiet) { Collection<SocketAddress> servers = topologyInfo.getServers(); Set<SocketAddress> addedServers = new HashSet<>(newServers); addedServers.removeAll(servers); Set<SocketAddress> failedServers = new HashSet<>(servers); failedServers.removeAll(newServers); if (trace) { log.tracef("Current list: %s", servers); log.tracef("New list: %s", newServers); log.tracef("Added servers: %s", addedServers); log.tracef("Removed servers: %s", failedServers); } if (failedServers.isEmpty() && addedServers.isEmpty()) { log.debug("Same list of servers, not changing the pool"); return Collections.emptyList(); } //1. first add new servers. For servers that went down, the returned transport will fail for now for (SocketAddress server : addedServers) { log.newServerAdded(server); try { connectionPool.addObject(server); } catch (Exception e) { if (!quiet) log.failedAddingNewServer(server, e); } } //2. Remove failed servers for (SocketAddress server : failedServers) { log.removingServer(server); connectionPool.clear(server); } servers = Collections.unmodifiableList(new ArrayList(newServers)); topologyInfo.updateServers(cacheName, servers); if (!failedServers.isEmpty()) { listenerNotifier.failoverClientListeners(failedServers); } return servers; } public Collection<SocketAddress> getServers() { synchronized (lock) { return topologyInfo.getServers(); } } private void logConnectionInfo(SocketAddress server) { if (trace) { KeyedObjectPool<SocketAddress, TcpTransport> pool = getConnectionPool(); log.tracef("For server %s: active = %d; idle = %d", server, pool.getNumActive(server), pool.getNumIdle(server)); } } private Transport borrowTransportFromPool(SocketAddress server) { // The borrowObject() call could take a long time, so we hold the lock only until we get the connection pool reference KeyedObjectPool<SocketAddress, TcpTransport> pool = getConnectionPool(); try { TcpTransport tcpTransport = pool.borrowObject(server); reconnectListenersIfNeeded(); return tcpTransport; } catch (Exception e) { String message = "Could not fetch transport"; log.debug(message, e); throw new TransportException(message, e, server); } finally { logConnectionInfo(server); } } private void reconnectListenersIfNeeded() { if (!disconnectedListeners.isEmpty()) { List<AddClientListenerOperation> drained = new ArrayList<>(); disconnectedListeners.drainTo(drained); for (AddClientListenerOperation op : drained) { if (trace) { log.tracef("Reconnecting client listener with id %s", Util.printArray(op.listenerId)); } op.execute(); } } } /** * Note that the returned <code>ConsistentHash</code> may not be thread-safe. */ @Override public ConsistentHash getConsistentHash(byte[] cacheName) { synchronized (lock) { return topologyInfo.getConsistentHash(cacheName); } } @Override public ConsistentHashFactory getConsistentHashFactory() { return topologyInfo.getConsistentHashFactory(); } @Override public boolean isTcpNoDelay() { return tcpNoDelay; } @Override public boolean isTcpKeepAlive() { return tcpKeepAlive; } @Override public int getMaxRetries() { if (Thread.currentThread().isInterrupted()) { return -1; } return maxRetries; } @Override public int getSoTimeout() { return soTimeout; } @Override public int getConnectTimeout() { return connectTimeout; } @Override public SSLContext getSSLContext() { return sslContext; } @Override public String getSniHostName() { return sniHostName; } @Override public void addDisconnectedListener(AddClientListenerOperation listener) throws InterruptedException { disconnectedListeners.put(listener); } @Override public void reset(byte[] cacheName) { updateServers(initialServers, cacheName, true); topologyInfo.setTopologyId(cacheName, HotRodConstants.DEFAULT_CACHE_TOPOLOGY); } @Override public AtomicInteger createTopologyId(byte[] cacheName) { synchronized (lock) { return topologyInfo.createTopologyId(cacheName, -1); } } @Override public int getTopologyId(byte[] cacheName) { synchronized (lock) { return topologyInfo.getTopologyId(cacheName); } } @Override public ClusterSwitchStatus trySwitchCluster(String failedClusterName, byte[] cacheName) { synchronized (lock) { if (trace) log.tracef("Trying to switch cluster away from '%s'", failedClusterName); if (clusters.isEmpty()) { log.debugf("No alternative clusters configured, so can't switch cluster"); return ClusterSwitchStatus.NOT_SWITCHED; } String currentClusterName = this.currentClusterName; if (!isSwitchedClusterNotAvailable(failedClusterName, currentClusterName)) { log.debugf("Cluster already switched from failed cluster `%s` to `%s`, try again", failedClusterName, currentClusterName); return ClusterSwitchStatus.IN_PROGRESS; } // Switch cluster if there has not been a topology id cluster switch reset recently, if (topologyInfo.isTopologyValid(cacheName)) { if (trace) log.tracef("Switching clusters, failed cluster is '%s' and current cluster name is '%s'", failedClusterName, currentClusterName); List<ClusterInfo> candidateClusters = new ArrayList<>(); for (ClusterInfo cluster : clusters) { String clusterName = cluster.clusterName; if (!clusterName.equals(failedClusterName)) candidateClusters.add(cluster); } for (int i = 0; i < candidateClusters.size(); i++) { ClusterInfo cluster = candidateClusters.get(i % candidateClusters.size()); boolean alive = checkServersAlive(cluster.clusterAddresses); if (alive) { topologyAge.incrementAndGet(); Collection<SocketAddress> servers = updateTopologyInfo(cacheName, cluster.clusterAddresses, true); if (!servers.isEmpty()) { FailoverRequestBalancingStrategy balancer = getOrCreateIfAbsentBalancer(cacheName); balancer.setServers(servers); } topologyInfo.setTopologyId(cacheName, HotRodConstants.SWITCH_CLUSTER_TOPOLOGY); //clustersViewed++; // Increase number of clusters viewed this.currentClusterName = cluster.clusterName; if (log.isInfoEnabled()) { if (!cluster.clusterName.equals(DEFAULT_CLUSTER_NAME)) log.switchedToCluster(cluster.clusterName); else log.switchedBackToMainCluster(); } return ClusterSwitchStatus.SWITCHED; } } log.debugf("All cluster addresses viewed and none worked: %s", clusters); return ClusterSwitchStatus.NOT_SWITCHED; } return ClusterSwitchStatus.IN_PROGRESS; } } public boolean checkServersAlive(Collection<SocketAddress> servers) { for (SocketAddress server : servers) { try { connectionPool.addObject(server); } catch (Exception e) { log.tracef(e, "Error checking whether this server is alive: %s", server); return false; } } return true; } private boolean isSwitchedClusterNotAvailable(String failedClusterName, String currentClusterName) { return currentClusterName.equals(failedClusterName); } public enum ClusterSwitchStatus { NOT_SWITCHED, SWITCHED, IN_PROGRESS; } @Override public Marshaller getMarshaller() { return listenerNotifier.getMarshaller(); } public boolean switchToCluster(String clusterName) { if (clusters.isEmpty()) { log.debugf("No alternative clusters configured, so can't switch cluster"); return false; } Collection<SocketAddress> addresses = findClusterInfo(clusterName); if (!addresses.isEmpty()) { updateServers(addresses, true); topologyInfo.setAllTopologyIds(HotRodConstants.SWITCH_CLUSTER_TOPOLOGY); if (log.isInfoEnabled()) { if (!clusterName.equals(DEFAULT_CLUSTER_NAME)) log.manuallySwitchedToCluster(clusterName); else log.manuallySwitchedBackToMainCluster(); } return true; } return false; } @Override public String getCurrentClusterName() { return currentClusterName; } @Override public int getTopologyAge() { return topologyAge.get(); } private Collection<SocketAddress> findClusterInfo(String clusterName) { for (ClusterInfo cluster : clusters) { if (cluster.clusterName.equals(clusterName)) return cluster.clusterAddresses; } return Collections.emptyList(); } /** * Note that the returned <code>RequestBalancingStrategy</code> may not be thread-safe. */ public FailoverRequestBalancingStrategy getBalancer(byte[] cacheName) { synchronized (lock) { return balancers.get(new WrappedByteArray(cacheName)); } } public GenericKeyedObjectPool<SocketAddress, TcpTransport> getConnectionPool() { synchronized (lock) { return connectionPool; } } private static final class ClusterInfo { final Collection<SocketAddress> clusterAddresses; final String clusterName; private ClusterInfo(String clusterName, Collection<SocketAddress> clusterAddresses) { this.clusterAddresses = clusterAddresses; this.clusterName = clusterName; } @Override public String toString() { return "ClusterInfo{" + "name='" + clusterName + '\'' + ", addresses=" + clusterAddresses + '}'; } } }