package io.datakernel.net;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadFactory;
public final class BlockingSocketServer {
public interface AcceptHandler {
void onAccept(Socket socket) throws IOException;
}
private static final Logger logger = LoggerFactory.getLogger(BlockingSocketServer.class);
private ThreadFactory acceptThreadFactory;
private final ExecutorService executor;
private final AcceptHandler acceptHandler;
private ServerSocketSettings serverSocketSettings;
private SocketSettings socketSettings;
private final List<InetSocketAddress> listenAddresses = new ArrayList<>();
private final List<ServerSocket> serverSockets = new ArrayList<>();
private final Map<ServerSocket, Thread> acceptThreads = new HashMap<>();
private BlockingSocketServer(ExecutorService executor, AcceptHandler acceptHandler) {
this.executor = executor;
this.acceptHandler = acceptHandler;
}
public BlockingSocketServer withAcceptThreadFactory(ThreadFactory acceptThreadFactory) {
this.acceptThreadFactory = acceptThreadFactory;
return this;
}
public BlockingSocketServer withListenAddresses(List<InetSocketAddress> listenAddresses) {
this.listenAddresses.addAll(listenAddresses);
return this;
}
public BlockingSocketServer withListenAddresses(InetSocketAddress... listenAddresses) {
return withListenAddresses(Arrays.asList(listenAddresses));
}
public BlockingSocketServer withListenAddress(InetSocketAddress listenAddress) {
this.listenAddresses.add(listenAddress);
return this;
}
public BlockingSocketServer withListenPort(int port) {
return withListenAddress(new InetSocketAddress(port));
}
public BlockingSocketServer withServerSocketSettings(ServerSocketSettings socketSettings) {
this.serverSocketSettings = socketSettings;
return this;
}
public BlockingSocketServer withSocketSettings(SocketSettings socketSettings) {
this.socketSettings = socketSettings;
return this;
}
private void serveClient(final Socket socket) throws IOException {
socketSettings.applySettings(socket.getChannel());
executor.execute(new Runnable() {
@Override
public void run() {
try {
acceptHandler.onAccept(socket);
} catch (Exception e) {
logger.error("Failed to serve socket " + socket, e);
}
}
});
}
public void start() throws Exception {
for (final InetSocketAddress address : listenAddresses) {
final ServerSocket serverSocket = new ServerSocket(address.getPort(), serverSocketSettings.getBacklog(), address.getAddress());
serverSocketSettings.applySettings(serverSocket.getChannel());
serverSockets.add(serverSocket);
Runnable runnable = new Runnable() {
@Override
public void run() {
while (!Thread.interrupted()) {
try {
serveClient(serverSocket.accept());
} catch (Exception e) {
if (Thread.currentThread().isInterrupted())
break;
logger.error("Socket error for " + serverSocket, e);
}
}
}
};
Thread acceptThread = acceptThreadFactory == null ?
new Thread(runnable) :
acceptThreadFactory.newThread(runnable);
acceptThread.setDaemon(true);
acceptThreads.put(serverSocket, acceptThread);
acceptThread.start();
}
}
public void stop() throws Exception {
for (ServerSocket serverSocket : serverSockets) {
Thread acceptThread = acceptThreads.get(serverSocket);
acceptThread.interrupt();
serverSocket.close();
}
for (Thread acceptThread : acceptThreads.values()) {
acceptThread.join();
}
serverSockets.clear();
}
}