package uk.ac.imperial.lsds.seepworker.comm;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.channels.UnresolvedAddressException;
import java.util.ArrayDeque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.ac.imperial.lsds.seep.api.DataStoreType;
import uk.ac.imperial.lsds.seep.api.data.Type;
import uk.ac.imperial.lsds.seep.comm.Connection;
import uk.ac.imperial.lsds.seep.comm.OutgoingConnectionRequest;
import uk.ac.imperial.lsds.seep.core.DataStoreSelector;
import uk.ac.imperial.lsds.seep.core.EventAPI;
import uk.ac.imperial.lsds.seep.core.IBuffer;
import uk.ac.imperial.lsds.seep.core.OBuffer;
import uk.ac.imperial.lsds.seep.infrastructure.SeepEndPointType;
import uk.ac.imperial.lsds.seepworker.WorkerConfig;
public class NetworkSelector implements EventAPI, DataStoreSelector {
final private static Logger LOG = LoggerFactory.getLogger(NetworkSelector.class);
private ServerSocketChannel listenerSocket;
private Selector acceptorSelector;
private boolean acceptorWorking = false;
private Thread acceptorWorker;
private Reader[] readers;
private Writer[] writers;
private CountDownLatch writersConfiguredLatch;
private int numReaderWorkers;
private int totalNumberPendingConnectionsPerThread;
private Thread[] readerWorkers;
private Thread[] writerWorkers;
private int numWriterWorkers;
private int myId;
private Map<Integer, SelectionKey> writerKeys;
private Map<SelectionKey, Integer> readerKeys;
// incoming id -> local input buffer
private Map<Integer, IBuffer> ibMap;
private int numUpstreamConnections;
public NetworkSelector(WorkerConfig wc, int opId) {
this.myId = opId;
this.writersConfiguredLatch = new CountDownLatch(0); // Initially non-defined, nobody waits here
this.numReaderWorkers = wc.getInt(WorkerConfig.NUM_NETWORK_READER_THREADS);
this.numWriterWorkers = wc.getInt(WorkerConfig.NUM_NETWORK_WRITER_THREADS);
this.totalNumberPendingConnectionsPerThread = wc.getInt(WorkerConfig.MAX_PENDING_NETWORK_CONNECTION_PER_THREAD);
LOG.info("Configuring NetworkSelector with: {} readers, {} workers and {} maxPendingNetworkConn",
numReaderWorkers, numWriterWorkers, totalNumberPendingConnectionsPerThread);
// Create pool of reader threads
readers = new Reader[numReaderWorkers];
readerWorkers = new Thread[numReaderWorkers];
for(int i = 0; i < numReaderWorkers; i++){
readers[i] = new Reader(i, totalNumberPendingConnectionsPerThread);
Thread reader = new Thread(readers[i]);
reader.setName("Network-Reader-"+i);
readerWorkers[i] = reader;
}
// Create pool of writer threads
writers = new Writer[numWriterWorkers];
writerWorkers = new Thread[numWriterWorkers];
for(int i = 0; i < numWriterWorkers; i++){
writers[i] = new Writer(i);
Thread writer = new Thread(writers[i]);
writer.setName("Network-Writer-"+i);
writerWorkers[i] = writer;
}
this.writerKeys = new HashMap<>();
this.readerKeys = new HashMap<>();
// Create the acceptorSelector
try {
this.acceptorSelector = Selector.open();
}
catch (IOException e) {
e.printStackTrace();
}
}
public static NetworkSelector makeNetworkSelectorWithMap(int myId){
Properties p = new Properties();
p.setProperty(WorkerConfig.MASTER_IP, "127.0.0.1");
p.setProperty(WorkerConfig.PROPERTIES_FILE, "");
p.setProperty(WorkerConfig.NUM_NETWORK_READER_THREADS, "1");
p.setProperty(WorkerConfig.NUM_NETWORK_WRITER_THREADS, "1");
p.setProperty(WorkerConfig.MAX_PENDING_NETWORK_CONNECTION_PER_THREAD, "1");
WorkerConfig wc = new WorkerConfig(p);
return new NetworkSelector(wc, myId);
}
/**
* Configures a server in myIp and dataPort. There is one per worker node.
* @param myIp
* @param dataPort
*/
public void configureServerToListen(InetAddress myIp, int dataPort) {
ServerSocketChannel channel = null;
try {
channel = ServerSocketChannel.open();
SocketAddress sa = new InetSocketAddress(myIp, dataPort);
channel.configureBlocking(false);
channel.bind(sa);
channel.register(acceptorSelector, SelectionKey.OP_ACCEPT);
LOG.info("Configured Acceptor thread to listen at: {}", sa.toString());
}
catch (ClosedChannelException cce) {
// TODO Auto-generated catch block
cce.printStackTrace();
}
catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
this.listenerSocket = channel;
this.acceptorWorker = new Thread(new AcceptorWorker());
this.acceptorWorker.setName("Network-Acceptor");
}
/**
* Used to notify of new outgoing connection requests. Assigns requests to writer threads that configure them.
* @param outgoingConnectionRequest
*/
public void configureOutgoingConnection(Set<OutgoingConnectionRequest> outgoingConnectionRequest) {
LOG.info("Request to configure {} outgoing connections", outgoingConnectionRequest.size());
int writerIdx = 0;
int totalWriters = writers.length;
for(OutgoingConnectionRequest ocr : outgoingConnectionRequest) {
writers[(writerIdx++)%totalWriters].newConnection(ocr);
}
this.writersConfiguredLatch = new CountDownLatch(outgoingConnectionRequest.size()); // Initialize countDown with num of outputConns
}
/**
* Used to configure incoming connections. Assigns requests to reader threads that configure them.
* @param ibMap
*/
public void configureExpectedIncomingConnection(Map<Integer, IBuffer> ibMap) {
this.ibMap = ibMap;
int expectedUpstream = ibMap.size();
this.numUpstreamConnections = expectedUpstream;
LOG.info("Expecting {} upstream connections", numUpstreamConnections);
}
@Override
public DataStoreType type() {
return DataStoreType.NETWORK;
}
@Override
public boolean startSelector() {
// Start readers
for(Thread r : readerWorkers){
LOG.info("Starting reader: {}", r.getName());
r.start();
}
// Start writers
for(Thread w : writerWorkers){
LOG.info("Starting writer: {}", w.getName());
w.start();
}
try {
LOG.trace("Waiting for all output connections to configure. Remaining: {}", writersConfiguredLatch.getCount());
this.writersConfiguredLatch.await();
LOG.trace("All output connections are now configured");
}
catch (InterruptedException e) {
e.printStackTrace();
}
return true;
}
@Override
public boolean initSelector() {
this.acceptorWorking = true;
// Check whether there is a network acceptor worker. There won't be one if there are no input network connections.
if(acceptorWorker != null){
LOG.info("Starting acceptor thread: {}", acceptorWorker.getName());
this.acceptorWorker.start();
}
return true;
}
@Override
public boolean stopSelector() {
this.acceptorWorking = false;
for(Reader r : readers){
r.stop();
}
for(Writer w : writers){
w.stop();
}
LOG.info("Stopped reader, writers and acceptor workers");
return true;
}
@Override
public void readyForWrite(int id) {
writerKeys.get(id).selector().wakeup();
}
@Override
public void readyForWrite(List<Integer> ids) {
for(Integer id : ids){
readyForWrite(id);
}
}
/**
* This class is the server thread that accepts new connections and assigns them to reader threads.
* @author ra
*
*/
class AcceptorWorker implements Runnable {
@Override
public void run() {
LOG.info("Started Acceptor worker: {}", Thread.currentThread().getName());
int readerIdx = 0;
int totalReaders = readers.length;
while(acceptorWorking) {
try{
int readyChannels = acceptorSelector.select();
while(readyChannels == 0){
continue;
}
Set<SelectionKey> selectedKeys = acceptorSelector.selectedKeys();
Iterator<SelectionKey> keyIt = selectedKeys.iterator();
while(keyIt.hasNext()){
SelectionKey key = keyIt.next();
// accept events
if(key.isAcceptable()){
// Accept connection and assign in a round robin fashion to readers
SocketChannel incomingCon = listenerSocket.accept();
int chosenReader = (readerIdx++)%totalReaders;
readers[chosenReader].newConnection(incomingCon);
readers[chosenReader].wakeUp();
}
if(! key.isValid()){
LOG.error("Acceptor key is disconnected !");
System.exit(0);
}
}
keyIt.remove();
}
catch(IOException e){
e.printStackTrace();
}
}
}
}
/**
* This class reads from a collection of incoming connections and writes to IBuffer that are the entrance to the system.
* @author ra
*
*/
class Reader implements Runnable {
private int id;
private boolean working;
private Queue<SocketChannel> pendingConnections;
private Selector readSelector;
Reader(int id, int totalNumberOfPendingConnectionsPerThread){
this.id = id;
this.working = true;
this.pendingConnections = new ArrayDeque<SocketChannel>(totalNumberOfPendingConnectionsPerThread);
try {
this.readSelector = Selector.open();
}
catch (IOException e) {
e.printStackTrace();
}
}
public int id(){
return id;
}
public void stop(){
this.working = false; // let thread die
}
public void newConnection(SocketChannel incomingChannel){
this.pendingConnections.add(incomingChannel);
LOG.info("New pending connection for Reader to configure");
}
public void wakeUp(){
this.readSelector.wakeup();
}
@Override
public void run() {
LOG.info("Started Reader worker: {}", Thread.currentThread().getName());
while(working) {
// First handle potential new connections that have been queued up
this.handleNewConnections();
try {
int readyChannels = readSelector.select();
if(readyChannels == 0){
continue;
}
Set<SelectionKey> selectedKeys = readSelector.selectedKeys();
Iterator<SelectionKey> keyIt = selectedKeys.iterator();
while(keyIt.hasNext()) {
SelectionKey key = keyIt.next();
keyIt.remove();
// read
if(key.isReadable()){
if(needsToConfigureConnection(key)) {
handleConnectionIdentifier(key);
}
else {
IBuffer ib = (IBuffer)key.attachment();
SocketChannel channel = (SocketChannel) key.channel();
int id = readerKeys.get(key);
ib.readFrom(channel);
}
}
if(! key.isValid()) {
String conn = ((SocketChannel)key.channel()).socket().getRemoteSocketAddress().toString();
LOG.warn("Invalid incoming data connection to: {}", conn);
}
}
}
catch(IOException ioe) {
ioe.printStackTrace();
}
}
this.closeReader();
}
private boolean needsToConfigureConnection(SelectionKey key) {
return !(key.attachment() instanceof IBuffer);
}
private boolean handleConnectionIdentifier(SelectionKey key) {
boolean moreConnectionsPending = true;
ByteBuffer dst = ByteBuffer.allocate(100);
try {
int readBytes = ((SocketChannel)key.channel()).read(dst);
if(readBytes != Type.INT.sizeOf(null)){
// TODO: throw some type of error
}
}
catch (IOException e) {
e.printStackTrace();
}
dst.flip();
int id = dst.getInt();
LOG.info("Received conn identifier: {}", id);
Map<Integer, IBuffer> ibMap = (Map<Integer, IBuffer>)key.attachment();
LOG.info("Configuring IBuffer for received conn identifier: {}", id);
IBuffer responsibleForThisChannel = ibMap.get(id);
if(responsibleForThisChannel == null){
// TODO: throw exception
LOG.error("Problem here, no existent IBuffer for id: {}", id);
System.exit(0);
}
// TODO: could we keep numUpstreamConnections internal to inputAdapter? probably not...
numUpstreamConnections--;
if(numUpstreamConnections == 0) {
moreConnectionsPending = false;
}
// Once we've identified the IBuffer responsible for this channel we attach the new object
key.attach(null);
key.attach(responsibleForThisChannel);
readerKeys.put(key, id);
return moreConnectionsPending;
}
private void handleNewConnections() {
SocketChannel incomingCon = null;
while((incomingCon = this.pendingConnections.poll()) != null) {
try{
incomingCon.configureBlocking(false);
incomingCon.socket().setTcpNoDelay(true);
// register new incoming connection in the thread-local selector
SelectionKey key = incomingCon.register(readSelector, SelectionKey.OP_READ);
// We attach the inputAdapterProvider Map, so that we can identify the channel once it starts
key.attach(ibMap);
LOG.info("Configured new incoming connection at: {}", incomingCon.toString());
}
catch(SocketException se) {
se.printStackTrace();
}
catch(IOException ioe) {
ioe.printStackTrace();
}
}
}
private void closeReader() {
// FIXME: test this
try {
// close channel and cancel registration
for(SelectionKey sk : readSelector.keys()) {
sk.channel().close();
sk.cancel();
}
// close pendingConnections
for(SocketChannel sc : pendingConnections) {
sc.close();
}
// close selector
readSelector.close();
}
catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
/**
* This class manages a collection of OBuffer that it drains to write to channels that are output connections.
* @author ra
*
*/
class Writer implements Runnable {
private int id;
private boolean working;
private Queue<OutgoingConnectionRequest> pendingConnections;
// buffer id - outputbuffer
private Map<Integer, OBuffer> outputBufferMap;
private Map<Integer, Boolean> needsConfigureOutputConnection;
private Selector writeSelector;
Writer(int id) {
this.id = id;
this.working = true;
this.outputBufferMap = new HashMap<>();
this.needsConfigureOutputConnection = new HashMap<>();
this.pendingConnections = new ArrayDeque<OutgoingConnectionRequest>();
try {
this.writeSelector = Selector.open();
}
catch (IOException e) {
e.printStackTrace();
}
}
public int id(){
return id;
}
public void stop(){
this.working = false;
}
public void newConnection(OutgoingConnectionRequest ocr) {
LOG.trace("Writer: {} has a pending connection to: {}", id, ocr.connection);
this.pendingConnections.add(ocr);
}
@Override
public void run(){
LOG.info("Started Writer worker: {}", Thread.currentThread().getName());
while(working){
// First handle potential new connections that have been queued up
handleNewConnections();
pollBuffers();
try {
int readyChannels = writeSelector.select();
if(readyChannels == 0){
continue;
}
Set<SelectionKey> selectedKeys = writeSelector.selectedKeys();
Iterator<SelectionKey> keyIt = selectedKeys.iterator();
while(keyIt.hasNext()) {
SelectionKey key = keyIt.next();
keyIt.remove();
// connectable
if(key.isConnectable()) {
SocketChannel sc = (SocketChannel) key.channel();
if(sc.isConnectionPending()) {
LOG.info("Attempting to finish conn to: "+sc.toString());
sc.finishConnect();
}
int interest = SelectionKey.OP_WRITE;
key.interestOps(interest); // as soon as it connects it can write the init protocol
LOG.info("Finished establishing output connection to: {}", sc.toString());
}
// writable
if(key.isWritable()) {
OBuffer ob = (OBuffer)key.attachment();
SocketChannel channel = (SocketChannel)key.channel();
if(needsConfigureOutputConnection.get(ob.id())) {
handleSendIdentifier(ob.id(), channel);
unsetWritable(key);
needsConfigureOutputConnection.put(ob.id(), false);
// Notify of a new configured connection
writersConfiguredLatch.countDown();
LOG.trace("CountDown to configure all output conns: {}", writersConfiguredLatch.getCount());
}
else {
// write batch
boolean fullyWritten = ob.drainTo(channel);
if(fullyWritten) unsetWritable(key);
}
}
if(! key.isValid()){
String conn = ((SocketChannel)key.channel()).socket().getRemoteSocketAddress().toString();
LOG.warn("Invalid outgoing data connection to: {}", conn);
}
}
}
catch(IOException ioe){
ioe.printStackTrace();
}
}
this.closeWriter();
}
private void pollBuffers(){
for(OBuffer ob : outputBufferMap.values()){
if(ob.readyToWrite()){
SelectionKey key = writerKeys.get(ob.id());
int interestOps = key.interestOps() | SelectionKey.OP_WRITE;
key.interestOps(interestOps);
}
}
}
private void handleSendIdentifier(int oBufferId, SocketChannel channel){
ByteBuffer bb = ByteBuffer.allocate(Integer.SIZE);
Type.INT.write(bb, oBufferId);
bb.flip();
try {
int writtenBytes = channel.write(bb);
if(writtenBytes != Type.INT.sizeOf(null)){
// TODO: throw some type of error
}
}
catch (IOException e) {
e.printStackTrace();
}
LOG.info("Sent connection identifier: {}", oBufferId);
}
private void unsetWritable(SelectionKey key){
final int newOps = key.interestOps() & ~SelectionKey.OP_WRITE;
key.interestOps(newOps);
}
private void handleNewConnections() {
try {
OutgoingConnectionRequest ocr = null;
while((ocr = this.pendingConnections.poll()) != null) {
OBuffer ob = ocr.oBuffer;
Connection c = ocr.connection;
SocketChannel channel = SocketChannel.open();
InetSocketAddress address = c.getInetSocketAddress(SeepEndPointType.DATA);
Socket socket = channel.socket();
socket.setKeepAlive(true); // Unlikely in non-production scenarios we'll be up for more than 2 hours but...
socket.setTcpNoDelay(true); // Disabling Nagle's algorithm
try {
channel.configureBlocking(false);
channel.connect(address);
}
catch (UnresolvedAddressException uae) {
channel.close();
uae.printStackTrace();
}
catch (IOException io) {
channel.close();
io.printStackTrace();
}
channel.configureBlocking(false);
int interestSet = SelectionKey.OP_CONNECT;
SelectionKey key = channel.register(writeSelector, interestSet);
key.attach(ob);
outputBufferMap.put(ob.id(), ob);
needsConfigureOutputConnection.put(ob.id(), true);
LOG.info("Configured new output connection with OP: {} at {}", ob.id(), address.toString());
// Associate id - key in the networkSelectorMap
writerKeys.put(ob.id(), key);
}
}
catch(IOException io){
io.printStackTrace();
}
}
private void closeWriter(){
// FIXME: test this
try {
for(SelectionKey sk : writeSelector.keys()){
sk.channel().close();
sk.cancel();
}
writeSelector.close();
}
catch (IOException io){
// TODO: proper handling
io.printStackTrace();
}
}
}
}