/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.ws.security.sts.provider;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.security.Principal;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Resource;
import javax.xml.bind.Binder;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBElement;
import javax.xml.bind.annotation.XmlAnyElement;
import javax.xml.bind.util.JAXBSource;
import javax.xml.namespace.QName;
import javax.xml.transform.Source;
import javax.xml.ws.Provider;
import javax.xml.ws.Service;
import javax.xml.ws.ServiceMode;
import javax.xml.ws.WebServiceContext;
import javax.xml.ws.handler.MessageContext;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.apache.cxf.binding.soap.SoapFault;
import org.apache.cxf.binding.soap.SoapVersion;
import org.apache.cxf.common.jaxb.JAXBContextCache;
import org.apache.cxf.common.jaxb.JAXBContextCache.CachedContextAndSchemas;
import org.apache.cxf.common.util.ReflectionUtil;
import org.apache.cxf.helpers.DOMUtils;
import org.apache.cxf.staxutils.StaxUtils;
import org.apache.cxf.ws.security.sts.provider.model.ObjectFactory;
import org.apache.cxf.ws.security.sts.provider.model.RequestSecurityTokenCollectionType;
import org.apache.cxf.ws.security.sts.provider.model.RequestSecurityTokenResponseCollectionType;
import org.apache.cxf.ws.security.sts.provider.model.RequestSecurityTokenResponseType;
import org.apache.cxf.ws.security.sts.provider.model.RequestSecurityTokenType;
import org.apache.cxf.ws.security.sts.provider.operation.CancelOperation;
import org.apache.cxf.ws.security.sts.provider.operation.IssueOperation;
import org.apache.cxf.ws.security.sts.provider.operation.IssueSingleOperation;
import org.apache.cxf.ws.security.sts.provider.operation.KeyExchangeTokenOperation;
import org.apache.cxf.ws.security.sts.provider.operation.RenewOperation;
import org.apache.cxf.ws.security.sts.provider.operation.RequestCollectionOperation;
import org.apache.cxf.ws.security.sts.provider.operation.ValidateOperation;
@ServiceMode(value = Service.Mode.PAYLOAD)
public class SecurityTokenServiceProvider implements Provider<Source> {
private static final String WSTRUST_13_NAMESPACE = "http://docs.oasis-open.org/ws-sx/ws-trust/200512";
private static final String WSTRUST_REQUESTTYPE_ELEMENTNAME = "RequestType";
private static final String WSTRUST_REQUESTTYPE_ISSUE = WSTRUST_13_NAMESPACE
+ "/Issue";
private static final String WSTRUST_REQUESTTYPE_CANCEL = WSTRUST_13_NAMESPACE
+ "/Cancel";
private static final String WSTRUST_REQUESTTYPE_RENEW = WSTRUST_13_NAMESPACE
+ "/Renew";
private static final String WSTRUST_REQUESTTYPE_VALIDATE = WSTRUST_13_NAMESPACE
+ "/Validate";
private static final String WSTRUST_REQUESTTYPE_REQUESTCOLLECTION = WSTRUST_13_NAMESPACE
+ "/RequestCollection";
private static final String WSTRUST_REQUESTTYPE_KEYEXCHANGETOKEN = WSTRUST_13_NAMESPACE
+ "/KeyExchangeToken";
private static final Map<String, Method> OPERATION_METHODS = new HashMap<>();
static {
try {
Method m = IssueOperation.class.getDeclaredMethod("issue",
RequestSecurityTokenType.class,
Principal.class,
Map.class);
OPERATION_METHODS.put(WSTRUST_REQUESTTYPE_ISSUE, m);
m = CancelOperation.class.getDeclaredMethod("cancel",
RequestSecurityTokenType.class,
Principal.class,
Map.class);
OPERATION_METHODS.put(WSTRUST_REQUESTTYPE_CANCEL, m);
m = RenewOperation.class.getDeclaredMethod("renew",
RequestSecurityTokenType.class,
Principal.class,
Map.class);
OPERATION_METHODS.put(WSTRUST_REQUESTTYPE_RENEW, m);
m = ValidateOperation.class.getDeclaredMethod("validate",
RequestSecurityTokenType.class,
Principal.class,
Map.class);
OPERATION_METHODS.put(WSTRUST_REQUESTTYPE_VALIDATE, m);
m = KeyExchangeTokenOperation.class.getDeclaredMethod("keyExchangeToken",
RequestSecurityTokenType.class,
Principal.class,
Map.class);
OPERATION_METHODS.put(WSTRUST_REQUESTTYPE_KEYEXCHANGETOKEN, m);
m = RequestCollectionOperation.class.getDeclaredMethod("requestCollection",
RequestSecurityTokenCollectionType.class,
Principal.class,
Map.class);
OPERATION_METHODS.put(WSTRUST_REQUESTTYPE_REQUESTCOLLECTION, m);
} catch (Exception ex) {
ex.printStackTrace();
}
}
protected JAXBContext jaxbContext;
protected Set<Class<?>> jaxbContextClasses;
private CancelOperation cancelOperation;
private IssueOperation issueOperation;
private IssueSingleOperation issueSingleOperation;
private KeyExchangeTokenOperation keyExchangeTokenOperation;
private RenewOperation renewOperation;
private RequestCollectionOperation requestCollectionOperation;
private ValidateOperation validateOperation;
private Map<String, Object> operationMap = new HashMap<>();
@Resource
private WebServiceContext context;
public SecurityTokenServiceProvider() throws Exception {
Set<Class<?>> classes = new HashSet<Class<?>>();
classes.add(ObjectFactory.class);
classes.add(org.apache.cxf.ws.security.sts.provider.model.wstrust14.ObjectFactory.class);
CachedContextAndSchemas cache =
JAXBContextCache.getCachedContextAndSchemas(classes, null, null, null, false);
jaxbContext = cache.getContext();
jaxbContextClasses = cache.getClasses();
}
public void setCancelOperation(CancelOperation cancelOperation) {
this.cancelOperation = cancelOperation;
operationMap.put(WSTRUST_REQUESTTYPE_CANCEL, cancelOperation);
}
public void setIssueOperation(IssueOperation issueOperation) {
this.issueOperation = issueOperation;
operationMap.put(WSTRUST_REQUESTTYPE_ISSUE, issueOperation);
}
/**
* Setting an IssueSingleOperation instance will override the default behaviour of issuing
* a token in a RequestSecurityTokenResponseCollection
*/
public void setIssueSingleOperation(IssueSingleOperation issueSingleOperation) {
this.issueSingleOperation = issueSingleOperation;
Method m;
try {
m = IssueSingleOperation.class.getDeclaredMethod("issueSingle",
RequestSecurityTokenType.class,
Principal.class,
Map.class);
OPERATION_METHODS.put(WSTRUST_REQUESTTYPE_ISSUE, m);
operationMap.put(WSTRUST_REQUESTTYPE_ISSUE, issueSingleOperation);
} catch (Exception e) {
e.printStackTrace();
}
}
public void setKeyExchangeTokenOperation(
KeyExchangeTokenOperation keyExchangeTokenOperation) {
this.keyExchangeTokenOperation = keyExchangeTokenOperation;
operationMap.put(WSTRUST_REQUESTTYPE_KEYEXCHANGETOKEN,
keyExchangeTokenOperation);
}
public void setRenewOperation(RenewOperation renewOperation) {
this.renewOperation = renewOperation;
operationMap.put(WSTRUST_REQUESTTYPE_RENEW, renewOperation);
}
public void setRequestCollectionOperation(
RequestCollectionOperation requestCollectionOperation) {
this.requestCollectionOperation = requestCollectionOperation;
operationMap.put(WSTRUST_REQUESTTYPE_REQUESTCOLLECTION,
requestCollectionOperation);
}
public void setValidateOperation(ValidateOperation validateOperation) {
this.validateOperation = validateOperation;
operationMap.put(WSTRUST_REQUESTTYPE_VALIDATE, validateOperation);
}
public Source invoke(Source request) {
Source response = null;
try {
Object obj = convertToJAXBObject(request);
Object operationImpl = null;
Method method = null;
if (obj instanceof RequestSecurityTokenCollectionType) {
operationImpl = operationMap.get(WSTRUST_REQUESTTYPE_REQUESTCOLLECTION);
method = OPERATION_METHODS.get(WSTRUST_REQUESTTYPE_REQUESTCOLLECTION);
} else {
RequestSecurityTokenType rst = (RequestSecurityTokenType)obj;
List<?> objectList = rst.getAny();
for (Object o : objectList) {
if (o instanceof JAXBElement) {
QName qname = ((JAXBElement<?>) o).getName();
if (qname.equals(new QName(WSTRUST_13_NAMESPACE,
WSTRUST_REQUESTTYPE_ELEMENTNAME))) {
String val = ((JAXBElement<?>) o).getValue().toString();
operationImpl = operationMap.get(val);
method = OPERATION_METHODS.get(val);
break;
}
}
}
}
if (operationImpl == null || method == null) {
throw new Exception(
"Implementation for this operation not found.");
}
obj = method.invoke(operationImpl, obj, context.getUserPrincipal(), context.getMessageContext());
if (obj == null) {
throw new Exception("Error in implementation class.");
}
if (obj instanceof RequestSecurityTokenResponseCollectionType) {
RequestSecurityTokenResponseCollectionType tokenResponse =
(RequestSecurityTokenResponseCollectionType)obj;
response = new JAXBSource(jaxbContext,
new ObjectFactory()
.createRequestSecurityTokenResponseCollection(tokenResponse));
} else {
RequestSecurityTokenResponseType tokenResponse =
(RequestSecurityTokenResponseType)obj;
response = new JAXBSource(jaxbContext,
new ObjectFactory()
.createRequestSecurityTokenResponse(tokenResponse));
}
} catch (InvocationTargetException ex) {
Throwable cause = ex.getCause();
throw createSOAPFault(cause);
} catch (Exception ex) {
throw createSOAPFault(ex);
}
return response;
}
private SoapFault createSOAPFault(Throwable ex) {
String faultString = "Internal STS error";
QName faultCode = null;
if (ex != null) {
if (ex instanceof STSException && ((STSException)ex).getFaultCode() != null) {
faultCode = ((STSException)ex).getFaultCode();
}
faultString = ex.getMessage();
}
MessageContext messageContext = context.getMessageContext();
SoapVersion soapVersion = (SoapVersion)messageContext.get(SoapVersion.class.getName());
SoapFault fault;
if (soapVersion.getVersion() == 1.1 && faultCode != null) {
fault = new SoapFault(faultString, faultCode);
} else {
fault = new SoapFault(faultString, soapVersion.getSender());
if (soapVersion.getVersion() != 1.1 && faultCode != null) {
fault.setSubCode(faultCode);
}
}
return fault;
}
private Object convertToJAXBObject(Source source) throws Exception {
//this is entirely to work around http://java.net/jira/browse/JAXB-909
//if that bug is ever fixed and we can detect it, we can remove this
//complete and total HACK HACK HACK and replace with just:
//Unmarshaller unmarshaller = jaxbContext.createUnmarshaller();
//JAXBElement<?> jaxbElement = (JAXBElement<?>) unmarshaller.unmarshal(source);
//return jaxbElement.getValue();
Document d = StaxUtils.read(source);
Binder<Node> binder = jaxbContext.createBinder();
JAXBElement<?> jaxbElement = (JAXBElement<?>)binder.unmarshal(d);
walkDom("", d.getDocumentElement(), binder, null);
return jaxbElement.getValue();
}
private void walkDom(String pfx, Element element, Binder<Node> binder, Object parent) {
try {
Object o = binder.getJAXBNode(element);
if (o instanceof JAXBElement) {
o = ((JAXBElement<?>)o).getValue();
}
//System.out.println(pfx + DOMUtils.getElementQName(element) + " -> "
// + (o == null ? "null" : o.getClass()));
if (o == null && parent != null) {
// if it's not able to bind to an object, it's possibly an xsd:any
// we'll check the parent for the standard "any" and replace with
// the original element.
Field f = parent.getClass().getDeclaredField("any");
if (f.getAnnotation(XmlAnyElement.class) != null) {
Object old = ReflectionUtil.setAccessible(f).get(parent);
if (old instanceof Element
&& DOMUtils.getElementQName(element).equals(DOMUtils.getElementQName((Element)old))) {
ReflectionUtil.setAccessible(f).set(parent, element);
}
}
}
if (o == null) {
return;
}
Node nd = element.getFirstChild();
while (nd != null) {
if (nd instanceof Element) {
walkDom(pfx + " ", (Element)nd, binder, o);
}
nd = nd.getNextSibling();
}
} catch (Throwable t) {
//ignore -this is a complete hack anyway
}
}
public CancelOperation getCancelOperation() {
return cancelOperation;
}
public IssueOperation getIssueOperation() {
return issueOperation;
}
public IssueSingleOperation getIssueSingleOperation() {
return issueSingleOperation;
}
public KeyExchangeTokenOperation getKeyExchangeTokenOperation() {
return keyExchangeTokenOperation;
}
public RenewOperation getRenewOperation() {
return renewOperation;
}
public RequestCollectionOperation getRequestCollectionOperation() {
return requestCollectionOperation;
}
public ValidateOperation getValidateOperation() {
return validateOperation;
}
}