/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.security.config.annotation.web.socket; import java.util.ArrayList; import java.util.List; import java.util.Map; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.SmartInitializingSingleton; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; import org.springframework.messaging.Message; import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.expression.SecurityExpressionHandler; import org.springframework.security.access.vote.AffirmativeBased; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.configuration.ObjectPostProcessorConfiguration; import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry; import org.springframework.security.messaging.access.expression.DefaultMessageSecurityExpressionHandler; import org.springframework.security.messaging.access.expression.MessageExpressionVoter; import org.springframework.security.messaging.access.intercept.ChannelSecurityInterceptor; import org.springframework.security.messaging.access.intercept.MessageSecurityMetadataSource; import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver; import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor; import org.springframework.util.AntPathMatcher; import org.springframework.util.PathMatcher; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService; /** * Allows configuring WebSocket Authorization. * * <p> * For example: * </p> * * <pre> * @Configuration * public class WebSocketSecurityConfig extends * AbstractSecurityWebSocketMessageBrokerConfigurer { * * @Override * protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { * messages.simpDestMatchers("/user/queue/errors").permitAll() * .simpDestMatchers("/admin/**").hasRole("ADMIN").anyMessage() * .authenticated(); * } * } * </pre> * * * @since 4.0 * @author Rob Winch */ @Order(Ordered.HIGHEST_PRECEDENCE + 100) @Import(ObjectPostProcessorConfiguration.class) public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer implements SmartInitializingSingleton { private final WebSocketMessageSecurityMetadataSourceRegistry inboundRegistry = new WebSocketMessageSecurityMetadataSourceRegistry(); private SecurityExpressionHandler<Message<Object>> defaultExpressionHandler = new DefaultMessageSecurityExpressionHandler<Object>(); private SecurityExpressionHandler<Message<Object>> expressionHandler; private ApplicationContext context; public void registerStompEndpoints(StompEndpointRegistry registry) { } @Override public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) { argumentResolvers.add(new AuthenticationPrincipalArgumentResolver()); } @Override public final void configureClientInboundChannel(ChannelRegistration registration) { ChannelSecurityInterceptor inboundChannelSecurity = inboundChannelSecurity(); registration.setInterceptors(securityContextChannelInterceptor()); if (!sameOriginDisabled()) { registration.setInterceptors(csrfChannelInterceptor()); } if (inboundRegistry.containsMapping()) { registration.setInterceptors(inboundChannelSecurity); } customizeClientInboundChannel(registration); } private PathMatcher getDefaultPathMatcher() { try { return context.getBean(SimpAnnotationMethodMessageHandler.class).getPathMatcher(); } catch(NoSuchBeanDefinitionException e) { return new AntPathMatcher(); } } /** * <p> * Determines if a CSRF token is required for connecting. This protects against remote * sites from connecting to the application and being able to read/write data over the * connection. The default is false (the token is required). * </p> * <p> * Subclasses can override this method to disable CSRF protection * </p> * * @return false if a CSRF token is required for connecting, else true */ protected boolean sameOriginDisabled() { return false; } /** * Allows subclasses to customize the configuration of the {@link ChannelRegistration} * . * * @param registration the {@link ChannelRegistration} to customize */ protected void customizeClientInboundChannel(ChannelRegistration registration) { } @Bean public CsrfChannelInterceptor csrfChannelInterceptor() { return new CsrfChannelInterceptor(); } @Bean public ChannelSecurityInterceptor inboundChannelSecurity() { ChannelSecurityInterceptor channelSecurityInterceptor = new ChannelSecurityInterceptor( inboundMessageSecurityMetadataSource()); MessageExpressionVoter<Object> voter = new MessageExpressionVoter<Object>(); voter.setExpressionHandler(getMessageExpressionHandler()); List<AccessDecisionVoter<? extends Object>> voters = new ArrayList<AccessDecisionVoter<? extends Object>>(); voters.add(voter); AffirmativeBased manager = new AffirmativeBased(voters); channelSecurityInterceptor.setAccessDecisionManager(manager); return channelSecurityInterceptor; } @Bean public SecurityContextChannelInterceptor securityContextChannelInterceptor() { return new SecurityContextChannelInterceptor(); } @Bean public MessageSecurityMetadataSource inboundMessageSecurityMetadataSource() { inboundRegistry.expressionHandler(getMessageExpressionHandler()); configureInbound(inboundRegistry); return inboundRegistry.createMetadataSource(); } /** * * @param messages */ protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { } private static class WebSocketMessageSecurityMetadataSourceRegistry extends MessageSecurityMetadataSourceRegistry { @Override public MessageSecurityMetadataSource createMetadataSource() { return super.createMetadataSource(); } @Override protected boolean containsMapping() { return super.containsMapping(); } @Override protected boolean isSimpDestPathMatcherConfigured() { return super.isSimpDestPathMatcherConfigured(); } } @Autowired public void setApplicationContext(ApplicationContext context) { this.context = context; } @Deprecated public void setMessageExpessionHandler(List<SecurityExpressionHandler<Message<Object>>> expressionHandlers) { setMessageExpressionHandler(expressionHandlers); } @Autowired(required = false) public void setMessageExpressionHandler(List<SecurityExpressionHandler<Message<Object>>> expressionHandlers) { if(expressionHandlers.size() == 1) { this.expressionHandler = expressionHandlers.get(0); } } @Autowired(required = false) public void setObjectPostProcessor(ObjectPostProcessor<Object> objectPostProcessor) { defaultExpressionHandler = objectPostProcessor.postProcess(defaultExpressionHandler); } private SecurityExpressionHandler<Message<Object>> getMessageExpressionHandler() { if(expressionHandler == null) { return defaultExpressionHandler; } return expressionHandler; } public void afterSingletonsInstantiated() { if (sameOriginDisabled()) { return; } String beanName = "stompWebSocketHandlerMapping"; SimpleUrlHandlerMapping mapping = context.getBean(beanName, SimpleUrlHandlerMapping.class); Map<String, Object> mappings = mapping.getHandlerMap(); for (Object object : mappings.values()) { if (object instanceof SockJsHttpRequestHandler) { SockJsHttpRequestHandler sockjsHandler = (SockJsHttpRequestHandler) object; SockJsService sockJsService = sockjsHandler.getSockJsService(); if (!(sockJsService instanceof TransportHandlingSockJsService)) { throw new IllegalStateException( "sockJsService must be instance of TransportHandlingSockJsService got " + sockJsService); } TransportHandlingSockJsService transportHandlingSockJsService = (TransportHandlingSockJsService) sockJsService; List<HandshakeInterceptor> handshakeInterceptors = transportHandlingSockJsService .getHandshakeInterceptors(); List<HandshakeInterceptor> interceptorsToSet = new ArrayList<HandshakeInterceptor>( handshakeInterceptors.size() + 1); interceptorsToSet.add(new CsrfTokenHandshakeInterceptor()); interceptorsToSet.addAll(handshakeInterceptors); transportHandlingSockJsService .setHandshakeInterceptors(interceptorsToSet); } else if (object instanceof WebSocketHttpRequestHandler) { WebSocketHttpRequestHandler handler = (WebSocketHttpRequestHandler) object; List<HandshakeInterceptor> handshakeInterceptors = handler .getHandshakeInterceptors(); List<HandshakeInterceptor> interceptorsToSet = new ArrayList<HandshakeInterceptor>( handshakeInterceptors.size() + 1); interceptorsToSet.add(new CsrfTokenHandshakeInterceptor()); interceptorsToSet.addAll(handshakeInterceptors); handler.setHandshakeInterceptors(interceptorsToSet); } else { throw new IllegalStateException( "Bean " + beanName + " is expected to contain mappings to either a SockJsHttpRequestHandler or a WebSocketHttpRequestHandler but got " + object); } } if (inboundRegistry.containsMapping() && !inboundRegistry.isSimpDestPathMatcherConfigured()) { PathMatcher pathMatcher = getDefaultPathMatcher(); inboundRegistry.simpDestPathMatcher(pathMatcher); } } }