package com.github.kmkt.util; import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousCloseException; import java.nio.channels.AsynchronousServerSocketChannel; import java.nio.channels.AsynchronousSocketChannel; import java.nio.channels.CompletionHandler; import java.nio.channels.NetworkChannel; import java.nio.charset.Charset; import java.util.Collections; import java.util.HashSet; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Socket から UTF-8 の文字列を行単位で読み出し、リスナを callback する * * License : MIT License */ public class UTF8StringReceiver { private static final Logger logger = LoggerFactory.getLogger(UTF8StringReceiver.class); private final ExecutorService execPool; private final boolean ownExecPool; private final InetSocketAddress listenEndpoint; private volatile ListenCompletionListener listenCallback; public UTF8StringReceiver(InetSocketAddress listen) { if (listen == null) throw new IllegalArgumentException("listen should not be null"); ownExecPool = true; execPool = Executors.newCachedThreadPool(Executors.defaultThreadFactory()); listenEndpoint = listen; } public UTF8StringReceiver(InetSocketAddress listen, ExecutorService pool) { if (listen == null) throw new IllegalArgumentException("listen should not be null"); if (pool != null) { ownExecPool = false; execPool = pool; } else { ownExecPool = true; execPool = Executors.newCachedThreadPool(Executors.defaultThreadFactory()); } listenEndpoint = listen; } /** * Socket accept 時の callback */ public interface ListenCompletionListener { /** * accept できた場合に callback される * accept された socket からのデータを受け取るための ReceiveListener を返す * null を返した場合は、受信されたデータは読み捨てされる * * @param remote リモートアドレス * @return */ ReceiveListener<?> accepted(SocketAddress remote); /** * accept に失敗した場合に callback される * @param e 失敗要因となった例外 */ void failed(Throwable e); } /** * Socket での UTF-8 文字列受信時に呼び出される callback * * @param <T> */ public static abstract class ReceiveListener<T> { private T attachment; /** * コンストラクタ * @param attachment callback時に付与される任意のオブジェクト */ public ReceiveListener(T attachment) { this.attachment = attachment; } /** * Socket での UTF-8 文字列受信時に呼び出される callback * * @param line 受信された文字列 * @param attachement コンストラクタで与えたオブジェクト */ public abstract void onReceive(String line, T attachement); /** * Socket close 時に呼び出される callback * @param attachement */ public abstract void onClose(T attachement); void onReceive(String line) { this.onReceive(line, this.attachment); } void onClose() { this.onClose(this.attachment); } } public void setCallback(ListenCompletionListener callback) { this.listenCallback = callback; } private Set<NetworkChannel> activeChannels = Collections.synchronizedSet(new HashSet<NetworkChannel>()); private AsynchronousServerSocketChannel assc = null; public void start() throws IOException { if (assc != null) return; assc = AsynchronousServerSocketChannel.open().bind(listenEndpoint); assc.accept(null, new CompletionHandler<AsynchronousSocketChannel, Void>() { @Override public void completed(final AsynchronousSocketChannel result, Void attachment) { assc.accept(null, this); try { SocketAddress remote = result.getRemoteAddress(); ListenCompletionListener local_listener = listenCallback; final ReceiveListener<?> listen; if (local_listener != null) { listen = local_listener.accepted(remote); if (listen == null) { logger.debug("Ignore and close connection from {}", remote); result.close(); return; } } else { listen = null; } execPool.execute(new Runnable() { @Override public void run() { ByteBuffer buffer = ByteBuffer.allocate(1024); activeChannels.add(result); try { StringBuilder buf = new StringBuilder(); while (result.isOpen()) { // 受信 if (result.read(buffer).get() < 0) { break; // EoS } buffer.flip(); // ByteBuffer から文字列に変換 buf.append(getStr(buffer)); // 行毎に分割 int st = 0; int ed = 0; while (0 < (ed = buf.indexOf("\r\n", st))) { String line = buf.substring(st, ed); if (listen != null) { listen.onReceive(line); } st = ed + 2; } ed = buf.lastIndexOf("\r\n"); if (0 < ed) { buf.delete(0, ed+2); } } } catch (InterruptedException e) { logger.debug("Exception occured when SocketChannel reading", e); } catch (ExecutionException e) { if (assc.isOpen()) { logger.error("Exception occured when SocketChannel reading", e); } } finally { try { result.close(); activeChannels.remove(result); } catch (IOException e) { logger.error("Exception occured when SocketChannel closing", e); } if (listen != null) { listen.onClose(); } } } }); } catch (IOException e) { this.failed(e, attachment); try { result.close(); } catch (IOException e1) { logger.error("Exception occured when SocketChannel closing", e1); } } } @Override public void failed(Throwable e, Void attachment) { // close時に出るAsynchronousCloseException は無視 if (!(e instanceof AsynchronousCloseException)) { listenCallback.failed(e); } } }); } public void stop() throws IOException { if (ownExecPool) { execPool.shutdown(); } if (assc != null) { assc.close(); } synchronized (activeChannels) { for (NetworkChannel soc : activeChannels) { try { soc.close(); } catch (IOException e) { logger.error("Error at processing a socket of " + soc.getLocalAddress(), e); } } activeChannels.clear(); } } // flip 済み ByteBuffer からUTF-8文字列を抽出 private String getStr(ByteBuffer buffer) { if (buffer.limit() == buffer.position()) return ""; String line = null; Charset charset = Charset.forName("UTF-8"); // 文字列端探索 if ((buffer.get(buffer.limit() - 1) & (byte) 0x80) == 0) { // 受信したバイト列の末端がASCII line = charset.decode(buffer).toString(); buffer.clear(); } else { // 受信したバイト列の末端が中途半端な可能性あり int pos = buffer.limit() - 1; // UTF-8 1 byte 目探索 byte b = 0; for (; 0 <= pos; pos--) { b = buffer.get(pos); if ((b & (byte) 0xc0) == (byte) 0xc0) { break; } } int following_len = 0; // 文字コード長 if ((byte) 0xf0 <= b && b <= (byte) 0xf7) { following_len = 3; } else if ((byte) 0xe0 <= b && b <= (byte) 0xef) { following_len = 2; } else if ((byte) 0xc2 <= b && b <= (byte) 0xdf) { following_len = 1; } if (pos + following_len == buffer.limit() - 1) { // 全部読める line = charset.decode(buffer).toString(); buffer.clear(); } else { // 端切れがある int parts_len = buffer.limit() - pos; // 端切れ分 byte[] parts = new byte[parts_len]; for (int i=0; pos < buffer.limit(); i++, pos++) { parts[i] = buffer.get(pos); } buffer.limit(buffer.limit() - parts_len); line = charset.decode(buffer).toString(); buffer.clear(); buffer.put(parts); } } return line; } }