package org.projectodd.sockjs;
import io.undertow.Undertow;
import io.undertow.server.handlers.PathHandler;
import io.undertow.servlet.api.DeploymentManager;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.projectodd.sockjs.servlet.SockJsServlet;
import java.lang.reflect.Field;
import java.net.URI;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
public class SockJsConnectionTest extends AbstractSockJsTest {
@Before
public void startServer() throws Exception {
server = new SockJsServer();
pathHandler = new PathHandler();
undertow = createUndertow(pathHandler, "localhost", 8081);
baseUrl = "http://localhost:8081";
undertow.start();
}
@After
public void stopServer() {
undertow.stop();
}
@Test
public void testDecoratedFromHttpRootContextRootMapping() throws Exception {
verifyHttp("/", "/*");
}
@Test
public void testDecoratedFromHttpNonRootContextRootMapping() throws Exception {
verifyHttp("/foo", "/*");
}
@Test
public void testDecoratedFromHttpNonRootContextNonRootMapping() throws Exception {
verifyHttp("/foo", "/bar/*");
}
@Test
public void testDecoratedFromWebsocketRootContextRootMapping() throws Exception {
verifyWs("/", "/*");
}
@Test
public void testDecoratedFromWebsocketNonRootContextRootMapping() throws Exception {
verifyWs("/foo", "/*");
}
@Test
public void testDecoratedFromWebsocketNonRootContextNonRootMapping() throws Exception {
verifyWs("/foo", "/bar/*");
}
@Test
public void testDecoratedFromRawWebsocketRootContextRootMapping() throws Exception {
verifyRawWs("/", "/*");
}
@Test
public void testDecoratedFromRawWebsocketNonRootContextRootMapping() throws Exception {
verifyRawWs("/foo", "/*");
}
@Test
public void testDecoratedFromRawWebsocketNonRootContextNonRootMapping() throws Exception {
verifyRawWs("/foo", "/bar/*");
}
private void verifyHttp(final String context, final String mapping) throws Exception {
verify(context, mapping, new VerifyBlock() {
CloseableHttpResponse response = null;
@Override
public void connect(String prefix, String sessionId) throws Exception {
CloseableHttpClient httpClient = HttpClients.createDefault();
HttpPost httpPost = new HttpPost(baseUrl + (context.equals("/") ? "" : context) + prefix + "/000/" + sessionId + "/xhr?foo=bar");
response = httpClient.execute(httpPost);
assertEquals(200, response.getStatusLine().getStatusCode());
String body = EntityUtils.toString(response.getEntity());
assertEquals("o\n", body);
}
@Override
public void verify(SockJsConnection connection, String prefix, String sessionId) {
verifyNonRaw(connection, context, prefix, sessionId, false);
}
@Override
public void disconnect() throws Exception {
response.close();
}
});
}
private void verifyWs(final String context, final String mapping) throws Exception {
verify(context, mapping, new VerifyBlock() {
WebSocketClient client = null;
@Override
public void connect(String prefix, String sessionId) throws Exception {
client = new WebSocketClient();
client.start();
String baseWsUrl = baseUrl.replace("http", "ws");
URI uri = new URI(baseWsUrl + (context.equals("/") ? "" : context) + prefix + "/000/" + sessionId + "/websocket?foo=bar");
ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setHeader("user-agent", "WebSocketClient");
client.connect(new WebSocketListener() {
@Override
public void onWebSocketBinary(byte[] payload, int offset, int len) {
}
@Override
public void onWebSocketClose(int statusCode, String reason) {
}
@Override
public void onWebSocketConnect(Session session) {
}
@Override
public void onWebSocketError(Throwable cause) {
}
@Override
public void onWebSocketText(String message) {
assertEquals("o", message);
}
}, uri, upgradeRequest);
}
@Override
public void verify(SockJsConnection connection, String prefix, String sessionId) {
verifyNonRaw(connection, context, prefix, sessionId, true);
}
@Override
public void disconnect() throws Exception {
client.stop();
// sleep prevents some xnio stack on undertow shutdown
Thread.sleep(25);
}
});
verifyNoSavedHeadersLeaked();
}
private void verifyRawWs(final String context, final String mapping) throws Exception {
verify(context, mapping, new VerifyBlock() {
WebSocketClient client = null;
@Override
public void connect(String prefix, String sessionId) throws Exception {
client = new WebSocketClient();
client.start();
String baseWsUrl = baseUrl.replace("http", "ws");
URI uri = new URI(baseWsUrl + (context.equals("/") ? "" : context) + prefix + "/websocket?foo=bar");
ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setHeader("user-agent", "WebSocketClient");
client.connect(new WebSocketListener() {
@Override
public void onWebSocketBinary(byte[] payload, int offset, int len) {
}
@Override
public void onWebSocketClose(int statusCode, String reason) {
}
@Override
public void onWebSocketConnect(Session session) {
}
@Override
public void onWebSocketError(Throwable cause) {
}
@Override
public void onWebSocketText(String message) {
}
}, uri, upgradeRequest);
}
@Override
public void verify(SockJsConnection connection, String prefix, String sessionId) {
assertNotNull(connection.id);
// TODO: Figure out if there's any way to get the remote IP and port
assertEquals(null, connection.remoteAddress);
assertEquals(0, connection.remotePort);
// TODO: Figure out if there's a way to get headers to raw websockets
// assertEquals("localhost:8081", connection.headers.get("host"));
// assertNotNull(connection.headers.get("user-agent"));
String baseUrl = (context.equals("/") ? "" : context) + prefix + "/websocket";
assertEquals(baseUrl + "?foo=bar", connection.url);
assertEquals("/websocket", connection.pathname);
assertEquals(context + prefix, connection.prefix);
assertEquals(Transport.READY_STATE.OPEN, connection.getReadyState());
}
@Override
public void disconnect() throws Exception {
client.stop();
// sleep prevents some xnio stack on undertow shutdown
Thread.sleep(25);
}
});
verifyNoSavedHeadersLeaked();
}
private void verifyNonRaw(SockJsConnection connection, String context, String prefix, String sessionId, boolean websocket) {
assertNotNull(connection.id);
String suffix = "";
if (websocket) {
suffix = "/websocket";
// TODO: Figure out if there's any way to get the remote IP and port
assertEquals(null, connection.remoteAddress);
assertEquals(0, connection.remotePort);
} else {
suffix = "/xhr";
assertEquals("127.0.0.1", connection.remoteAddress);
assertTrue(connection.remotePort > 0);
}
assertEquals("localhost:8081", connection.headers.get("host"));
assertNotNull(connection.headers.get("user-agent"));
String baseUrl = (context.equals("/") ? "" : context) + prefix + "/000/" + sessionId + suffix;
assertEquals(baseUrl + "?foo=bar", connection.url);
assertEquals("/000/" + sessionId + suffix, connection.pathname);
assertEquals(context + prefix, connection.prefix);
assertEquals(Transport.READY_STATE.OPEN, connection.getReadyState());
}
private void verify(final String context, final String mapping, final VerifyBlock verifyBlock) throws Exception {
final String sessionId = UUID.randomUUID().toString();
final String prefix = extractPrefixFromMapping(mapping);
final CountDownLatch connected = new CountDownLatch(1);
server.onConnection(new SockJsServer.OnConnectionHandler() {
@Override
public void handle(SockJsConnection connection) {
verifyBlock.verify(connection, prefix, sessionId);
connected.countDown();
}
});
DeploymentManager manager = createDeploymentManager(server, context, mapping);
manager.deploy();
pathHandler.addPrefixPath(context, manager.start());
verifyBlock.connect(prefix, sessionId);
assertTrue(connected.await(5, TimeUnit.SECONDS));
verifyBlock.disconnect();
manager.stop();
}
private void verifyNoSavedHeadersLeaked() throws Exception {
// Make sure we're not leaking saved headers, using reflection so we
// don't have to expose this field to anyone else
Field savedHeadersField = SockJsServlet.class.getDeclaredField("savedHeaders");
savedHeadersField.setAccessible(true);
Map savedHeaders = (Map) savedHeadersField.get(SockJsServlet.class);
assertEquals(0, savedHeaders.size());
}
private String extractPrefixFromMapping(String mapping) {
if (mapping.endsWith("*")) {
mapping = mapping.substring(0, mapping.length() - 1);
}
if (mapping.endsWith("/")) {
mapping = mapping.substring(0, mapping.length() - 1);
}
return mapping;
}
private SockJsServer server;
private PathHandler pathHandler;
private Undertow undertow;
private String baseUrl;
private static interface VerifyBlock {
void connect(String prefix, String sessionId) throws Exception;
void verify(SockJsConnection connection, String prefix, String sessionId);
void disconnect() throws Exception;
}
}