package org.jooby.internal.undertow; import static org.easymock.EasyMock.eq; import static org.easymock.EasyMock.expect; import static org.easymock.EasyMock.expectLastCall; import static org.easymock.EasyMock.isA; import static org.junit.Assert.assertEquals; import io.undertow.websockets.core.BufferedBinaryMessage; import io.undertow.websockets.core.BufferedTextMessage; import io.undertow.websockets.core.CloseMessage; import io.undertow.websockets.core.WebSocketCallback; import io.undertow.websockets.core.WebSocketChannel; import io.undertow.websockets.core.WebSockets; import java.nio.ByteBuffer; import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import java.util.function.Consumer; import org.jooby.WebSocket; import org.jooby.test.MockUnit; import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; import org.xnio.ChannelListener.Setter; import org.xnio.IoUtils; import org.xnio.Pooled; import com.typesafe.config.Config; @RunWith(PowerMockRunner.class) @PrepareForTest({UndertowWebSocket.class, CountDownLatch.class, Thread.class, WebSockets.class, IoUtils.class }) public class UndertowWebSocketTest { private MockUnit.Block config = unit -> { Config config = unit.get(Config.class); expect(config.getDuration("undertow.ws.IdleTimeout", TimeUnit.MILLISECONDS)) .andReturn(6000L); expect(config.getBytes("undertow.ws.MaxBinaryBufferSize")).andReturn(60L); expect(config.getBytes("undertow.ws.MaxTextBufferSize")).andReturn(80L); }; @SuppressWarnings("unchecked") private MockUnit.Block connect = unit -> { Setter<WebSocketChannel> setter = unit.mock(Setter.class); setter.set(isA(UndertowWebSocket.class)); WebSocketChannel ws = unit.get(WebSocketChannel.class); ws.setIdleTimeout(6000L); expect(ws.getReceiveSetter()).andReturn(setter); ws.resumeReceives(); unit.get(Runnable.class).run(); }; @Test public void defaults() throws Exception { new MockUnit(Config.class) .expect(config) .run(unit -> { new UndertowWebSocket(unit.get(Config.class)); }); } @SuppressWarnings("unchecked") @Test public void connect() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, Runnable.class) .expect(unit -> { CountDownLatch ready = unit.mockConstructor(CountDownLatch.class, new Class[]{int.class }, 1); ready.countDown(); }) .expect(unit -> { Setter<WebSocketChannel> setter = unit.mock(Setter.class); setter.set(isA(UndertowWebSocket.class)); WebSocketChannel ws = unit.get(WebSocketChannel.class); ws.setIdleTimeout(6000L); expect(ws.getReceiveSetter()).andReturn(setter); ws.resumeReceives(); }) .expect(unit -> { unit.get(Runnable.class).run(); }) .expect(config) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onConnect(unit.get(Runnable.class)); ws.connect(unit.get(WebSocketChannel.class)); }); } @Test public void maxBinaryBufferSize() throws Exception { new MockUnit(Config.class) .expect(config) .run(unit -> { assertEquals(60L, new UndertowWebSocket(unit.get(Config.class)).getMaxBinaryBufferSize()); }); } @Test public void maxTextBufferSize() throws Exception { new MockUnit(Config.class) .expect(config) .run(unit -> { assertEquals(80L, new UndertowWebSocket(unit.get(Config.class)).getMaxTextBufferSize()); }); } @SuppressWarnings("unchecked") @Test public void onFullTextMessage() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, BufferedTextMessage.class, Consumer.class) .expect(config) .expect(unit -> { CountDownLatch ready = unit.mockConstructor(CountDownLatch.class, new Class[]{int.class }, 1); ready.await(); }) .expect(unit -> { BufferedTextMessage msg = unit.get(BufferedTextMessage.class); expect(msg.getData()).andReturn("x"); Consumer<String> callback = unit.get(Consumer.class); callback.accept("x"); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onTextMessage(unit.get(Consumer.class)); ws.onFullTextMessage(unit.get(WebSocketChannel.class), unit.get(BufferedTextMessage.class)); }); } @SuppressWarnings("unchecked") @Test public void onFullTextMessageInterrupted() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, BufferedTextMessage.class, Consumer.class) .expect(config) .expect(unit -> { CountDownLatch ready = unit.mockConstructor(CountDownLatch.class, new Class[]{int.class }, 1); ready.await(); expectLastCall().andThrow(new InterruptedException("intentional err")); }) .expect(unit -> { Thread thread = unit.mock(Thread.class); thread.interrupt(); unit.mockStatic(Thread.class); expect(Thread.currentThread()).andReturn(thread); }) .expect(unit -> { BufferedTextMessage msg = unit.get(BufferedTextMessage.class); expect(msg.getData()).andReturn("x"); Consumer<String> callback = unit.get(Consumer.class); callback.accept("x"); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onTextMessage(unit.get(Consumer.class)); ws.onFullTextMessage(unit.get(WebSocketChannel.class), unit.get(BufferedTextMessage.class)); }); } @SuppressWarnings("unchecked") @Test public void onFullBinaryMessage() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, BufferedBinaryMessage.class, Consumer.class) .expect(config) .expect(unit -> { CountDownLatch ready = unit.mockConstructor(CountDownLatch.class, new Class[]{int.class }, 1); ready.await(); }) .expect(unit -> { ByteBuffer buff = ByteBuffer.wrap(new byte[0]); ByteBuffer[] resource = {buff }; unit.mockStatic(WebSockets.class); expect(WebSockets.mergeBuffers(resource)).andReturn(buff); Pooled<ByteBuffer[]> pooled = unit.mock(Pooled.class); expect(pooled.getResource()).andReturn(resource); pooled.free(); BufferedBinaryMessage msg = unit.get(BufferedBinaryMessage.class); expect(msg.getData()).andReturn(pooled); Consumer<ByteBuffer> callback = unit.get(Consumer.class); callback.accept(buff); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onBinaryMessage(unit.get(Consumer.class)); ws.onFullBinaryMessage(unit.get(WebSocketChannel.class), unit.get(BufferedBinaryMessage.class)); }); } @SuppressWarnings("unchecked") @Test(expected = IllegalStateException.class) public void onFullBinaryMessageFailure() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, BufferedBinaryMessage.class, Consumer.class) .expect(config) .expect(unit -> { CountDownLatch ready = unit.mockConstructor(CountDownLatch.class, new Class[]{int.class }, 1); ready.await(); }) .expect(unit -> { Pooled<ByteBuffer[]> pooled = unit.mock(Pooled.class); expect(pooled.getResource()).andThrow(new IllegalStateException("intentional err")); pooled.free(); BufferedBinaryMessage msg = unit.get(BufferedBinaryMessage.class); expect(msg.getData()).andReturn(pooled); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onBinaryMessage(unit.get(Consumer.class)); ws.onFullBinaryMessage(unit.get(WebSocketChannel.class), unit.get(BufferedBinaryMessage.class)); }); } @SuppressWarnings("unchecked") @Test public void onCloseMessage() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, CloseMessage.class, BiConsumer.class) .expect(config) .expect(unit -> { CountDownLatch ready = unit.mockConstructor(CountDownLatch.class, new Class[]{int.class }, 1); ready.await(); }) .expect(unit -> { CloseMessage msg = unit.get(CloseMessage.class); expect(msg.getCode()).andReturn(1000); expect(msg.getReason()).andReturn(null); BiConsumer<Integer, Optional<String>> callback = unit.get(BiConsumer.class); callback.accept(1000, Optional.empty()); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onCloseMessage(unit.get(BiConsumer.class)); ws.onCloseMessage(unit.get(CloseMessage.class), unit.get(WebSocketChannel.class)); }); } @SuppressWarnings("unchecked") @Test public void onError() throws Exception { Throwable cause = new IllegalStateException("intentional err"); new MockUnit(Config.class, WebSocketChannel.class, CloseMessage.class, Consumer.class) .expect(config) .expect(unit -> { CountDownLatch ready = unit.mockConstructor(CountDownLatch.class, new Class[]{int.class }, 1); ready.await(); }) .expect(unit -> { Consumer<Throwable> callback = unit.get(Consumer.class); callback.accept(cause); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onErrorMessage(unit.get(Consumer.class)); ws.onError(unit.get(WebSocketChannel.class), cause); }); } @SuppressWarnings("unchecked") @Test public void close() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, CloseMessage.class, Consumer.class, Runnable.class) .expect(config) .expect(unit -> { CountDownLatch ready = unit.mockConstructor(CountDownLatch.class, new Class[]{int.class }, 1); ready.countDown(); }) .expect(unit -> { Setter<WebSocketChannel> setter = unit.mock(Setter.class); setter.set(isA(UndertowWebSocket.class)); WebSocketChannel ws = unit.get(WebSocketChannel.class); ws.setIdleTimeout(6000L); expect(ws.getReceiveSetter()).andReturn(setter); ws.resumeReceives(); }) .expect(unit -> { unit.get(Runnable.class).run(); }) .expect(unit -> { unit.mockStatic(WebSockets.class); WebSockets.sendClose(eq(1000), eq("reason"), eq(unit.get(WebSocketChannel.class)), unit.capture(WebSocketCallback.class)); }) .expect(unit -> { unit.mockStatic(IoUtils.class); IoUtils.safeClose(unit.get(WebSocketChannel.class)); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onConnect(unit.get(Runnable.class)); ws.connect(unit.get(WebSocketChannel.class)); ws.close(1000, "reason"); }, unit -> { WebSocketCallback<Void> callback = unit.captured(WebSocketCallback.class).iterator() .next(); callback.complete(unit.get(WebSocketChannel.class), null); }); } @SuppressWarnings("unchecked") @Test public void closeWithErr() throws Exception { Throwable cause = new IllegalStateException("intentional err"); new MockUnit(Config.class, WebSocketChannel.class, CloseMessage.class, Consumer.class, Runnable.class) .expect(config) .expect(unit -> { CountDownLatch ready = unit.mockConstructor(CountDownLatch.class, new Class[]{int.class }, 1); ready.countDown(); }) .expect(unit -> { Setter<WebSocketChannel> setter = unit.mock(Setter.class); setter.set(isA(UndertowWebSocket.class)); WebSocketChannel ws = unit.get(WebSocketChannel.class); ws.setIdleTimeout(6000L); expect(ws.getReceiveSetter()).andReturn(setter); ws.resumeReceives(); }) .expect(unit -> { unit.get(Runnable.class).run(); }) .expect(unit -> { unit.mockStatic(WebSockets.class); WebSockets.sendClose(eq(1000), eq("reason"), eq(unit.get(WebSocketChannel.class)), unit.capture(WebSocketCallback.class)); }) .expect(unit -> { unit.mockStatic(IoUtils.class); IoUtils.safeClose(unit.get(WebSocketChannel.class)); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onConnect(unit.get(Runnable.class)); ws.connect(unit.get(WebSocketChannel.class)); ws.close(1000, "reason"); }, unit -> { WebSocketCallback<Void> callback = unit.captured(WebSocketCallback.class).iterator() .next(); callback.onError(unit.get(WebSocketChannel.class), null, cause); }); } @Test public void resume() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, Runnable.class) .expect(config) .expect(connect) .expect(unit -> { WebSocketChannel ch = unit.get(WebSocketChannel.class); ch.resumeReceives(); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onConnect(unit.get(Runnable.class)); ws.connect(unit.get(WebSocketChannel.class)); ws.resume(); }); } @Test public void isOpen() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, Runnable.class) .expect(config) .expect(connect) .expect(unit -> { WebSocketChannel ch = unit.get(WebSocketChannel.class); expect(ch.isOpen()).andReturn(true); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onConnect(unit.get(Runnable.class)); ws.connect(unit.get(WebSocketChannel.class)); ws.isOpen(); }); } @Test public void pause() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, Runnable.class) .expect(config) .expect(connect) .expect(unit -> { WebSocketChannel ch = unit.get(WebSocketChannel.class); ch.suspendReceives(); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onConnect(unit.get(Runnable.class)); ws.connect(unit.get(WebSocketChannel.class)); ws.pause(); }); } @SuppressWarnings("unchecked") @Test public void terminate() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, Runnable.class, BiConsumer.class) .expect(config) .expect(connect) .expect(unit -> { BiConsumer<Integer, Optional<String>> callback = unit.get(BiConsumer.class); callback.accept(1006, Optional.of("Harsh disconnect")); }) .expect(unit -> { unit.mockStatic(IoUtils.class); IoUtils.safeClose(unit.get(WebSocketChannel.class)); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onConnect(unit.get(Runnable.class)); ws.connect(unit.get(WebSocketChannel.class)); ws.onCloseMessage(unit.get(BiConsumer.class)); ws.terminate(); }); } @SuppressWarnings("unchecked") @Test public void sendText() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, CloseMessage.class, Consumer.class, Runnable.class, WebSocket.SuccessCallback.class, WebSocket.OnError.class) .expect(config) .expect(connect) .expect(unit -> { unit.mockStatic(WebSockets.class); WebSockets.sendText(eq("data"), eq(unit.get(WebSocketChannel.class)), unit.capture(WebSocketCallback.class)); }) .expect(unit -> { unit.get(WebSocket.SuccessCallback.class).invoke(); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onConnect(unit.get(Runnable.class)); ws.connect(unit.get(WebSocketChannel.class)); ws.sendText("data", unit.get(WebSocket.SuccessCallback.class), unit.get(WebSocket.OnError.class)); }, unit -> { WebSocketCallback<Void> callback = unit.captured(WebSocketCallback.class).iterator() .next(); callback.complete(unit.get(WebSocketChannel.class), null); }); } @SuppressWarnings("unchecked") @Test public void sendTextCallbackErr() throws Exception { new MockUnit(Config.class, WebSocketChannel.class, CloseMessage.class, Consumer.class, Runnable.class, WebSocket.SuccessCallback.class, WebSocket.OnError.class) .expect(config) .expect(connect) .expect(unit -> { unit.mockStatic(WebSockets.class); WebSockets.sendText(eq("data"), eq(unit.get(WebSocketChannel.class)), unit.capture(WebSocketCallback.class)); }) .expect(unit -> { unit.get(WebSocket.SuccessCallback.class).invoke(); expectLastCall().andThrow(new IllegalStateException("intentional err")); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onConnect(unit.get(Runnable.class)); ws.connect(unit.get(WebSocketChannel.class)); ws.sendText("data", unit.get(WebSocket.SuccessCallback.class), unit.get(WebSocket.OnError.class)); }, unit -> { WebSocketCallback<Void> callback = unit.captured(WebSocketCallback.class).iterator() .next(); callback.complete(unit.get(WebSocketChannel.class), null); }); } @SuppressWarnings("unchecked") @Test public void sendBytes() throws Exception { ByteBuffer data = ByteBuffer.wrap(new byte[0]); new MockUnit(Config.class, WebSocketChannel.class, CloseMessage.class, Consumer.class, Runnable.class, WebSocket.SuccessCallback.class, WebSocket.OnError.class) .expect(config) .expect(connect) .expect(unit -> { unit.mockStatic(WebSockets.class); WebSockets.sendBinary(eq(data), eq(unit.get(WebSocketChannel.class)), unit.capture(WebSocketCallback.class)); }) .expect(unit -> { unit.get(WebSocket.SuccessCallback.class).invoke(); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onConnect(unit.get(Runnable.class)); ws.connect(unit.get(WebSocketChannel.class)); ws.sendBytes(data, unit.get(WebSocket.SuccessCallback.class), unit.get(WebSocket.OnError.class)); }, unit -> { WebSocketCallback<Void> callback = unit.captured(WebSocketCallback.class).iterator() .next(); callback.complete(unit.get(WebSocketChannel.class), null); }); } @SuppressWarnings("unchecked") @Test public void sendTextErrCallback() throws Exception { Throwable cause = new IllegalStateException("intentional err"); new MockUnit(Config.class, WebSocketChannel.class, CloseMessage.class, Consumer.class, Runnable.class, WebSocket.SuccessCallback.class, WebSocket.OnError.class) .expect(config) .expect(connect) .expect(unit -> { unit.mockStatic(WebSockets.class); WebSockets.sendText(eq("data"), eq(unit.get(WebSocketChannel.class)), unit.capture(WebSocketCallback.class)); }) .expect(unit -> { unit.get(WebSocket.OnError.class).onError(cause); }) .run(unit -> { UndertowWebSocket ws = new UndertowWebSocket(unit.get(Config.class)); ws.onConnect(unit.get(Runnable.class)); ws.connect(unit.get(WebSocketChannel.class)); ws.sendText("data", unit.get(WebSocket.SuccessCallback.class), unit.get(WebSocket.OnError.class)); }, unit -> { WebSocketCallback<Void> callback = unit.captured(WebSocketCallback.class).iterator() .next(); callback.onError(unit.get(WebSocketChannel.class), null, cause); }); } }