/* * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.security.config.websocket; import static org.springframework.security.config.Elements.*; import java.util.Comparator; import java.util.List; import java.util.Map; import org.springframework.beans.BeansException; import org.springframework.beans.PropertyValue; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanReference; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; import org.springframework.beans.factory.support.ManagedList; import org.springframework.beans.factory.support.ManagedMap; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.xml.BeanDefinitionParser; import org.springframework.beans.factory.xml.ParserContext; import org.springframework.beans.factory.xml.XmlReaderContext; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; import org.springframework.security.access.vote.ConsensusBased; import org.springframework.security.config.Elements; import org.springframework.security.messaging.access.expression.ExpressionBasedMessageSecurityMetadataSourceFactory; import org.springframework.security.messaging.access.expression.MessageExpressionVoter; import org.springframework.security.messaging.access.intercept.ChannelSecurityInterceptor; import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver; import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher; import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher; import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor; import org.springframework.util.AntPathMatcher; import org.springframework.util.PathMatcher; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; import org.w3c.dom.Element; /** * Parses Spring Security's websocket namespace support. A simple example is: * * <code> * <websocket-message-broker> * <intercept-message pattern='/permitAll' access='permitAll' /> * <intercept-message pattern='/denyAll' access='denyAll' /> * </websocket-message-broker> * </code> * * <p> * The above configuration will ensure that any SimpAnnotationMethodMessageHandler has the * AuthenticationPrincipalArgumentResolver registered as a custom argument resolver. It * also ensures that the SecurityContextChannelInterceptor is automatically registered for * the clientInboundChannel. Last, it ensures that a ChannelSecurityInterceptor is * registered with the clientInboundChannel. * </p> * * <p> * If finer control is necessary, the id attribute can be used as shown below: * </p> * * <code> * <websocket-message-broker id="channelSecurityInterceptor"> * <intercept-message pattern='/permitAll' access='permitAll' /> * <intercept-message pattern='/denyAll' access='denyAll' /> * </websocket-message-broker> * </code> * * <p> * Now the configuration will only create a bean named ChannelSecurityInterceptor and * assign it to the id of channelSecurityInterceptor. Users can explicitly wire Spring * Security using the standard Spring Messaging XML namespace support. * </p> * * @author Rob Winch * @since 4.0 */ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements BeanDefinitionParser { private static final String ID_ATTR = "id"; private static final String DISABLED_ATTR = "same-origin-disabled"; private static final String PATTERN_ATTR = "pattern"; private static final String ACCESS_ATTR = "access"; private static final String TYPE_ATTR = "type"; private static final String PATH_MATCHER_BEAN_NAME = "springSecurityMessagePathMatcher"; /** * @param element * @param parserContext * @return */ public BeanDefinition parse(Element element, ParserContext parserContext) { BeanDefinitionRegistry registry = parserContext.getRegistry(); XmlReaderContext context = parserContext.getReaderContext(); ManagedMap<BeanDefinition, String> matcherToExpression = new ManagedMap<BeanDefinition, String>(); String id = element.getAttribute(ID_ATTR); Element expressionHandlerElt = DomUtils.getChildElementByTagName(element, EXPRESSION_HANDLER); String expressionHandlerRef = expressionHandlerElt == null ? null : expressionHandlerElt.getAttribute("ref"); boolean expressionHandlerDefined = StringUtils.hasText(expressionHandlerRef); boolean sameOriginDisabled = Boolean.parseBoolean(element .getAttribute(DISABLED_ATTR)); List<Element> interceptMessages = DomUtils.getChildElementsByTagName(element, Elements.INTERCEPT_MESSAGE); for (Element interceptMessage : interceptMessages) { String matcherPattern = interceptMessage.getAttribute(PATTERN_ATTR); String accessExpression = interceptMessage.getAttribute(ACCESS_ATTR); String messageType = interceptMessage.getAttribute(TYPE_ATTR); BeanDefinition matcher = createMatcher(matcherPattern, messageType, parserContext, interceptMessage); matcherToExpression.put(matcher, accessExpression); } BeanDefinitionBuilder mds = BeanDefinitionBuilder .rootBeanDefinition(ExpressionBasedMessageSecurityMetadataSourceFactory.class); mds.setFactoryMethod("createExpressionMessageMetadataSource"); mds.addConstructorArgValue(matcherToExpression); if(expressionHandlerDefined) { mds.addConstructorArgReference(expressionHandlerRef); } String mdsId = context.registerWithGeneratedName(mds.getBeanDefinition()); ManagedList<BeanDefinition> voters = new ManagedList<BeanDefinition>(); BeanDefinitionBuilder messageExpressionVoterBldr = BeanDefinitionBuilder.rootBeanDefinition(MessageExpressionVoter.class); if(expressionHandlerDefined) { messageExpressionVoterBldr.addPropertyReference("expressionHandler", expressionHandlerRef); } voters.add(messageExpressionVoterBldr.getBeanDefinition()); BeanDefinitionBuilder adm = BeanDefinitionBuilder .rootBeanDefinition(ConsensusBased.class); adm.addConstructorArgValue(voters); BeanDefinitionBuilder inboundChannelSecurityInterceptor = BeanDefinitionBuilder .rootBeanDefinition(ChannelSecurityInterceptor.class); inboundChannelSecurityInterceptor.addConstructorArgValue(registry .getBeanDefinition(mdsId)); inboundChannelSecurityInterceptor.addPropertyValue("accessDecisionManager", adm.getBeanDefinition()); String inSecurityInterceptorName = context .registerWithGeneratedName(inboundChannelSecurityInterceptor .getBeanDefinition()); if (StringUtils.hasText(id)) { registry.registerAlias(inSecurityInterceptorName, id); if(!registry.containsBeanDefinition(PATH_MATCHER_BEAN_NAME)) { registry.registerBeanDefinition(PATH_MATCHER_BEAN_NAME, new RootBeanDefinition(AntPathMatcher.class)); } } else { BeanDefinitionBuilder mspp = BeanDefinitionBuilder .rootBeanDefinition(MessageSecurityPostProcessor.class); mspp.addConstructorArgValue(inSecurityInterceptorName); mspp.addConstructorArgValue(sameOriginDisabled); context.registerWithGeneratedName(mspp.getBeanDefinition()); } return null; } private BeanDefinition createMatcher(String matcherPattern, String messageType, ParserContext parserContext, Element interceptMessage) { boolean hasPattern = StringUtils.hasText(matcherPattern); boolean hasMessageType = StringUtils.hasText(messageType); if (!hasPattern) { BeanDefinitionBuilder matcher = BeanDefinitionBuilder .rootBeanDefinition(SimpMessageTypeMatcher.class); matcher.addConstructorArgValue(messageType); return matcher.getBeanDefinition(); } String factoryName = null; if (hasPattern && hasMessageType) { SimpMessageType type = SimpMessageType.valueOf(messageType); if (SimpMessageType.MESSAGE == type) { factoryName = "createMessageMatcher"; } else if (SimpMessageType.SUBSCRIBE == type) { factoryName = "createSubscribeMatcher"; } else { parserContext .getReaderContext() .error("Cannot use intercept-websocket@message-type=" + messageType + " with a pattern because the type does not have a destination.", interceptMessage); } } BeanDefinitionBuilder matcher = BeanDefinitionBuilder .rootBeanDefinition(SimpDestinationMessageMatcher.class); matcher.setFactoryMethod(factoryName); matcher.addConstructorArgValue(matcherPattern); matcher.addConstructorArgValue(new RuntimeBeanReference("springSecurityMessagePathMatcher")); return matcher.getBeanDefinition(); } static class MessageSecurityPostProcessor implements BeanDefinitionRegistryPostProcessor { /** * This is not available prior to Spring 4.2 */ private static final String WEB_SOCKET_AMMH_CLASS_NAME = "org.springframework.web.socket.messaging.WebSocketAnnotationMethodMessageHandler"; private static final String CLIENT_INBOUND_CHANNEL_BEAN_ID = "clientInboundChannel"; private static final String INTERCEPTORS_PROP = "interceptors"; private static final String CUSTOM_ARG_RESOLVERS_PROP = "customArgumentResolvers"; private final String inboundSecurityInterceptorId; private final boolean sameOriginDisabled; public MessageSecurityPostProcessor(String inboundSecurityInterceptorId, boolean sameOriginDisabled) { this.inboundSecurityInterceptorId = inboundSecurityInterceptorId; this.sameOriginDisabled = sameOriginDisabled; } public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { String[] beanNames = registry.getBeanDefinitionNames(); for (String beanName : beanNames) { BeanDefinition bd = registry.getBeanDefinition(beanName); String beanClassName = bd.getBeanClassName(); if (SimpAnnotationMethodMessageHandler.class.getName().equals(beanClassName) || WEB_SOCKET_AMMH_CLASS_NAME.equals(beanClassName)) { PropertyValue current = bd.getPropertyValues().getPropertyValue( CUSTOM_ARG_RESOLVERS_PROP); ManagedList<Object> argResolvers = new ManagedList<Object>(); if (current != null) { argResolvers.addAll((ManagedList<?>) current.getValue()); } argResolvers.add(new RootBeanDefinition( AuthenticationPrincipalArgumentResolver.class)); bd.getPropertyValues().add(CUSTOM_ARG_RESOLVERS_PROP, argResolvers); if(!registry.containsBeanDefinition(PATH_MATCHER_BEAN_NAME)) { PropertyValue pathMatcherProp = bd.getPropertyValues().getPropertyValue("pathMatcher"); Object pathMatcher = pathMatcherProp == null ? null : pathMatcherProp.getValue(); if(pathMatcher instanceof BeanReference) { registry.registerAlias(((BeanReference) pathMatcher).getBeanName(), PATH_MATCHER_BEAN_NAME); } } } else if ("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler" .equals(beanClassName)) { addCsrfTokenHandshakeInterceptor(bd); } else if ("org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService" .equals(beanClassName)) { addCsrfTokenHandshakeInterceptor(bd); } else if ("org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService" .equals(beanClassName)) { addCsrfTokenHandshakeInterceptor(bd); } } if (!registry.containsBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID)) { return; } ManagedList<Object> interceptors = new ManagedList(); interceptors.add(new RootBeanDefinition( SecurityContextChannelInterceptor.class)); if (!sameOriginDisabled) { interceptors.add(new RootBeanDefinition(CsrfChannelInterceptor.class)); } interceptors.add(registry.getBeanDefinition(inboundSecurityInterceptorId)); BeanDefinition inboundChannel = registry .getBeanDefinition(CLIENT_INBOUND_CHANNEL_BEAN_ID); PropertyValue currentInterceptorsPv = inboundChannel.getPropertyValues() .getPropertyValue(INTERCEPTORS_PROP); if (currentInterceptorsPv != null) { ManagedList<?> currentInterceptors = (ManagedList<?>) currentInterceptorsPv .getValue(); interceptors.addAll(currentInterceptors); } inboundChannel.getPropertyValues().add(INTERCEPTORS_PROP, interceptors); if(!registry.containsBeanDefinition(PATH_MATCHER_BEAN_NAME)) { registry.registerBeanDefinition(PATH_MATCHER_BEAN_NAME, new RootBeanDefinition(AntPathMatcher.class)); } } private void addCsrfTokenHandshakeInterceptor(BeanDefinition bd) { if (sameOriginDisabled) { return; } String interceptorPropertyName = "handshakeInterceptors"; ManagedList<? super Object> interceptors = new ManagedList<Object>(); interceptors.add(new RootBeanDefinition(CsrfTokenHandshakeInterceptor.class)); interceptors.addAll((ManagedList<Object>) bd.getPropertyValues().get( interceptorPropertyName)); bd.getPropertyValues().add(interceptorPropertyName, interceptors); } public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { } } static class DelegatingPathMatcher implements PathMatcher { private PathMatcher delegate = new AntPathMatcher(); public boolean isPattern(String path) { return delegate.isPattern(path); } public boolean match(String pattern, String path) { return delegate.match(pattern, path); } public boolean matchStart(String pattern, String path) { return delegate.matchStart(pattern, path); } public String extractPathWithinPattern(String pattern, String path) { return delegate.extractPathWithinPattern(pattern, path); } public Map<String, String> extractUriTemplateVariables(String pattern, String path) { return delegate.extractUriTemplateVariables(pattern, path); } public Comparator<String> getPatternComparator(String path) { return delegate.getPatternComparator(path); } public String combine(String pattern1, String pattern2) { return delegate.combine(pattern1, pattern2); } void setPathMatcher(PathMatcher pathMatcher) { this.delegate = pathMatcher; } } }