package org.infinispan.client.hotrod.event; import java.io.IOException; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.net.SocketAddress; import java.net.SocketException; import java.net.SocketTimeoutException; import java.nio.channels.CancelledKeyException; import java.nio.channels.ClosedChannelException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import org.infinispan.client.hotrod.annotation.ClientCacheEntryCreated; import org.infinispan.client.hotrod.annotation.ClientCacheEntryExpired; import org.infinispan.client.hotrod.annotation.ClientCacheEntryModified; import org.infinispan.client.hotrod.annotation.ClientCacheEntryRemoved; import org.infinispan.client.hotrod.annotation.ClientCacheFailover; import org.infinispan.client.hotrod.exceptions.TransportException; import org.infinispan.client.hotrod.impl.operations.AddClientListenerOperation; import org.infinispan.client.hotrod.impl.protocol.Codec; import org.infinispan.client.hotrod.impl.transport.Transport; import org.infinispan.client.hotrod.impl.transport.TransportFactory; import org.infinispan.client.hotrod.logging.Log; import org.infinispan.client.hotrod.logging.LogFactory; import org.infinispan.commons.marshall.Marshaller; import org.infinispan.commons.marshall.WrappedByteArray; import org.infinispan.commons.util.Util; /** * @author Galder ZamarreƱo */ public class ClientListenerNotifier { private static final Log log = LogFactory.getLog(ClientListenerNotifier.class, Log.class); private static final boolean trace = log.isTraceEnabled(); private static final Map<Class<? extends Annotation>, Class<?>[]> allowedListeners = new HashMap<>(4); static { allowedListeners.put(ClientCacheEntryCreated.class, new Class[]{ClientCacheEntryCreatedEvent.class, ClientCacheEntryCustomEvent.class}); allowedListeners.put(ClientCacheEntryModified.class, new Class[]{ClientCacheEntryModifiedEvent.class, ClientCacheEntryCustomEvent.class}); allowedListeners.put(ClientCacheEntryRemoved.class, new Class[]{ClientCacheEntryRemovedEvent.class, ClientCacheEntryCustomEvent.class}); allowedListeners.put(ClientCacheEntryExpired.class, new Class[]{ClientCacheEntryExpiredEvent.class, ClientCacheEntryCustomEvent.class}); allowedListeners.put(ClientCacheFailover.class, new Class[]{ClientCacheFailoverEvent.class}); } private final ConcurrentMap<WrappedByteArray, EventDispatcher> clientListeners = new ConcurrentHashMap<>(); private final ExecutorService executor; private final Codec codec; private final Marshaller marshaller; private final TransportFactory transportFactory; private final Consumer<WrappedByteArray> failoverClientListener = this::failoverClientListener; protected ClientListenerNotifier( ExecutorService executor, Codec codec, Marshaller marshaller, TransportFactory transportFactory) { this.executor = executor; this.codec = codec; this.marshaller = marshaller; this.transportFactory = transportFactory; } public static ClientListenerNotifier create(Codec codec, Marshaller marshaller, TransportFactory transportFactory) { ExecutorService executor = Executors.newCachedThreadPool(getRestoreThreadNameThreadFactory()); return new ClientListenerNotifier(executor, codec, marshaller, transportFactory); } private static ThreadFactory getRestoreThreadNameThreadFactory() { return r -> new Thread(() -> { final String originalName = Thread.currentThread().getName(); try { r.run(); } finally { Thread.currentThread().setName(originalName); } }); } public Marshaller getMarshaller() { return marshaller; } public void addClientListener(AddClientListenerOperation op) { Map<Class<? extends Annotation>, List<ClientListenerInvocation>> invocables = findMethods(op.listener); EventDispatcher eventDispatcher = new EventDispatcher(op, invocables, op.getCacheName()); clientListeners.put(new WrappedByteArray(op.listenerId), eventDispatcher); if (trace) log.tracef("Add client listener with id %s, for listener %s and invocable methods %s", Util.printArray(op.listenerId), op.listener, invocables); } public void failoverClientListeners(Set<SocketAddress> failedServers) { // Compile all listener ids that need failing over List<WrappedByteArray> failoverListenerIds = new ArrayList<>(); for (Map.Entry<WrappedByteArray, EventDispatcher> entry : clientListeners.entrySet()) { EventDispatcher dispatcher = entry.getValue(); if (failedServers.contains(dispatcher.transport.getRemoteSocketAddress())) failoverListenerIds.add(entry.getKey()); } if (trace && failoverListenerIds.isEmpty()) log.tracef("No event listeners registered in faild servers: %s", failedServers); // Remove tracking listeners and read to the fallback transport failoverListenerIds.forEach(failoverClientListener); } public void failoverClientListener(byte[] listenerId) { failoverClientListener(new WrappedByteArray(listenerId)); } private void failoverClientListener(WrappedByteArray listenerId) { EventDispatcher dispatcher = clientListeners.get(listenerId); removeClientListener(listenerId); // Invoke failover event callback, if presents invokeFailoverEvent(dispatcher); // Re-execute adding client listener in one of the remaining nodes dispatcher.op.execute(); if (trace) { SocketAddress failedServerAddress = dispatcher.transport.getRemoteSocketAddress(); log.tracef("Fallback listener id %s from a failed server %s to %s", Util.printArray(listenerId.getBytes()), failedServerAddress, dispatcher.op.getDedicatedTransport().getRemoteSocketAddress()); } } private void invokeFailoverEvent(EventDispatcher dispatcher) { List<ClientListenerInvocation> callbacks = dispatcher.invocables.get(ClientCacheFailover.class); if (callbacks != null) { for (ClientListenerInvocation callback : callbacks) callback.invoke(ClientEvents.mkCachefailoverEvent()); } } public void startClientListener(byte[] listenerId) { EventDispatcher eventDispatcher = clientListeners.get(new WrappedByteArray(listenerId)); executor.submit(eventDispatcher); } public void removeClientListener(byte[] listenerId) { removeClientListener(new WrappedByteArray(listenerId)); } private void removeClientListener(WrappedByteArray listenerId) { EventDispatcher dispatcher = clientListeners.remove(listenerId); dispatcher.transport.release(); // force shutting it if (trace) log.tracef("Remove client listener with id %s", Util.printArray(listenerId.getBytes())); } public byte[] findListenerId(Object listener) { for (EventDispatcher dispatcher : clientListeners.values()) { if (dispatcher.op.listener.equals(listener)) return dispatcher.op.listenerId; } return null; } public boolean isListenerConnected(byte[] listenerId) { EventDispatcher dispatcher = clientListeners.get(new WrappedByteArray(listenerId)); // If listener not present, is not active return dispatcher != null && !dispatcher.stopped; } public Transport findTransport(byte[] listenerId) { EventDispatcher dispatcher = clientListeners.get(new WrappedByteArray(listenerId)); if (dispatcher != null) return dispatcher.transport; return null; } public Map<Class<? extends Annotation>, List<ClientListenerInvocation>> findMethods(Object listener) { Map<Class<? extends Annotation>, List<ClientListenerInvocation>> listenerMethodMap = new HashMap<>(4, 0.99f); for (Method m : listener.getClass().getMethods()) { // loop through all valid method annotations for (Map.Entry<Class<? extends Annotation>, Class<?>[]> entry : allowedListeners.entrySet()) { Class<? extends Annotation> annotationType = entry.getKey(); Class<?>[] eventTypes = entry.getValue(); if (m.isAnnotationPresent(annotationType)) { testListenerMethodValidity(m, eventTypes, annotationType.getName()); SecurityActions.setAccessible(m); ClientListenerInvocation invocation = new ClientListenerInvocation(listener, m); List<ClientListenerInvocation> invocables = listenerMethodMap.get(annotationType); if (invocables == null) { invocables = new ArrayList<>(); listenerMethodMap.put(annotationType, invocables); } invocables.add(invocation); } } } return listenerMethodMap; } private void testListenerMethodValidity(Method m, Class<?>[] allowedParameters, String annotationName) { boolean isAllowed = false; for (Class<?> allowedParameter : allowedParameters) { if (m.getParameterTypes().length == 1 && m.getParameterTypes()[0].isAssignableFrom(allowedParameter)) { isAllowed = true; break; } } if (!isAllowed) throw log.incorrectClientListener(annotationName, Arrays.asList(allowedParameters)); if (!m.getReturnType().equals(void.class)) throw log.incorrectClientListener(annotationName); } public Set<Object> getListeners(byte[] cacheName) { Set<Object> ret = new HashSet<>(clientListeners.size()); for (EventDispatcher dispatcher : clientListeners.values()) { if (Arrays.equals(dispatcher.op.cacheName, cacheName)) ret.add(dispatcher.op.listener); } return ret; } public void stop() { for (WrappedByteArray listenerId : clientListeners.keySet()) { if (trace) log.tracef("Remote cache manager stopping, remove client listener id %s", Util.printArray(listenerId.getBytes())); removeClientListener(listenerId); } executor.shutdown(); try { executor.awaitTermination(5, TimeUnit.SECONDS); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } public void invokeEvent(byte[] listenerId, ClientEvent clientEvent) { EventDispatcher eventDispatcher = clientListeners.get(new WrappedByteArray(listenerId)); eventDispatcher.invokeClientEvent(clientEvent); } private final class EventDispatcher implements Runnable { final Map<Class<? extends Annotation>, List<ClientListenerInvocation>> invocables; final AddClientListenerOperation op; final Transport transport; final String cacheName; volatile boolean stopped = false; private EventDispatcher(AddClientListenerOperation op, Map<Class<? extends Annotation>, List<ClientListenerInvocation>> invocables, String cacheName) { this.op = op; this.transport = op.getDedicatedTransport(); this.invocables = invocables; this.cacheName = cacheName; } @Override public void run() { Thread.currentThread().setName(getThreadName()); while (!Thread.currentThread().isInterrupted()) { ClientEvent clientEvent = null; try { clientEvent = codec.readEvent(transport, op.listenerId, marshaller); invokeClientEvent(clientEvent); // Nullify event, makes it easier to identify network vs invocation error messages clientEvent = null; } catch (TransportException e) { Throwable cause = e.getCause(); if (cause instanceof ClosedChannelException || (cause instanceof SocketException && !transport.isValid())) { // Channel closed, ignore and exit log.debug("Channel closed, exiting event reader thread"); stopped = true; return; } else if (cause instanceof SocketTimeoutException) { log.debug("Timed out reading event, retry"); } else if (clientEvent != null) { log.unexpectedErrorConsumingEvent(clientEvent, e); } else if (cause instanceof IOException && cause.getMessage().contains("Connection reset by peer")) { tryFailoverClientListener(); stopped = true; return; } else { log.unrecoverableErrorReadingEvent(e, transport.getRemoteSocketAddress()); stopped = true; return; // Server is likely gone! } } catch (CancelledKeyException e) { // Cancelled key exceptions are also thrown when the channel has been closed log.debug("Key cancelled, most likely channel closed, exiting event reader thread"); stopped = true; return; } catch (Throwable t) { if (clientEvent != null) log.unexpectedErrorConsumingEvent(clientEvent, t); else log.unableToReadEventFromServer(t, transport.getRemoteSocketAddress()); if (!transport.isValid()) { stopped = true; return; } } } } private void tryFailoverClientListener() { try { log.debug("Connection reset by peer, so failover client listener"); failoverClientListener(op.listenerId); } catch (TransportException e) { log.debug("Unable to failover client listener, so ignore connection reset"); try { transportFactory.addDisconnectedListener(op); } catch (InterruptedException e1) { Thread.currentThread().interrupt(); } } } String getThreadName() { String listenerId = Util.toHexString(op.listenerId, 8); return cacheName.isEmpty() ? "Client-Listener-" + listenerId : "Client-Listener-" + cacheName + "-" + listenerId; } void invokeClientEvent(ClientEvent clientEvent) { if (trace) log.tracef("Event %s received for listener with id=%s", clientEvent, Util.printArray(op.listenerId)); switch (clientEvent.getType()) { case CLIENT_CACHE_ENTRY_CREATED: invokeCallbacks(clientEvent, ClientCacheEntryCreated.class); break; case CLIENT_CACHE_ENTRY_MODIFIED: invokeCallbacks(clientEvent, ClientCacheEntryModified.class); break; case CLIENT_CACHE_ENTRY_REMOVED: invokeCallbacks(clientEvent, ClientCacheEntryRemoved.class); break; case CLIENT_CACHE_ENTRY_EXPIRED: invokeCallbacks(clientEvent, ClientCacheEntryExpired.class); break; } } private void invokeCallbacks(ClientEvent event, Class<? extends Annotation> type) { List<ClientListenerInvocation> callbacks = invocables.get(type); if (callbacks != null) { for (ClientListenerInvocation callback : callbacks) callback.invoke(event); } } } private static final class ClientListenerInvocation { private static final Log log = LogFactory.getLog(ClientListenerInvocation.class, Log.class); final Object listener; final Method method; private ClientListenerInvocation(Object listener, Method method) { this.listener = listener; this.method = method; } public void invoke(ClientEvent event) { try { method.invoke(listener, event); } catch (Exception e) { throw log.exceptionInvokingListener( e.getClass().getName(), method, listener, e); } } } }