/*
* Copyright 2014-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.websocket.inbound;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.Lifecycle;
import org.springframework.integration.channel.FixedSubscriberChannel;
import org.springframework.integration.endpoint.MessageProducerSupport;
import org.springframework.integration.support.json.JacksonPresent;
import org.springframework.integration.websocket.IntegrationWebSocketContainer;
import org.springframework.integration.websocket.ServerWebSocketContainer;
import org.springframework.integration.websocket.WebSocketListener;
import org.springframework.integration.websocket.event.ReceiptEvent;
import org.springframework.integration.websocket.support.PassThruSubProtocolHandler;
import org.springframework.integration.websocket.support.SubProtocolHandlerRegistry;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandlingException;
import org.springframework.messaging.converter.ByteArrayMessageConverter;
import org.springframework.messaging.converter.CompositeMessageConverter;
import org.springframework.messaging.converter.DefaultContentTypeResolver;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.converter.StringMessageConverter;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler;
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeTypeUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.messaging.SessionConnectedEvent;
/**
* @author Artem Bilan
* @since 4.1
*/
public class WebSocketInboundChannelAdapter extends MessageProducerSupport
implements WebSocketListener, ApplicationEventPublisherAware {
private static final byte[] EMPTY_PAYLOAD = new byte[0];
private final List<MessageConverter> defaultConverters = new ArrayList<MessageConverter>(3);
private ApplicationEventPublisher eventPublisher;
{
this.defaultConverters.add(new StringMessageConverter());
this.defaultConverters.add(new ByteArrayMessageConverter());
if (JacksonPresent.isJackson2Present()) {
DefaultContentTypeResolver resolver = new DefaultContentTypeResolver();
resolver.setDefaultMimeType(MimeTypeUtils.APPLICATION_JSON);
MappingJackson2MessageConverter converter = new MappingJackson2MessageConverter();
converter.setContentTypeResolver(resolver);
this.defaultConverters.add(converter);
}
}
private final CompositeMessageConverter messageConverter = new CompositeMessageConverter(this.defaultConverters);
private final IntegrationWebSocketContainer webSocketContainer;
private final boolean server;
private final SubProtocolHandlerRegistry subProtocolHandlerRegistry;
private final MessageChannel subProtocolHandlerChannel;
private final AtomicReference<Class<?>> payloadType = new AtomicReference<Class<?>>(String.class);
private volatile List<MessageConverter> messageConverters;
private volatile boolean mergeWithDefaultConverters = false;
private volatile boolean active;
private volatile boolean useBroker;
private AbstractBrokerMessageHandler brokerHandler;
public WebSocketInboundChannelAdapter(IntegrationWebSocketContainer webSocketContainer) {
this(webSocketContainer, new SubProtocolHandlerRegistry(new PassThruSubProtocolHandler()));
}
public WebSocketInboundChannelAdapter(IntegrationWebSocketContainer webSocketContainer,
SubProtocolHandlerRegistry protocolHandlerRegistry) {
Assert.notNull(webSocketContainer, "'webSocketContainer' must not be null");
Assert.notNull(protocolHandlerRegistry, "'protocolHandlerRegistry' must not be null");
this.webSocketContainer = webSocketContainer;
this.server = this.webSocketContainer instanceof ServerWebSocketContainer;
this.subProtocolHandlerRegistry = protocolHandlerRegistry;
this.subProtocolHandlerChannel = new FixedSubscriberChannel(message -> {
try {
handleMessageAndSend(message);
}
catch (Exception e) {
throw new MessageHandlingException(message, e);
}
});
}
/**
* Set the message converters to use. These converters are used to convert the message to send for appropriate
* internal subProtocols type.
* @param messageConverters The message converters.
*/
public void setMessageConverters(List<MessageConverter> messageConverters) {
Assert.noNullElements(messageConverters.toArray(), "'messageConverters' must not contain null entries");
this.messageConverters = new ArrayList<MessageConverter>(messageConverters);
}
/**
* Flag which determines if the default converters should be available after
* custom converters.
* @param mergeWithDefaultConverters true to merge, false to replace.
*/
public void setMergeWithDefaultConverters(boolean mergeWithDefaultConverters) {
this.mergeWithDefaultConverters = mergeWithDefaultConverters;
}
/**
* Set the type for target message payload to convert the WebSocket message body to.
* @param payloadType to convert inbound WebSocket message body
* @see CompositeMessageConverter
*/
public void setPayloadType(Class<?> payloadType) {
Assert.notNull(payloadType, "'payloadType' must not be null");
this.payloadType.set(payloadType);
}
/**
* Specify if this adapter should use an existing single {@link AbstractBrokerMessageHandler}
* bean for {@code non-MESSAGE} {@link org.springframework.web.socket.WebSocketMessage}s
* and to route messages with broker destinations.
* Since only single {@link AbstractBrokerMessageHandler} bean is allowed in the current
* application context, the algorithm to lookup the former by type, rather than applying
* the bean reference.
* This is used only on server side and is ignored from client side.
* @param useBroker the boolean flag.
*/
public void setUseBroker(boolean useBroker) {
this.useBroker = useBroker;
}
@Override
public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
this.eventPublisher = applicationEventPublisher;
}
@Override
protected void onInit() {
super.onInit();
this.webSocketContainer.setMessageListener(this);
if (!CollectionUtils.isEmpty(this.messageConverters)) {
List<MessageConverter> converters = this.messageConverter.getConverters();
if (this.mergeWithDefaultConverters) {
ListIterator<MessageConverter> iterator =
this.messageConverters.listIterator(this.messageConverters.size());
while (iterator.hasPrevious()) {
MessageConverter converter = iterator.previous();
converters.add(0, converter);
}
}
else {
converters.clear();
converters.addAll(this.messageConverters);
}
}
if (this.server && this.useBroker) {
Map<String, AbstractBrokerMessageHandler> brokers = getApplicationContext()
.getBeansOfType(AbstractBrokerMessageHandler.class);
for (AbstractBrokerMessageHandler broker : brokers.values()) {
if (broker instanceof SimpleBrokerMessageHandler || broker instanceof StompBrokerRelayMessageHandler) {
this.brokerHandler = broker;
break;
}
}
Assert.state(this.brokerHandler != null,
"WebSocket Broker Relay isn't present in the application context; " +
"it is required when 'useBroker = true'.");
}
}
@Override
public List<String> getSubProtocols() {
return this.subProtocolHandlerRegistry.getSubProtocols();
}
@Override
public void afterSessionStarted(WebSocketSession session) throws Exception {
if (isActive()) {
this.subProtocolHandlerRegistry.findProtocolHandler(session)
.afterSessionStarted(session, this.subProtocolHandlerChannel);
}
}
@Override
public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus) throws Exception {
if (isActive()) {
this.subProtocolHandlerRegistry.findProtocolHandler(session)
.afterSessionEnded(session, closeStatus, this.subProtocolHandlerChannel);
}
}
@Override
public void onMessage(WebSocketSession session, WebSocketMessage<?> webSocketMessage) throws Exception {
if (isActive()) {
this.subProtocolHandlerRegistry.findProtocolHandler(session)
.handleMessageFromClient(session, webSocketMessage, this.subProtocolHandlerChannel);
}
}
@Override
public String getComponentType() {
return "websocket:inbound-channel-adapter";
}
@Override
protected void doStart() {
this.active = true;
if (this.webSocketContainer instanceof Lifecycle) {
((Lifecycle) this.webSocketContainer).start();
}
}
@Override
protected void doStop() {
this.active = false;
}
private boolean isActive() {
if (!this.active) {
logger.warn("MessageProducer '" + this + " 'isn't started to accept WebSocket events.");
}
return this.active;
}
@SuppressWarnings("unchecked")
private void handleMessageAndSend(Message<?> message) throws Exception {
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message);
StompCommand stompCommand = (StompCommand) headerAccessor.getHeader("stompCommand");
SimpMessageType messageType = headerAccessor.getMessageType();
if ((messageType == null || SimpMessageType.MESSAGE.equals(messageType)
|| (SimpMessageType.CONNECT.equals(messageType) && !this.useBroker)
|| StompCommand.CONNECTED.equals(stompCommand)
|| StompCommand.RECEIPT.equals(stompCommand))
&& !checkDestinationPrefix(headerAccessor.getDestination())) {
if (SimpMessageType.CONNECT.equals(messageType)) {
String sessionId = headerAccessor.getSessionId();
SimpMessageHeaderAccessor connectAck = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
connectAck.setSessionId(sessionId);
connectAck.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message);
Message<byte[]> ackMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAck.getMessageHeaders());
WebSocketSession session = this.webSocketContainer.getSession(sessionId);
this.subProtocolHandlerRegistry.findProtocolHandler(session).handleMessageToClient(session, ackMessage);
}
else if (StompCommand.CONNECTED.equals(stompCommand)) {
this.eventPublisher.publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message));
}
else if (StompCommand.RECEIPT.equals(stompCommand)) {
this.eventPublisher.publishEvent(new ReceiptEvent(this, (Message<byte[]>) message));
}
else {
headerAccessor.removeHeader(SimpMessageHeaderAccessor.NATIVE_HEADERS);
Object payload = this.messageConverter.fromMessage(message, this.payloadType.get());
sendMessage(getMessageBuilderFactory().withPayload(payload).copyHeaders(headerAccessor.toMap()).build());
}
}
else {
if (this.useBroker) {
this.brokerHandler.handleMessage(message);
}
else if (logger.isDebugEnabled()) {
logger.debug("Messages with non 'SimpMessageType.MESSAGE' type are ignored for sending to the " +
"'outputChannel'. They have to be emitted as 'ApplicationEvent's " +
"from the 'SubProtocolHandler'. Or using 'AbstractBrokerMessageHandler'(useBroker = true) " +
"from server side. Received message: " + message);
}
}
}
private boolean checkDestinationPrefix(String destination) {
if (this.useBroker) {
Collection<String> destinationPrefixes = this.brokerHandler.getDestinationPrefixes();
if ((destination == null) || CollectionUtils.isEmpty(destinationPrefixes)) {
return false;
}
for (String prefix : destinationPrefixes) {
if (destination.startsWith(prefix)) {
return true;
}
}
}
return false;
}
}