/******************************************************************************
* Copyright © 2013-2016 The Nxt Core Developers. *
* *
* See the AUTHORS.txt, DEVELOPER-AGREEMENT.txt and LICENSE.txt files at *
* the top-level directory of this distribution for the individual copyright *
* holder information and the developer policies on copyright and licensing. *
* *
* Unless otherwise agreed in a custom licensing agreement, no part of the *
* Nxt software, including this file, may be copied, modified, propagated, *
* or distributed except according to the terms contained in the LICENSE.txt *
* file. *
* *
* Removal or modification of this copyright notice is prohibited. *
* *
******************************************************************************/
package nxt.peer;
import nxt.util.Logger;
import nxt.util.QueuedThreadPool;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeException;
import org.eclipse.jetty.websocket.api.WebSocketException;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketClose;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketConnect;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ProtocolException;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.locks.ReentrantLock;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
/**
* PeerWebSocket represents an HTTP/HTTPS upgraded connection
*/
@WebSocket
public class PeerWebSocket {
/** Compressed message flag */
private static final int FLAG_COMPRESSED = 1;
/** Our WebSocket message version */
private static final int VERSION = 1;
/** Create the WebSocket client */
private static WebSocketClient peerClient;
static {
try {
peerClient = new WebSocketClient();
peerClient.getPolicy().setIdleTimeout(Peers.webSocketIdleTimeout);
peerClient.getPolicy().setMaxBinaryMessageSize(Peers.MAX_MESSAGE_SIZE);
peerClient.setConnectTimeout(Peers.connectTimeout);
peerClient.start();
} catch (Exception exc) {
Logger.logErrorMessage("Unable to start WebSocket client", exc);
peerClient = null;
}
}
/** Negotiated WebSocket message version */
private int version = VERSION;
/** Thread pool for server request processing */
private static final ExecutorService threadPool = new QueuedThreadPool(
Runtime.getRuntime().availableProcessors(),
Runtime.getRuntime().availableProcessors() * 4);
/** WebSocket session */
private volatile Session session;
/** WebSocket endpoint - set for an accepted connection */
private final PeerServlet peerServlet;
/** WebSocket lock */
private final ReentrantLock lock = new ReentrantLock();
/** Pending POST request map */
private final ConcurrentHashMap<Long, PostRequest> requestMap = new ConcurrentHashMap<>();
/** Next POST request identifier */
private long nextRequestId = 0;
/** WebSocket connection timestamp */
private long connectTime = 0;
/**
* Create a client socket
*/
public PeerWebSocket() {
peerServlet = null;
}
/**
* Create a server socket
*
* @param peerServlet Servlet for request processing
*/
public PeerWebSocket(PeerServlet peerServlet) {
this.peerServlet = peerServlet;
}
/**
* Start a client session
*
* @param uri Server URI
* @return TRUE if the WebSocket connection was completed
* @throws IOException I/O error occurred
*/
public boolean startClient(URI uri) throws IOException {
if (peerClient == null) {
return false;
}
String address = String.format("%s:%d", uri.getHost(), uri.getPort());
boolean useWebSocket = false;
//
// Create a WebSocket connection. We need to serialize the connection requests
// since the NRS server will issue multiple concurrent requests to the same peer.
// After a successful connection, the subsequent connection requests will return
// immediately. After an unsuccessful connection, a new connect attempt will not
// be done until 60 seconds have passed.
//
lock.lock();
try {
if (session != null) {
useWebSocket = true;
} else if (System.currentTimeMillis() > connectTime + 10 * 1000) {
connectTime = System.currentTimeMillis();
ClientUpgradeRequest req = new ClientUpgradeRequest();
Future<Session> conn = peerClient.connect(this, uri, req);
conn.get(Peers.connectTimeout + 100, TimeUnit.MILLISECONDS);
useWebSocket = true;
}
} catch (ExecutionException exc) {
if (exc.getCause() instanceof UpgradeException) {
// We will use HTTP
} else if (exc.getCause() instanceof IOException) {
// Report I/O exception
throw (IOException)exc.getCause();
} else {
// We will use HTTP
Logger.logDebugMessage(String.format("WebSocket connection to %s failed", address), exc);
}
} catch (TimeoutException exc) {
throw new SocketTimeoutException(String.format("WebSocket connection to %s timed out", address));
} catch (IllegalStateException exc) {
if (! peerClient.isStarted()) {
Logger.logDebugMessage("WebSocket client not started or shutting down");
throw exc;
}
Logger.logDebugMessage(String.format("WebSocket connection to %s failed", address), exc);
} catch (Exception exc) {
Logger.logDebugMessage(String.format("WebSocket connection to %s failed", address), exc);
} finally {
if (!useWebSocket) {
close();
}
lock.unlock();
}
return useWebSocket;
}
/**
* WebSocket connection complete
*
* @param session WebSocket session
*/
@OnWebSocketConnect
public void onConnect(Session session) {
this.session = session;
if ((Peers.communicationLoggingMask & Peers.LOGGING_MASK_200_RESPONSES) != 0) {
Logger.logDebugMessage(String.format("%s WebSocket connection with %s completed",
peerServlet != null ? "Inbound" : "Outbound",
session.getRemoteAddress().getHostString()));
}
}
/**
* Check if we have a WebSocket connection
*
* @return TRUE if we have a WebSocket connection
*/
public boolean isOpen() {
Session s;
return ((s=session) != null && s.isOpen());
}
/**
* Return the remote address for this connection
*
* @return Remote address or null if the connection is closed
*/
public InetSocketAddress getRemoteAddress() {
Session s;
return ((s=session) != null && s.isOpen() ? s.getRemoteAddress() : null);
}
/**
* Process a POST request by sending the request message and then
* waiting for a response. This method is used by the connection
* originator.
*
* @param request Request message
* @return Response message
* @throws IOException I/O error occurred
*/
public String doPost(String request) throws IOException {
long requestId;
//
// Send the POST request
//
lock.lock();
try {
if (session == null || !session.isOpen()) {
throw new IOException("WebSocket session is not open");
}
requestId = nextRequestId++;
byte[] requestBytes = request.getBytes("UTF-8");
int requestLength = requestBytes.length;
int flags = 0;
if (Peers.isGzipEnabled && requestLength >= Peers.MIN_COMPRESS_SIZE) {
flags |= FLAG_COMPRESSED;
ByteArrayOutputStream outStream = new ByteArrayOutputStream(requestLength);
try (GZIPOutputStream gzipStream = new GZIPOutputStream(outStream)) {
gzipStream.write(requestBytes);
}
requestBytes = outStream.toByteArray();
}
ByteBuffer buf = ByteBuffer.allocate(requestBytes.length + 20);
buf.putInt(version)
.putLong(requestId)
.putInt(flags)
.putInt(requestLength)
.put(requestBytes)
.flip();
if (buf.limit() > Peers.MAX_MESSAGE_SIZE) {
throw new ProtocolException("POST request length exceeds max message size");
}
session.getRemote().sendBytes(buf);
} catch (WebSocketException exc) {
throw new SocketException(exc.getMessage());
} finally {
lock.unlock();
}
//
// Get the response
//
String response;
try {
PostRequest postRequest = new PostRequest();
requestMap.put(requestId, postRequest);
response = postRequest.get(Peers.readTimeout, TimeUnit.MILLISECONDS);
} catch (InterruptedException exc) {
throw new SocketTimeoutException("WebSocket POST interrupted");
}
return response;
}
/**
* Send POST response
*
* This method is used by the connection acceptor to return the POST response
*
* @param requestId Request identifier
* @param response Response message
* @throws IOException I/O error occurred
*/
public void sendResponse(long requestId, String response) throws IOException {
lock.lock();
try {
if (session != null && session.isOpen()) {
byte[] responseBytes = response.getBytes("UTF-8");
int responseLength = responseBytes.length;
int flags = 0;
if (Peers.isGzipEnabled && responseLength >= Peers.MIN_COMPRESS_SIZE) {
flags |= FLAG_COMPRESSED;
ByteArrayOutputStream outStream = new ByteArrayOutputStream(responseLength);
try (GZIPOutputStream gzipStream = new GZIPOutputStream(outStream)) {
gzipStream.write(responseBytes);
}
responseBytes = outStream.toByteArray();
}
ByteBuffer buf = ByteBuffer.allocate(responseBytes.length + 20);
buf.putInt(version)
.putLong(requestId)
.putInt(flags)
.putInt(responseLength)
.put(responseBytes)
.flip();
if (buf.limit() > Peers.MAX_MESSAGE_SIZE) {
throw new ProtocolException("POST response length exceeds max message size");
}
session.getRemote().sendBytes(buf);
}
} catch (WebSocketException exc) {
throw new SocketException(exc.getMessage());
} finally {
lock.unlock();
}
}
/**
* Process a socket message
*
* @param inbuf Message buffer
* @param off Starting offset
* @param len Message length
*/
@OnWebSocketMessage
public void onMessage(byte[] inbuf, int off, int len) {
lock.lock();
try {
ByteBuffer buf = ByteBuffer.wrap(inbuf, off, len);
version = Math.min(buf.getInt(), VERSION);
Long requestId = buf.getLong();
int flags = buf.getInt();
int length = buf.getInt();
byte[] msgBytes = new byte[buf.remaining()];
buf.get(msgBytes);
if ((flags&FLAG_COMPRESSED) != 0) {
ByteArrayInputStream inStream = new ByteArrayInputStream(msgBytes);
try (GZIPInputStream gzipStream = new GZIPInputStream(inStream, 1024)) {
msgBytes = new byte[length];
int offset = 0;
while (offset < msgBytes.length) {
int count = gzipStream.read(msgBytes, offset, msgBytes.length - offset);
if (count < 0) {
throw new EOFException("End-of-data reading compressed data");
}
offset += count;
}
}
}
String message = new String(msgBytes, "UTF-8");
if (peerServlet != null) {
threadPool.execute(() -> peerServlet.doPost(this, requestId, message));
} else {
PostRequest postRequest = requestMap.remove(requestId);
if (postRequest != null) {
postRequest.complete(message);
}
}
} catch (Exception exc) {
Logger.logDebugMessage("Exception while processing WebSocket message", exc);
} finally {
lock.unlock();
}
}
/**
* WebSocket session has been closed
*
* @param statusCode Status code
* @param reason Reason message
*/
@OnWebSocketClose
public void onClose(int statusCode, String reason) {
lock.lock();
try {
if (session != null) {
if ((Peers.communicationLoggingMask & Peers.LOGGING_MASK_200_RESPONSES) != 0) {
Logger.logDebugMessage(String.format("%s WebSocket connection with %s closed",
peerServlet != null ? "Inbound" : "Outbound",
session.getRemoteAddress().getHostString()));
}
session = null;
}
SocketException exc = new SocketException("WebSocket connection closed");
Set<Map.Entry<Long, PostRequest>> requests = requestMap.entrySet();
requests.forEach((entry) -> entry.getValue().complete(exc));
requestMap.clear();
} finally {
lock.unlock();
}
}
/**
* Close the WebSocket
*/
public void close() {
lock.lock();
try {
if (session != null && session.isOpen()) {
session.close();
}
} catch (Exception exc) {
Logger.logDebugMessage("Exception while closing WebSocket", exc);
} finally {
lock.unlock();
}
}
/**
* POST request
*/
private class PostRequest {
/** Request latch */
private final CountDownLatch latch = new CountDownLatch(1);
/** Response message */
private volatile String response;
/** Socket exception */
private volatile IOException exception;
/**
* Create a post request
*/
public PostRequest() {
}
/**
* Wait for the response
*
* The caller must hold the lock for the request condition
*
* @param timeout Wait timeout
* @param unit Time unit
* @return Response message
* @throws InterruptedException Wait interrupted
* @throws IOException I/O error occurred
*/
public String get(long timeout, TimeUnit unit) throws InterruptedException, IOException {
if (!latch.await(timeout, unit)) {
throw new SocketTimeoutException("WebSocket read timeout exceeded");
}
if (exception != null) {
throw exception;
}
return response;
}
/**
* Complete the request with a response message
*
* The caller must hold the lock for the request condition
*
* @param response Response message
*/
public void complete(String response) {
this.response = response;
latch.countDown();
}
/**
* Complete the request with an exception
*
* The caller must hold the lock for the request condition
*
* @param exception I/O exception
*/
public void complete(IOException exception) {
this.exception = exception;
latch.countDown();
}
}
}