package lsr.paxos.network;
import static lsr.common.ProcessDescriptor.processDescriptor;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import lsr.paxos.messages.Message;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class NioNetwork extends Network implements Runnable {
private Selector selector;
// input, output
NioConnection[][] connections;
private HashMap<SocketChannel, ByteBuffer> tmpBuffers = new HashMap<SocketChannel, ByteBuffer>();
public NioNetwork()
throws IOException {
selector = SelectorProvider.provider().openSelector();
ServerSocketChannel serverChannel = ServerSocketChannel.open();
serverChannel.configureBlocking(false);
serverChannel.socket().bind(new InetSocketAddress((InetAddress) null,
processDescriptor.getLocalProcess().getReplicaPort()));
serverChannel.register(selector, SelectionKey.OP_ACCEPT);
// input, output
connections = new NioConnection[processDescriptor.numReplicas][2];
// for (int i = localId + 1; i < processDescriptor.numReplicas; i++) {
for (int i = 0; i < processDescriptor.numReplicas; i++) {
if (i != localId)
{
connections[i][1] = new NioOutputConnection(this,
processDescriptor.config.getProcess(i), null);
if (i > localId)
connections[i][1].start();
}
}
}
@Override
public void start() {
Thread t = new Thread(this);
t.setDaemon(true);
t.start();
}
@Override
public void run() {
while (true) {
try {
selector.select();
Iterator<SelectionKey> selectedKeys = this.selector
.selectedKeys().iterator();
while (selectedKeys.hasNext()) {
SelectionKey key = (SelectionKey) selectedKeys.next();
selectedKeys.remove();
if (!key.isValid())
continue;
if (key.isAcceptable())
accept(key);
if (key.isReadable())
read(key);
}
} catch (ClosedChannelException e) {
logger.debug("other process closed a channel when attepting to connect here");
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
}
protected void accept(SelectionKey key) throws IOException,
ClosedChannelException {
ServerSocketChannel serverSocketChannel = (ServerSocketChannel) key
.channel();
SocketChannel channel = serverSocketChannel.accept();
configureChannel(channel);
tmpBuffers.put(channel, ByteBuffer.allocate(8));
channel.register(selector, SelectionKey.OP_READ);
logger.trace("socket accept: {}", channel);
}
protected void configureChannel(SocketChannel channel) throws IOException,
SocketException {
channel.configureBlocking(false);
channel.socket().setTcpNoDelay(true);
channel.socket().setSendBufferSize(NioConnection.TCP_BUFFER_SIZE);
channel.socket().setReceiveBufferSize(NioConnection.TCP_BUFFER_SIZE);
}
private void read(SelectionKey key) throws IOException {
SocketChannel channel = (SocketChannel) key.channel();
ByteBuffer buffer = tmpBuffers.get(channel);
if (logger.isTraceEnabled()) {
logger.trace("socket read: {} {}", channel, (tmpBuffers.get(channel) != null));
}
int numRead = 0;
try {
numRead = channel.read(buffer);
} catch (IOException e) {
logger.trace("client disconnected");
key.cancel();
tmpBuffers.remove(channel);
return;
}
if (numRead == -1) {
logger.trace("client disconnected");
key.cancel();
tmpBuffers.remove(channel);
return;
}
if (buffer.position() != buffer.capacity())
return;
buffer.flip();
buffer.getInt();
int senderId = buffer.getInt();
key.cancel();
tmpBuffers.remove(channel);
if (senderId >= processDescriptor.numReplicas)
{
logger.error("Invalid senderId: {}", senderId);
return;
}
NioInputConnection inputConnection = new NioInputConnection(this,
processDescriptor.config.getProcess(senderId),
channel);
connections[senderId][0] = inputConnection;
inputConnection.start();
NioOutputConnection outputConnection = new NioOutputConnection(this,
processDescriptor.config.getProcess(senderId),
channel);
connections[senderId][1] = outputConnection;
outputConnection.start();
logger.debug("input connection established with: {}", senderId);
}
@Override
protected void send(Message message, int destination) {
boolean sent = ((NioOutputConnection) connections[destination][1])
.send(message.toByteArray());
if (logger.isTraceEnabled()) {
logger.trace("send message to: {} - {}", destination, (sent == true ? "submitted"
: "rejected"));
}
}
@Override
protected void send(Message message, BitSet destinations) {
for (int i = destinations.nextSetBit(0); i >= 0; i = destinations
.nextSetBit(i + 1)) {
send(message, i);
}
}
protected void removeConnection(int senderId, int receiverId) {
logger.trace("connection closed with: {} --> ", senderId, receiverId);
if (receiverId == localId)
((NioOutputConnection) connections[senderId][1])
.notifyAboutInputDisconnected();
}
private final static Logger logger = LoggerFactory.getLogger(Network.class);
}