/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.annotation.web.socket;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import static org.springframework.messaging.simp.SimpMessageType.*;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.mock.web.MockServletConfig;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.stereotype.Controller;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import java.util.HashMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;
public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests {
AnnotationConfigWebApplicationContext context;
TestingAuthenticationToken messageUser;
CsrfToken token;
String sessionAttr;
@Before
public void setup() {
token = new DefaultCsrfToken("header", "param", "token");
sessionAttr = "sessionAttr";
messageUser = new TestingAuthenticationToken("user", "pass", "ROLE_USER");
}
@After
public void cleanup() {
if (context != null) {
context.close();
}
}
@Test
public void securityMappings() {
loadConfig(WebSocketSecurityConfig.class);
clientInboundChannel().send(
message("/user/queue/errors", SimpMessageType.SUBSCRIBE));
try {
clientInboundChannel().send(message("/denyAll", SimpMessageType.MESSAGE));
fail("Expected Exception");
}
catch (MessageDeliveryException expected) {
assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class);
}
}
private void loadConfig(Class<?>... configs) {
context = new AnnotationConfigWebApplicationContext();
context.register(configs);
context.register(WebSocketConfig.class, SyncExecutorConfig.class);
context.setServletConfig(new MockServletConfig());
context.refresh();
}
private MessageChannel clientInboundChannel() {
return context.getBean("clientInboundChannel", MessageChannel.class);
}
private Message<String> message(String destination, SimpMessageType type) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type);
return message(headers, destination);
}
private Message<String> message(SimpMessageHeaderAccessor headers, String destination) {
headers.setSessionId("123");
headers.setSessionAttributes(new HashMap<String, Object>());
if (destination != null) {
headers.setDestination(destination);
}
if (messageUser != null) {
headers.setUser(messageUser);
}
return new GenericMessage<String>("hi", headers.getMessageHeaders());
}
@Controller
static class MyController {
@MessageMapping("/authentication")
public void authentication(@AuthenticationPrincipal String un) {
// ... do something ...
}
}
@Configuration
static class WebSocketSecurityConfig extends
AbstractSecurityWebSocketMessageBrokerConfigurer {
@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
messages.nullDestMatcher().authenticated()
// <1>
.simpSubscribeDestMatchers("/user/queue/errors").permitAll()
// <2>
.simpDestMatchers("/app/**").hasRole("USER")
// <3>
.simpSubscribeDestMatchers("/user/**", "/topic/friends/*")
.hasRole("USER") // <4>
.simpTypeMatchers(MESSAGE, SUBSCRIBE).denyAll() // <5>
.anyMessage().denyAll(); // <6>
}
}
@Configuration
@EnableWebSocketMessageBroker
static class WebSocketConfig extends AbstractWebSocketMessageBrokerConfigurer {
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/chat").withSockJS();
}
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.enableSimpleBroker("/queue/", "/topic/");
registry.setApplicationDestinationPrefixes("/permitAll", "/denyAll");
}
@Bean
public MyController myController() {
return new MyController();
}
}
@Configuration
static class SyncExecutorConfig {
@Bean
public static SyncExecutorSubscribableChannelPostProcessor postProcessor() {
return new SyncExecutorSubscribableChannelPostProcessor();
}
}
}