/*
* JBoss, Home of Professional Open Source.
* See the COPYRIGHT.txt file distributed with this work for information
* regarding copyright ownership. Some portions may be licensed
* to Red Hat, Inc. under one or more contributor license agreements.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
* 02110-1301 USA.
*/
package org.teiid.resource.adapter.ws;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.MalformedURLException;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
import javax.activation.DataSource;
import javax.resource.ResourceException;
import javax.security.auth.Subject;
import javax.ws.rs.core.Response.Status;
import javax.xml.namespace.QName;
import javax.xml.ws.AsyncHandler;
import javax.xml.ws.Binding;
import javax.xml.ws.Dispatch;
import javax.xml.ws.EndpointReference;
import javax.xml.ws.Response;
import javax.xml.ws.Service;
import javax.xml.ws.Service.Mode;
import javax.xml.ws.WebServiceException;
import javax.xml.ws.handler.MessageContext;
import javax.xml.ws.http.HTTPBinding;
import org.apache.cxf.Bus;
import org.apache.cxf.BusFactory;
import org.apache.cxf.bus.spring.SpringBusFactory;
import org.apache.cxf.configuration.security.AuthorizationPolicy;
import org.apache.cxf.endpoint.Client;
import org.apache.cxf.endpoint.Endpoint;
import org.apache.cxf.interceptor.Interceptor;
import org.apache.cxf.jaxrs.client.JAXRSClientFactoryBean;
import org.apache.cxf.jaxrs.client.WebClient;
import org.apache.cxf.jaxws.DispatchImpl;
import org.apache.cxf.transport.http.HTTPConduit;
import org.apache.cxf.transport.http.HTTPConduitFactory;
import org.apache.cxf.transport.http.asyncclient.AsyncHTTPConduitFactory;
import org.apache.cxf.transports.http.configuration.HTTPClientPolicy;
import org.apache.cxf.rt.security.SecurityConstants;
import org.apache.cxf.ws.security.wss4j.WSS4JInInterceptor;
import org.apache.cxf.ws.security.wss4j.WSS4JOutInterceptor;
import org.ietf.jgss.GSSCredential;
import org.teiid.OAuthCredential;
import org.teiid.core.util.ArgCheck;
import org.teiid.logging.LogConstants;
import org.teiid.logging.LogManager;
import org.teiid.logging.MessageLevel;
import org.teiid.resource.spi.BasicConnection;
import org.teiid.resource.spi.ConnectionContext;
import org.teiid.translator.WSConnection;
/**
* WebService connection implementation.
*
* TODO: set a handler chain
*/
public class WSConnectionImpl extends BasicConnection implements WSConnection {
private static final String CONNECTION_TIMEOUT = "javax.xml.ws.client.connectionTimeout"; //$NON-NLS-1$
private static final String RECEIVE_TIMEOUT = "javax.xml.ws.client.receiveTimeout"; //$NON-NLS-1$
private static final class HttpDataSource implements DataSource {
private final URL url;
private InputStream content;
private String contentType;
private HttpDataSource(URL url, InputStream entity, String contentType) {
this.url = url;
this.content = entity;
this.contentType = contentType;
}
@Override
public OutputStream getOutputStream() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public String getName() {
return this.url.getPath();
}
@Override
public InputStream getInputStream() throws IOException {
return this.content;
}
@Override
public String getContentType() {
return this.contentType;
}
}
private static final class HttpDispatch implements Dispatch<DataSource> {
private static final String AUTHORIZATION = "Authorization"; //$NON-NLS-1$
private HashMap<String, Object> requestContext = new HashMap<String, Object>();
private HashMap<String, Object> responseContext = new HashMap<String, Object>();
private WebClient client;
private String endpoint;
private String configFile;
public HttpDispatch(String endpoint, String configFile, @SuppressWarnings("unused") String configName) {
this.endpoint = endpoint;
this.configFile = configFile;
}
WebClient createWebClient(String baseAddress, Bus bus) {
JAXRSClientFactoryBean bean = new JAXRSClientFactoryBean();
bean.setBus(bus);
bean.setAddress(baseAddress);
return bean.createWebClient();
}
Bus getBus(String configLocation) {
if (configLocation != null) {
SpringBusFactory bf = new SpringBusFactory();
return bf.createBus(configLocation);
} else {
return BusFactory.getThreadDefaultBus();
}
}
@Override
public DataSource invoke(DataSource msg) {
try {
final URL url = new URL(this.endpoint);
url.toURI(); //ensure this is a valid uri
final String httpMethod = (String)this.requestContext.get(MessageContext.HTTP_REQUEST_METHOD);
// see to use patch
// http://stackoverflow.com/questions/32067687/how-to-use-patch-method-in-cxf
Bus bus = getBus(this.configFile);
if (httpMethod.equals("PATCH")) {
bus.setProperty("use.async.http.conduit", Boolean.TRUE);
bus.setExtension(new AsyncHTTPConduitFactory(bus), HTTPConduitFactory.class);
}
this.client = createWebClient(this.endpoint, bus);
Map<String, List<String>> header = (Map<String, List<String>>)this.requestContext.get(MessageContext.HTTP_REQUEST_HEADERS);
for (Map.Entry<String, List<String>> entry : header.entrySet()) {
this.client.header(entry.getKey(), entry.getValue().toArray());
}
if (this.requestContext.get(AuthorizationPolicy.class.getName()) != null) {
HTTPConduit conduit = (HTTPConduit)WebClient.getConfig(this.client).getConduit();
AuthorizationPolicy policy = (AuthorizationPolicy)this.requestContext.get(AuthorizationPolicy.class.getName());
conduit.setAuthorization(policy);
}
else if (this.requestContext.get(GSSCredential.class.getName()) != null) {
WebClient.getConfig(this.client).getRequestContext().put(GSSCredential.class.getName(), this.requestContext.get(GSSCredential.class.getName()));
WebClient.getConfig(this.client).getRequestContext().put("auth.spnego.requireCredDelegation", true); //$NON-NLS-1$
}
else if (this.requestContext.get(OAuthCredential.class.getName()) != null) {
OAuthCredential credential = (OAuthCredential)this.requestContext.get(OAuthCredential.class.getName());
this.client.header(AUTHORIZATION, credential.getAuthorizationHeader(this.endpoint, httpMethod));
}
InputStream payload = null;
if (msg != null) {
payload = msg.getInputStream();
}
HTTPClientPolicy clientPolicy = WebClient.getConfig(this.client).getHttpConduit().getClient();
Long timeout = (Long) this.requestContext.get(RECEIVE_TIMEOUT);
if (timeout != null) {
clientPolicy.setReceiveTimeout(timeout);
}
timeout = (Long) this.requestContext.get(CONNECTION_TIMEOUT);
if (timeout != null) {
clientPolicy.setConnectionTimeout(timeout);
}
javax.ws.rs.core.Response response = this.client.invoke(httpMethod, payload);
this.responseContext.put(WSConnection.STATUS_CODE, response.getStatus());
this.responseContext.putAll(response.getMetadata());
ArrayList contentTypes = (ArrayList)this.responseContext.get("content-type"); //$NON-NLS-1$
String contentType = contentTypes != null ? (String)contentTypes.get(0):"application/octet-stream"; //$NON-NLS-1$
return new HttpDataSource(url, (InputStream)response.getEntity(), contentType);
} catch (IOException e) {
throw new WebServiceException(e);
} catch (URISyntaxException e) {
throw new WebServiceException(e);
}
}
@Override
public Map<String, Object> getRequestContext() {
return this.requestContext;
}
@Override
public Map<String, Object> getResponseContext() {
return this.responseContext;
}
@Override
public Binding getBinding() {
throw new UnsupportedOperationException();
}
@Override
public EndpointReference getEndpointReference() {
throw new UnsupportedOperationException();
}
@Override
public <T extends EndpointReference> T getEndpointReference(Class<T> clazz) {
throw new UnsupportedOperationException();
}
@Override
public Response<DataSource> invokeAsync(DataSource msg) {
throw new UnsupportedOperationException();
}
@Override
public Future<?> invokeAsync(DataSource msg,AsyncHandler<DataSource> handler) {
throw new UnsupportedOperationException();
}
@Override
public void invokeOneWay(DataSource msg) {
throw new UnsupportedOperationException();
}
}
private WSManagedConnectionFactory mcf;
private Service wsdlService;
public WSConnectionImpl(WSManagedConnectionFactory mcf) {
this.mcf = mcf;
}
public <T> Dispatch<T> createDispatch(Class<T> type, Mode mode) throws IOException {
if (this.wsdlService == null) {
Bus bus = BusFactory.getThreadDefaultBus();
BusFactory.setThreadDefaultBus(this.mcf.getBus());
try {
this.wsdlService = Service.create(this.mcf.getWsdlUrl(), this.mcf.getServiceQName());
} finally {
BusFactory.setThreadDefaultBus(bus);
}
if (LogManager.isMessageToBeRecorded(LogConstants.CTX_WS, MessageLevel.DETAIL)) {
LogManager.logDetail(LogConstants.CTX_WS, "Created the WSDL service for", this.mcf.getWsdl()); //$NON-NLS-1$
}
}
Dispatch<T> dispatch = this.wsdlService.createDispatch(this.mcf.getPortQName(), type, mode);
configureWSSecurity(dispatch);
setDispatchProperties(dispatch, "SOAP12"); //$NON-NLS-1$
return dispatch;
}
public <T> Dispatch<T> createDispatch(String binding, String endpoint, Class<T> type, Mode mode) {
ArgCheck.isNotNull(binding);
if (endpoint != null) {
try {
new URL(endpoint);
//valid url, just use the endpoint
} catch (MalformedURLException e) {
//otherwise it should be a relative value
//but we should still preserve the base path and query string
String defaultEndpoint = this.mcf.getEndPoint();
String defaultQueryString = null;
String defaultFragment = null;
if (defaultEndpoint == null) {
throw new WebServiceException(WSManagedConnectionFactory.UTIL.getString("null_default_endpoint")); //$NON-NLS-1$
}
String[] parts = defaultEndpoint.split("\\?", 2); //$NON-NLS-1$
defaultEndpoint = parts[0];
if (parts.length > 1) {
defaultQueryString = parts[1];
parts = defaultQueryString.split("#"); //$NON-NLS-1$
defaultQueryString = parts[0];
if (parts.length > 1) {
defaultFragment = parts[1];
}
}
if (endpoint.startsWith("?") || endpoint.startsWith("/") || defaultEndpoint.endsWith("/")) { //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$
endpoint = defaultEndpoint + endpoint;
} else {
endpoint = defaultEndpoint + "/" + endpoint; //$NON-NLS-1$
}
if ((defaultQueryString != null) && (defaultQueryString.trim().length() > 0)) {
endpoint = WSConnection.Util.appendQueryString(endpoint, defaultQueryString);
}
if ((defaultFragment != null) && (endpoint.indexOf('#') < 0)) {
endpoint = endpoint + '#' + defaultFragment;
}
}
} else {
endpoint = this.mcf.getEndPoint();
if (endpoint == null) {
throw new WebServiceException(WSManagedConnectionFactory.UTIL.getString("null_endpoint")); //$NON-NLS-1$
}
}
Dispatch<T> dispatch = null;
if (HTTPBinding.HTTP_BINDING.equals(binding) && (type == DataSource.class)) {
Bus bus = BusFactory.getThreadDefaultBus();
BusFactory.setThreadDefaultBus(this.mcf.getBus());
try {
dispatch = (Dispatch<T>) new HttpDispatch(endpoint, this.mcf.getConfigFile(), this.mcf.getConfigName());
} finally {
BusFactory.setThreadDefaultBus(bus);
}
} else {
//TODO: cache service/port/dispatch instances?
Bus bus = BusFactory.getThreadDefaultBus();
BusFactory.setThreadDefaultBus(this.mcf.getBus());
Service svc;
try {
svc = Service.create(this.mcf.getServiceQName());
} finally {
BusFactory.setThreadDefaultBus(bus);
}
if (LogManager.isMessageToBeRecorded(LogConstants.CTX_WS, MessageLevel.DETAIL)) {
LogManager.logDetail(LogConstants.CTX_WS, "Creating a dispatch with endpoint", endpoint); //$NON-NLS-1$
}
svc.addPort(this.mcf.getPortQName(), binding, endpoint);
dispatch = svc.createDispatch(this.mcf.getPortQName(), type, mode);
configureWSSecurity(dispatch);
}
setDispatchProperties(dispatch, binding);
return dispatch;
}
private <T> void configureWSSecurity(Dispatch<T> dispatch) {
if (this.mcf.getAsSecurityType() == WSManagedConnectionFactory.SecurityType.WSSecurity) {
Bus bus = BusFactory.getThreadDefaultBus();
BusFactory.setThreadDefaultBus(this.mcf.getBus());
try {
Client client = ((DispatchImpl)dispatch).getClient();
Endpoint ep = client.getEndpoint();
// spring configuration file
if (this.mcf.getOutInterceptors() != null) {
for (Interceptor i : this.mcf.getOutInterceptors()) {
ep.getOutInterceptors().add(i);
}
}
// ws-security pass-thru from custom jaas domain
Subject subject = ConnectionContext.getSubject();
if (subject != null) {
WSSecurityCredential credential = ConnectionContext.getSecurityCredential(subject, WSSecurityCredential.class);
if (credential != null) {
if (credential.useSts()) {
dispatch.getRequestContext().put(SecurityConstants.STS_CLIENT, credential.buildStsClient(bus));
}
if(credential.getSecurityHandler() == WSSecurityCredential.SecurityHandler.WSS4J) {
ep.getOutInterceptors().add(new WSS4JOutInterceptor(credential.getRequestPropterties()));
ep.getInInterceptors().add(new WSS4JInInterceptor(credential.getResponsePropterties()));
}
else if (credential.getSecurityHandler() == WSSecurityCredential.SecurityHandler.WSPOLICY) {
dispatch.getRequestContext().putAll(credential.getRequestPropterties());
dispatch.getResponseContext().putAll(credential.getResponsePropterties());
}
}
// When properties are set on subject treat them as they can configure WS-Security
HashMap<String, String> properties = ConnectionContext.getSecurityCredential(subject, HashMap.class);
for (String key:properties.keySet()) {
if (key.startsWith("ws-security.")) { //$NON-NLS-1$
ep.put(key, properties.get(key));
}
}
}
} finally {
BusFactory.setThreadDefaultBus(bus);
}
}
}
private <T> void setDispatchProperties(Dispatch<T> dispatch, String binding) {
if (this.mcf.getAsSecurityType() == WSManagedConnectionFactory.SecurityType.HTTPBasic
|| this.mcf.getAsSecurityType() == WSManagedConnectionFactory.SecurityType.Digest){
String userName = this.mcf.getAuthUserName();
String password = this.mcf.getAuthPassword();
// if security-domain is specified and caller identity is used; then use
// credentials from subject
Subject subject = ConnectionContext.getSubject();
if (subject != null) {
userName = ConnectionContext.getUserName(subject, this.mcf, userName);
password = ConnectionContext.getPassword(subject, this.mcf, userName, password);
}
AuthorizationPolicy policy = new AuthorizationPolicy();
policy.setUserName(userName);
policy.setPassword(password);
if (this.mcf.getAsSecurityType() == WSManagedConnectionFactory.SecurityType.Digest) {
policy.setAuthorizationType("Digest");
} else {
policy.setAuthorizationType("Basic");
}
dispatch.getRequestContext().put(AuthorizationPolicy.class.getName(), policy);
}
else if (this.mcf.getAsSecurityType() == WSManagedConnectionFactory.SecurityType.Kerberos) {
boolean credentialFound = false;
Subject subject = ConnectionContext.getSubject();
if (subject != null) {
GSSCredential credential = ConnectionContext.getSecurityCredential(subject, GSSCredential.class);
if (credential != null) {
dispatch.getRequestContext().put(GSSCredential.class.getName(), credential);
credentialFound = true;
}
}
if (!credentialFound) {
throw new WebServiceException(WSManagedConnectionFactory.UTIL.getString("no_gss_credential")); //$NON-NLS-1$
}
}
else if (this.mcf.getAsSecurityType() == WSManagedConnectionFactory.SecurityType.OAuth) {
boolean credentialFound = false;
Subject subject = ConnectionContext.getSubject();
if (subject != null) {
OAuthCredential credential = ConnectionContext.getSecurityCredential(subject, OAuthCredential.class);
if (credential != null) {
dispatch.getRequestContext().put(OAuthCredential.class.getName(), credential);
credentialFound = true;
}
}
if (!credentialFound) {
throw new WebServiceException(WSManagedConnectionFactory.UTIL.getString("no_oauth_credential")); //$NON-NLS-1$
}
}
if (this.mcf.getRequestTimeout() != null){
dispatch.getRequestContext().put(RECEIVE_TIMEOUT, this.mcf.getRequestTimeout());
}
if (this.mcf.getConnectTimeout() != null){
dispatch.getRequestContext().put(CONNECTION_TIMEOUT, this.mcf.getConnectTimeout());
}
if (HTTPBinding.HTTP_BINDING.equals(binding)) {
Map<String, List<String>> httpHeaders = (Map<String, List<String>>)dispatch.getRequestContext().get(MessageContext.HTTP_REQUEST_HEADERS);
if(httpHeaders == null) {
httpHeaders = new HashMap<String, List<String>>();
}
httpHeaders.put("Content-Type", Collections.singletonList("text/xml; charset=utf-8"));//$NON-NLS-1$ //$NON-NLS-2$
httpHeaders.put("User-Agent", Collections.singletonList("Teiid Server"));//$NON-NLS-1$ //$NON-NLS-2$
dispatch.getRequestContext().put(MessageContext.HTTP_REQUEST_HEADERS, httpHeaders);
}
}
@Override
public void close() throws ResourceException {
}
@Override
public URL getWsdl() {
return this.mcf.getWsdlUrl();
}
@Override
public QName getServiceQName() {
return this.mcf.getServiceQName();
}
@Override
public QName getPortQName() {
return this.mcf.getPortQName();
}
@Override
public String getStatusMessage(int status) {
Status s = javax.ws.rs.core.Response.Status.fromStatusCode(status);
if (s != null) {
return s.getReasonPhrase();
}
return null;
}
}