/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.web.socket.handler; import java.io.IOException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import org.junit.Test; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; import static org.junit.Assert.*; /** * Unit tests for * {@link org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator}. * * @author Rossen Stoyanchev */ @SuppressWarnings("resource") public class ConcurrentWebSocketSessionDecoratorTests { @Test public void send() throws IOException { TestWebSocketSession session = new TestWebSocketSession(); session.setOpen(true); ConcurrentWebSocketSessionDecorator concurrentSession = new ConcurrentWebSocketSessionDecorator(session, 1000, 1024); TextMessage textMessage = new TextMessage("payload"); concurrentSession.sendMessage(textMessage); assertEquals(1, session.getSentMessages().size()); assertEquals(textMessage, session.getSentMessages().get(0)); assertEquals(0, concurrentSession.getBufferSize()); assertEquals(0, concurrentSession.getTimeSinceSendStarted()); assertTrue(session.isOpen()); } @Test public void sendAfterBlockedSend() throws IOException, InterruptedException { BlockingSession blockingSession = new BlockingSession(); blockingSession.setOpen(true); CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); final ConcurrentWebSocketSessionDecorator concurrentSession = new ConcurrentWebSocketSessionDecorator(blockingSession, 10 * 1000, 1024); Executors.newSingleThreadExecutor().submit((Runnable) () -> { TextMessage message = new TextMessage("slow message"); try { concurrentSession.sendMessage(message); } catch (IOException e) { e.printStackTrace(); } }); assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); // ensure some send time elapses Thread.sleep(100); assertTrue(concurrentSession.getTimeSinceSendStarted() > 0); TextMessage payload = new TextMessage("payload"); for (int i = 0; i < 5; i++) { concurrentSession.sendMessage(payload); } assertTrue(concurrentSession.getTimeSinceSendStarted() > 0); assertEquals(5 * payload.getPayloadLength(), concurrentSession.getBufferSize()); assertTrue(blockingSession.isOpen()); } @Test public void sendTimeLimitExceeded() throws IOException, InterruptedException { BlockingSession blockingSession = new BlockingSession(); blockingSession.setId("123"); blockingSession.setOpen(true); CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); int sendTimeLimit = 100; int bufferSizeLimit = 1024; final ConcurrentWebSocketSessionDecorator concurrentSession = new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit); Executors.newSingleThreadExecutor().submit((Runnable) () -> { TextMessage message = new TextMessage("slow message"); try { concurrentSession.sendMessage(message); } catch (IOException e) { e.printStackTrace(); } }); assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); // ensure some send time elapses Thread.sleep(sendTimeLimit + 100); try { TextMessage payload = new TextMessage("payload"); concurrentSession.sendMessage(payload); fail("Expected exception"); } catch (SessionLimitExceededException ex) { String actual = ex.getMessage(); String regex = "Message send time [\\d]+ \\(ms\\) for session '123' exceeded the allowed limit 100"; assertTrue("Unexpected message: " + actual, actual.matches(regex)); assertEquals(CloseStatus.SESSION_NOT_RELIABLE, ex.getStatus()); } } @Test public void sendBufferSizeExceeded() throws IOException, InterruptedException { BlockingSession blockingSession = new BlockingSession(); blockingSession.setId("123"); blockingSession.setOpen(true); CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); int sendTimeLimit = 10 * 1000; int bufferSizeLimit = 1024; final ConcurrentWebSocketSessionDecorator concurrentSession = new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit); Executors.newSingleThreadExecutor().submit((Runnable) () -> { TextMessage message = new TextMessage("slow message"); try { concurrentSession.sendMessage(message); } catch (IOException e) { e.printStackTrace(); } }); assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); StringBuilder sb = new StringBuilder(); for (int i = 0 ; i < 1023; i++) { sb.append("a"); } TextMessage message = new TextMessage(sb.toString()); concurrentSession.sendMessage(message); assertEquals(1023, concurrentSession.getBufferSize()); assertTrue(blockingSession.isOpen()); try { concurrentSession.sendMessage(message); fail("Expected exception"); } catch (SessionLimitExceededException ex) { String actual = ex.getMessage(); String regex = "The send buffer size [\\d]+ bytes for session '123' exceeded the allowed limit 1024"; assertTrue("Unexpected message: " + actual, actual.matches(regex)); assertEquals(CloseStatus.SESSION_NOT_RELIABLE, ex.getStatus()); } } @Test public void closeStatusNormal() throws Exception { BlockingSession delegate = new BlockingSession(); delegate.setOpen(true); WebSocketSession decorator = new ConcurrentWebSocketSessionDecorator(delegate, 10 * 1000, 1024); decorator.close(CloseStatus.PROTOCOL_ERROR); assertEquals(CloseStatus.PROTOCOL_ERROR, delegate.getCloseStatus()); decorator.close(CloseStatus.SERVER_ERROR); assertEquals("Should have been ignored", CloseStatus.PROTOCOL_ERROR, delegate.getCloseStatus()); } @Test public void closeStatusChangesToSessionNotReliable() throws Exception { BlockingSession blockingSession = new BlockingSession(); blockingSession.setId("123"); blockingSession.setOpen(true); CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); int sendTimeLimit = 100; int bufferSizeLimit = 1024; final ConcurrentWebSocketSessionDecorator concurrentSession = new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit); Executors.newSingleThreadExecutor().submit((Runnable) () -> { TextMessage message = new TextMessage("slow message"); try { concurrentSession.sendMessage(message); } catch (IOException e) { e.printStackTrace(); } }); assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); // ensure some send time elapses Thread.sleep(sendTimeLimit + 100); concurrentSession.close(CloseStatus.PROTOCOL_ERROR); assertEquals("CloseStatus should have changed to SESSION_NOT_RELIABLE", CloseStatus.SESSION_NOT_RELIABLE, blockingSession.getCloseStatus()); } private static class BlockingSession extends TestWebSocketSession { private AtomicReference<CountDownLatch> nextMessageLatch = new AtomicReference<>(); private AtomicReference<CountDownLatch> releaseLatch = new AtomicReference<>(); public CountDownLatch getSentMessageLatch() { this.nextMessageLatch.set(new CountDownLatch(1)); return this.nextMessageLatch.get(); } @Override public void sendMessage(WebSocketMessage<?> message) throws IOException { super.sendMessage(message); if (this.nextMessageLatch != null) { this.nextMessageLatch.get().countDown(); } block(); } private void block() { try { this.releaseLatch.set(new CountDownLatch(1)); this.releaseLatch.get().await(); } catch (InterruptedException e) { e.printStackTrace(); } } } }