/*
Copyright (C) 2011 monte
This file is part of PSP NetParty.
PSP NetParty is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pspnetparty.lib.socket;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedSelectorException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.HashMap;
import java.util.Iterator;
import java.util.concurrent.ConcurrentHashMap;
import pspnetparty.lib.Utility;
import pspnetparty.lib.constants.AppConstants;
public class AsyncTcpServer implements IServer {
private static final int INITIAL_READ_BUFFER_SIZE = 2000;
private int maxPacketSize;
private Selector selector;
private ConcurrentHashMap<IServerListener, Object> serverListeners;
private ServerSocketChannel serverChannel;
private HashMap<String, IProtocol> protocolHandlers = new HashMap<String, IProtocol>();
private ConcurrentHashMap<Connection, Object> establishedConnections;
private final Object valueObject = new Object();
private ByteBuffer bufferProtocolOK = AppConstants.CHARSET.encode(IProtocol.PROTOCOL_OK);
private ByteBuffer bufferProtocolNG = AppConstants.CHARSET.encode(IProtocol.PROTOCOL_NG);
private ByteBuffer bufferProtocolNumber = AppConstants.CHARSET.encode(IProtocol.NUMBER);
private Thread selectorThread;
private Thread keepAliveThread;
public AsyncTcpServer(int maxPacketSize) {
this.maxPacketSize = maxPacketSize;
serverListeners = new ConcurrentHashMap<IServerListener, Object>();
establishedConnections = new ConcurrentHashMap<Connection, Object>(30, 0.75f, 3);
}
@Override
public void addServerListener(IServerListener listener) {
serverListeners.put(listener, this);
}
@Override
public void addProtocol(IProtocol handler) {
protocolHandlers.put(handler.getProtocol(), handler);
}
@Override
public boolean isListening() {
return selector != null && selector.isOpen();
}
private void log(String message) {
for (IServerListener listener : serverListeners.keySet())
listener.log(message);
}
@Override
public void startListening(InetSocketAddress bindAddress) throws IOException {
if (isListening())
stopListening();
selector = Selector.open();
serverChannel = ServerSocketChannel.open();
serverChannel.configureBlocking(false);
serverChannel.socket().bind(bindAddress);
serverChannel.register(selector, SelectionKey.OP_ACCEPT);
ServerSocket socket = serverChannel.socket();
log("TCP: Listening on " + socket.getLocalSocketAddress());
selectorThread = new Thread(new Runnable() {
@Override
public void run() {
for (IServerListener listener : serverListeners.keySet())
listener.serverStartupFinished();
selectorLoop();
for (IServerListener listener : serverListeners.keySet()) {
listener.log("TCP: Now shuting down...");
listener.serverShutdownFinished();
}
keepAliveThread.interrupt();
}
}, getClass().getName() + " Selector");
selectorThread.setDaemon(true);
selectorThread.start();
keepAliveThread = new Thread(new Runnable() {
@Override
public void run() {
try {
keepAliveLoop();
} catch (InterruptedException e) {
}
}
}, getClass().getName() + " KeepAlive");
keepAliveThread.setDaemon(true);
keepAliveThread.start();
}
private void selectorLoop() {
try {
while (serverChannel.isOpen())
while (selector.select(1000) > 0) {
for (Iterator<SelectionKey> it = selector.selectedKeys().iterator(); it.hasNext();) {
SelectionKey key = it.next();
it.remove();
Connection conn = null;
try {
if (key.isAcceptable()) {
ServerSocketChannel channel = (ServerSocketChannel) key.channel();
doAccept(channel);
} else if (key.isReadable()) {
conn = (Connection) key.attachment();
if (conn.doRead()) {
} else if (conn.sendBufferQueue.isEmpty() || !conn.channel.isOpen()) {
conn.disconnect();
key.cancel();
} else {
conn.toBeClosed = true;
}
} else if (key.isWritable()) {
conn = (Connection) key.attachment();
SendBufferQueue<Connection>.Allotment allot = conn.sendBufferQueue.poll();
if (allot == null) {
if (conn.toBeClosed) {
conn.disconnect();
key.cancel();
} else {
key.interestOps(SelectionKey.OP_READ);
}
} else {
conn.channel.write(allot.getBuffer());
}
}
} catch (CancelledKeyException e) {
} catch (IOException e) {
if (conn != null)
conn.disconnect();
key.cancel();
} catch (RuntimeException e) {
if (conn != null)
conn.disconnect();
key.cancel();
}
}
}
} catch (IOException e) {
} catch (ClosedSelectorException e) {
} catch (RuntimeException e) {
for (IServerListener listener : serverListeners.keySet())
listener.log(Utility.stackTraceToString(e));
}
}
private void keepAliveLoop() throws InterruptedException {
ByteBuffer keepAliveBuffer = ByteBuffer.allocate(IProtocol.HEADER_BYTE_SIZE);
keepAliveBuffer.putInt(0);
while (serverChannel.isOpen()) {
long deadline = System.currentTimeMillis() - IProtocol.KEEPALIVE_DEADLINE;
for (Connection conn : establishedConnections.keySet()) {
try {
if (conn.lastKeepAliveReceived < deadline) {
log(Utility.makeKeepAliveDisconnectLog("TCP", conn.getRemoteAddress(), deadline, conn.lastKeepAliveReceived));
conn.disconnect();
} else {
keepAliveBuffer.clear();
conn.addToSendQueue(keepAliveBuffer, false);
}
} catch (RuntimeException e) {
log(Utility.stackTraceToString(e));
} catch (Exception e) {
log(Utility.stackTraceToString(e));
}
}
Thread.sleep(IProtocol.KEEPALIVE_INTERVAL);
}
}
@Override
public void stopListening() {
if (!isListening())
return;
try {
selector.close();
} catch (IOException e) {
}
if (serverChannel != null && serverChannel.isOpen()) {
try {
serverChannel.close();
} catch (IOException e) {
}
}
for (Connection conn : establishedConnections.keySet()) {
conn.disconnect();
}
}
private void doAccept(ServerSocketChannel serverChannel) throws IOException {
SocketChannel channel = serverChannel.accept();
Connection conn = new Connection(channel);
channel.configureBlocking(false);
conn.selectionKey = channel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE, conn);
establishedConnections.put(conn, valueObject);
}
private class Connection implements ISocketConnection {
private SocketChannel channel;
private SelectionKey selectionKey;
private IProtocolDriver driver;
private long lastKeepAliveReceived;
private boolean toBeClosed = false;
private ByteBuffer headerReadBuffer = ByteBuffer.allocate(IProtocol.HEADER_BYTE_SIZE);
private ByteBuffer dataReadBuffer = ByteBuffer.allocateDirect(INITIAL_READ_BUFFER_SIZE);
private PacketData packetData = new PacketData(dataReadBuffer);
private SendBufferQueue<Connection> sendBufferQueue = new SendBufferQueue<Connection>(20000);
Connection(SocketChannel channel) {
this.channel = channel;
lastKeepAliveReceived = System.currentTimeMillis();
}
@Override
public InetSocketAddress getRemoteAddress() {
return (InetSocketAddress) channel.socket().getRemoteSocketAddress();
}
@Override
public InetSocketAddress getLocalAddress() {
return (InetSocketAddress) channel.socket().getLocalSocketAddress();
}
@Override
public boolean isConnected() {
return channel.isConnected();
}
@Override
public void disconnect() {
if (establishedConnections.remove(this) == null)
return;
try {
if (driver != null) {
driver.connectionDisconnected();
driver = null;
}
} catch (RuntimeException e) {
}
selectionKey = null;
try {
if (channel.isOpen())
channel.close();
} catch (IOException e) {
}
}
private boolean doRead() throws IOException {
if (toBeClosed) {
int readBytes = channel.read(dataReadBuffer);
return readBytes != -1;
}
if (headerReadBuffer.remaining() != 0) {
if (channel.read(headerReadBuffer) < 0)
return false;
if (headerReadBuffer.remaining() != 0)
return true;
int dataSize = headerReadBuffer.getInt(0);
if (dataSize == 0) {
lastKeepAliveReceived = System.currentTimeMillis();
headerReadBuffer.position(0);
return true;
}
// System.out.println("Data size=" + dataSize);
if (dataSize < 1 || dataSize > maxPacketSize) {
/* Invalid data size */
// headerBuffer.position(0);
// System.out.println(Utility.decode(headerBuffer));
return false;
}
if (dataSize > dataReadBuffer.capacity()) {
dataReadBuffer = ByteBuffer.allocateDirect(dataSize);
packetData.replaceBuffer(dataReadBuffer);
} else {
dataReadBuffer.limit(dataSize);
}
}
int readBytes = channel.read(dataReadBuffer);
if (readBytes < 0) {
/** Client has disconnected */
return false;
}
if (dataReadBuffer.remaining() != 0)
return true;
dataReadBuffer.position(0);
if (driver == null) {
String message = packetData.getMessage();
String[] tokens = message.split(IProtocol.SEPARATOR);
if (tokens.length != 2) {
bufferProtocolNG.position(0);
send(bufferProtocolNG);
return false;
}
String protocol = tokens[0];
String number = tokens[1];
IProtocol handler = protocolHandlers.get(protocol);
if (handler == null) {
bufferProtocolNG.position(0);
send(bufferProtocolNG);
return false;
}
if (!number.equals(IProtocol.NUMBER)) {
bufferProtocolNumber.position(0);
send(bufferProtocolNumber);
return false;
}
bufferProtocolOK.position(0);
send(bufferProtocolOK);
driver = handler.createDriver(this);
if (driver == null) {
return false;
}
} else if (driver.process(packetData)) {
} else {
return false;
}
headerReadBuffer.position(0);
dataReadBuffer.clear();
return true;
}
@Override
public void send(ByteBuffer buffer) {
if (!channel.isConnected())
return;
addToSendQueue(buffer, true);
}
private void addToSendQueue(ByteBuffer buffer, boolean prependSizeHeader) {
sendBufferQueue.queue(buffer, prependSizeHeader, this);
try {
if (selectionKey != null) {
selector.wakeup();
selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE);
}
} catch (CancelledKeyException e) {
}
}
}
public static void main(String[] args) throws IOException {
InetSocketAddress address = new InetSocketAddress(30000);
AsyncTcpServer server = new AsyncTcpServer(40000);
server.addServerListener(new IServerListener() {
@Override
public void log(String message) {
System.out.println(message);
}
@Override
public void serverStartupFinished() {
}
@Override
public void serverShutdownFinished() {
}
});
server.addProtocol(new IProtocol() {
@Override
public void log(String message) {
System.out.println(message);
}
@Override
public String getProtocol() {
return "TEST";
}
@Override
public IProtocolDriver createDriver(final ISocketConnection connection) {
System.out.println(connection.getRemoteAddress() + " [接続されました]");
return new IProtocolDriver() {
@Override
public ISocketConnection getConnection() {
return connection;
}
@Override
public boolean process(PacketData data) {
String remoteAddress = connection.getRemoteAddress().toString();
String message = data.getMessage();
System.out.println(remoteAddress + " (" + message.length() + ")");
connection.send(Utility.encode(message));
return true;
}
@Override
public void connectionDisconnected() {
System.out.println(connection.getRemoteAddress() + " [切断されました]");
}
@Override
public void errorProtocolNumber(String number) {
}
};
}
});
server.startListening(address);
while (System.in.read() != '\n') {
}
server.stopListening();
}
}