/** * Copyright (C) 2014 Red Hat, Inc, and individual contributors. */ package org.projectodd.sockjs.servlet; import org.projectodd.sockjs.SockJsException; import org.projectodd.sockjs.SockJsServer; import javax.servlet.AsyncContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.DeploymentException; import javax.websocket.HandshakeResponse; import javax.websocket.server.HandshakeRequest; import javax.websocket.server.ServerContainer; import javax.websocket.server.ServerEndpointConfig; import java.io.IOException; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import java.util.regex.Matcher; import java.util.regex.Pattern; public class SockJsServlet extends HttpServlet { public SockJsServlet() { } public SockJsServlet(SockJsServer sockJsServer) { this.sockJsServer = sockJsServer; } public void setServer(SockJsServer sockJsServer) { this.sockJsServer = sockJsServer; } public SockJsServer getServer() { return sockJsServer; } @Override public void init() throws ServletException { if (sockJsServer == null) { sockJsServer = new SockJsServer(); } sockJsServer.init(); if (sockJsServer.options.websocket) { // Make sure we listen on all possible mappings of the servlet for (String mapping : getServletContext().getServletRegistration(getServletName()).getMappings()) { final String commonPrefix = extractPrefixFromMapping(mapping); String websocketPath = commonPrefix + "/{server}/{session}/websocket"; ServerEndpointConfig sockJsConfig = ServerEndpointConfig.Builder .create(SockJsEndpoint.class, websocketPath) .configurator(configuratorFor(commonPrefix, false)) .build(); String rawWebsocketPath = commonPrefix + "/websocket"; ServerEndpointConfig rawWsConfig = ServerEndpointConfig.Builder .create(RawWebsocketEndpoint.class, rawWebsocketPath) .configurator(configuratorFor(commonPrefix, true)) .build(); ServerContainer serverContainer = (ServerContainer) getServletContext().getAttribute("javax.websocket.server.ServerContainer"); try { serverContainer.addEndpoint(sockJsConfig); serverContainer.addEndpoint(rawWsConfig); } catch (DeploymentException ex) { throw new ServletException("Error deploying websocket endpoint:", ex); } } } } private String extractPrefixFromMapping(String mapping) { if (mapping.endsWith("*")) { mapping = mapping.substring(0, mapping.length() - 1); } if (mapping.endsWith("/")) { mapping = mapping.substring(0, mapping.length() - 1); } return mapping; } private ServerEndpointConfig.Configurator configuratorFor(final String prefix, final boolean isRaw) { return new ServerEndpointConfig.Configurator() { @Override public <T> T getEndpointInstance(Class<T> endpointClass) throws InstantiationException { try { return endpointClass.getConstructor(SockJsServer.class, String.class, String.class) .newInstance(sockJsServer, getServletContext().getContextPath(), prefix); } catch (Exception e) { throw new RuntimeException(e); } } @Override public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) { if (isRaw) { // We have no reliable key (like session id) to save // headers with for raw websocket requests return; } String path = request.getRequestURI().getPath(); Matcher matcher = SESSION_PATTERN.matcher(path); if (matcher.matches()) { String sessionId = matcher.group(1); saveHeaders(sessionId, request.getHeaders()); } } }; } @Override protected void service(HttpServletRequest req, HttpServletResponse res) throws ServletException, IOException { log.log(Level.FINE, "SockJsServlet#service for {0} {1}", new Object[] {req.getMethod(), req.getPathInfo()}); AsyncContext asyncContext = req.startAsync(); asyncContext.setTimeout(0); // no timeout SockJsServletRequest sockJsReq = new SockJsServletRequest(req); SockJsServletResponse sockJsRes = new SockJsServletResponse(res, asyncContext); try { sockJsServer.dispatch(sockJsReq, sockJsRes); } catch (SockJsException ex) { throw new ServletException("Error during SockJS request:", ex); } if ("application/x-www-form-urlencoded".equals(req.getHeader("Content-Type"))) { // Let the servlet parse data and just pretend like we did sockJsReq.onAllDataRead(); } else if (req.isAsyncStarted()) { req.getInputStream().setReadListener(sockJsReq); } } @Override public void destroy() { super.destroy(); sockJsServer.destroy(); } static void saveHeaders(String sessionId, Map<String, List<String>> headers) { savedHeaders.put(sessionId, headers); } static Map<String, List<String>> retrieveHeaders(String sessionId) { return savedHeaders.remove(sessionId); } private SockJsServer sockJsServer; private static final Pattern SESSION_PATTERN = Pattern.compile(".*/.+/(.+)/websocket$"); private static final Logger log = Logger.getLogger(SockJsServlet.class.getName()); /** * Store a map of SockJS sessionId to header values from the upgrade * request since JSR-356 gives us no way to access this from Endpoints * directly. The MAX_INFLIGHT_HEADERS and LinkedHashMap#removeEldestEntry * are used to make sure any really misbehaving clients don't cause * entries to accumulate in the map. Under normal circumstances, an entry * is removed very shortly after it's added since we don't add it until * the handshake process is complete and remove it as soon as the * Endpoint's onOpen gets called. */ private static final int MAX_INFLIGHT_HEADERS = 100; private static final Map<String, Map<String, List<String>>> savedHeaders = Collections.synchronizedMap(new LinkedHashMap<String,Map<String, List<String>>>() { @Override protected boolean removeEldestEntry(Map.Entry eldest) { return size() > MAX_INFLIGHT_HEADERS; } }); }