/** * GRANITE DATA SERVICES * Copyright (C) 2006-2015 GRANITE DATA SERVICES S.A.S. * * This file is part of the Granite Data Services Platform. * * Granite Data Services is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation; either * version 2.1 of the License, or (at your option) any later version. * * Granite Data Services is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser * General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, * USA, or see <http://www.gnu.org/licenses/>. */ package org.granite.gravity.tomcat; import java.lang.reflect.Field; import java.util.List; import javax.servlet.ServletConfig; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpSession; import org.apache.catalina.websocket.StreamInbound; import org.apache.catalina.websocket.WebSocketServlet; import org.apache.catalina.websocket.WsHttpServletRequestWrapper; import org.granite.context.GraniteContext; import org.granite.gravity.GravityInternal; import org.granite.gravity.GravityManager; import org.granite.gravity.GravityServletUtil; import org.granite.gravity.websocket.WebSocketUtil; import org.granite.logging.Logger; import org.granite.messaging.webapp.ServletGraniteContext; import org.granite.util.ContentType; import flex.messaging.messages.CommandMessage; import flex.messaging.messages.Message; public class TomcatWebSocketServlet extends WebSocketServlet { private static final long serialVersionUID = 1L; private static final Logger log = Logger.getLogger(TomcatWebSocketServlet.class); private static Field requestField = null; static { try { requestField = WsHttpServletRequestWrapper.class.getDeclaredField("request"); requestField.setAccessible(true); } catch (NoSuchFieldException e) { } } @Override public void init(ServletConfig config) throws ServletException { super.init(config); GravityServletUtil.init(config); } @Override protected String selectSubProtocol(List<String> subProtocols) { for (String protocol : subProtocols) { if (protocol.startsWith("org.granite.gravity")) return protocol; } return null; } @Override protected StreamInbound createWebSocketInbound(String protocol, HttpServletRequest request) { GravityInternal gravity = (GravityInternal)GravityManager.getGravity(getServletContext()); TomcatWebSocketChannelFactory channelFactory = new TomcatWebSocketChannelFactory(gravity); try { String connectMessageId = getHeaderOrParameter(request, "connectId"); String clientId = getHeaderOrParameter(request, "GDSClientId"); String clientType = getHeaderOrParameter(request, "GDSClientType"); String sessionId = null; HttpSession session = request.getSession(false); if (session != null) { ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(), getServletContext(), session, clientType); sessionId = session.getId(); } else if (request.getCookies() != null) { for (int i = 0; i < request.getCookies().length; i++) { if ("JSESSIONID".equals(request.getCookies()[i].getName())) { sessionId = request.getCookies()[i].getValue(); break; } } ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(), getServletContext(), sessionId, clientType); } else { ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(), getServletContext(), (String)null, clientType); } log.info("WebSocket connection started %s clientId %s sessionId %s", protocol, clientId, sessionId); CommandMessage pingMessage = new CommandMessage(); pingMessage.setMessageId(connectMessageId != null ? connectMessageId : "OPEN_CONNECTION"); pingMessage.setOperation(CommandMessage.CLIENT_PING_OPERATION); if (clientId != null) pingMessage.setClientId(clientId); Message ackMessage = gravity.handleMessage(channelFactory, pingMessage); if (sessionId != null) ackMessage.setHeader("JSESSIONID", sessionId); TomcatWebSocketChannel channel = gravity.getChannel(channelFactory, (String)ackMessage.getClientId()); channel.setSession(session); if (gravity.getGraniteConfig().getSecurityService() != null) { try { gravity.getGraniteConfig().getSecurityService().prelogin(session, request instanceof WsHttpServletRequestWrapper ? requestField.get(request) : request, getServletConfig().getServletName()); } catch (IllegalAccessException e) { log.warn(e, "Could not get internal request object"); } } String ctype = request.getContentType(); ContentType contentType = WebSocketUtil.getContentType(ctype, protocol); channel.setContentType(contentType); if (!ackMessage.getClientId().equals(clientId)) channel.setConnectAckMessage(ackMessage); return channel.getStreamInbound(); } finally { GraniteContext.release(); } } private static String getHeaderOrParameter(HttpServletRequest request, String key) { String value = request.getHeader(key); if (value == null) value = request.getParameter(key); return value; } // @Override // protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { // // // Information required to send the server handshake message // String key; // String subProtocol = null; // List<String> extensions = Collections.emptyList(); // // if (!headerContainsToken(req, "upgrade", "websocket")) { // resp.sendError(HttpServletResponse.SC_BAD_REQUEST); // return; // } // // if (!headerContainsToken(req, "connection", "upgrade")) { // resp.sendError(HttpServletResponse.SC_BAD_REQUEST); // return; // } // // if (!headerContainsToken(req, "sec-websocket-version", "13")) { // resp.setStatus(426); // resp.setHeader("Sec-WebSocket-Version", "13"); // return; // } // // key = req.getHeader("Sec-WebSocket-Key"); // if (key == null) { // resp.sendError(HttpServletResponse.SC_BAD_REQUEST); // return; // } // // String origin = req.getHeader("Origin"); // if (!verifyOrigin(origin)) { // resp.sendError(HttpServletResponse.SC_FORBIDDEN); // return; // } // // // Fix for Tomcat-7.0.29 bad header name (was Sec-WebSocket-Protocol-Client") // List<String> subProtocols = getTokensFromHeader(req, "Sec-WebSocket-Protocol"); // if (!subProtocols.isEmpty()) // subProtocol = selectSubProtocol(subProtocols); // // // TODO Read client handshake - Sec-WebSocket-Extensions // // // TODO Extensions require the ability to specify something (API TBD) // // that can be passed to the Tomcat internals and process extension // // data present when the frame is fragmented. // // // If we got this far, all is good. Accept the connection. // resp.setHeader("Upgrade", "websocket"); // resp.setHeader("Connection", "upgrade"); // resp.setHeader("Sec-WebSocket-Accept", getWebSocketAccept(key)); // if (subProtocol != null) // resp.setHeader("Sec-WebSocket-Protocol", subProtocol); // // if (!extensions.isEmpty()) { // // TODO // } // // WsHttpServletRequestWrapper wrapper = new WsHttpServletRequestWrapper(req); // StreamInbound inbound = createWebSocketInbound(subProtocol, wrapper); // wrapper.invalidate(); // // // Hack to avoid chunked transfer // resp.setContentLength(((TomcatWebSocketChannel.MessageInboundImpl)inbound).getAckLength()); // // // Small hack until the Servlet API provides a way to do this. // ServletRequest inner = req; // // Unwrap the request // while (inner instanceof ServletRequestWrapper) // inner = ((ServletRequestWrapper)inner).getRequest(); // // if (inner instanceof RequestFacade) // ((RequestFacade)inner).doUpgrade(inbound); // else // resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, sm.getString("servlet.reqUpgradeFail")); // } // // // private boolean headerContainsToken(HttpServletRequest req, // String headerName, String target) { // Enumeration<String> headers = req.getHeaders(headerName); // while (headers.hasMoreElements()) { // String header = headers.nextElement(); // String[] tokens = header.split(","); // for (String token : tokens) { // if (target.equalsIgnoreCase(token.trim())) { // return true; // } // } // } // return false; // } // // private List<String> getTokensFromHeader(HttpServletRequest req, // String headerName) { // List<String> result = new ArrayList<String>(); // // Enumeration<String> headers = req.getHeaders(headerName); // while (headers.hasMoreElements()) { // String header = headers.nextElement(); // String[] tokens = header.split(","); // for (String token : tokens) { // result.add(token.trim()); // } // } // return result; // } // // private String getWebSocketAccept(String key) throws ServletException { // // MessageDigest sha1Helper = sha1Helpers.poll(); // if (sha1Helper == null) { // try { // sha1Helper = MessageDigest.getInstance("SHA1"); // } catch (NoSuchAlgorithmException e) { // throw new ServletException(e); // } // } // // sha1Helper.reset(); // sha1Helper.update(key.getBytes(B2CConverter.ISO_8859_1)); // String result = Base64.encode(sha1Helper.digest(WS_ACCEPT)); // // sha1Helpers.add(sha1Helper); // // return result; // } }