/*
* Copyright 2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.bearchoke.platform.tests.web.websocket.support.client;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.List;
public class WebSocketStompClient implements StompClient {
private static Log logger = LogFactory.getLog(WebSocketStompClient.class);
private final URI uri;
private final WebSocketHttpHeaders headers;
private final WebSocketClient webSocketClient;
private MessageConverter messageConverter;
public WebSocketStompClient(URI uri, WebSocketHttpHeaders headers, WebSocketClient webSocketClient) {
this.uri = uri;
this.headers = headers;
this.webSocketClient = webSocketClient;
}
public void setMessageConverter(MessageConverter messageConverter) {
this.messageConverter = messageConverter;
}
@Override
public void connect(StompMessageHandler stompMessageHandler) {
try {
StompWebSocketHandler webSocketHandler = new StompWebSocketHandler(stompMessageHandler, this.messageConverter);
this.webSocketClient.doHandshake(webSocketHandler, this.headers, this.uri).get();
}
catch (Exception e) {
throw new IllegalStateException(e);
}
}
private static class StompWebSocketHandler extends AbstractWebSocketHandler {
private static final Charset UTF_8 = Charset.forName("UTF-8");
private final StompMessageHandler stompMessageHandler;
private final MessageConverter messageConverter;
private final StompEncoder encoder = new StompEncoder();
private final StompDecoder decoder = new StompDecoder();
private StompWebSocketHandler(StompMessageHandler delegate) {
this(delegate, new MappingJackson2MessageConverter());
}
private StompWebSocketHandler(StompMessageHandler delegate, MessageConverter messageConverter) {
this.stompMessageHandler = delegate;
this.messageConverter = messageConverter;
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setAcceptVersion("1.1,1.2");
headers.setHeartbeat(0, 0);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
TextMessage textMessage = new TextMessage(new String(this.encoder.encode(message), UTF_8));
session.sendMessage(textMessage);
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) throws Exception {
ByteBuffer payload = ByteBuffer.wrap(textMessage.getPayload().getBytes(UTF_8));
List<Message<byte[]>> messages = this.decoder.decode(payload);
for (Message message : messages) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (StompCommand.CONNECTED.equals(headers.getCommand())) {
WebSocketStompSession stompSession = new WebSocketStompSession(session, this.messageConverter);
this.stompMessageHandler.afterConnected(stompSession, headers);
}
else if (StompCommand.MESSAGE.equals(headers.getCommand())) {
this.stompMessageHandler.handleMessage(message);
}
else if (StompCommand.RECEIPT.equals(headers.getCommand())) {
this.stompMessageHandler.handleReceipt(headers.getReceiptId());
}
else if (StompCommand.ERROR.equals(headers.getCommand())) {
this.stompMessageHandler.handleError(message);
}
else if (StompCommand.ERROR.equals(headers.getCommand())) {
this.stompMessageHandler.afterDisconnected();
}
else {
logger.debug("Unhandled message " + message);
}
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
logger.error("WebSocket transport error", exception);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
this.stompMessageHandler.afterDisconnected();
}
}
}