/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.activemq.transport.ws.jetty9; import java.io.IOException; import java.net.URI; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.activemq.broker.BrokerService; import org.apache.activemq.broker.BrokerServiceAware; import org.apache.activemq.transport.Transport; import org.apache.activemq.transport.TransportAcceptListener; import org.apache.activemq.transport.TransportFactory; import org.apache.activemq.transport.util.HttpTransportUtils; import org.apache.activemq.transport.ws.WSTransportProxy; import org.eclipse.jetty.websocket.api.WebSocketListener; import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; import org.eclipse.jetty.websocket.servlet.WebSocketCreator; import org.eclipse.jetty.websocket.servlet.WebSocketServlet; import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; /** * Handle connection upgrade requests and creates web sockets */ public class WSServlet extends WebSocketServlet implements BrokerServiceAware { private static final long serialVersionUID = -4716657876092884139L; private TransportAcceptListener listener; private final static Map<String, Integer> stompProtocols = new ConcurrentHashMap<>(); private final static Map<String, Integer> mqttProtocols = new ConcurrentHashMap<>(); private Map<String, Object> transportOptions; private BrokerService brokerService; private enum Protocol { MQTT, STOMP, UNKNOWN } static { stompProtocols.put("v12.stomp", 3); stompProtocols.put("v11.stomp", 2); stompProtocols.put("v10.stomp", 1); stompProtocols.put("stomp", 0); mqttProtocols.put("mqttv3.1", 1); mqttProtocols.put("mqtt", 0); } @Override public void init() throws ServletException { super.init(); listener = (TransportAcceptListener) getServletContext().getAttribute("acceptListener"); if (listener == null) { throw new ServletException("No such attribute 'acceptListener' available in the ServletContext"); } } @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { //return empty response - AMQ-6491 } @Override public void configure(WebSocketServletFactory factory) { factory.setCreator(new WebSocketCreator() { @Override public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) { WebSocketListener socket; Protocol requestedProtocol = Protocol.UNKNOWN; // When no sub-protocol is requested we default to STOMP for legacy reasons. if (!req.getSubProtocols().isEmpty()) { for (String subProtocol : req.getSubProtocols()) { if (subProtocol.startsWith("mqtt")) { requestedProtocol = Protocol.MQTT; } else if (subProtocol.contains("stomp")) { requestedProtocol = Protocol.STOMP; } } } else { requestedProtocol = Protocol.STOMP; } switch (requestedProtocol) { case MQTT: socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); ((MQTTSocket) socket).setTransportOptions(new HashMap<>(transportOptions)); ((MQTTSocket) socket).setPeerCertificates(req.getCertificates()); resp.setAcceptedSubProtocol(getAcceptedSubProtocol(mqttProtocols, req.getSubProtocols(), "mqtt")); break; case UNKNOWN: socket = findWSTransport(req, resp); if (socket != null) { break; } case STOMP: socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); ((StompSocket) socket).setPeerCertificates(req.getCertificates()); resp.setAcceptedSubProtocol(getAcceptedSubProtocol(stompProtocols, req.getSubProtocols(), "stomp")); break; default: socket = null; listener.onAcceptError(new IOException("Unknown protocol requested")); break; } if (socket != null) { listener.onAccept((Transport) socket); } return socket; } }); } private WebSocketListener findWSTransport(ServletUpgradeRequest request, ServletUpgradeResponse response) { WSTransportProxy proxy = null; for (String subProtocol : request.getSubProtocols()) { try { String remoteAddress = HttpTransportUtils.generateWsRemoteAddress(request.getHttpServletRequest(), subProtocol); URI remoteURI = new URI(remoteAddress); TransportFactory factory = TransportFactory.findTransportFactory(remoteURI); if (factory instanceof BrokerServiceAware) { ((BrokerServiceAware) factory).setBrokerService(brokerService); } Transport transport = factory.doConnect(remoteURI); proxy = new WSTransportProxy(remoteAddress, transport); proxy.setPeerCertificates(request.getCertificates()); proxy.setTransportOptions(transportOptions); response.setAcceptedSubProtocol(proxy.getSubProtocol()); } catch (Exception e) { proxy = null; // Keep going and try any other sub-protocols present. continue; } } return proxy; } private String getAcceptedSubProtocol(final Map<String, Integer> protocols, List<String> subProtocols, String defaultProtocol) { List<SubProtocol> matchedProtocols = new ArrayList<>(); if (subProtocols != null && subProtocols.size() > 0) { // detect which subprotocols match accepted protocols and add to the // list for (String subProtocol : subProtocols) { Integer priority = protocols.get(subProtocol); if (subProtocol != null && priority != null) { // only insert if both subProtocol and priority are not null matchedProtocols.add(new SubProtocol(subProtocol, priority)); } } // sort the list by priority if (matchedProtocols.size() > 0) { Collections.sort(matchedProtocols, new Comparator<SubProtocol>() { @Override public int compare(SubProtocol s1, SubProtocol s2) { return s2.priority.compareTo(s1.priority); } }); return matchedProtocols.get(0).protocol; } } return defaultProtocol; } private class SubProtocol { private String protocol; private Integer priority; public SubProtocol(String protocol, Integer priority) { this.protocol = protocol; this.priority = priority; } } public void setTransportOptions(Map<String, Object> transportOptions) { this.transportOptions = transportOptions; } @Override public void setBrokerService(BrokerService brokerService) { this.brokerService = brokerService; } }