/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.web.socket.sockjs.client; import java.net.URI; import java.security.Principal; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.context.Lifecycle; import org.springframework.http.HttpHeaders; import org.springframework.scheduling.TaskScheduler; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.util.concurrent.SettableListenableFuture; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.client.WebSocketClient; import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.util.UriComponentsBuilder; /** * A SockJS implementation of * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient} * with fallback alternatives that simulate a WebSocket interaction through plain * HTTP streaming and long polling techniques.. * * <p>Implements {@link Lifecycle} in order to propagate lifecycle events to * the transports it is configured with. * * @author Rossen Stoyanchev * @since 4.1 * @see <a href="http://sockjs.org">http://sockjs.org</a> * @see org.springframework.web.socket.sockjs.client.Transport */ public class SockJsClient implements WebSocketClient, Lifecycle { private static final boolean jackson2Present = ClassUtils.isPresent( "com.fasterxml.jackson.databind.ObjectMapper", SockJsClient.class.getClassLoader()); private static final Log logger = LogFactory.getLog(SockJsClient.class); private static final Set<String> supportedProtocols = new HashSet<>(4); static { supportedProtocols.add("ws"); supportedProtocols.add("wss"); supportedProtocols.add("http"); supportedProtocols.add("https"); } private final List<Transport> transports; private String[] httpHeaderNames; private InfoReceiver infoReceiver; private SockJsMessageCodec messageCodec; private TaskScheduler connectTimeoutScheduler; private volatile boolean running = false; private final Map<URI, ServerInfo> serverInfoCache = new ConcurrentHashMap<>(); /** * Create a {@code SockJsClient} with the given transports. * <p>If the list includes an {@link XhrTransport} (or more specifically an * implementation of {@link InfoReceiver}) the instance is used to initialize * the {@link #setInfoReceiver(InfoReceiver) infoReceiver} property, or * otherwise is defaulted to {@link RestTemplateXhrTransport}. * @param transports the (non-empty) list of transports to use */ public SockJsClient(List<Transport> transports) { Assert.notEmpty(transports, "No transports provided"); this.transports = new ArrayList<>(transports); this.infoReceiver = initInfoReceiver(transports); if (jackson2Present) { this.messageCodec = new Jackson2SockJsMessageCodec(); } } private static InfoReceiver initInfoReceiver(List<Transport> transports) { for (Transport transport : transports) { if (transport instanceof InfoReceiver) { return ((InfoReceiver) transport); } } return new RestTemplateXhrTransport(); } /** * The names of HTTP headers that should be copied from the handshake headers * of each call to {@link SockJsClient#doHandshake(WebSocketHandler, WebSocketHttpHeaders, URI)} * and also used with other HTTP requests issued as part of that SockJS * connection, e.g. the initial info request, XHR send or receive requests. * * <p>By default if this property is not set, all handshake headers are also * used for other HTTP requests. Set it if you want only a subset of handshake * headers (e.g. auth headers) to be used for other HTTP requests. * * @param httpHeaderNames HTTP header names */ public void setHttpHeaderNames(String... httpHeaderNames) { this.httpHeaderNames = httpHeaderNames; } /** * The configured HTTP header names to be copied from the handshake * headers and also included in other HTTP requests. */ public String[] getHttpHeaderNames() { return this.httpHeaderNames; } /** * Configure the {@code InfoReceiver} to use to perform the SockJS "Info" * request before the SockJS session starts. * <p>If the list of transports provided to the constructor contained an * {@link XhrTransport} or an implementation of {@link InfoReceiver} that * instance would have been used to initialize this property, or otherwise * it defaults to {@link RestTemplateXhrTransport}. * @param infoReceiver the transport to use for the SockJS "Info" request */ public void setInfoReceiver(InfoReceiver infoReceiver) { Assert.notNull(infoReceiver, "InfoReceiver is required"); this.infoReceiver = infoReceiver; } /** * Return the configured {@code InfoReceiver} (never {@code null}). */ public InfoReceiver getInfoReceiver() { return this.infoReceiver; } /** * Set the SockJsMessageCodec to use. * <p>By default {@link org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec * Jackson2SockJsMessageCodec} is used if Jackson is on the classpath. */ public void setMessageCodec(SockJsMessageCodec messageCodec) { Assert.notNull(messageCodec, "'messageCodec' is required"); this.messageCodec = messageCodec; } /** * Return the SockJsMessageCodec to use. */ public SockJsMessageCodec getMessageCodec() { return this.messageCodec; } /** * Configure a {@code TaskScheduler} for scheduling a connect timeout task * where the timeout value is calculated based on the duration of the initial * SockJS "Info" request. The connect timeout task ensures a more timely * fallback but is otherwise entirely optional. * <p>By default this is not configured in which case a fallback may take longer. * @param connectTimeoutScheduler the task scheduler to use */ public void setConnectTimeoutScheduler(TaskScheduler connectTimeoutScheduler) { this.connectTimeoutScheduler = connectTimeoutScheduler; } @Override public void start() { if (!isRunning()) { this.running = true; for (Transport transport : this.transports) { if (transport instanceof Lifecycle) { if (!((Lifecycle) transport).isRunning()) { ((Lifecycle) transport).start(); } } } } } @Override public void stop() { if (isRunning()) { this.running = false; for (Transport transport : this.transports) { if (transport instanceof Lifecycle) { if (((Lifecycle) transport).isRunning()) { ((Lifecycle) transport).stop(); } } } } } @Override public boolean isRunning() { return this.running; } @Override public ListenableFuture<WebSocketSession> doHandshake( WebSocketHandler handler, String uriTemplate, Object... uriVars) { Assert.notNull(uriTemplate, "uriTemplate must not be null"); URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode().toUri(); return doHandshake(handler, null, uri); } @Override public final ListenableFuture<WebSocketSession> doHandshake( WebSocketHandler handler, WebSocketHttpHeaders headers, URI url) { Assert.notNull(handler, "WebSocketHandler is required"); Assert.notNull(url, "URL is required"); String scheme = url.getScheme(); if (!supportedProtocols.contains(scheme)) { throw new IllegalArgumentException("Invalid scheme: '" + scheme + "'"); } SettableListenableFuture<WebSocketSession> connectFuture = new SettableListenableFuture<>(); try { SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url); ServerInfo serverInfo = getServerInfo(sockJsUrlInfo, getHttpRequestHeaders(headers)); createRequest(sockJsUrlInfo, headers, serverInfo).connect(handler, connectFuture); } catch (Throwable exception) { if (logger.isErrorEnabled()) { logger.error("Initial SockJS \"Info\" request to server failed, url=" + url, exception); } connectFuture.setException(exception); } return connectFuture; } private HttpHeaders getHttpRequestHeaders(HttpHeaders webSocketHttpHeaders) { if (getHttpHeaderNames() == null) { return webSocketHttpHeaders; } else { HttpHeaders httpHeaders = new HttpHeaders(); for (String name : getHttpHeaderNames()) { if (webSocketHttpHeaders.containsKey(name)) { httpHeaders.put(name, webSocketHttpHeaders.get(name)); } } return httpHeaders; } } private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo, HttpHeaders headers) { URI infoUrl = sockJsUrlInfo.getInfoUrl(); ServerInfo info = this.serverInfoCache.get(infoUrl); if (info == null) { long start = System.currentTimeMillis(); String response = this.infoReceiver.executeInfoRequest(infoUrl, headers); long infoRequestTime = System.currentTimeMillis() - start; info = new ServerInfo(response, infoRequestTime); this.serverInfoCache.put(infoUrl, info); } return info; } private DefaultTransportRequest createRequest(SockJsUrlInfo urlInfo, HttpHeaders headers, ServerInfo serverInfo) { List<DefaultTransportRequest> requests = new ArrayList<>(this.transports.size()); for (Transport transport : this.transports) { for (TransportType type : transport.getTransportTypes()) { if (serverInfo.isWebSocketEnabled() || !TransportType.WEBSOCKET.equals(type)) { requests.add(new DefaultTransportRequest(urlInfo, headers, getHttpRequestHeaders(headers), transport, type, getMessageCodec())); } } } if (CollectionUtils.isEmpty(requests)) { throw new IllegalStateException( "No transports: " + urlInfo + ", webSocketEnabled=" + serverInfo.isWebSocketEnabled()); } for (int i = 0; i < requests.size() - 1; i++) { DefaultTransportRequest request = requests.get(i); request.setUser(getUser()); if (this.connectTimeoutScheduler != null) { request.setTimeoutValue(serverInfo.getRetransmissionTimeout()); request.setTimeoutScheduler(this.connectTimeoutScheduler); } request.setFallbackRequest(requests.get(i + 1)); } return requests.get(0); } /** * Return the user to associate with the SockJS session and make available via * {@link org.springframework.web.socket.WebSocketSession#getPrincipal()}. * <p>By default this method returns {@code null}. * @return the user to associate with the session (possibly {@code null}) */ protected Principal getUser() { return null; } /** * By default the result of a SockJS "Info" request, including whether the * server has WebSocket disabled and how long the request took (used for * calculating transport timeout time) is cached. This method can be used to * clear that cache hence causing it to re-populate. */ public void clearServerInfoCache() { this.serverInfoCache.clear(); } /** * A simple value object holding the result from a SockJS "Info" request. */ private static class ServerInfo { private final boolean webSocketEnabled; private final long responseTime; public ServerInfo(String response, long responseTime) { this.responseTime = responseTime; this.webSocketEnabled = !response.matches(".*[\"']websocket[\"']\\s*:\\s*false.*"); } public boolean isWebSocketEnabled() { return this.webSocketEnabled; } public long getRetransmissionTimeout() { return (this.responseTime > 100 ? 4 * this.responseTime : this.responseTime + 300); } } }