/** * JBoss, Home of Professional Open Source * Copyright Red Hat, Inc., and individual contributors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.jboss.aerogear.simplepush.server.netty; import static io.netty.handler.codec.http.HttpHeaders.Values.WEBSOCKET; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static io.netty.util.CharsetUtil.UTF_8; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.MatcherAssert.assertThat; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerInvoker; import io.netty.channel.ChannelOutboundBuffer; import io.netty.channel.ChannelPromise; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.*; import io.netty.handler.codec.http.HttpHeaders.Names; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.util.ReferenceCountUtil; import java.net.SocketAddress; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.Set; import java.util.UUID; import org.jboss.aerogear.simplepush.protocol.Ack; import org.jboss.aerogear.io.netty.handler.codec.sockjs.SockJsConfig; import org.jboss.aerogear.io.netty.handler.codec.sockjs.SockJsService; import org.jboss.aerogear.io.netty.handler.codec.sockjs.SockJsServiceFactory; import org.jboss.aerogear.io.netty.handler.codec.sockjs.handler.CorsInboundHandler; import org.jboss.aerogear.io.netty.handler.codec.sockjs.handler.CorsOutboundHandler; import org.jboss.aerogear.io.netty.handler.codec.sockjs.handler.SockJsHandler; import org.jboss.aerogear.io.netty.handler.codec.sockjs.transport.Transports; import org.jboss.aerogear.simplepush.protocol.HelloResponse; import org.jboss.aerogear.simplepush.protocol.MessageType; import org.jboss.aerogear.simplepush.protocol.PingMessage; import org.jboss.aerogear.simplepush.protocol.RegisterResponse; import org.jboss.aerogear.simplepush.protocol.UnregisterResponse; import org.jboss.aerogear.simplepush.protocol.impl.AckMessageImpl; import org.jboss.aerogear.simplepush.protocol.impl.HelloResponseImpl; import org.jboss.aerogear.simplepush.protocol.impl.NotificationMessageImpl; import org.jboss.aerogear.simplepush.protocol.impl.PingMessageImpl; import org.jboss.aerogear.simplepush.protocol.impl.RegisterResponseImpl; import org.jboss.aerogear.simplepush.protocol.impl.UnregisterResponseImpl; import org.jboss.aerogear.simplepush.protocol.impl.AckImpl; import org.jboss.aerogear.simplepush.protocol.impl.json.JsonUtil; import org.jboss.aerogear.simplepush.server.DefaultSimplePushConfig; import org.jboss.aerogear.simplepush.server.DefaultSimplePushServer; import org.jboss.aerogear.simplepush.server.SimplePushServer; import org.jboss.aerogear.simplepush.server.SimplePushServerConfig; import org.jboss.aerogear.simplepush.server.datastore.ChannelNotFoundException; import org.jboss.aerogear.simplepush.server.datastore.DataStore; import org.jboss.aerogear.simplepush.server.datastore.InMemoryDataStore; import org.jboss.aerogear.simplepush.util.CryptoUtil; import org.jboss.aerogear.simplepush.util.UUIDUtil; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; public class SimplePushSockJSServiceTest { private SockJsServiceFactory factory; private String sessionUrl; @Before public void setup() { factory = defaultFactory(); sessionUrl = randomSessionIdUrl(factory); } @Test public void xhrPollingOpenFrame() throws Exception { final FullHttpResponse openFrameResponse = sendXhrOpenFrameRequest(factory, sessionUrl); assertThat(openFrameResponse.getStatus(), is(HttpResponseStatus.OK)); assertThat(openFrameResponse.content().toString(UTF_8), equalTo("o\n")); } @Test public void xhrPollingHelloWithChannelId() throws Exception { final String uaid = UUIDUtil.newUAID(); final String channelId = UUID.randomUUID().toString(); sendXhrOpenFrameRequest(factory, sessionUrl); final FullHttpResponse sendResponse = sendXhrHelloMessageRequest(factory, sessionUrl, uaid, channelId); assertThat(sendResponse.getStatus(), is(HttpResponseStatus.NO_CONTENT)); final HelloResponseImpl handshakeResponse = pollXhrHelloMessageResponse(factory, sessionUrl); assertThat(handshakeResponse.getUAID(), equalTo(uaid)); } @Test public void xhrPollingHelloWithInvalidUaid() throws Exception { final String uaid = "non-valie2233??"; final String channelId = UUID.randomUUID().toString(); sendXhrOpenFrameRequest(factory, sessionUrl); final FullHttpResponse sendResponse = sendXhrHelloMessageRequest(factory, sessionUrl, uaid, channelId); assertThat(sendResponse.getStatus(), is(HttpResponseStatus.NO_CONTENT)); final HelloResponseImpl handshakeResponse = pollXhrHelloMessageResponse(factory, sessionUrl); assertThat(handshakeResponse.getMessageType(), is(MessageType.Type.HELLO)); assertThat(handshakeResponse.getUAID(), not(equalTo(uaid))); } @Test public void xhrPollingRegister() throws Exception { final String channelId = UUID.randomUUID().toString(); sendXhrOpenFrameRequest(factory, sessionUrl); sendXhrHelloMessageRequest(factory, sessionUrl, UUIDUtil.newUAID()); pollXhrHelloMessageResponse(factory, sessionUrl); final FullHttpResponse registerChannelIdRequest = sendXhrRegisterChannelIdRequest(factory, sessionUrl, channelId); assertThat(registerChannelIdRequest.getStatus(), is(HttpResponseStatus.NO_CONTENT)); final RegisterResponseImpl registerChannelIdResponse = pollXhrRegisterChannelIdResponse(factory, sessionUrl); assertThat(registerChannelIdResponse.getChannelId(), equalTo(channelId)); assertThat(registerChannelIdResponse.getStatus().getCode(), equalTo(200)); assertThat(registerChannelIdResponse.getPushEndpoint().startsWith("http://127.0.0.1:7777/update/"), is(true)); } @Test public void xhrPollingUnregister() throws Exception { final String channelId = UUID.randomUUID().toString(); sendXhrOpenFrameRequest(factory, sessionUrl); sendXhrHelloMessageRequest(factory, sessionUrl, UUIDUtil.newUAID()); pollXhrHelloMessageResponse(factory, sessionUrl); sendXhrRegisterChannelIdRequest(factory, sessionUrl, channelId); pollXhrRegisterChannelIdResponse(factory, sessionUrl); final FullHttpResponse unregisterChannelIdRequest = unregisterChannelIdRequest(factory, sessionUrl, channelId); assertThat(unregisterChannelIdRequest.getStatus(), is(HttpResponseStatus.NO_CONTENT)); final UnregisterResponseImpl unregisterChannelIdResponse = unregisterChannelIdResponse(factory, sessionUrl); assertThat(unregisterChannelIdResponse.getStatus().getCode(), is(200)); assertThat(unregisterChannelIdResponse.getChannelId(), equalTo(channelId)); } @Test public void xhrPollingPing() throws Exception { sendXhrOpenFrameRequest(factory, sessionUrl); sendXhrHelloMessageRequest(factory, sessionUrl, UUIDUtil.newUAID()); pollXhrHelloMessageResponse(factory, sessionUrl); final FullHttpResponse registerChannelIdRequest = sendXhrPingRequest(factory, sessionUrl); assertThat(registerChannelIdRequest.getStatus(), is(HttpResponseStatus.NO_CONTENT)); final PingMessageImpl pingResponse = pollXhrPingMessageResponse(factory, sessionUrl); assertThat(pingResponse.getPingMessage(), equalTo(PingMessage.PING_MESSAGE)); } @Test public void websocketUpgradeRequest() throws Exception { final EmbeddedChannel channel = createChannel(factory); final HttpResponse response = websocketHttpUpgradeRequest(sessionUrl, channel); assertThat(response.getStatus(), is(HttpResponseStatus.SWITCHING_PROTOCOLS)); assertThat(response.headers().get(HttpHeaders.Names.UPGRADE), equalTo("websocket")); assertThat(response.headers().get(HttpHeaders.Names.CONNECTION), equalTo("Upgrade")); assertThat(response.headers().get(Names.SEC_WEBSOCKET_ACCEPT), equalTo("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")); channel.close(); } public static HttpResponse decodeHttpResponse(final EmbeddedChannel channel) { final EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); ch.writeInbound(channel.readOutbound()); return ch.readInbound(); } public static FullHttpResponse decodeFullHttpResponse(final EmbeddedChannel channel) { final EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); ch.writeInbound(channel.outboundMessages().toArray()); final HttpResponse response = ch.readInbound(); final HttpContent content = ch.readInbound(); final DefaultFullHttpResponse fullResponse; if (content != null) { fullResponse = new DefaultFullHttpResponse(response.getProtocolVersion(), response.getStatus(), content.content()); } else { fullResponse = new DefaultFullHttpResponse(response.getProtocolVersion(), response.getStatus()); } fullResponse.headers().add(response.headers()); return fullResponse; } @Test public void rawWebSocketUpgradeRequest() throws Exception { final SimplePushServerConfig simplePushConfig = DefaultSimplePushConfig.create().password("test").build(); final SockJsConfig sockjsConf = SockJsConfig.withPrefix("/simplepush").webSocketProtocols("push-notification").build(); final byte[] privateKey = CryptoUtil.secretKey(simplePushConfig.password(), "someSaltForTesting".getBytes()); final SimplePushServer pushServer = new DefaultSimplePushServer(new InMemoryDataStore(), simplePushConfig, privateKey); final SimplePushServiceFactory factory = new SimplePushServiceFactory(sockjsConf, pushServer); final EmbeddedChannel channel = createChannel(factory); final FullHttpRequest request = websocketUpgradeRequest(factory.config().prefix() + Transports.Type.WEBSOCKET.path()); request.headers().set(Names.SEC_WEBSOCKET_PROTOCOL, "push-notification"); channel.writeInbound(request); final FullHttpResponse response = decodeFullHttpResponse(channel); assertThat(response.getStatus(), is(HttpResponseStatus.SWITCHING_PROTOCOLS)); assertThat(response.headers().get(HttpHeaders.Names.UPGRADE), equalTo("websocket")); assertThat(response.headers().get(HttpHeaders.Names.CONNECTION), equalTo("Upgrade")); assertThat(response.headers().get(Names.SEC_WEBSOCKET_PROTOCOL), equalTo("push-notification")); assertThat(response.headers().get(Names.SEC_WEBSOCKET_ACCEPT), equalTo("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")); channel.close(); } @Test public void websocketHello() { final EmbeddedChannel channel = createWebSocketChannel(factory); final String uaid = UUIDUtil.newUAID(); sendWebSocketHttpUpgradeRequest(sessionUrl, channel); final HelloResponse response = sendWebSocketHelloFrame(uaid, channel); assertThat(response.getMessageType(), equalTo(MessageType.Type.HELLO)); assertThat(response.getUAID(), equalTo(uaid)); channel.close(); } @Test public void websocketHelloWithInvalidUaid() { final String uaid = "non-valie2233??"; final EmbeddedChannel channel = createWebSocketChannel(factory); sendWebSocketHttpUpgradeRequest(sessionUrl, channel); final HelloResponse response = sendWebSocketHelloFrame(uaid, channel); assertThat(response.getMessageType(), equalTo(MessageType.Type.HELLO)); assertThat(response.getUAID(), not(equalTo(uaid))); channel.close(); } @Test public void websocketRegister() { final EmbeddedChannel channel = createWebSocketChannel(factory); final String channelId = UUID.randomUUID().toString(); sendWebSocketHttpUpgradeRequest(sessionUrl, channel); sendWebSocketHelloFrame(UUIDUtil.newUAID(), channel); final RegisterResponse registerResponse = sendWebSocketRegisterFrame(channelId, channel); assertThat(registerResponse.getStatus().getCode(), is(200)); assertThat(registerResponse.getChannelId(), equalTo(channelId)); channel.close(); } @Test public void websocketRegisterDuplicateChannelId() { final EmbeddedChannel channel = createWebSocketChannel(factory); final String channelId = UUID.randomUUID().toString(); sendWebSocketHttpUpgradeRequest(sessionUrl, channel); sendWebSocketHelloFrame(UUIDUtil.newUAID(), channel); assertThat(sendWebSocketRegisterFrame(channelId, channel).getStatus().getCode(), is(200)); assertThat(sendWebSocketRegisterFrame(channelId, channel).getStatus().getCode(), is(409)); channel.close(); } @Test public void websocketUnregister() { final EmbeddedChannel channel = createWebSocketChannel(factory); final String channelId = UUID.randomUUID().toString(); sendWebSocketHttpUpgradeRequest(sessionUrl, channel); sendWebSocketHelloFrame(UUIDUtil.newUAID(), channel); sendWebSocketRegisterFrame(channelId, channel); final UnregisterResponse registerResponse = websocketUnRegisterFrame(channelId, channel); assertThat(registerResponse.getStatus().getCode(), is(200)); channel.close(); } @Test public void websocketUnregisterNonRegistered() { final EmbeddedChannel channel = createWebSocketChannel(factory); sendWebSocketHttpUpgradeRequest(sessionUrl, channel); sendWebSocketHelloFrame(UUIDUtil.newUAID(), channel); final UnregisterResponse registerResponse = websocketUnRegisterFrame("notRegistered", channel); assertThat(registerResponse.getMessageType(), equalTo(MessageType.Type.UNREGISTER)); assertThat(registerResponse.getChannelId(), equalTo("notRegistered")); assertThat(registerResponse.getStatus().getCode(), is(200)); channel.close(); } @Test public void websocketHandleAcknowledgement() throws Exception { final SimplePushServer simplePushServer = defaultPushServer(); final SockJsServiceFactory serviceFactory = defaultFactory(simplePushServer); final EmbeddedChannel channel = createWebSocketChannel(serviceFactory); final String uaid = UUIDUtil.newUAID(); final String channelId = UUID.randomUUID().toString(); sendWebSocketHttpUpgradeRequest(sessionUrl, channel); sendWebSocketHelloFrame(uaid, channel); final RegisterResponse registerResponse = sendWebSocketRegisterFrame(channelId, channel); final String endpointToken = extractEndpointToken(registerResponse.getPushEndpoint()); sendNotification(endpointToken, 1L, simplePushServer); final Set<Ack> unacked = sendAcknowledge(channel, ack(channelId, 1L)); assertThat(unacked.isEmpty(), is(true)); channel.close(); } @Test public void websocketHandleAcknowledgements() throws Exception { final SimplePushServer simplePushServer = defaultPushServer(); final SockJsServiceFactory serviceFactory = defaultFactory(simplePushServer); final EmbeddedChannel channel = createWebSocketChannel(serviceFactory); final String uaid = UUIDUtil.newUAID(); final String channelId1 = UUID.randomUUID().toString(); final String channelId2 = UUID.randomUUID().toString(); sendWebSocketHttpUpgradeRequest(sessionUrl, channel); sendWebSocketHelloFrame(uaid, channel); final RegisterResponse registerResponse1 = sendWebSocketRegisterFrame(channelId1, channel); final String endpointToken1 = extractEndpointToken(registerResponse1.getPushEndpoint()); final RegisterResponse registerResponse2 = sendWebSocketRegisterFrame(channelId2, channel); final String endpointToken2 = extractEndpointToken(registerResponse2.getPushEndpoint()); sendNotification(endpointToken1, 1L, simplePushServer); sendNotification(endpointToken2, 1L, simplePushServer); final Set<Ack> unacked = sendAcknowledge(channel, ack(channelId1, 1L), ack(channelId2, 1L)); assertThat(unacked.isEmpty(), is(true)); channel.close(); } private String extractEndpointToken(final String pushEndpoint) { return pushEndpoint.substring(pushEndpoint.lastIndexOf('/') + 1); } @Test @Ignore("Need to figure out how to run a schedules job with the new EmbeddedChannel") // https://groups.google.com/forum/#!topic/netty/Q-_wat_9Odo public void websocketHandleOneUnacknowledgement() throws Exception { final SimplePushServer simplePushServer = defaultPushServer(); final SockJsServiceFactory serviceFactory = defaultFactory(simplePushServer); final EmbeddedChannel channel = createWebSocketChannel(serviceFactory); final String uaid = UUIDUtil.newUAID(); final String channelId1 = UUID.randomUUID().toString(); final String channelId2 = UUID.randomUUID().toString(); sendWebSocketHttpUpgradeRequest(sessionUrl, channel); sendWebSocketHelloFrame(uaid, channel); final RegisterResponse registerResponse1 = sendWebSocketRegisterFrame(channelId1, channel); final String endpointToken1 = extractEndpointToken(registerResponse1.getPushEndpoint()); sendNotification(endpointToken1, 1L, simplePushServer); final RegisterResponse registerResponse2 = sendWebSocketRegisterFrame(channelId2, channel); final String endpointToken2 = extractEndpointToken(registerResponse2.getPushEndpoint()); sendNotification(endpointToken2, 1L, simplePushServer); final Set<Ack> unacked = sendAcknowledge(channel, ack(channelId1, 1L)); assertThat(unacked.size(), is(1)); assertThat(unacked, hasItem(new AckImpl(channelId2, 1L))); channel.close(); } @Test @Ignore("Need to figure out how to run a schedules job with the new EmbeddedChannel") // https://groups.google.com/forum/#!topic/netty/Q-_wat_9Odo public void websocketHandleUnacknowledgement() throws Exception { final SimplePushServer simplePushServer = defaultPushServer(); final SockJsServiceFactory serviceFactory = defaultFactory(simplePushServer); final EmbeddedChannel channel = createWebSocketChannel(serviceFactory); final String uaid = UUIDUtil.newUAID(); final String channelId1 = UUID.randomUUID().toString(); final String channelId2 = UUID.randomUUID().toString(); sendWebSocketHttpUpgradeRequest(sessionUrl, channel); sendWebSocketHelloFrame(uaid, channel); final RegisterResponse registerResponse1 = sendWebSocketRegisterFrame(channelId1, channel); final String endpointToken1 = extractEndpointToken(registerResponse1.getPushEndpoint()); sendNotification(endpointToken1, 1L, simplePushServer); final RegisterResponse registerResponse2 = sendWebSocketRegisterFrame(channelId2, channel); final String endpointToken2 = extractEndpointToken(registerResponse2.getPushEndpoint()); sendNotification(endpointToken2, 1L, simplePushServer); final Set<Ack> unacked = sendAcknowledge(channel); assertThat(unacked.size(), is(1)); assertThat(unacked, hasItems(ack(channelId1, 1L), ack(channelId2, 1L))); channel.close(); } @Test public void websocketPing() { final EmbeddedChannel channel = createWebSocketChannel(factory); sendWebSocketHttpUpgradeRequest(sessionUrl, channel); sendWebSocketHelloFrame(UUIDUtil.newUAID(), channel); final PingMessage pingResponse = sendWebSocketPingFrame(channel); assertThat(pingResponse.getPingMessage(), equalTo(PingMessage.PING_MESSAGE)); channel.close(); } private SimplePushServer defaultPushServer() { final DataStore store = new InMemoryDataStore(); final SimplePushServerConfig config = DefaultSimplePushConfig.create().password("test").build(); final byte[] privateKey = DefaultSimplePushServer.generateAndStorePrivateKey(store, config); return new DefaultSimplePushServer(store, config, privateKey); } private void sendNotification(final String endpointToken, final long version, final SimplePushServer simplePushServer) throws ChannelNotFoundException { simplePushServer.handleNotification(endpointToken, "version=" + version); } private Ack ack(final String channelId, final Long version) { return new AckImpl(channelId, version); } private Set<Ack> sendAcknowledge(final EmbeddedChannel channel, final Ack... acks) { final Set<Ack> ups = new HashSet<Ack>(Arrays.asList(acks)); final TextWebSocketFrame ackFrame = ackFrame(ups); channel.writeInbound(ackFrame); channel.runPendingTasks(); final Object out = channel.readOutbound(); if (out == null) { return Collections.emptySet(); } final NotificationMessageImpl unacked = responseToType(out, NotificationMessageImpl.class); return unacked.getAcks(); } private TextWebSocketFrame ackFrame(final Set<Ack> acks) { return new TextWebSocketFrame(JsonUtil.toJson(new AckMessageImpl(acks))); } private RegisterResponseImpl sendWebSocketRegisterFrame(final String channelId, final EmbeddedChannel ch) { ch.writeInbound(TestUtil.registerChannelIdWebSocketFrame(channelId)); return responseToType(readOutboundDiscardEmpty(ch), RegisterResponseImpl.class); } private PingMessageImpl sendWebSocketPingFrame(final EmbeddedChannel ch) { ch.writeInbound(TestUtil.pingWebSocketFrame()); return responseToType(ch.readOutbound(), PingMessageImpl.class); } private UnregisterResponse websocketUnRegisterFrame(final String channelId, final EmbeddedChannel ch) { ch.writeInbound(TestUtil.unregisterChannelIdWebSocketFrame(channelId)); return responseToType(ch.readOutbound(), UnregisterResponseImpl.class); } private HttpResponse websocketHttpUpgradeRequest(final String sessionUrl, final EmbeddedChannel ch) throws Exception{ ch.writeInbound(websocketUpgradeRequest(sessionUrl + Transports.Type.WEBSOCKET.path())); return decodeHttpResponse(ch); } private void sendWebSocketHttpUpgradeRequest(final String sessionUrl, final EmbeddedChannel ch) { ch.writeInbound(websocketUpgradeRequest(sessionUrl + Transports.Type.WEBSOCKET.path())); // Discarding the Http upgrade response ch.readOutbound(); ch.readOutbound(); // Discard open frame ch.readOutbound(); ch.readOutbound(); ch.pipeline().remove("wsencoder"); } private HelloResponse sendWebSocketHelloFrame(final String uaid, final EmbeddedChannel ch) { ch.writeInbound(TestUtil.helloWebSocketFrame(uaid)); return responseToType(ch.readOutbound(), HelloResponseImpl.class); } private Object readOutboundDiscardEmpty(final EmbeddedChannel ch) { final Object obj = ch.readOutbound(); if (obj instanceof ByteBuf) { final ByteBuf buf = (ByteBuf) obj; if (buf.capacity() == 0) { ReferenceCountUtil.release(buf); return ch.readOutbound(); } } return obj; } private <T> T responseToType(final Object response, Class<T> type) { if (response instanceof TextWebSocketFrame) { final TextWebSocketFrame frame = (TextWebSocketFrame) response; String content = frame.text(); if (content.startsWith("a[")) { content = TestUtil.extractJsonFromSockJSMessage(content); } return JsonUtil.fromJson(content, type); } throw new IllegalArgumentException("Response is expected to be of type TextWebSocketFrame was: " + response); } private FullHttpResponse sendXhrOpenFrameRequest(final SockJsServiceFactory factory, final String sessionUrl) throws Exception { final EmbeddedChannel openChannel = createChannel(factory); openChannel.writeInbound(httpGetRequest(sessionUrl + Transports.Type.XHR.path())); final FullHttpResponse openFrameResponse = decodeFullHttpResponse(openChannel); openChannel.close(); return openFrameResponse; } private FullHttpResponse sendXhrHelloMessageRequest(final SockJsServiceFactory factory, final String sessionUrl, final String uaid, final String... channelIds) throws Exception { return xhrSend(factory, sessionUrl, TestUtil.helloSockJSFrame(uaid, channelIds)); } private HelloResponseImpl pollXhrHelloMessageResponse(final SockJsServiceFactory factory, final String sessionUrl) throws Exception { final FullHttpResponse pollResponse = xhrPoll(factory, sessionUrl); assertThat(pollResponse.getStatus(), is(HttpResponseStatus.OK)); final String helloJson = TestUtil.extractJsonFromSockJSMessage(pollResponse.content().toString(UTF_8)); return JsonUtil.fromJson(helloJson, HelloResponseImpl.class); } private FullHttpResponse sendXhrRegisterChannelIdRequest(final SockJsServiceFactory factory, final String sessionUrl, final String channelId) throws Exception { return xhrSend(factory, sessionUrl, TestUtil.registerChannelIdMessageSockJSFrame(channelId)); } private RegisterResponseImpl pollXhrRegisterChannelIdResponse(final SockJsServiceFactory factory, final String sessionUrl) throws Exception { final FullHttpResponse pollResponse = xhrPoll(factory, sessionUrl); assertThat(pollResponse.getStatus(), is(HttpResponseStatus.OK)); final String json = TestUtil.extractJsonFromSockJSMessage(pollResponse.content().toString(UTF_8)); return JsonUtil.fromJson(json, RegisterResponseImpl.class); } private FullHttpResponse unregisterChannelIdRequest(final SockJsServiceFactory factory, final String sessionUrl, final String channelId) throws Exception { return xhrSend(factory, sessionUrl, TestUtil.unregisterChannelIdMessageSockJSFrame(channelId)); } private UnregisterResponseImpl unregisterChannelIdResponse(final SockJsServiceFactory factory, final String sessionUrl) throws Exception { final FullHttpResponse pollResponse = xhrPoll(factory, sessionUrl); assertThat(pollResponse.getStatus(), is(HttpResponseStatus.OK)); final String json = TestUtil.extractJsonFromSockJSMessage(pollResponse.content().toString(UTF_8)); return JsonUtil.fromJson(json, UnregisterResponseImpl.class); } private FullHttpResponse sendXhrPingRequest(final SockJsServiceFactory factory, final String sessionUrl) throws Exception { return xhrSend(factory, sessionUrl, TestUtil.pingSockJSFrame()); } private PingMessageImpl pollXhrPingMessageResponse(final SockJsServiceFactory factory, final String sessionUrl) throws Exception { final FullHttpResponse pollResponse = xhrPoll(factory, sessionUrl); assertThat(pollResponse.getStatus(), is(HttpResponseStatus.OK)); final String helloJson = TestUtil.extractJsonFromSockJSMessage(pollResponse.content().toString(UTF_8)); return JsonUtil.fromJson(helloJson, PingMessageImpl.class); } private FullHttpResponse xhrSend(final SockJsServiceFactory factory, final String sessionUrl, final String content) throws Exception { final EmbeddedChannel sendChannel = createChannel(factory); final FullHttpRequest sendRequest = httpPostRequest(sessionUrl + Transports.Type.XHR_SEND.path()); sendRequest.content().writeBytes(Unpooled.copiedBuffer(content, UTF_8)); sendChannel.writeInbound(sendRequest); final FullHttpResponse sendResponse = decodeFullHttpResponse(sendChannel); sendChannel.close(); return sendResponse; } private FullHttpResponse xhrPoll(final SockJsServiceFactory factory, final String sessionUrl) throws Exception { final EmbeddedChannel pollChannel = createChannel(factory); pollChannel.writeInbound(httpGetRequest(sessionUrl + Transports.Type.XHR.path())); return decodeFullHttpResponse(pollChannel); } private FullHttpRequest httpGetRequest(final String path) { return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path); } private FullHttpRequest websocketUpgradeRequest(final String path) { final FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, path); req.headers().set(Names.HOST, "server.test.com"); req.headers().set(Names.UPGRADE, WEBSOCKET.toString()); req.headers().set(Names.CONNECTION, "Upgrade"); req.headers().set(Names.SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ=="); req.headers().set(Names.SEC_WEBSOCKET_ORIGIN, "http://test.com"); req.headers().set(Names.SEC_WEBSOCKET_VERSION, "13"); req.headers().set(Names.CONTENT_LENGTH, "0"); return req; } private FullHttpRequest httpPostRequest(final String path) { return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, path); } private SockJsServiceFactory defaultFactory() { final SimplePushServerConfig simplePushConfig = DefaultSimplePushConfig.create().password("test").build(); final SockJsConfig sockjsConf = SockJsConfig.withPrefix("/simplepush").build(); final byte[] privateKey = CryptoUtil.secretKey(simplePushConfig.password(), "someSaltForTesting".getBytes()); final SimplePushServer pushServer = new DefaultSimplePushServer(new InMemoryDataStore(), simplePushConfig, privateKey); return new SimplePushServiceFactory(sockjsConf, pushServer); } private SockJsServiceFactory defaultFactory(final SimplePushServer simplePushServer) { final SockJsConfig sockJSConfig = SockJsConfig.withPrefix("/simplepush").build(); return new SockJsServiceFactory() { @Override public SockJsService create() { return new SimplePushSockJSService(config(), simplePushServer); } @Override public SockJsConfig config() { return sockJSConfig; } }; } private String randomSessionIdUrl(final SockJsServiceFactory factory) { return factory.config().prefix() + "/111/" + UUID.randomUUID().toString(); } private EmbeddedChannel createChannel(final SockJsServiceFactory factory) { final EmbeddedChannel ch = new TestEmbeddedChannel( new HttpRequestDecoder(), new HttpResponseEncoder(), new CorsInboundHandler(), new SockJsHandler(factory), new CorsOutboundHandler()); ch.pipeline().remove("EmbeddedChannel$LastInboundHandler#0"); return ch; } private EmbeddedChannel createWebSocketChannel(final SockJsServiceFactory factory) { final EmbeddedChannel ch = new TestEmbeddedChannel( new HttpRequestDecoder(), new HttpResponseEncoder(), new CorsInboundHandler(), new SockJsHandler(factory), new CorsOutboundHandler()); ch.pipeline().remove("EmbeddedChannel$LastInboundHandler#0"); return ch; } private static class TestEmbeddedChannel extends EmbeddedChannel { public TestEmbeddedChannel(final ChannelHandler... handlers) { super(handlers); } @Override public Unsafe unsafe() { final AbstractUnsafe delegate = super.newUnsafe(); return new TestUnsafe(delegate, new StubEmbeddedEventLoop(super.eventLoop())); } private class TestUnsafe implements Unsafe { private final Unsafe delegate; private final ChannelHandlerInvoker invoker; public TestUnsafe(final Unsafe delegate, final ChannelHandlerInvoker invoker) { this.delegate = delegate; this.invoker = invoker; } @Override public ChannelHandlerInvoker invoker() { return invoker; } @Override public SocketAddress localAddress() { return delegate.localAddress(); } @Override public SocketAddress remoteAddress() { return delegate.remoteAddress(); } @Override public void register(ChannelPromise promise) { delegate.register(promise); } @Override public void bind(SocketAddress localAddress, ChannelPromise promise) { delegate.bind(localAddress, promise); } @Override public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { delegate.connect(remoteAddress, localAddress, promise); } @Override public void disconnect(ChannelPromise promise) { delegate.disconnect(promise); } @Override public void close(ChannelPromise promise) { delegate.close(promise); } @Override public void closeForcibly() { delegate.closeForcibly(); } @Override public void beginRead() { delegate.beginRead(); } @Override public void write(Object msg, ChannelPromise promise) { delegate.write(msg, promise); } @Override public void flush() { delegate.flush(); } @Override public ChannelPromise voidPromise() { return delegate.voidPromise(); } @Override public ChannelOutboundBuffer outboundBuffer() { return delegate.outboundBuffer(); } } } }