/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.annotation.web.socket;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.core.MethodParameter;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletConfig;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.access.expression.SecurityExpressionHandler;
import org.springframework.security.access.expression.SecurityExpressionOperations;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.security.messaging.access.expression.DefaultMessageSecurityExpressionHandler;
import org.springframework.security.messaging.access.expression.MessageSecurityExpressionRoot;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
import org.springframework.stereotype.Controller;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.util.AntPathMatcher;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler;
import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession;
public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
AnnotationConfigWebApplicationContext context;
TestingAuthenticationToken messageUser;
CsrfToken token;
String sessionAttr;
@Before
public void setup() {
token = new DefaultCsrfToken("header", "param", "token");
sessionAttr = "sessionAttr";
messageUser = new TestingAuthenticationToken("user", "pass", "ROLE_USER");
}
@After
public void cleanup() {
if (context != null) {
context.close();
}
}
@Test
public void simpleRegistryMappings() {
loadConfig(SockJsSecurityConfig.class);
clientInboundChannel().send(message("/permitAll"));
try {
clientInboundChannel().send(message("/denyAll"));
fail("Expected Exception");
}
catch (MessageDeliveryException expected) {
assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class);
}
}
@Test
public void annonymousSupported() {
loadConfig(SockJsSecurityConfig.class);
messageUser = null;
clientInboundChannel().send(message("/permitAll"));
}
// gh-3797
@Test
public void beanResolver() {
loadConfig(SockJsSecurityConfig.class);
messageUser = null;
clientInboundChannel().send(message("/beanResolver"));
}
@Test
public void addsAuthenticationPrincipalResolver() throws InterruptedException {
loadConfig(SockJsSecurityConfig.class);
MessageChannel messageChannel = clientInboundChannel();
Message<String> message = message("/permitAll/authentication");
messageChannel.send(message);
assertThat(context.getBean(MyController.class).authenticationPrincipal)
.isEqualTo((String) messageUser.getPrincipal());
}
@Test
public void addsAuthenticationPrincipalResolverWhenNoAuthorization()
throws InterruptedException {
loadConfig(NoInboundSecurityConfig.class);
MessageChannel messageChannel = clientInboundChannel();
Message<String> message = message("/permitAll/authentication");
messageChannel.send(message);
assertThat(context.getBean(MyController.class).authenticationPrincipal)
.isEqualTo((String) messageUser.getPrincipal());
}
@Test
public void addsCsrfProtectionWhenNoAuthorization() throws InterruptedException {
loadConfig(NoInboundSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor
.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/authentication");
MessageChannel messageChannel = clientInboundChannel();
try {
messageChannel.send(message);
fail("Expected Exception");
}
catch (MessageDeliveryException success) {
assertThat(success.getCause()).isInstanceOf(MissingCsrfTokenException.class);
}
}
@Test
public void csrfProtectionForConnect() throws InterruptedException {
loadConfig(SockJsSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor
.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/authentication");
MessageChannel messageChannel = clientInboundChannel();
try {
messageChannel.send(message);
fail("Expected Exception");
}
catch (MessageDeliveryException success) {
assertThat(success.getCause()).isInstanceOf(MissingCsrfTokenException.class);
}
}
@Test
public void csrfProtectionDisabledForConnect() throws InterruptedException {
loadConfig(CsrfDisabledSockJsSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor
.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/permitAll/connect");
MessageChannel messageChannel = clientInboundChannel();
messageChannel.send(message);
}
@Test
public void messagesConnectUseCsrfTokenHandshakeInterceptor() throws Exception {
loadConfig(SockJsSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor
.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/authentication");
MockHttpServletRequest request = sockjsHttpRequest("/chat");
HttpRequestHandler handler = handler(request);
handler.handleRequest(request, new MockHttpServletResponse());
assertHandshake(request);
}
@Test
public void messagesConnectUseCsrfTokenHandshakeInterceptorMultipleMappings()
throws Exception {
loadConfig(SockJsSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor
.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/authentication");
MockHttpServletRequest request = sockjsHttpRequest("/other");
HttpRequestHandler handler = handler(request);
handler.handleRequest(request, new MockHttpServletResponse());
assertHandshake(request);
}
@Test
public void messagesConnectWebSocketUseCsrfTokenHandshakeInterceptor()
throws Exception {
loadConfig(WebSocketSecurityConfig.class);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor
.create(SimpMessageType.CONNECT);
Message<?> message = message(headers, "/authentication");
MockHttpServletRequest request = websocketHttpRequest("/websocket");
HttpRequestHandler handler = handler(request);
handler.handleRequest(request, new MockHttpServletResponse());
assertHandshake(request);
}
@Test
public void msmsRegistryCustomPatternMatcher()
throws Exception {
loadConfig(MsmsRegistryCustomPatternMatcherConfig.class);
clientInboundChannel().send(message("/app/a.b"));
try {
clientInboundChannel().send(message("/app/a.b.c"));
fail("Expected Exception");
}
catch (MessageDeliveryException expected) {
assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class);
}
}
@Configuration
@EnableWebSocketMessageBroker
@Import(SyncExecutorConfig.class)
static class MsmsRegistryCustomPatternMatcherConfig extends
AbstractSecurityWebSocketMessageBrokerConfigurer {
// @formatter:off
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry
.addEndpoint("/other")
.setHandshakeHandler(testHandshakeHandler());
}
// @formatter:on
// @formatter:off
@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
messages
.simpDestMatchers("/app/a.*").permitAll()
.anyMessage().denyAll();
}
// @formatter:on
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.setPathMatcher(new AntPathMatcher("."));
registry.enableSimpleBroker("/queue/", "/topic/");
registry.setApplicationDestinationPrefixes("/app");
}
@Bean
public TestHandshakeHandler testHandshakeHandler() {
return new TestHandshakeHandler();
}
}
@Test
public void overrideMsmsRegistryCustomPatternMatcher()
throws Exception {
loadConfig(OverrideMsmsRegistryCustomPatternMatcherConfig.class);
clientInboundChannel().send(message("/app/a/b"));
try {
clientInboundChannel().send(message("/app/a/b/c"));
fail("Expected Exception");
}
catch (MessageDeliveryException expected) {
assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class);
}
}
@Configuration
@EnableWebSocketMessageBroker
@Import(SyncExecutorConfig.class)
static class OverrideMsmsRegistryCustomPatternMatcherConfig extends
AbstractSecurityWebSocketMessageBrokerConfigurer {
// @formatter:off
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry
.addEndpoint("/other")
.setHandshakeHandler(testHandshakeHandler());
}
// @formatter:on
// @formatter:off
@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
messages
.simpDestPathMatcher(new AntPathMatcher())
.simpDestMatchers("/app/a/*").permitAll()
.anyMessage().denyAll();
}
// @formatter:on
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.setPathMatcher(new AntPathMatcher("."));
registry.enableSimpleBroker("/queue/", "/topic/");
registry.setApplicationDestinationPrefixes("/app");
}
@Bean
public TestHandshakeHandler testHandshakeHandler() {
return new TestHandshakeHandler();
}
}
@Test
public void defaultPatternMatcher()
throws Exception {
loadConfig(DefaultPatternMatcherConfig.class);
clientInboundChannel().send(message("/app/a/b"));
try {
clientInboundChannel().send(message("/app/a/b/c"));
fail("Expected Exception");
}
catch (MessageDeliveryException expected) {
assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class);
}
}
@Configuration
@EnableWebSocketMessageBroker
@Import(SyncExecutorConfig.class)
static class DefaultPatternMatcherConfig extends
AbstractSecurityWebSocketMessageBrokerConfigurer {
// @formatter:off
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry
.addEndpoint("/other")
.setHandshakeHandler(testHandshakeHandler());
}
// @formatter:on
// @formatter:off
@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
messages
.simpDestMatchers("/app/a/*").permitAll()
.anyMessage().denyAll();
}
// @formatter:on
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.enableSimpleBroker("/queue/", "/topic/");
registry.setApplicationDestinationPrefixes("/app");
}
@Bean
public TestHandshakeHandler testHandshakeHandler() {
return new TestHandshakeHandler();
}
}
@Test
public void customExpression()
throws Exception {
loadConfig(CustomExpressionConfig.class);
clientInboundChannel().send(message("/denyRob"));
this.messageUser = new TestingAuthenticationToken("rob", "password", "ROLE_USER");
try {
clientInboundChannel().send(message("/denyRob"));
fail("Expected Exception");
}
catch (MessageDeliveryException expected) {
assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class);
}
}
@Configuration
@EnableWebSocketMessageBroker
@Import(SyncExecutorConfig.class)
static class CustomExpressionConfig extends
AbstractSecurityWebSocketMessageBrokerConfigurer {
// @formatter:off
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry
.addEndpoint("/other")
.setHandshakeHandler(testHandshakeHandler());
}
// @formatter:on
// @formatter:off
@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
messages
.anyMessage().access("denyRob()");
}
// @formatter:on
@Bean
public static SecurityExpressionHandler<Message<Object>> messageSecurityExpressionHandler() {
return new DefaultMessageSecurityExpressionHandler<Object>() {
@Override
protected SecurityExpressionOperations createSecurityExpressionRoot(
Authentication authentication,
Message<Object> invocation) {
return new MessageSecurityExpressionRoot(authentication, invocation) {
public boolean denyRob() {
Authentication auth = getAuthentication();
return auth != null && !"rob".equals(auth.getName());
}
};
}
};
}
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.enableSimpleBroker("/queue/", "/topic/");
registry.setApplicationDestinationPrefixes("/app");
}
@Bean
public TestHandshakeHandler testHandshakeHandler() {
return new TestHandshakeHandler();
}
}
private void assertHandshake(HttpServletRequest request) {
TestHandshakeHandler handshakeHandler = context
.getBean(TestHandshakeHandler.class);
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(
token);
assertThat(handshakeHandler.attributes.get(sessionAttr)).isEqualTo(
request.getSession().getAttribute(sessionAttr));
}
private HttpRequestHandler handler(HttpServletRequest request) throws Exception {
HandlerMapping handlerMapping = context.getBean(HandlerMapping.class);
return (HttpRequestHandler) handlerMapping.getHandler(request).getHandler();
}
private MockHttpServletRequest websocketHttpRequest(String mapping) {
MockHttpServletRequest request = sockjsHttpRequest(mapping);
request.setRequestURI(mapping);
return request;
}
private MockHttpServletRequest sockjsHttpRequest(String mapping) {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("GET");
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE,
"/289/tpyx6mde/websocket");
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
request.getSession().setAttribute(sessionAttr, "sessionValue");
request.setAttribute(CsrfToken.class.getName(), token);
return request;
}
private Message<String> message(String destination) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();
return message(headers, destination);
}
private Message<String> message(SimpMessageHeaderAccessor headers, String destination) {
headers.setSessionId("123");
headers.setSessionAttributes(new HashMap<String, Object>());
if (destination != null) {
headers.setDestination(destination);
}
if (messageUser != null) {
headers.setUser(messageUser);
}
return new GenericMessage<String>("hi", headers.getMessageHeaders());
}
private MessageChannel clientInboundChannel() {
return context.getBean("clientInboundChannel", MessageChannel.class);
}
private void loadConfig(Class<?>... configs) {
context = new AnnotationConfigWebApplicationContext();
context.register(configs);
context.setServletConfig(new MockServletConfig());
context.refresh();
}
@Controller
static class MyController {
String authenticationPrincipal;
MyCustomArgument myCustomArgument;
@MessageMapping("/authentication")
public void authentication(@AuthenticationPrincipal String un) {
this.authenticationPrincipal = un;
}
@MessageMapping("/myCustom")
public void myCustom(MyCustomArgument myCustomArgument) {
this.myCustomArgument = myCustomArgument;
}
}
static class MyCustomArgument {
MyCustomArgument(String notDefaultConstr) {
}
}
static class MyCustomArgumentResolver implements HandlerMethodArgumentResolver {
public boolean supportsParameter(MethodParameter parameter) {
return parameter.getParameterType().isAssignableFrom(MyCustomArgument.class);
}
public Object resolveArgument(MethodParameter parameter, Message<?> message)
throws Exception {
return new MyCustomArgument("");
}
}
static class TestHandshakeHandler implements HandshakeHandler {
Map<String, Object> attributes;
public boolean doHandshake(ServerHttpRequest request,
ServerHttpResponse response, WebSocketHandler wsHandler,
Map<String, Object> attributes) throws HandshakeFailureException {
this.attributes = attributes;
if (wsHandler instanceof SockJsWebSocketHandler) {
// work around SPR-12716
SockJsWebSocketHandler sockJs = (SockJsWebSocketHandler) wsHandler;
WebSocketServerSockJsSession session = (WebSocketServerSockJsSession) ReflectionTestUtils
.getField(sockJs, "sockJsSession");
this.attributes = session.getAttributes();
}
return true;
}
}
@Configuration
@EnableWebSocketMessageBroker
@Import(SyncExecutorConfig.class)
static class SockJsSecurityConfig extends
AbstractSecurityWebSocketMessageBrokerConfigurer {
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/other").setHandshakeHandler(testHandshakeHandler())
.withSockJS().setInterceptors(new HttpSessionHandshakeInterceptor());
registry.addEndpoint("/chat").setHandshakeHandler(testHandshakeHandler())
.withSockJS().setInterceptors(new HttpSessionHandshakeInterceptor());
}
// @formatter:off
@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
messages
.simpDestMatchers("/permitAll/**").permitAll()
.simpDestMatchers("/beanResolver/**").access("@security.check()")
.anyMessage().denyAll();
}
// @formatter:on
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.enableSimpleBroker("/queue/", "/topic/");
registry.setApplicationDestinationPrefixes("/permitAll", "/denyAll");
}
@Bean
public MyController myController() {
return new MyController();
}
@Bean
public TestHandshakeHandler testHandshakeHandler() {
return new TestHandshakeHandler();
}
@Bean
public SecurityCheck security() {
return new SecurityCheck();
}
static class SecurityCheck {
private boolean check;
public boolean check() {
check = !check;
return check;
}
}
}
@Configuration
@EnableWebSocketMessageBroker
@Import(SyncExecutorConfig.class)
static class NoInboundSecurityConfig extends
AbstractSecurityWebSocketMessageBrokerConfigurer {
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/other").withSockJS()
.setInterceptors(new HttpSessionHandshakeInterceptor());
registry.addEndpoint("/chat").withSockJS()
.setInterceptors(new HttpSessionHandshakeInterceptor());
}
@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
}
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.enableSimpleBroker("/queue/", "/topic/");
registry.setApplicationDestinationPrefixes("/permitAll", "/denyAll");
}
@Bean
public MyController myController() {
return new MyController();
}
}
@Configuration
static class CsrfDisabledSockJsSecurityConfig extends SockJsSecurityConfig {
@Override
protected boolean sameOriginDisabled() {
return true;
}
}
@Configuration
@EnableWebSocketMessageBroker
@Import(SyncExecutorConfig.class)
static class WebSocketSecurityConfig extends
AbstractSecurityWebSocketMessageBrokerConfigurer {
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/websocket")
.setHandshakeHandler(testHandshakeHandler())
.addInterceptors(new HttpSessionHandshakeInterceptor());
}
// @formatter:off
@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
messages
.simpDestMatchers("/permitAll/**").permitAll()
.simpDestMatchers("/customExpression/**").access("denyRob")
.anyMessage().denyAll();
}
// @formatter:on
@Bean
public TestHandshakeHandler testHandshakeHandler() {
return new TestHandshakeHandler();
}
}
@Configuration
static class SyncExecutorConfig {
@Bean
public static SyncExecutorSubscribableChannelPostProcessor postProcessor() {
return new SyncExecutorSubscribableChannelPostProcessor();
}
}
}