/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.cxf.transport.websocket.ahc; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.MalformedURLException; import java.net.SocketTimeoutException; import java.net.URI; import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Level; import java.util.logging.Logger; import com.ning.http.client.AsyncHttpClient; import com.ning.http.client.ws.WebSocket; import com.ning.http.client.ws.WebSocketByteListener; import com.ning.http.client.ws.WebSocketTextListener; import com.ning.http.client.ws.WebSocketUpgradeHandler; import org.apache.cxf.Bus; import org.apache.cxf.common.logging.LogUtils; import org.apache.cxf.message.Message; import org.apache.cxf.service.model.EndpointInfo; import org.apache.cxf.transport.http.Address; import org.apache.cxf.transport.http.Headers; import org.apache.cxf.transport.http.URLConnectionHTTPConduit; import org.apache.cxf.transport.https.HttpsURLConnectionInfo; import org.apache.cxf.transport.websocket.WebSocketConstants; import org.apache.cxf.transport.websocket.WebSocketUtils; import org.apache.cxf.transports.http.configuration.HTTPClientPolicy; import org.apache.cxf.ws.addressing.EndpointReferenceType; /** * */ public class AhcWebSocketConduit extends URLConnectionHTTPConduit { private static final Logger LOG = LogUtils.getL7dLogger(AhcWebSocketConduit.class); private AsyncHttpClient ahcclient; private WebSocket websocket; //REVISIT make these keys configurable private String requestIdKey = WebSocketConstants.DEFAULT_REQUEST_ID_KEY; private String responseIdKey = WebSocketConstants.DEFAULT_RESPONSE_ID_KEY; private Map<String, RequestResponse> uncorrelatedRequests = new ConcurrentHashMap<String, RequestResponse>(); public AhcWebSocketConduit(Bus b, EndpointInfo ei, EndpointReferenceType t) throws IOException { super(b, ei, t); ahcclient = new AsyncHttpClient(); } @Override protected void setupConnection(Message message, Address address, HTTPClientPolicy csPolicy) throws IOException { URI currentURL = address.getURI(); String s = currentURL.getScheme(); if (!"ws".equals(s) && !"wss".equals(s)) { throw new MalformedURLException("unknown protocol: " + s); } message.put("http.scheme", currentURL.getScheme()); String httpRequestMethod = (String)message.get(Message.HTTP_REQUEST_METHOD); if (httpRequestMethod == null) { httpRequestMethod = "POST"; message.put(Message.HTTP_REQUEST_METHOD, httpRequestMethod); } final AhcWebSocketConduitRequest request = new AhcWebSocketConduitRequest(currentURL, httpRequestMethod); final int rtimeout = determineReceiveTimeout(message, csPolicy); request.setReceiveTimeout(rtimeout); message.put(AhcWebSocketConduitRequest.class, request); } @Override protected OutputStream createOutputStream(Message message, boolean needToCacheRequest, boolean isChunking, int chunkThreshold) throws IOException { AhcWebSocketConduitRequest entity = message.get(AhcWebSocketConduitRequest.class); return new AhcWebSocketWrappedOutputStream(message, needToCacheRequest, isChunking, chunkThreshold, getConduitName(), entity.getUri()); } public class AhcWebSocketWrappedOutputStream extends WrappedOutputStream { private AhcWebSocketConduitRequest entity; private Response response; protected AhcWebSocketWrappedOutputStream(Message message, boolean possibleRetransmit, boolean isChunking, int chunkThreshold, String conduitName, URI url) { super(message, possibleRetransmit, isChunking, chunkThreshold, conduitName, url); entity = message.get(AhcWebSocketConduitRequest.class); //REVISIT how we prepare the request String requri = (String)message.getContextualProperty("org.apache.cxf.request.uri"); if (requri != null) { // jaxrs speicfies a sub-path using prop org.apache.cxf.request.uri if (requri.startsWith("ws")) { entity.setPath(requri.substring(requri.indexOf('/', 3 + requri.indexOf(':')))); } else { entity.setPath(url.getPath() + requri); } } else { // jaxws entity.setPath(url.getPath()); } entity.setId(UUID.randomUUID().toString()); uncorrelatedRequests.put(entity.getId(), new RequestResponse(entity)); } @Override protected void setupWrappedStream() throws IOException { connect(); wrappedStream = new OutputStream() { @Override public void write(byte b[], int off, int len) throws IOException { //REVISIT support multiple writes and flush() to write the entire block data? // or provides the fragment mode? Map<String, String> headers = new HashMap<>(); headers.put("Content-Type", entity.getContentType()); headers.put(requestIdKey, entity.getId()); websocket.sendMessage(WebSocketUtils.buildRequest( entity.getMethod(), entity.getPath(), headers, b, off, len)); } @Override public void write(int b) throws IOException { //REVISIT support this single byte write and use flush() to write the block data? } @Override public void close() throws IOException { } }; } @Override protected void handleNoOutput() throws IOException { connect(); Map<String, String> headers = new HashMap<>(); headers.put(requestIdKey, entity.getId()); websocket.sendMessage(WebSocketUtils.buildRequest( entity.getMethod(), entity.getPath(), headers, null, 0, 0)); } @Override protected HttpsURLConnectionInfo getHttpsURLConnectionInfo() throws IOException { return null; } @Override protected void setProtocolHeaders() throws IOException { Headers h = new Headers(outMessage); entity.setContentType(h.determineContentType()); //REVISIT may provide an option to add other headers // boolean addHeaders = MessageUtils.isTrue(outMessage.getContextualProperty(Headers.ADD_HEADERS_PROPERTY)); } @Override protected void setFixedLengthStreamingMode(int i) { // ignore } @Override protected int getResponseCode() throws IOException { Response r = getResponse(); return r.getStatusCode(); } @Override protected String getResponseMessage() throws IOException { //TODO return a generic message based on the status code return null; } @Override protected void updateResponseHeaders(Message inMessage) throws IOException { Headers h = new Headers(inMessage); String ct = getResponse().getContentType(); inMessage.put(Message.CONTENT_TYPE, ct); //REVISIT if we are allowing more headers, we need to add them into the cxf's headers h.headerMap().put(Message.CONTENT_TYPE, Collections.singletonList(ct)); } @Override protected void handleResponseAsync() throws IOException { handleResponseOnWorkqueue(true, false); } @Override protected void closeInputStream() throws IOException { } @Override protected boolean usingProxy() { // TODO add proxy support ... return false; } @Override protected InputStream getInputStream() throws IOException { Response r = getResponse(); //REVISIT return new java.io.ByteArrayInputStream(r.getTextEntity().getBytes()); } @Override protected InputStream getPartialResponse() throws IOException { Response r = getResponse(); //REVISIT return new java.io.ByteArrayInputStream(r.getTextEntity().getBytes()); } @Override protected void setupNewConnection(String newURL) throws IOException { // TODO throw new IOException("not supported"); } @Override protected void retransmitStream() throws IOException { // TODO throw new IOException("not supported"); } @Override protected void updateCookiesBeforeRetransmit() throws IOException { // ignore for now and may consider a specific websocket binding variant to use cookies } @Override public void thresholdReached() throws IOException { // ignore } // // other methods follow // protected void connect() { LOG.log(Level.FINE, "connecting"); if (websocket == null) { try { websocket = ahcclient.prepareGet(url.toASCIIString()).execute( new WebSocketUpgradeHandler.Builder() .addWebSocketListener(new AhcWebSocketListener()).build()).get(); LOG.log(Level.FINE, "connected"); } catch (Exception e) { LOG.log(Level.SEVERE, "unable to connect", e); } } else { LOG.log(Level.FINE, "already connected"); } } Response getResponse() throws IOException { if (response == null) { String rid = entity.getId(); RequestResponse rr = uncorrelatedRequests.get(rid); synchronized (rr) { try { long timetowait = entity.getReceiveTimeout(); response = rr.getResponse(); if (response == null) { rr.wait(timetowait); response = rr.getResponse(); } } catch (InterruptedException e) { // ignore } } if (response == null) { throw new SocketTimeoutException("Read timed out while invoking " + entity.getUri()); } } return response; } } protected class AhcWebSocketListener implements WebSocketTextListener, WebSocketByteListener { public void onOpen(WebSocket ws) { if (LOG.isLoggable(Level.FINE)) { LOG.log(Level.FINE, "onOpen({0})", ws); } } public void onClose(WebSocket ws) { if (LOG.isLoggable(Level.FINE)) { LOG.log(Level.FINE, "onCose({0})", ws); } } public void onError(Throwable t) { LOG.log(Level.SEVERE, "[ws] onError", t); } public void onMessage(byte[] message) { if (LOG.isLoggable(Level.FINE)) { LOG.log(Level.FINE, "onMessage({0})", message); } Response resp = new Response(responseIdKey, message); RequestResponse rr = uncorrelatedRequests.get(resp.getId()); if (rr != null) { synchronized (rr) { rr.setResponse(resp); rr.notifyAll(); } } } public void onFragment(byte[] fragment, boolean last) { //TODO LOG.log(Level.WARNING, "NOT IMPLEMENTED onFragment({0}, {1})", new Object[]{fragment, last}); } public void onMessage(String message) { if (LOG.isLoggable(Level.FINE)) { LOG.log(Level.FINE, "onMessage({0})", message); } Response resp = new Response(responseIdKey, message); RequestResponse rr = uncorrelatedRequests.get(resp.getId()); if (rr != null) { synchronized (rr) { rr.setResponse(resp); rr.notifyAll(); } } } public void onFragment(String fragment, boolean last) { //TODO LOG.log(Level.WARNING, "NOT IMPLEMENTED onFragment({0}, {1})", new Object[]{fragment, last}); } } // Request and Response are used to represent request and response messages transfered over the websocket //REVIST move these classes to be used in other places after finalizing their contained information. static class Response { private Object data; private int pos; private int statusCode; private String contentType; private String id; private Object entity; Response(String idKey, Object data) { this.data = data; String line; boolean first = true; while ((line = readLine()) != null) { if (first && isStatusCode(line)) { statusCode = Integer.parseInt(line); continue; } else { first = false; } int del = line.indexOf(':'); String h = line.substring(0, del).trim(); String v = line.substring(del + 1).trim(); if ("Content-Type".equalsIgnoreCase(h)) { contentType = v; } else if (WebSocketConstants.DEFAULT_RESPONSE_ID_KEY.equals(h)) { id = v; } } if (data instanceof String) { entity = ((String)data).substring(pos); } else if (data instanceof byte[]) { entity = new byte[((byte[])data).length - pos]; System.arraycopy((byte[])data, pos, (byte[])entity, 0, ((byte[])entity).length); } } private static boolean isStatusCode(String line) { char c = line.charAt(0); return '0' <= c && c <= '9'; } public int getStatusCode() { return statusCode; } public String getContentType() { return contentType; } public String getId() { return id; } public Object getEntity() { return entity; } public String getTextEntity() { return gettext(entity); } private String readLine() { StringBuilder sb = new StringBuilder(); while (pos < length(data)) { int c = getchar(data, pos++); if (c == '\n') { break; } else if (c == '\r') { continue; } else { sb.append((char)c); } } if (sb.length() == 0) { return null; } return sb.toString(); } private int length(Object o) { if (o instanceof String) { return ((String)o).length(); } else if (o instanceof char[]) { return ((char[])o).length; } else if (o instanceof byte[]) { return ((byte[])o).length; } else { return 0; } } private int getchar(Object o, int p) { return 0xff & (o instanceof String ? ((String)o).charAt(p) : (o instanceof byte[] ? ((byte[])o)[p] : -1)); } private String gettext(Object o) { return o instanceof String ? (String)o : (o instanceof byte[] ? new String((byte[])o) : null); } } static class RequestResponse { private AhcWebSocketConduitRequest request; private Response response; RequestResponse(AhcWebSocketConduitRequest request) { this.request = request; } public AhcWebSocketConduitRequest getRequest() { return request; } public Response getResponse() { return response; } public void setResponse(Response response) { this.response = response; } } }