/** * Copyright (C) 2014 Red Hat, Inc, and individual contributors. */ package org.projectodd.sockjs.servlet; import java.io.IOException; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.servlet.AsyncContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.HandshakeResponse; import javax.websocket.server.HandshakeRequest; import javax.websocket.server.ServerEndpointConfig; import org.projectodd.sockjs.SockJsException; import org.projectodd.sockjs.SockJsServer; public class SockJsServlet extends HttpServlet { public SockJsServlet() { } public void setServer(SockJsServer sockJsServer) { this.sockJsServer = sockJsServer; } public SockJsServer getServer() { return sockJsServer; } @Override public void init() throws ServletException { // } private String extractPrefixFromMapping(String mapping) { if (mapping.endsWith("*")) { mapping = mapping.substring(0, mapping.length() - 1); } if (mapping.endsWith("/")) { mapping = mapping.substring(0, mapping.length() - 1); } return mapping; } private ServerEndpointConfig.Configurator configuratorFor(final String prefix, final boolean isRaw) { return new ServerEndpointConfig.Configurator() { @Override public <T> T getEndpointInstance(Class<T> endpointClass) throws InstantiationException { try { return endpointClass.getConstructor(SockJsServer.class, String.class, String.class) .newInstance(sockJsServer, getServletContext().getContextPath(), prefix); } catch (Exception e) { throw new RuntimeException(e); } } @Override public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) { if (isRaw) { // We have no reliable key (like session id) to save // headers with for raw websocket requests return; } String path = request.getRequestURI().getPath(); Matcher matcher = SESSION_PATTERN.matcher(path); if (matcher.matches()) { String sessionId = matcher.group(1); saveHeaders(sessionId, request.getHeaders()); } } }; } @Override protected void service(HttpServletRequest req, HttpServletResponse res) throws ServletException, IOException { res.setHeader("Access-Control-Allow-Origin", "true"); log.log(Level.FINE, "SockJsServlet#service for {0} {1}", new Object[]{ req.getMethod(), req.getPathInfo() }); AsyncContext asyncContext = req.startAsync(); asyncContext.setTimeout(0); // no timeout SockJsServletRequest sockJsReq = new SockJsServletRequest(req); SockJsServletResponse sockJsRes = new SockJsServletResponse(res, asyncContext); try { sockJsServer.dispatch(sockJsReq, sockJsRes); } catch (SockJsException ex) { throw new ServletException("Error during SockJS request:", ex); } if ("application/x-www-form-urlencoded".equals(req.getHeader("Content-Type"))) { // Let the servlet parse data and just pretend like we did sockJsReq.onAllDataRead(); } else if (req.isAsyncStarted()) { req.getInputStream().setReadListener(sockJsReq); } } @Override public void destroy() { super.destroy(); sockJsServer.destroy(); } static void saveHeaders(String sessionId, Map<String, List<String>> headers) { savedHeaders.put(sessionId, headers); } static Map<String, List<String>> retrieveHeaders(String sessionId) { return savedHeaders.remove(sessionId); } private SockJsServer sockJsServer; private static final Pattern SESSION_PATTERN = Pattern.compile(".*/.+/(.+)/websocket$"); private static final Logger log = Logger.getLogger(SockJsServlet.class.getName()); /** * Store a map of SockJS sessionId to header values from the upgrade request since JSR-356 gives us no way to access * this from Endpoints directly. The MAX_INFLIGHT_HEADERS and LinkedHashMap#removeEldestEntry are used to make sure * any really misbehaving clients don't cause entries to accumulate in the map. Under normal circumstances, an entry * is removed very shortly after it's added since we don't add it until the handshake process is complete and remove * it as soon as the Endpoint's onOpen gets called. */ private static final int MAX_INFLIGHT_HEADERS = 100; private static final Map<String, Map<String, List<String>>> savedHeaders = Collections.synchronizedMap(new LinkedHashMap<String, Map<String, List<String>>>() { @Override protected boolean removeEldestEntry(Map.Entry eldest) { return size() > MAX_INFLIGHT_HEADERS; } }); }