/* * Quasar: lightweight threads and actors for the JVM. * Copyright (c) 2013-2015, Parallel Universe Software Co. All rights reserved. * * This program and the accompanying materials are dual-licensed under * either the terms of the Eclipse Public License v1.0 as published by * the Eclipse Foundation * * or (per the licensee's choosing) * * under the terms of the GNU Lesser General Public License version 3.0 * as published by the Free Software Foundation. */ package co.paralleluniverse.remote.galaxy; import co.paralleluniverse.concurrent.util.MapUtil; import co.paralleluniverse.fibers.SuspendExecution; import co.paralleluniverse.galaxy.MessageListener; import co.paralleluniverse.galaxy.cluster.NodeChangeListener; import co.paralleluniverse.galaxy.quasar.Grid; import co.paralleluniverse.io.serialization.Serialization; import co.paralleluniverse.strands.channels.SendPort; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicLong; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * This class listens to messages received from remote ends of a channel, and forwards them to the right channel. * */ public class RemoteChannelReceiver<Message> implements MessageListener { private static final Logger LOG = LoggerFactory.getLogger(RemoteChannelReceiver.class); private static final ConcurrentMap<SendPort<?>, RemoteChannelReceiver<?>> receivers = MapUtil.newConcurrentHashMap(); private static final AtomicLong topicGen = new AtomicLong(1000); public static <Message> RemoteChannelReceiver<Message> getReceiver(SendPort<Message> channel) { RemoteChannelReceiver<Message> receiver = (RemoteChannelReceiver<Message>) receivers.get(channel); if (receiver == null) { receiver = new RemoteChannelReceiver<Message>(channel); RemoteChannelReceiver<Message> tmp = (RemoteChannelReceiver<Message>) receivers.putIfAbsent(channel, receiver); if (tmp == null) receiver.subscribe(); else receiver = tmp; } return receiver; } void shutdown() { unsubscribe(); receivers.remove(this.channel); } public interface MessageFilter<Message> { boolean shouldForwardMessage(Message msg); } ////////////////////////////// private final SendPort<Message> channel; private final long topic; private volatile MessageFilter<Message> filter; private final Map<Short, Integer> references = new ConcurrentHashMap<>(); private RemoteChannelReceiver(SendPort<Message> channel) { this.channel = channel; this.topic = topicGen.incrementAndGet(); try { Grid.getInstance().cluster().addNodeChangeListener(new NodeChangeListener() { @Override public void nodeAdded(short id) { } @Override public void nodeSwitched(short id) { } @Override public void nodeRemoved(short id) { LOG.debug("decrease RefCount for {} from node {}", this, id); references.remove(id); if (references.isEmpty()) { LOG.debug("Shutting down receiver due to zero references" + this); shutdown(); } } }); } catch (InterruptedException ex) { LOG.error(ex.toString()); } } public void setFilter(MessageFilter<Message> filter) { this.filter = filter; } @Override public void messageReceived(short fromNode, byte[] message) { Object m1 = Serialization.getInstance().read(message); LOG.debug("Received: " + m1); if (m1 instanceof GlxRemoteChannel.CloseMessage) { Throwable t = ((GlxRemoteChannel.CloseMessage) m1).getException(); if (t != null) channel.close(t); else channel.close(); unsubscribe(); return; } else if (m1 instanceof GlxRemoteChannel.RefMessage) { handleRefMessage((GlxRemoteChannel.RefMessage) m1); return; } final Message m = (Message) m1; if (filter == null || filter.shouldForwardMessage(m)) { try { channel.send(m); // TODO: this may potentially block the whole messenger thread!!! } catch (SuspendExecution e) { throw new AssertionError(e); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } } private void subscribe() { GlxRemoteChannel.getMessenger().addMessageListener((Long) topic, this); } private void unsubscribe() { GlxRemoteChannel.getMessenger().removeMessageListener(topic, this); } public long getTopic() { return topic; } void handleRefMessage(GlxRemoteChannel.RefMessage msg) throws RuntimeException { LOG.debug("handling: " + msg); if (msg.isAdd()) { Integer refCount = references.get(msg.getNodeId()); if (refCount == null) { references.put(msg.getNodeId(), 1); } else references.put(msg.getNodeId(), refCount + 1); } else { Integer refCount = references.get(msg.getNodeId()); if (refCount == null) { throw new RuntimeException("decrease reference counter message received for unknown cluster node"); } else { if (--refCount > 0) references.put(msg.getNodeId(), refCount); else { references.remove(msg.getNodeId()); if (references.isEmpty()) shutdown(); } } } } }