/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.transport.websocket.undertow;
import java.io.IOException;
import java.net.URL;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.servlet.ServletConfig;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.cxf.Bus;
import org.apache.cxf.common.logging.LogUtils;
import org.apache.cxf.common.util.StringUtils;
import org.apache.cxf.service.model.EndpointInfo;
import org.apache.cxf.transport.http.DestinationRegistry;
import org.apache.cxf.transport.http_undertow.UndertowHTTPDestination;
import org.apache.cxf.transport.http_undertow.UndertowHTTPHandler;
import org.apache.cxf.transport.http_undertow.UndertowHTTPServerEngineFactory;
import org.apache.cxf.transport.websocket.WebSocketConstants;
import org.apache.cxf.transport.websocket.WebSocketDestinationService;
import org.apache.cxf.workqueue.WorkQueueManager;
import org.xnio.StreamConnection;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.HttpUpgradeListener;
import io.undertow.servlet.handlers.ServletRequestContext;
import io.undertow.servlet.spec.HttpServletRequestImpl;
import io.undertow.servlet.spec.HttpServletResponseImpl;
import io.undertow.servlet.spec.ServletContextImpl;
import io.undertow.util.Methods;
import io.undertow.websockets.core.AbstractReceiveListener;
import io.undertow.websockets.core.BufferedBinaryMessage;
import io.undertow.websockets.core.BufferedTextMessage;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.protocol.Handshake;
import io.undertow.websockets.core.protocol.version07.Hybi07Handshake;
import io.undertow.websockets.core.protocol.version08.Hybi08Handshake;
import io.undertow.websockets.core.protocol.version13.Hybi13Handshake;
import io.undertow.websockets.spi.AsyncWebSocketHttpServerExchange;
/**
*
*/
public class UndertowWebSocketDestination extends UndertowHTTPDestination
implements WebSocketDestinationService {
private static final Logger LOG = LogUtils.getL7dLogger(UndertowWebSocketDestination.class);
private final Executor executor;
public UndertowWebSocketDestination(Bus bus, DestinationRegistry registry, EndpointInfo ei,
UndertowHTTPServerEngineFactory serverEngineFactory)
throws IOException {
super(bus, registry, ei, serverEngineFactory);
executor = bus.getExtension(WorkQueueManager.class).getAutomaticWorkQueue();
}
@Override
public void invokeInternal(ServletConfig config, ServletContext context, HttpServletRequest req,
HttpServletResponse resp) throws IOException {
super.invoke(config, context, req, resp);
}
private static String getNonWSAddress(EndpointInfo endpointInfo) {
String address = endpointInfo.getAddress();
if (address.startsWith("ws")) {
address = "http" + address.substring(2);
}
return address;
}
@Override
protected String getAddress(EndpointInfo endpointInfo) {
return getNonWSAddress(endpointInfo);
}
@Override
protected String getBasePath(String contextPath) throws IOException {
if (StringUtils.isEmpty(endpointInfo.getAddress())) {
return "";
}
return new URL(getAddress(endpointInfo)).getPath();
}
@Override
protected UndertowHTTPHandler createUndertowHTTPHandler(UndertowHTTPDestination jhd, boolean cmExact) {
return new AtmosphereUndertowWebSocketHandler(jhd, cmExact);
}
private class AtmosphereUndertowWebSocketHandler extends UndertowHTTPHandler {
private final Set<Handshake> handshakes;
private final Set<WebSocketChannel> peerConnections = Collections
.newSetFromMap(new ConcurrentHashMap<WebSocketChannel, Boolean>());
AtmosphereUndertowWebSocketHandler(UndertowHTTPDestination jhd, boolean cmExact) {
super(jhd, cmExact);
handshakes = new HashSet<>();
handshakes.add(new Hybi13Handshake());
handshakes.add(new Hybi08Handshake());
handshakes.add(new Hybi07Handshake());
}
@Override
public void handleRequest(HttpServerExchange undertowExchange) throws Exception {
if (undertowExchange.isInIoThread()) {
undertowExchange.dispatch(this);
return;
}
if (!undertowExchange.getRequestMethod().equals(Methods.GET)) {
// Only GET is supported to start the handshake
handleNormalRequest(undertowExchange);
return;
}
final AsyncWebSocketHttpServerExchange facade = new AsyncWebSocketHttpServerExchange(undertowExchange,
peerConnections);
Handshake handshaker = null;
for (Handshake method : handshakes) {
if (method.matches(facade)) {
handshaker = method;
break;
}
}
if (handshaker == null) {
handleNormalRequest(undertowExchange);
} else {
final Handshake selected = handshaker;
undertowExchange.upgradeChannel(new HttpUpgradeListener() {
@Override
public void handleUpgrade(StreamConnection streamConnection,
HttpServerExchange exchange) {
try {
WebSocketChannel channel = selected.createChannel(facade, streamConnection,
facade.getBufferPool());
peerConnections.add(channel);
channel.getReceiveSetter().set(new AbstractReceiveListener() {
@Override
protected void onFullTextMessage(WebSocketChannel channel,
BufferedTextMessage message) {
handleReceivedMessage(channel, message, exchange);
}
protected void onFullBinaryMessage(WebSocketChannel channel,
BufferedBinaryMessage message)
throws IOException {
handleReceivedMessage(channel, message, exchange);
}
});
channel.resumeReceives();
} catch (Exception e) {
LOG.log(Level.WARNING, "Failed to invoke service", e);
}
}
});
handshaker.handshake(facade);
}
}
public void handleNormalRequest(HttpServerExchange undertowExchange) throws Exception {
HttpServletResponseImpl response = new HttpServletResponseImpl(undertowExchange,
(ServletContextImpl)servletContext);
HttpServletRequestImpl request = new HttpServletRequestImpl(undertowExchange,
(ServletContextImpl)servletContext);
ServletRequestContext servletRequestContext = new ServletRequestContext(((ServletContextImpl)servletContext)
.getDeployment(), request, response, null);
undertowExchange.putAttachment(ServletRequestContext.ATTACHMENT_KEY, servletRequestContext);
doService(request, response);
}
public void handleNormalRequest(HttpServletRequest request, HttpServletResponse response)
throws Exception {
doService(request, response);
}
private void handleReceivedMessage(WebSocketChannel channel, Object message, HttpServerExchange exchange) {
executor.execute(new Runnable() {
@Override
public void run() {
try {
HttpServletRequest request = new WebSocketUndertowServletRequest(channel, message, exchange);
HttpServletResponse response = new WebSocketUndertowServletResponse(channel);
if (request.getHeader(WebSocketConstants.DEFAULT_REQUEST_ID_KEY) != null) {
response.setHeader(WebSocketConstants.DEFAULT_RESPONSE_ID_KEY,
request.getHeader(WebSocketConstants.DEFAULT_REQUEST_ID_KEY));
}
handleNormalRequest(request, response);
} catch (Exception ex) {
ex.printStackTrace();
}
}
});
}
}
}