/* Copyright (C) 2011 monte This file is part of PSP NetParty. PSP NetParty is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>. */ package pspnetparty.lib.socket; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; import java.nio.channels.ClosedSelectorException; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.SocketChannel; import java.util.Iterator; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import pspnetparty.lib.ILogger; import pspnetparty.lib.Utility; public class AsyncTcpClient implements IClient { private final int initialReadBufferSize; private final int maxPacketSize; private ILogger logger; private Selector selector; private ConcurrentLinkedQueue<Connection> newConnectionQueue = new ConcurrentLinkedQueue<Connection>(); private ConcurrentHashMap<Connection, Object> establishedConnections; public AsyncTcpClient(ILogger logger, int maxPacketSize, final int selectTimeout) { this.logger = logger; this.maxPacketSize = maxPacketSize; this.initialReadBufferSize = Math.min(maxPacketSize, 2000); establishedConnections = new ConcurrentHashMap<AsyncTcpClient.Connection, Object>(16, 0.75f, 3); try { selector = Selector.open(); } catch (IOException e) { throw new RuntimeException(e); } Thread selectorThread = new Thread(new Runnable() { @Override public void run() { selectorLoop(selectTimeout); } }, getClass().getName() + " Selector"); selectorThread.setDaemon(true); selectorThread.start(); Thread keepAliveThread = new Thread(new Runnable() { @Override public void run() { try { keepAliveLoop(); } catch (InterruptedException e) { AsyncTcpClient.this.logger.log(Utility.stackTraceToString(e)); } } }, getClass().getName() + " KeepAlive"); keepAliveThread.setDaemon(true); keepAliveThread.start(); } private void selectorLoop(int timeout) { try { while (selector.isOpen()) { if (selector.select(timeout) > 0) { for (Iterator<SelectionKey> it = selector.selectedKeys().iterator(); it.hasNext();) { SelectionKey key = it.next(); it.remove(); Connection conn = null; try { conn = (Connection) key.attachment(); if (key.isReadable()) { if (conn.doRead()) { } else if (conn.sendBufferQueue.isEmpty() || !conn.channel.isOpen()) { conn.disconnect(); key.cancel(); } else { conn.toBeClosed = true; } } else if (key.isWritable()) { SendBufferQueue<Connection>.Allotment allot = conn.sendBufferQueue.poll(); if (allot == null) { if (conn.toBeClosed) { conn.disconnect(); key.cancel(); } else { key.interestOps(SelectionKey.OP_READ); } } else { conn.channel.write(allot.getBuffer()); } } } catch (CancelledKeyException e) { } catch (IOException e) { if (conn != null) conn.disconnect(); key.cancel(); } catch (RuntimeException e) { if (conn != null) conn.disconnect(); key.cancel(); } } } Connection conn; while ((conn = newConnectionQueue.poll()) != null) conn.selectionKey = conn.channel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE, conn); } } catch (ClosedSelectorException e) { } catch (IOException e) { } } private void keepAliveLoop() throws InterruptedException { ByteBuffer keepAliveBuffer = ByteBuffer.allocate(IProtocol.HEADER_BYTE_SIZE); keepAliveBuffer.putInt(0); while (selector.isOpen()) { long deadline = System.currentTimeMillis() - IProtocol.KEEPALIVE_DEADLINE; for (Connection conn : establishedConnections.keySet()) { try { if (conn.lastKeepAliveReceived < deadline) { logger.log(Utility.makeKeepAliveDisconnectLog("TCP", conn.remoteAddress, deadline, conn.lastKeepAliveReceived)); conn.disconnect(); } else { keepAliveBuffer.position(0); conn.addToSendQueue(keepAliveBuffer, false); } } catch (RuntimeException e) { logger.log(Utility.stackTraceToString(e)); } catch (Exception e) { logger.log(Utility.stackTraceToString(e)); } } Thread.sleep(IProtocol.KEEPALIVE_INTERVAL); } } @Override public void connect(InetSocketAddress address, int timeout, IProtocol protocol) throws IOException { if (address == null || protocol == null) throw new IllegalArgumentException(); try { SocketChannel channel = SocketChannel.open(); channel.socket().connect(address, timeout); channel.configureBlocking(false); Connection conn = new Connection(channel); ByteBuffer buf = Utility.encode(protocol.getProtocol() + IProtocol.SEPARATOR + IProtocol.NUMBER); conn.send(buf); conn.driver = protocol.createDriver(conn); if (conn.driver == null) { channel.close(); return; } establishedConnections.put(conn, conn); newConnectionQueue.add(conn); selector.wakeup(); } catch (RuntimeException e) { protocol.log(Utility.stackTraceToString(e)); throw new IOException(e); } } @Override public void dispose() { if (!selector.isOpen()) return; try { selector.close(); } catch (IOException e) { e.printStackTrace(); } for (Connection conn : establishedConnections.keySet()) { conn.disconnect(); } } @Override protected void finalize() throws Throwable { super.finalize(); dispose(); } private class Connection implements ISocketConnection { private SocketChannel channel; private SelectionKey selectionKey; private InetSocketAddress remoteAddress; private IProtocolDriver driver; private long lastKeepAliveReceived; private boolean toBeClosed = false; private ByteBuffer headerReadBuffer = ByteBuffer.allocate(IProtocol.HEADER_BYTE_SIZE); private ByteBuffer dataReadBuffer = ByteBuffer.allocateDirect(initialReadBufferSize); private PacketData packetData = new PacketData(dataReadBuffer); private boolean protocolMatched = false; private SendBufferQueue<Connection> sendBufferQueue = new SendBufferQueue<AsyncTcpClient.Connection>(20000); Connection(SocketChannel channel) { this.channel = channel; this.remoteAddress = (InetSocketAddress) channel.socket().getRemoteSocketAddress(); lastKeepAliveReceived = System.currentTimeMillis(); } private boolean doRead() throws IOException { if (toBeClosed) { int readBytes = channel.read(dataReadBuffer); return readBytes != -1; } if (headerReadBuffer.remaining() != 0) { if (channel.read(headerReadBuffer) < 0) return false; if (headerReadBuffer.remaining() != 0) return true; int dataSize = headerReadBuffer.getInt(0); if (dataSize == 0) { lastKeepAliveReceived = System.currentTimeMillis(); headerReadBuffer.position(0); return true; } if (dataSize < 1 || dataSize > maxPacketSize) { /* Invalid data size */ // readHeaderBuffer.position(0); // System.out.println(Utility.decode(readHeaderBuffer)); return false; } if (dataSize > dataReadBuffer.capacity()) { dataReadBuffer = ByteBuffer.allocateDirect(dataSize); packetData.replaceBuffer(dataReadBuffer); } else { dataReadBuffer.limit(dataSize); } } if (channel.read(dataReadBuffer) < 0) return false; if (dataReadBuffer.remaining() != 0) return true; dataReadBuffer.position(0); if (protocolMatched) { if (!driver.process(packetData)) return false; } else { String message = packetData.getMessage(); if (IProtocol.PROTOCOL_OK.equals(message)) { protocolMatched = true; } else if (IProtocol.PROTOCOL_NG.equals(message)) { return false; } else { driver.errorProtocolNumber(message); return false; } } headerReadBuffer.position(0); dataReadBuffer.clear(); return true; } @Override protected void finalize() throws Throwable { super.finalize(); disconnect(); } @Override public boolean isConnected() { return channel != null && channel.isConnected(); } @Override public void disconnect() { if (establishedConnections.remove(this) == null) return; try { if (driver != null) { driver.connectionDisconnected(); driver = null; } } catch (RuntimeException re) { } selectionKey = null; try { if (channel.isOpen()) channel.close(); } catch (IOException e) { } } @Override public InetSocketAddress getRemoteAddress() { return remoteAddress; } @Override public InetSocketAddress getLocalAddress() { return (InetSocketAddress) channel.socket().getLocalSocketAddress(); } @Override public void send(ByteBuffer buffer) { if (!isConnected()) return; addToSendQueue(buffer, true); } private void addToSendQueue(ByteBuffer buffer, boolean prependSizeHeader) { sendBufferQueue.queue(buffer, prependSizeHeader, this); try { if (selectionKey != null) { selector.wakeup(); selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE); } } catch (CancelledKeyException e) { } } } public static void main(String[] args) throws Exception { final AsyncTcpClient client = new AsyncTcpClient(new ILogger() { @Override public void log(String message) { } }, 100000, 0); InetSocketAddress address = new InetSocketAddress("localhost", 30000); client.connect(address, 1000, new IProtocol() { @Override public void log(String message) { System.out.println(message); } @Override public String getProtocol() { return "TEST"; } @Override public IProtocolDriver createDriver(final ISocketConnection connection) { System.out.println("接続しました: " + connection.getRemoteAddress()); Thread sendThread = new Thread(new Runnable() { private String makeLongString(char c, int length) { StringBuilder sb = new StringBuilder(); for (int i = 0; i < length; i++) { sb.append(c); } return sb.toString(); } @Override public void run() { try { Thread.sleep(500); if (connection.isConnected()) { String text = "S" + makeLongString('T', 39998) + "E"; System.out.println("length: " + text.length()); for (int i = 0; i < 3; i++) { if (!connection.isConnected()) break; connection.send(Utility.encode(text)); Thread.sleep(1000); } Thread.sleep(100); } } catch (InterruptedException e) { } connection.disconnect(); client.dispose(); } }); sendThread.start(); return new IProtocolDriver() { @Override public ISocketConnection getConnection() { return connection; } @Override public boolean process(PacketData data) { String msg = data.getMessage(); System.out.println("受信(" + msg.length() + ")"); return true; } @Override public void connectionDisconnected() { System.out.println("切断しました"); } @Override public void errorProtocolNumber(String number) { System.out.println("プロトコルエラー: " + number); } }; } }); } }