/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.transport.websocket.ahc;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.MalformedURLException;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import com.ning.http.client.AsyncHttpClient;
import com.ning.http.client.ws.WebSocket;
import com.ning.http.client.ws.WebSocketByteListener;
import com.ning.http.client.ws.WebSocketTextListener;
import com.ning.http.client.ws.WebSocketUpgradeHandler;
import org.apache.cxf.Bus;
import org.apache.cxf.common.logging.LogUtils;
import org.apache.cxf.message.Message;
import org.apache.cxf.service.model.EndpointInfo;
import org.apache.cxf.transport.http.Address;
import org.apache.cxf.transport.http.Headers;
import org.apache.cxf.transport.http.URLConnectionHTTPConduit;
import org.apache.cxf.transport.https.HttpsURLConnectionInfo;
import org.apache.cxf.transport.websocket.WebSocketConstants;
import org.apache.cxf.transport.websocket.WebSocketUtils;
import org.apache.cxf.transports.http.configuration.HTTPClientPolicy;
import org.apache.cxf.ws.addressing.EndpointReferenceType;
/**
*
*/
public class AhcWebSocketConduit extends URLConnectionHTTPConduit {
private static final Logger LOG = LogUtils.getL7dLogger(AhcWebSocketConduit.class);
private AsyncHttpClient ahcclient;
private WebSocket websocket;
//REVISIT make these keys configurable
private String requestIdKey = WebSocketConstants.DEFAULT_REQUEST_ID_KEY;
private String responseIdKey = WebSocketConstants.DEFAULT_RESPONSE_ID_KEY;
private Map<String, RequestResponse> uncorrelatedRequests = new ConcurrentHashMap<String, RequestResponse>();
public AhcWebSocketConduit(Bus b, EndpointInfo ei, EndpointReferenceType t) throws IOException {
super(b, ei, t);
ahcclient = new AsyncHttpClient();
}
@Override
protected void setupConnection(Message message, Address address, HTTPClientPolicy csPolicy)
throws IOException {
URI currentURL = address.getURI();
String s = currentURL.getScheme();
if (!"ws".equals(s) && !"wss".equals(s)) {
throw new MalformedURLException("unknown protocol: " + s);
}
message.put("http.scheme", currentURL.getScheme());
String httpRequestMethod =
(String)message.get(Message.HTTP_REQUEST_METHOD);
if (httpRequestMethod == null) {
httpRequestMethod = "POST";
message.put(Message.HTTP_REQUEST_METHOD, httpRequestMethod);
}
final AhcWebSocketConduitRequest request = new AhcWebSocketConduitRequest(currentURL, httpRequestMethod);
final int rtimeout = determineReceiveTimeout(message, csPolicy);
request.setReceiveTimeout(rtimeout);
message.put(AhcWebSocketConduitRequest.class, request);
}
@Override
protected OutputStream createOutputStream(Message message, boolean needToCacheRequest,
boolean isChunking, int chunkThreshold) throws IOException {
AhcWebSocketConduitRequest entity = message.get(AhcWebSocketConduitRequest.class);
return new AhcWebSocketWrappedOutputStream(message, needToCacheRequest, isChunking, chunkThreshold,
getConduitName(), entity.getUri());
}
public class AhcWebSocketWrappedOutputStream extends WrappedOutputStream {
private AhcWebSocketConduitRequest entity;
private Response response;
protected AhcWebSocketWrappedOutputStream(Message message, boolean possibleRetransmit,
boolean isChunking, int chunkThreshold, String conduitName, URI url) {
super(message, possibleRetransmit, isChunking, chunkThreshold, conduitName, url);
entity = message.get(AhcWebSocketConduitRequest.class);
//REVISIT how we prepare the request
String requri = (String)message.getContextualProperty("org.apache.cxf.request.uri");
if (requri != null) {
// jaxrs speicfies a sub-path using prop org.apache.cxf.request.uri
if (requri.startsWith("ws")) {
entity.setPath(requri.substring(requri.indexOf('/', 3 + requri.indexOf(':'))));
} else {
entity.setPath(url.getPath() + requri);
}
} else {
// jaxws
entity.setPath(url.getPath());
}
entity.setId(UUID.randomUUID().toString());
uncorrelatedRequests.put(entity.getId(), new RequestResponse(entity));
}
@Override
protected void setupWrappedStream() throws IOException {
connect();
wrappedStream = new OutputStream() {
@Override
public void write(byte b[], int off, int len) throws IOException {
//REVISIT support multiple writes and flush() to write the entire block data?
// or provides the fragment mode?
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", entity.getContentType());
headers.put(requestIdKey, entity.getId());
websocket.sendMessage(WebSocketUtils.buildRequest(
entity.getMethod(), entity.getPath(),
headers,
b, off, len));
}
@Override
public void write(int b) throws IOException {
//REVISIT support this single byte write and use flush() to write the block data?
}
@Override
public void close() throws IOException {
}
};
}
@Override
protected void handleNoOutput() throws IOException {
connect();
Map<String, String> headers = new HashMap<>();
headers.put(requestIdKey, entity.getId());
websocket.sendMessage(WebSocketUtils.buildRequest(
entity.getMethod(), entity.getPath(),
headers,
null, 0, 0));
}
@Override
protected HttpsURLConnectionInfo getHttpsURLConnectionInfo() throws IOException {
return null;
}
@Override
protected void setProtocolHeaders() throws IOException {
Headers h = new Headers(outMessage);
entity.setContentType(h.determineContentType());
//REVISIT may provide an option to add other headers
// boolean addHeaders = MessageUtils.isTrue(outMessage.getContextualProperty(Headers.ADD_HEADERS_PROPERTY));
}
@Override
protected void setFixedLengthStreamingMode(int i) {
// ignore
}
@Override
protected int getResponseCode() throws IOException {
Response r = getResponse();
return r.getStatusCode();
}
@Override
protected String getResponseMessage() throws IOException {
//TODO return a generic message based on the status code
return null;
}
@Override
protected void updateResponseHeaders(Message inMessage) throws IOException {
Headers h = new Headers(inMessage);
String ct = getResponse().getContentType();
inMessage.put(Message.CONTENT_TYPE, ct);
//REVISIT if we are allowing more headers, we need to add them into the cxf's headers
h.headerMap().put(Message.CONTENT_TYPE, Collections.singletonList(ct));
}
@Override
protected void handleResponseAsync() throws IOException {
handleResponseOnWorkqueue(true, false);
}
@Override
protected void closeInputStream() throws IOException {
}
@Override
protected boolean usingProxy() {
// TODO add proxy support ...
return false;
}
@Override
protected InputStream getInputStream() throws IOException {
Response r = getResponse();
//REVISIT
return new java.io.ByteArrayInputStream(r.getTextEntity().getBytes());
}
@Override
protected InputStream getPartialResponse() throws IOException {
Response r = getResponse();
//REVISIT
return new java.io.ByteArrayInputStream(r.getTextEntity().getBytes());
}
@Override
protected void setupNewConnection(String newURL) throws IOException {
// TODO
throw new IOException("not supported");
}
@Override
protected void retransmitStream() throws IOException {
// TODO
throw new IOException("not supported");
}
@Override
protected void updateCookiesBeforeRetransmit() throws IOException {
// ignore for now and may consider a specific websocket binding variant to use cookies
}
@Override
public void thresholdReached() throws IOException {
// ignore
}
//
// other methods follow
//
protected void connect() {
LOG.log(Level.FINE, "connecting");
if (websocket == null) {
try {
websocket = ahcclient.prepareGet(url.toASCIIString()).execute(
new WebSocketUpgradeHandler.Builder()
.addWebSocketListener(new AhcWebSocketListener()).build()).get();
LOG.log(Level.FINE, "connected");
} catch (Exception e) {
LOG.log(Level.SEVERE, "unable to connect", e);
}
} else {
LOG.log(Level.FINE, "already connected");
}
}
Response getResponse() throws IOException {
if (response == null) {
String rid = entity.getId();
RequestResponse rr = uncorrelatedRequests.get(rid);
synchronized (rr) {
try {
long timetowait = entity.getReceiveTimeout();
response = rr.getResponse();
if (response == null) {
rr.wait(timetowait);
response = rr.getResponse();
}
} catch (InterruptedException e) {
// ignore
}
}
if (response == null) {
throw new SocketTimeoutException("Read timed out while invoking " + entity.getUri());
}
}
return response;
}
}
protected class AhcWebSocketListener implements WebSocketTextListener, WebSocketByteListener {
public void onOpen(WebSocket ws) {
if (LOG.isLoggable(Level.FINE)) {
LOG.log(Level.FINE, "onOpen({0})", ws);
}
}
public void onClose(WebSocket ws) {
if (LOG.isLoggable(Level.FINE)) {
LOG.log(Level.FINE, "onCose({0})", ws);
}
}
public void onError(Throwable t) {
LOG.log(Level.SEVERE, "[ws] onError", t);
}
public void onMessage(byte[] message) {
if (LOG.isLoggable(Level.FINE)) {
LOG.log(Level.FINE, "onMessage({0})", message);
}
Response resp = new Response(responseIdKey, message);
RequestResponse rr = uncorrelatedRequests.get(resp.getId());
if (rr != null) {
synchronized (rr) {
rr.setResponse(resp);
rr.notifyAll();
}
}
}
public void onFragment(byte[] fragment, boolean last) {
//TODO
LOG.log(Level.WARNING, "NOT IMPLEMENTED onFragment({0}, {1})", new Object[]{fragment, last});
}
public void onMessage(String message) {
if (LOG.isLoggable(Level.FINE)) {
LOG.log(Level.FINE, "onMessage({0})", message);
}
Response resp = new Response(responseIdKey, message);
RequestResponse rr = uncorrelatedRequests.get(resp.getId());
if (rr != null) {
synchronized (rr) {
rr.setResponse(resp);
rr.notifyAll();
}
}
}
public void onFragment(String fragment, boolean last) {
//TODO
LOG.log(Level.WARNING, "NOT IMPLEMENTED onFragment({0}, {1})", new Object[]{fragment, last});
}
}
// Request and Response are used to represent request and response messages transfered over the websocket
//REVIST move these classes to be used in other places after finalizing their contained information.
static class Response {
private Object data;
private int pos;
private int statusCode;
private String contentType;
private String id;
private Object entity;
Response(String idKey, Object data) {
this.data = data;
String line;
boolean first = true;
while ((line = readLine()) != null) {
if (first && isStatusCode(line)) {
statusCode = Integer.parseInt(line);
continue;
} else {
first = false;
}
int del = line.indexOf(':');
String h = line.substring(0, del).trim();
String v = line.substring(del + 1).trim();
if ("Content-Type".equalsIgnoreCase(h)) {
contentType = v;
} else if (WebSocketConstants.DEFAULT_RESPONSE_ID_KEY.equals(h)) {
id = v;
}
}
if (data instanceof String) {
entity = ((String)data).substring(pos);
} else if (data instanceof byte[]) {
entity = new byte[((byte[])data).length - pos];
System.arraycopy((byte[])data, pos, (byte[])entity, 0, ((byte[])entity).length);
}
}
private static boolean isStatusCode(String line) {
char c = line.charAt(0);
return '0' <= c && c <= '9';
}
public int getStatusCode() {
return statusCode;
}
public String getContentType() {
return contentType;
}
public String getId() {
return id;
}
public Object getEntity() {
return entity;
}
public String getTextEntity() {
return gettext(entity);
}
private String readLine() {
StringBuilder sb = new StringBuilder();
while (pos < length(data)) {
int c = getchar(data, pos++);
if (c == '\n') {
break;
} else if (c == '\r') {
continue;
} else {
sb.append((char)c);
}
}
if (sb.length() == 0) {
return null;
}
return sb.toString();
}
private int length(Object o) {
if (o instanceof String) {
return ((String)o).length();
} else if (o instanceof char[]) {
return ((char[])o).length;
} else if (o instanceof byte[]) {
return ((byte[])o).length;
} else {
return 0;
}
}
private int getchar(Object o, int p) {
return 0xff & (o instanceof String ? ((String)o).charAt(p) : (o instanceof byte[] ? ((byte[])o)[p] : -1));
}
private String gettext(Object o) {
return o instanceof String ? (String)o : (o instanceof byte[] ? new String((byte[])o) : null);
}
}
static class RequestResponse {
private AhcWebSocketConduitRequest request;
private Response response;
RequestResponse(AhcWebSocketConduitRequest request) {
this.request = request;
}
public AhcWebSocketConduitRequest getRequest() {
return request;
}
public Response getResponse() {
return response;
}
public void setResponse(Response response) {
this.response = response;
}
}
}