package org.threadly.litesockets; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; import java.nio.channels.ClosedChannelException; import java.nio.channels.SelectableChannel; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import org.threadly.concurrent.SubmitterScheduler; import org.threadly.concurrent.future.ListenableFuture; import org.threadly.concurrent.future.WatchdogCache; import org.threadly.litesockets.utils.SimpleByteStats; import org.threadly.util.AbstractService; import org.threadly.util.ArgumentVerifier; import org.threadly.util.ExceptionUtils; /** * This is a common base class for the Threaded and NoThread SocketExecuters. */ abstract class SocketExecuterCommonBase extends AbstractService implements SocketExecuter { protected static final int WATCHDOG_CLEANUP_TIME = 30000; protected final SubmitterScheduler acceptScheduler; protected final SubmitterScheduler readScheduler; protected final SubmitterScheduler writeScheduler; protected final SubmitterScheduler schedulerPool; protected final ConcurrentHashMap<SocketChannel, Client> clients = new ConcurrentHashMap<SocketChannel, Client>(); protected final ConcurrentHashMap<SelectableChannel, Server> servers = new ConcurrentHashMap<SelectableChannel, Server>(); protected final SocketExecuterByteStats stats = new SocketExecuterByteStats(); protected final WatchdogCache dogCache; protected Selector readSelector; protected Selector writeSelector; protected Selector acceptSelector; SocketExecuterCommonBase(final SubmitterScheduler scheduler) { this(scheduler,scheduler,scheduler,scheduler); } SocketExecuterCommonBase(final SubmitterScheduler acceptScheduler, final SubmitterScheduler readScheduler, final SubmitterScheduler writeScheduler, final SubmitterScheduler ssi) { ArgumentVerifier.assertNotNull(ssi, "ThreadScheduler"); ArgumentVerifier.assertNotNull(acceptScheduler, "Accept Scheduler"); ArgumentVerifier.assertNotNull(readScheduler, "Read Scheduler"); ArgumentVerifier.assertNotNull(writeScheduler, "Write Scheduler"); schedulerPool = ssi; dogCache = new WatchdogCache(ssi, true); this.acceptScheduler = acceptScheduler; this.readScheduler = readScheduler; this.writeScheduler = writeScheduler; } protected void checkRunning() { if(!isRunning()) { throw new IllegalStateException("SocketExecuter is not running!"); } } @Override public TCPClient createTCPClient(final String host, final int port) throws IOException { checkRunning(); TCPClient tc = new TCPClient(this, host, port); clients.put(((Client)tc).getChannel(), tc); return tc; } @Override public TCPClient createTCPClient(final SocketChannel sc) throws IOException { checkRunning(); final TCPClient tc = new TCPClient(this, sc); clients.put(((Client)tc).getChannel(), tc); this.setClientOperations(tc); return tc; } @Override public TCPServer createTCPServer(final String host, final int port) throws IOException { checkRunning(); TCPServer ts = new TCPServer(this, host, port); servers.put(ts.getSelectableChannel(), ts); return ts; } @Override public TCPServer createTCPServer(final ServerSocketChannel ssc) throws IOException { checkRunning(); TCPServer ts = new TCPServer(this, ssc); servers.put(ts.getSelectableChannel(), ts); return ts; } @Override public UDPServer createUDPServer(final String host, final int port) throws IOException { checkRunning(); UDPServer us = new UDPServer(this, host, port); servers.put(us.getSelectableChannel(), us); return us; } protected boolean checkServer(final Server server) { if(!isRunning() || server.isClosed() || server.getSocketExecuter() != this || !servers.containsKey(server.getSelectableChannel())) { servers.remove(server.getSelectableChannel()); return false; } return true; } @Override public void startListening(final Server server) { if(!checkServer(server)) { return; } else { if(server.getServerType() == WireProtocol.TCP) { acceptScheduler.execute(new AddToSelector(acceptScheduler, server, acceptSelector, SelectionKey.OP_ACCEPT)); acceptSelector.wakeup(); } else { throw new UnsupportedOperationException("Unknown Server WireProtocol!"+ server.getServerType()); } } } @Override public void stopListening(final Server server) { if(!checkServer(server)) { return; } else { if(server.getServerType() == WireProtocol.TCP) { acceptScheduler.execute(new AddToSelector(acceptScheduler, server, acceptSelector, 0)); acceptSelector.wakeup(); } else if(server.getServerType() == WireProtocol.UDP) { readScheduler.execute(new AddToSelector(readScheduler, server, readSelector, 0)); writeScheduler.execute(new AddToSelector(writeScheduler, server, writeSelector, 0)); readSelector.wakeup(); writeSelector.wakeup(); } else { throw new UnsupportedOperationException("Unknown Server WireProtocol!"+ server.getServerType()); } } } @Override public int getClientCount() { return clients.size(); } @Override public int getServerCount() { return servers.size(); } @Override public SubmitterScheduler getThreadScheduler() { return schedulerPool; } @Override public SimpleByteStats getStats() { return stats; } protected SocketExecuterByteStats writeableStats() { return stats; } @Override public void watchFuture(final ListenableFuture<?> lf, final long delay) { dogCache.watch(lf, delay); } protected static Selector openSelector() { try { return Selector.open(); } catch (IOException e) { throw new StartupException(e); } } protected static void closeSelector(final SubmitterScheduler scheduler, final Selector selector) { scheduler.execute(new Runnable() { @Override public void run() { try { selector.close(); } catch (IOException e) { ExceptionUtils.handleException(e); } }}); selector.wakeup(); } protected static void doServerAccept(final Server server) { if(server != null) { try { final SocketChannel client = ((ServerSocketChannel)server.getSelectableChannel()).accept(); if(client != null) { client.configureBlocking(false); server.acceptChannel(client); } } catch (IOException e) { server.close(); ExceptionUtils.handleException(e); } } } protected static void doClientConnect(final Client client, final Selector selector) { if(client == null) { return; } try { if(client.getChannel().finishConnect()) { client.setConnectionStatus(null); } } catch(IOException e) { client.close(); client.setConnectionStatus(e); ExceptionUtils.handleException(e); } } protected static int doClientWrite(final Client client, final Selector selector) { int wrote = 0; if(client != null) { try { wrote = client.getChannel().write(client.getWriteBuffer()); if(wrote > 0) { client.reduceWrite(wrote); } final SelectionKey sk = client.getChannel().keyFor(selector); if(! client.canWrite() && (sk.interestOps() & SelectionKey.OP_WRITE) == SelectionKey.OP_WRITE) { client.getChannel().register(selector, sk.interestOps() - SelectionKey.OP_WRITE); } else { client.getChannel().register(selector, sk.interestOps()); } } catch(Exception e) { client.close(); ExceptionUtils.handleException(e); } } return wrote; } private static int doRead(ByteBuffer bb, SocketChannel sc) throws IOException { return sc.read(bb); } protected static int doClientRead(final Client client, final Selector selector) { int read = 0; if(client != null) { try { final ByteBuffer readByteBuffer = client.provideReadByteBuffer(); final int origPos = readByteBuffer.position(); read = doRead(readByteBuffer, client.getChannel()); if(read < 0) { client.close(); } else if( read > 0){ readByteBuffer.position(origPos); final ByteBuffer resultBuffer = readByteBuffer.slice(); readByteBuffer.position(origPos+read); resultBuffer.limit(read); client.addReadBuffer(resultBuffer); final SelectionKey sk = client.getChannel().keyFor(selector); if(! client.canRead() && (sk.interestOps() & SelectionKey.OP_READ) == SelectionKey.OP_READ) { client.getChannel().register(selector, sk.interestOps() - SelectionKey.OP_READ); } else { client.getChannel().register(selector, sk.interestOps()); } } } catch(Exception e) { client.close(); ExceptionUtils.handleException(e); } } if(read >= 0) { return read; } else { return 0; } } /** * This class is a helper runnable to generically remove SelectableChannels/SelectionKeys from a selector. * */ protected static class RemoveFromSelector implements Runnable { private final Selector selector; private final Client client; public RemoveFromSelector(Selector selector, Client client) { this.client = client; this.selector = selector; } @Override public void run() { SelectionKey sk = client.getChannel().keyFor(selector); if(sk != null) { sk.cancel(); } } } /** * This class is a helper runnable to generically add SelectableChannels to a selector for certain operations. * */ protected static class AddToSelector implements Runnable { final Client localClient; final Server localServer; final Selector localSelector; final int registerType; final Executor exec; public AddToSelector(final Executor exec, final Client client, final Selector selector, final int registerType) { this.exec = exec; localClient = client; localServer = null; localSelector = selector; this.registerType = registerType; } public AddToSelector(final Executor exec, final Server server, final Selector selector, final int registerType) { this.exec = exec; localClient = null; localServer = server; localSelector = selector; this.registerType = registerType; } private void runClient() { if(!localClient.isClosed()) { try { localSelector.wakeup(); localClient.getChannel().register(localSelector, registerType); } catch (CancelledKeyException e) { exec.execute(this); } catch (ClosedChannelException e) { localClient.close(); } } } private void runServer() { if(!localServer.isClosed()) { try { localServer.getSelectableChannel().register(localSelector, registerType); } catch (ClosedChannelException e) { localServer.close(); } } } @Override public void run() { if(localSelector.isOpen()) { if(localClient == null && localServer != null) { runServer(); } else if (localClient != null) { runClient(); } localSelector.wakeup(); } } } /** * Implementation of the SimpleByteStats. */ protected static class SocketExecuterByteStats extends SimpleByteStats { @Override protected void addWrite(final int size) { super.addWrite(size); } @Override protected void addRead(final int size) { super.addRead(size); } } }