/* Copyright (C) 2011 monte This file is part of PSP NetParty. PSP NetParty is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>. */ package pspnetparty.lib.socket; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; import java.nio.channels.ClosedSelectorException; import java.nio.channels.DatagramChannel; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.util.HashMap; import java.util.Iterator; import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; import pspnetparty.lib.Utility; import pspnetparty.lib.constants.AppConstants; public class AsyncUdpServer implements IServer { private static final int READ_BUFFER_SIZE = 20000; private Selector selector; private ConcurrentHashMap<IServerListener, Object> serverListeners; private DatagramChannel serverChannel; private ByteBuffer readBuffer = ByteBuffer.allocateDirect(READ_BUFFER_SIZE); private PacketData data = new PacketData(readBuffer); private SelectionKey selectionKey; private SendBufferQueue<InetSocketAddress> sendBufferQueue = new SendBufferQueue<InetSocketAddress>(100000); private HashMap<String, IProtocol> protocolHandlers = new HashMap<String, IProtocol>(); private ConcurrentHashMap<InetSocketAddress, Connection> establishedConnections; private ConcurrentHashMap<Connection, Connection> toBeClosedConnections; private ByteBuffer bufferProtocolOK = AppConstants.CHARSET.encode(IProtocol.PROTOCOL_OK); private ByteBuffer bufferProtocolNG = AppConstants.CHARSET.encode(IProtocol.PROTOCOL_NG); private ByteBuffer bufferProtocolNumber = AppConstants.CHARSET.encode(IProtocol.NUMBER); private ByteBuffer terminateBuffer = ByteBuffer.wrap(new byte[] { 0 }); private Thread selectorThread; private Thread keepAliveThread; public AsyncUdpServer() { serverListeners = new ConcurrentHashMap<IServerListener, Object>(); establishedConnections = new ConcurrentHashMap<InetSocketAddress, Connection>(30, 0.75f, 3); toBeClosedConnections = new ConcurrentHashMap<Connection, Connection>(); } @Override public void addServerListener(IServerListener listener) { serverListeners.put(listener, this); } @Override public void addProtocol(IProtocol handler) { protocolHandlers.put(handler.getProtocol(), handler); } @Override public boolean isListening() { return selector != null && selector.isOpen(); } private void log(String message) { for (IServerListener listener : serverListeners.keySet()) listener.log(message); } @Override public void startListening(InetSocketAddress bindAddress) throws IOException { if (isListening()) stopListening(); selector = Selector.open(); serverChannel = DatagramChannel.open(); serverChannel.configureBlocking(false); serverChannel.socket().bind(bindAddress); selectionKey = serverChannel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE); log("UDP: Listening on " + bindAddress); selectorThread = new Thread(new Runnable() { @Override public void run() { for (IServerListener listener : serverListeners.keySet()) listener.serverStartupFinished(); selectorLoop(); for (IServerListener listener : serverListeners.keySet()) { listener.log("UDP: Now shuting down..."); listener.serverShutdownFinished(); } keepAliveThread.interrupt(); } }); selectorThread.setName(getClass().getName() + " Selector"); selectorThread.setDaemon(true); selectorThread.start(); keepAliveThread = new Thread(new Runnable() { @Override public void run() { try { keepAliveLoop(); } catch (InterruptedException e) { } } }); keepAliveThread.setName(getClass().getName() + " KeepAlive"); keepAliveThread.setDaemon(true); keepAliveThread.start(); } private void selectorLoop() { try { while (serverChannel.isOpen()) while (selector.select() > 0) { for (Iterator<SelectionKey> it = selector.selectedKeys().iterator(); it.hasNext();) { SelectionKey key = it.next(); it.remove(); try { if (key.isReadable()) { readBuffer.clear(); InetSocketAddress remoteAddress = (InetSocketAddress) serverChannel.receive(readBuffer); if (remoteAddress == null) { continue; } readBuffer.flip(); Connection conn = establishedConnections.get(remoteAddress); if (conn == null) { String message = data.getMessage(); String[] tokens = message.split(IProtocol.SEPARATOR); if (tokens.length != 2) { bufferProtocolNG.position(0); addToSendQueue(bufferProtocolNG, remoteAddress); continue; } String protocol = tokens[0]; String number = tokens[1]; IProtocol handler = protocolHandlers.get(protocol); if (handler == null) { bufferProtocolNG.position(0); addToSendQueue(bufferProtocolNG, remoteAddress); continue; } conn = new Connection(remoteAddress); if (!number.equals(IProtocol.NUMBER)) { bufferProtocolNumber.position(0); conn.send(bufferProtocolNumber); continue; } bufferProtocolOK.position(0); conn.send(bufferProtocolOK); conn.driver = handler.createDriver(conn); if (conn.driver == null) { terminateBuffer.position(0); conn.send(terminateBuffer); continue; } establishedConnections.put(remoteAddress, conn); } else if (!toBeClosedConnections.containsKey(conn)) { conn.processData(); } } else if (key.isWritable()) { SendBufferQueue<InetSocketAddress>.Allotment allot = sendBufferQueue.poll(); if (allot == null) { for (Connection conn : toBeClosedConnections.keySet()) { conn.disconnect(); } toBeClosedConnections.clear(); key.interestOps(SelectionKey.OP_READ); } else { serverChannel.send(allot.getBuffer(), allot.getAttachment()); } } } catch (Exception e) { } } } } catch (CancelledKeyException e) { } catch (ClosedSelectorException e) { } catch (IOException e) { } catch (RuntimeException e) { for (IServerListener listener : serverListeners.keySet()) listener.log(Utility.stackTraceToString(e)); } } private void keepAliveLoop() throws InterruptedException { ByteBuffer keepAliveBuffer = ByteBuffer.wrap(new byte[] { 1 }); while (isListening()) { long deadline = System.currentTimeMillis() - IProtocol.KEEPALIVE_DEADLINE; for (Entry<InetSocketAddress, Connection> entry : establishedConnections.entrySet()) { try { Connection conn = entry.getValue(); if (conn.lastKeepAliveReceived < deadline) { log(Utility.makeKeepAliveDisconnectLog("UDP", conn.remoteAddress, deadline, conn.lastKeepAliveReceived)); conn.disconnect(); } else { keepAliveBuffer.clear(); conn.send(keepAliveBuffer); } } catch (RuntimeException e) { log(Utility.stackTraceToString(e)); } catch (Exception e) { log(Utility.stackTraceToString(e)); } } Thread.sleep(IProtocol.KEEPALIVE_INTERVAL); } } @Override public void stopListening() { if (!isListening()) return; try { selector.close(); } catch (IOException e) { } if (serverChannel != null && serverChannel.isOpen()) { try { serverChannel.close(); } catch (IOException e) { } } for (Entry<InetSocketAddress, Connection> entry : establishedConnections.entrySet()) { Connection conn = entry.getValue(); conn.disconnect(); } } private void addToSendQueue(ByteBuffer buffer, InetSocketAddress address) { sendBufferQueue.queue(buffer, false, address); try { if (selectionKey != null) { selector.wakeup(); selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE); } } catch (CancelledKeyException e) { } } private class Connection implements ISocketConnection { private InetSocketAddress remoteAddress; private IProtocolDriver driver; private long lastKeepAliveReceived; public Connection(InetSocketAddress remoteAddress) { this.remoteAddress = remoteAddress; lastKeepAliveReceived = System.currentTimeMillis(); } @Override public InetSocketAddress getRemoteAddress() { return remoteAddress; } @Override public InetSocketAddress getLocalAddress() { return (InetSocketAddress) serverChannel.socket().getLocalSocketAddress(); } @Override public boolean isConnected() { return serverChannel.isConnected(); } @Override public void disconnect() { if (establishedConnections.remove(remoteAddress) == null) return; ByteBuffer terminateBuffer = ByteBuffer.wrap(new byte[] { 0 }); send(terminateBuffer); if (driver != null) { driver.connectionDisconnected(); driver = null; } lastKeepAliveReceived = 0; } private void processData() { if (readBuffer.limit() == 1) { switch (readBuffer.get(0)) { case 0: if (sendBufferQueue.isEmpty()) disconnect(); else toBeClosedConnections.put(this, this); return; case 1: lastKeepAliveReceived = System.currentTimeMillis(); return; } } boolean sessionContinue = false; try { sessionContinue = driver.process(data); } catch (Exception e) { } if (!sessionContinue) { disconnect(); } } @Override public void send(ByteBuffer buffer) { addToSendQueue(buffer, remoteAddress); } } public static void main(String[] args) throws IOException { InetSocketAddress address = new InetSocketAddress(30000); AsyncUdpServer server = new AsyncUdpServer(); server.addServerListener(new IServerListener() { @Override public void log(String message) { System.out.println(message); } @Override public void serverStartupFinished() { } @Override public void serverShutdownFinished() { } }); server.addProtocol(new IProtocol() { @Override public void log(String message) { System.out.println(message); } @Override public String getProtocol() { return "TEST"; } @Override public IProtocolDriver createDriver(final ISocketConnection connection) { System.out.println(connection.getRemoteAddress() + " [接続されました]"); Thread pingThread = new Thread(new Runnable() { @Override public void run() { try { Thread.sleep(20000); System.out.println("Send PING"); connection.send(Utility.encode("PING")); } catch (InterruptedException e) { } } }); pingThread.setDaemon(true); // pingThread.start(); return new IProtocolDriver() { @Override public boolean process(PacketData data) { String remoteAddress = connection.getRemoteAddress().toString(); String message = data.getMessage(); System.out.println(remoteAddress + " >" + message); connection.send(Utility.encode(message)); return true; } @Override public ISocketConnection getConnection() { return connection; } @Override public void connectionDisconnected() { System.out.println(connection.getRemoteAddress() + " [切断されました]"); } @Override public void errorProtocolNumber(String number) { } }; } }); server.startListening(address); while (System.in.read() != '\n') { } server.stopListening(); } }