/*
* Galaxy
* Copyright (c) 2012-2014, Parallel Universe Software Co. All rights reserved.
*
* This program and the accompanying materials are dual-licensed under
* either the terms of the Eclipse Public License v1.0 as published by
* the Eclipse Foundation
*
* or (per the licensee's choosing)
*
* under the terms of the GNU Lesser General Public License version 3.0
* as published by the Free Software Foundation.
*/
package co.paralleluniverse.galaxy.netty;
import co.paralleluniverse.galaxy.Cluster;
import co.paralleluniverse.galaxy.cluster.NodeChangeListener;
import co.paralleluniverse.galaxy.cluster.NodeInfo;
import co.paralleluniverse.galaxy.cluster.ReaderWriters;
import co.paralleluniverse.galaxy.core.Comm;
import co.paralleluniverse.galaxy.core.Message;
import co.paralleluniverse.galaxy.core.MessageReceiver;
import static co.paralleluniverse.galaxy.netty.IpConstants.*;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Maps;
import java.beans.ConstructorProperties;
import java.net.InetAddress;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandler;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ServerChannel;
import org.jboss.netty.channel.group.DefaultChannelGroup;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
*
* @author pron
*/
final class TcpServerServerComm extends AbstractTcpServer implements Comm {
private static final Logger LOG = LoggerFactory.getLogger(TcpServerServerComm.class);
private MessageReceiver receiver;
@ConstructorProperties({"name", "cluster", "port"})
public TcpServerServerComm(String name, Cluster cluster, int port) throws Exception {
this(name, cluster, port, null);
}
TcpServerServerComm(String name, final Cluster cluster, int port, final ChannelHandler testHandler) throws Exception {
super(name, cluster, new ChannelGroup(), port, testHandler);
cluster.addNodeProperty(IP_ADDRESS, true, true, INET_ADDRESS_READER_WRITER);
cluster.setNodeProperty(IP_ADDRESS, InetAddress.getLocalHost());
cluster.addNodeProperty(IP_SERVER_PORT, false, true, ReaderWriters.INTEGER);
cluster.setNodeProperty(IP_SERVER_PORT, port);
cluster.addNodeChangeListener(new NodeChangeListener() {
@Override
public void nodeAdded(short id) {
}
@Override
public void nodeSwitched(short id) {
final Channel channel = getChannels().get(id);
if (channel != null) {
LOG.info("Closing channel for switched node {}", id);
channel.close();
}
}
@Override
public void nodeRemoved(short id) {
final Channel channel = getChannels().get(id);
if (channel != null) {
LOG.info("Closing channel for removed node {}", id);
channel.close();
}
}
});
}
@Override
public void start(boolean master) {
bind();
}
@Override
protected ChannelGroup getChannels() {
return (ChannelGroup) super.getChannels();
}
@Override
public void setReceiver(MessageReceiver receiver) {
assertDuringInitialization();
this.receiver = receiver;
}
@Override
public void send(Message message) {
if (!message.isResponse())
message.setMessageId(nextMessageId()); // TODO: possible pitfall: b/c this method is not synchronized, two threads may run it concurrently, one would get a smaller id but the other would put the message in a queue first - broken invariant!
LOG.debug("Send {}", message);
final Channel ch = getChannels().get(message.getNode());
if (ch == null) {
LOG.warn("No open channel found for node {}", message.getNode());
return;
}
ch.write(message);
}
@Override
protected void receive(ChannelHandlerContext ctx, Message message) {
receiver.receive(message);
}
private static class ChannelGroup extends DefaultChannelGroup {
private final BiMap<Short, Channel> channels = Maps.synchronizedBiMap((HashBiMap) HashBiMap.create());
public ChannelGroup(String name) {
super(name);
}
public ChannelGroup() {
}
@Override
public boolean add(Channel channel) {
if (channel instanceof ServerChannel)
return super.add(channel);
else {
final NodeInfo node = ChannelNodeInfo.nodeInfo.get(channel);
if (node == null) {
LOG.warn("Received connection from an unknown address {}.", channel.getRemoteAddress());
throw new RuntimeException("Unknown node for address " + channel.getRemoteAddress());
}
final short nodeId = node.getNodeId();
if (channels.containsKey(nodeId)) {
LOG.warn("Received connection from address {} of node {}, but this node is already connected from address {}.",
channel.getRemoteAddress(), nodeId, channels.get(nodeId).getRemoteAddress());
throw new RuntimeException("Node " + nodeId + " already connected.");
}
final boolean added = super.add(channel);
if (added)
channels.put(nodeId, channel);
return added;
}
}
@Override
public boolean remove(Object o) {
final Channel channel = (Channel) o;
final boolean removed = super.remove(o);
if (removed)
channels.inverse().remove(channel);
ChannelNodeInfo.nodeInfo.remove(channel);
return removed;
}
@Override
public void clear() {
super.clear();
channels.clear();
}
@Override
public boolean contains(Object o) {
if (o instanceof Short)
return channels.containsKey((Short) o);
else
return super.contains(o);
}
public Channel get(short node) {
return channels.get(node);
}
}
}