/*
Copyright (C) 2011 monte
This file is part of PSP NetParty.
PSP NetParty is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pspnetparty.lib.socket;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedSelectorException;
import java.nio.channels.DatagramChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.util.Iterator;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import pspnetparty.lib.ILogger;
import pspnetparty.lib.Utility;
import pspnetparty.lib.constants.AppConstants;
public class AsyncUdpClient implements IClient {
private static final int BUFFER_SIZE = 20000;
private ILogger logger;
private Selector selector;
private ConcurrentLinkedQueue<Connection> newConnectionQueue = new ConcurrentLinkedQueue<Connection>();
private ConcurrentHashMap<Connection, Object> establishedConnections;
public AsyncUdpClient(ILogger logger) {
this.logger = logger;
establishedConnections = new ConcurrentHashMap<AsyncUdpClient.Connection, Object>(16, 0.75f, 2);
try {
selector = Selector.open();
} catch (IOException e) {
throw new RuntimeException(e);
}
Thread selectorThread = new Thread(new Runnable() {
@Override
public void run() {
selectorLoop();
}
}, getClass().getName() + " Selector");
selectorThread.setDaemon(true);
selectorThread.start();
Thread keepAliveThread = new Thread(new Runnable() {
@Override
public void run() {
try {
keepAliveLoop();
} catch (InterruptedException e) {
AsyncUdpClient.this.logger.log(Utility.stackTraceToString(e));
}
}
}, getClass().getName() + " KeepAlive");
keepAliveThread.setDaemon(true);
keepAliveThread.start();
}
private void selectorLoop() {
try {
while (selector.isOpen()) {
if (selector.select() > 0) {
for (Iterator<SelectionKey> it = selector.selectedKeys().iterator(); it.hasNext();) {
SelectionKey key = it.next();
it.remove();
Connection conn = null;
try {
conn = (Connection) key.attachment();
if (key.isReadable()) {
if (conn.doRead()) {
} else if (conn.sendBufferQueue.isEmpty() || conn.serverDisconnected) {
conn.disconnect();
key.cancel();
} else {
conn.toBeClosed = true;
}
} else if (key.isWritable()) {
SendBufferQueue<Connection>.Allotment allot = conn.sendBufferQueue.poll();
if (allot == null) {
if (conn.toBeClosed) {
conn.disconnect();
key.cancel();
} else {
key.interestOps(SelectionKey.OP_READ);
}
} else if (!conn.serverDisconnected) {
conn.channel.write(allot.getBuffer());
}
}
} catch (CancelledKeyException e) {
} catch (IOException e) {
if (conn != null)
conn.disconnect();
key.cancel();
} catch (RuntimeException e) {
if (conn != null)
conn.disconnect();
key.cancel();
}
}
}
Connection conn;
while ((conn = newConnectionQueue.poll()) != null)
conn.selectionKey = conn.channel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE, conn);
}
} catch (ClosedSelectorException e) {
} catch (IOException e) {
}
}
private void keepAliveLoop() throws InterruptedException {
ByteBuffer keepAliveBuffer = ByteBuffer.wrap(new byte[] { 1 });
while (selector.isOpen()) {
long deadline = System.currentTimeMillis() - IProtocol.KEEPALIVE_DEADLINE;
for (Connection conn : establishedConnections.keySet()) {
try {
if (conn.lastKeepAliveReceived < deadline) {
logger.log(Utility.makeKeepAliveDisconnectLog("UDP", conn.remoteAddress, deadline, conn.lastKeepAliveReceived));
conn.disconnect();
} else {
keepAliveBuffer.clear();
conn.send(keepAliveBuffer);
}
} catch (RuntimeException e) {
logger.log(Utility.stackTraceToString(e));
} catch (Exception e) {
logger.log(Utility.stackTraceToString(e));
}
}
Thread.sleep(IProtocol.KEEPALIVE_INTERVAL);
}
}
@Override
public void connect(InetSocketAddress address, int timeout, IProtocol protocol) throws IOException {
if (address == null || protocol == null)
throw new IllegalArgumentException();
try {
DatagramChannel channel = DatagramChannel.open();
channel.connect(address);
Connection conn = new Connection(channel);
ByteBuffer data = AppConstants.CHARSET.encode(protocol.getProtocol() + IProtocol.SEPARATOR + IProtocol.NUMBER);
channel.write(data);
channel.socket().setSoTimeout(timeout);
byte[] buffer = new byte[5];
DatagramPacket packet = new DatagramPacket(buffer, buffer.length);
channel.socket().receive(packet);
conn.driver = protocol.createDriver(conn);
if (conn.driver == null) {
channel.close();
return;
}
String message = Utility.decode(ByteBuffer.wrap(buffer, 0, packet.getLength()));
if (IProtocol.PROTOCOL_OK.equals(message)) {
establishedConnections.put(conn, conn);
channel.configureBlocking(false);
newConnectionQueue.offer(conn);
selector.wakeup();
} else if (IProtocol.PROTOCOL_NG.equals(message)) {
channel.close();
conn.driver.connectionDisconnected();
} else {
channel.close();
conn.driver.errorProtocolNumber(message);
conn.driver.connectionDisconnected();
}
} catch (RuntimeException e) {
protocol.log(Utility.stackTraceToString(e));
throw new IOException(e);
}
}
@Override
public void dispose() {
if (!selector.isOpen())
return;
try {
selector.close();
} catch (IOException e) {
e.printStackTrace();
}
for (Connection conn : establishedConnections.keySet()) {
conn.disconnect();
}
}
@Override
protected void finalize() throws Throwable {
super.finalize();
dispose();
}
private class Connection implements ISocketConnection {
private DatagramChannel channel;
private SelectionKey selectionKey;
private InetSocketAddress remoteAddress;
private IProtocolDriver driver;
private long lastKeepAliveReceived;
private boolean toBeClosed = false;
private boolean serverDisconnected = false;
private ByteBuffer readBuffer = ByteBuffer.allocateDirect(BUFFER_SIZE);
private PacketData packetData = new PacketData(readBuffer);
private SendBufferQueue<Connection> sendBufferQueue = new SendBufferQueue<Connection>(20000);
public Connection(DatagramChannel channel) {
this.channel = channel;
this.remoteAddress = (InetSocketAddress) channel.socket().getRemoteSocketAddress();
lastKeepAliveReceived = System.currentTimeMillis();
}
private boolean doRead() throws IOException {
readBuffer.clear();
channel.read(readBuffer);
if (toBeClosed || serverDisconnected)
return true;
readBuffer.flip();
if (readBuffer.limit() == 1) {
switch (readBuffer.get(0)) {
case 0:
serverDisconnected = true;
return false;
case 1:
lastKeepAliveReceived = System.currentTimeMillis();
return true;
}
}
return driver.process(packetData);
}
@Override
public void disconnect() {
if (establishedConnections.remove(this) == null)
return;
try {
if (driver != null) {
driver.connectionDisconnected();
driver = null;
}
} catch (RuntimeException re) {
}
try {
if (channel.isOpen()) {
ByteBuffer terminateBuffer = ByteBuffer.wrap(new byte[] { 0 });
channel.send(terminateBuffer, remoteAddress);
channel.close();
}
} catch (IOException e) {
}
}
@Override
public boolean isConnected() {
return channel != null && channel.isConnected();
}
@Override
public InetSocketAddress getRemoteAddress() {
return remoteAddress;
}
@Override
public InetSocketAddress getLocalAddress() {
return (InetSocketAddress) channel.socket().getLocalSocketAddress();
}
@Override
public void send(ByteBuffer buffer) {
if (!isConnected())
return;
addToSendQueue(buffer);
}
private void addToSendQueue(ByteBuffer buffer) {
sendBufferQueue.queue(buffer, false, this);
try {
if (selectionKey != null) {
selector.wakeup();
selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE);
}
} catch (CancelledKeyException e) {
}
}
}
public static void main(String[] args) throws IOException {
final AsyncUdpClient client = new AsyncUdpClient(new ILogger() {
@Override
public void log(String message) {
}
});
InetSocketAddress address = new InetSocketAddress("localhost", 30000);
client.connect(address, 5000, new IProtocol() {
@Override
public void log(String message) {
System.out.println(message);
}
@Override
public String getProtocol() {
return "TEST";
}
@Override
public IProtocolDriver createDriver(final ISocketConnection connection) {
System.out.println("接続しました: " + connection.getRemoteAddress());
Thread sendThread = new Thread(new Runnable() {
@Override
public void run() {
for (int i = 0; i < 3; i++)
try {
Thread.sleep(500);
connection.send(Utility.encode("TEST " + i));
Thread.sleep(500);
} catch (InterruptedException e) {
}
connection.disconnect();
client.dispose();
}
});
sendThread.start();
return new IProtocolDriver() {
@Override
public ISocketConnection getConnection() {
return null;
}
@Override
public boolean process(PacketData data) {
String msg = data.getMessage();
System.out.println("受信(" + msg.length() + "): " + msg);
return true;
}
@Override
public void connectionDisconnected() {
System.out.println("切断しました");
}
@Override
public void errorProtocolNumber(String number) {
System.out.println("プロトコルエラー: " + number);
}
};
}
});
}
}