/* * Copyright 2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.bearchoke.platform.server.frontend.config; import com.bearchoke.platform.base.SpringSecurityHelper; import com.bearchoke.platform.server.common.ServerConstants; import com.bearchoke.platform.domain.user.repositories.ActiveWebSocketUserRepository; import com.bearchoke.platform.server.common.web.websocket.WebSocketConnectHandler; import com.bearchoke.platform.server.common.web.websocket.WebSocketDisconnectHandler; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.env.Environment; import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.converter.MappingJackson2MessageConverter; import org.springframework.messaging.converter.MessageConverter; import org.springframework.messaging.simp.SimpMessageSendingOperations; import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.support.ChannelInterceptorAdapter; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken; import org.springframework.session.ExpiringSession; import org.springframework.session.web.socket.config.annotation.AbstractSessionWebSocketMessageBrokerConfigurer; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.config.annotation.WebSocketTransportRegistration; import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import javax.inject.Inject; import java.security.Principal; import java.util.List; import java.util.Map; /** * Created by Bjorn Harvold * Date: 8/25/14 * Time: 1:12 AM * Responsibility: */ @Configuration @EnableWebSocketMessageBroker @Log4j2 public class WebSocketConfig<S extends ExpiringSession> extends AbstractSessionWebSocketMessageBrokerConfigurer<S> { @Inject private Environment environment; @Inject private ObjectMapper objectMapper; @Inject @Qualifier("preAuthAuthenticationManager") private AuthenticationManager preAuthAuthenticationManager; @Override public void configureStompEndpoints(StompEndpointRegistry registry) { log.info("WebSocket config: Allowing only origins: " + environment.getProperty("allowed.origin")); registry.addEndpoint("/ws").setAllowedOrigins(environment.getProperty("allowed.origin")).setHandshakeHandler(new SecureHandshakeHandler(preAuthAuthenticationManager)) .withSockJS() .setStreamBytesLimit(512 * 1024) .setHttpMessageCacheSize(1000) .setDisconnectDelay(30 * 1000) .setInterceptors(new HttpSessionIdHandshakeInterceptor()) ; } @Override public void configureWebSocketTransport(WebSocketTransportRegistration registration) { registration.setSendTimeLimit(15 * 1000).setSendBufferSizeLimit(512 * 1024); registration.setMessageSizeLimit(128 * 1024); } @Override public void configureClientInboundChannel(ChannelRegistration channelRegistration) { channelRegistration.setInterceptors(sessionContextChannelInterceptorAdapter()); } @Override public void configureClientOutboundChannel(ChannelRegistration channelRegistration) { } @Override public boolean configureMessageConverters(List<MessageConverter> converters) { MappingJackson2MessageConverter jacksonConverter = new MappingJackson2MessageConverter(); jacksonConverter.setObjectMapper(objectMapper); converters.add(jacksonConverter); return true; } @Override public void configureMessageBroker(MessageBrokerRegistry config) { config.setApplicationDestinationPrefixes("/app"); config.enableSimpleBroker("/queue/", "/topic/"); // This uses too much data for CF AMPQ service // StompBrokerRelayRegistration stompBrokerRelayRegistration = config.enableStompBrokerRelay("/queue/", "/topic/"); // // stompBrokerRelayRegistration.setRelayHost(environment.getProperty("rabbitmq.host")); // stompBrokerRelayRegistration.setVirtualHost(environment.getProperty("rabbitmq.virtualhost")); // stompBrokerRelayRegistration.setClientLogin(environment.getProperty("rabbitmq.username")); // stompBrokerRelayRegistration.setSystemLogin(environment.getProperty("rabbitmq.username")); // stompBrokerRelayRegistration.setClientPasscode(environment.getProperty("rabbitmq.password")); // stompBrokerRelayRegistration.setSystemPasscode(environment.getProperty("rabbitmq.password")); // only if we want to use . instead of / for path separator e.g. /app/user.chat // config.setPathMatcher(new AntPathMatcher(".")); } @Bean public WebSocketConnectHandler<S> webSocketConnectHandler(SimpMessageSendingOperations messagingTemplate, ActiveWebSocketUserRepository repository) { return new WebSocketConnectHandler<>(messagingTemplate, repository); } @Bean public WebSocketDisconnectHandler<S> webSocketDisconnectHandler(SimpMessageSendingOperations messagingTemplate, ActiveWebSocketUserRepository repository) { return new WebSocketDisconnectHandler<S>(messagingTemplate, repository); } /** * For serving up websockets in a Tomcat / GlassFish / WildFly environment * * @return */ @Bean public ServletServerContainerFactoryBean createWebSocketContainer() { ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean(); container.setMaxTextMessageBufferSize(8192); container.setMaxBinaryMessageBufferSize(8192); return container; } @Bean public ChannelInterceptorAdapter sessionContextChannelInterceptorAdapter() { return new ChannelInterceptorAdapter() { @Override public Message<?> preSend(Message<?> message, MessageChannel channel) { StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message); StompCommand command = accessor.getCommand(); if (log.isDebugEnabled() && command != null) { log.debug("StompCommand: " + command.toString()); } String authToken = accessor.getFirstNativeHeader(ServerConstants.X_AUTH_TOKEN); if (log.isDebugEnabled() && StringUtils.isNotEmpty(authToken)) { log.debug("Header auth token: " + authToken); } if (StringUtils.isNotBlank(authToken)) { // set cached authenticated user back in the spring security context Authentication authentication = preAuthAuthenticationManager.authenticate(new PreAuthenticatedAuthenticationToken(authToken, "N/A")); if (log.isDebugEnabled()) { log.debug("Adding Authentication to SecurityContext for WebSocket call: " + authentication); } SpringSecurityHelper.setAuthentication(authentication); } return super.preSend(message, channel); } }; } static class HttpSessionIdHandshakeInterceptor implements HandshakeInterceptor { @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { if (request instanceof ServletServerHttpRequest) { HttpHeaders headers = request.getHeaders(); if (headers.containsKey(ServerConstants.X_AUTH_TOKEN)) { List<String> authToken = headers.get(ServerConstants.X_AUTH_TOKEN); if (log.isDebugEnabled()) { log.debug("Header auth token: " + authToken.get(0)); } attributes.put(ServerConstants.X_AUTH_TOKEN, headers.get(ServerConstants.X_AUTH_TOKEN)); } } return true; } public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception ex) { } } static class SecureHandshakeHandler extends DefaultHandshakeHandler { private final AuthenticationManager authenticationManager; public SecureHandshakeHandler(AuthenticationManager authenticationManager) { this.authenticationManager = authenticationManager; } @Override protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler, Map<String, Object> attributes) { Principal result = null; String authToken = null; HttpHeaders headers = request.getHeaders(); if (log.isDebugEnabled()) { log.debug("Determining user..."); } if (headers.containsKey(ServerConstants.X_AUTH_TOKEN)) { authToken = headers.getFirst(ServerConstants.X_AUTH_TOKEN); authenticate(authToken); } else if (attributes.containsKey(ServerConstants.X_AUTH_TOKEN)) { authToken = (String) attributes.get(ServerConstants.X_AUTH_TOKEN); authenticate(authToken); } else { result = super.determineUser(request, wsHandler, attributes); } return result; } private void authenticate(String authToken) { if (log.isDebugEnabled() && StringUtils.isNotEmpty(authToken)) { log.debug("Header auth token: " + authToken); } if (StringUtils.isNotBlank(authToken)) { // set cached authenticated user back in the spring security context Authentication authentication = authenticationManager.authenticate(new PreAuthenticatedAuthenticationToken(authToken, "N/A")); if (log.isDebugEnabled()) { log.debug("Adding Authentication to SecurityContext for WebSocket call: " + authentication); } SpringSecurityHelper.setAuthentication(authentication); } } } }