package com.googlecode.jsonrpc4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ServerSocketFactory;
import javax.net.ssl.SSLException;
import java.io.BufferedInputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* A multi-threaded streaming server that uses JSON-RPC over sockets.
*/
@SuppressWarnings({"unused", "WeakerAccess"})
public class StreamServer {
private static final Logger logger = LoggerFactory.getLogger(StreamServer.class);
private static final long SERVER_SOCKET_SO_TIMEOUT = 5000;
private final ThreadPoolExecutor executor;
private final ServerSocket serverSocket;
private final JsonRpcBasicServer jsonRpcServer;
private final AtomicBoolean isStarted = new AtomicBoolean(false);
private final AtomicBoolean keepRunning = new AtomicBoolean(false);
private final Set<Server> servers = new HashSet<>();
private int maxClientErrors = 5;
/**
* Creates a {@code StreamServer} with the given max number
* of threads. A {@link ServerSocket} is created using the
* default {@link ServerSocketFactory} that lists on the
* given {@code port} and {@link InetAddress}.
*
* @param jsonRpcServer the {@link JsonRpcBasicServer} that will handleRequest requests
* @param maxThreads the mac number of threads the server will spawn
* @param port the port to listen on
* @param backlog the {@link ServerSocket} backlog
* @param bindAddress the address to listen on
* @throws IOException on error
*/
private StreamServer(JsonRpcBasicServer jsonRpcServer, int maxThreads, int port, int backlog, InetAddress bindAddress) throws IOException {
this(jsonRpcServer, maxThreads, ServerSocketFactory.getDefault().createServerSocket(port, backlog, bindAddress));
}
/**
* Creates a {@code StreamServer} with the given max number
* of threads using the given {@link ServerSocket} to listen
* for client connections.
*
* @param jsonRpcServer the {@link JsonRpcBasicServer} that will handleRequest requests
* @param maxThreads the mac number of threads the server will spawn
* @param serverSocket the {@link ServerSocket} used for accepting client connections
*/
public StreamServer(JsonRpcBasicServer jsonRpcServer, int maxThreads, ServerSocket serverSocket) {
this.jsonRpcServer = jsonRpcServer;
this.serverSocket = serverSocket;
executor = new ThreadPoolExecutor(maxThreads + 1, maxThreads + 1, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>());
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
jsonRpcServer.setRethrowExceptions(false);
}
/**
* Returns the current servers.
*
* @return the servers
*/
public Set<Server> getServers() {
return Collections.unmodifiableSet(servers);
}
/**
* Starts the server.
*/
public void start() {
if (tryToStart()) {
throw new IllegalStateException("The StreamServer is already started");
}
logger.debug("StreamServer starting {}:{}", serverSocket.getInetAddress(), serverSocket.getLocalPort());
keepRunning.set(true);
executor.submit(new Server());
}
private boolean tryToStart() {
return !isStarted.compareAndSet(false, true);
}
/**
* Stops the server thread.
*
* @throws InterruptedException if a graceful shutdown didn't happen
*/
public void stop() throws InterruptedException {
if (!isStarted.get()) {
throw new IllegalStateException("The StreamServer is not started");
}
stopServer();
stopClients();
closeSocket();
try {
waitForServerToTerminate();
isStarted.set(false);
stopServer();
} catch (InterruptedException e) {
logger.error("InterruptedException while waiting for termination", e);
throw e;
}
}
private void stopServer() {
keepRunning.set(false);
}
private void stopClients() {
executor.shutdownNow();
}
private void closeSocket() {
try {
serverSocket.close();
} catch (IOException e) {
logger.debug("Failed to close socket", e);
}
}
private void waitForServerToTerminate() throws InterruptedException {
if (!executor.isTerminated()) {
executor.awaitTermination(2000 + SERVER_SOCKET_SO_TIMEOUT, TimeUnit.MILLISECONDS);
}
}
/**
* Closes something quietly.
*
* @param c closable
*/
private void closeQuietly(Closeable c) {
if (c != null) {
try {
c.close();
} catch (Throwable t) {
logger.warn("Error closing, ignoring", t);
}
}
}
/**
* @return the number of connected clients
*/
public int getNumberOfConnections() {
return servers.size();
}
/**
* @return the maxClientErrors
*/
public int getMaxClientErrors() {
return maxClientErrors;
}
/**
* @param maxClientErrors the maxClientErrors to set
*/
public void setMaxClientErrors(int maxClientErrors) {
this.maxClientErrors = maxClientErrors;
}
/**
* @return the isStarted
*/
public boolean isStarted() {
return isStarted.get();
}
/**
* Server thread.
*/
public class Server implements Runnable {
private int errors;
private Throwable lastException;
public int getNumberOfErrors() {
return errors;
}
public Throwable getLastException() {
return lastException;
}
/**
* {@inheritDoc}
*/
public void run() {
ServerSocket serverSocket = StreamServer.this.serverSocket;
Socket clientSocket = null;
while (StreamServer.this.keepRunning.get()) {
try {
serverSocket.setSoTimeout((int) SERVER_SOCKET_SO_TIMEOUT);
clientSocket = serverSocket.accept();
logger.debug("Client connected: {}:{}", clientSocket.getInetAddress().getHostAddress(), clientSocket.getPort());
// spawn a new Server for the next connection and break out of the server loop
executor.submit(new Server());
break;
} catch (SocketTimeoutException e) {
handleSocketTimeoutException(e);
} catch (SSLException sslException) {
logger.error("SSLException while listening for clients, terminating", sslException);
break;
} catch (IOException ioe) {
// this could be because the ServerSocket was closed
if (SocketException.class.isInstance(ioe) && !keepRunning.get()) {
break;
}
logger.error("Exception while listening for clients", ioe);
}
}
if (clientSocket != null) {
BufferedInputStream input;
OutputStream output;
try {
input = new BufferedInputStream(clientSocket.getInputStream());
output = clientSocket.getOutputStream();
} catch (IOException e) {
logger.error("Client socket failed", e);
return;
}
servers.add(this);
try {
while (StreamServer.this.keepRunning.get()) {
try {
jsonRpcServer.handleRequest(input, output);
} catch (Throwable t) {
if (StreamEndedException.class.isInstance(t)) {
logger.debug("Client disconnected: {}:{}", clientSocket.getInetAddress().getHostAddress(), clientSocket.getPort());
break;
}
errors++;
lastException = t;
if (errors < maxClientErrors) {
logger.error("Exception while handling request", t);
} else {
logger.error("Closing client connection due to repeated errors", t);
break;
}
}
}
} finally {
servers.remove(this);
closeQuietly(clientSocket);
closeQuietly(input);
closeQuietly(output);
}
}
}
private void handleSocketTimeoutException(SocketTimeoutException e) {
// this is expected because of so_timeout
}
}
}