/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.cxf.sts.rest; import java.io.StringWriter; import java.nio.charset.StandardCharsets; import java.security.Principal; import java.security.cert.X509Certificate; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.logging.Logger; import java.util.zip.Deflater; import javax.ws.rs.core.Context; import javax.ws.rs.core.Response; import javax.xml.bind.JAXBElement; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.apache.cxf.common.logging.LogUtils; import org.apache.cxf.common.util.Base64Exception; import org.apache.cxf.common.util.Base64Utility; import org.apache.cxf.common.util.CompressionUtils; import org.apache.cxf.common.util.PropertyUtils; import org.apache.cxf.helpers.DOMUtils; import org.apache.cxf.jaxrs.ext.MessageContext; import org.apache.cxf.message.Message; import org.apache.cxf.phase.PhaseInterceptorChain; import org.apache.cxf.security.SecurityContext; import org.apache.cxf.security.transport.TLSSessionInfo; import org.apache.cxf.sts.QNameConstants; import org.apache.cxf.sts.STSConstants; import org.apache.cxf.sts.token.provider.jwt.JWTTokenProvider; import org.apache.cxf.ws.security.sts.provider.SecurityTokenServiceImpl; import org.apache.cxf.ws.security.sts.provider.model.ClaimsType; import org.apache.cxf.ws.security.sts.provider.model.ObjectFactory; import org.apache.cxf.ws.security.sts.provider.model.RequestSecurityTokenResponseType; import org.apache.cxf.ws.security.sts.provider.model.RequestSecurityTokenType; import org.apache.cxf.ws.security.sts.provider.model.RequestedSecurityTokenType; import org.apache.cxf.ws.security.sts.provider.model.UseKeyType; import org.apache.cxf.ws.security.trust.STSUtils; import org.apache.wss4j.common.util.DOM2Writer; import org.apache.wss4j.dom.WSConstants; import org.apache.xml.security.exceptions.XMLSecurityException; import org.apache.xml.security.keys.content.X509Data; public class RESTSecurityTokenServiceImpl extends SecurityTokenServiceImpl implements RESTSecurityTokenService { public static final Map<String, String> DEFAULT_CLAIM_TYPE_MAP; public static final Map<String, String> DEFAULT_TOKEN_TYPE_MAP; private static final Map<String, String> DEFAULT_KEY_TYPE_MAP = new HashMap<>(); private static final String CLAIM_TYPE = "ClaimType"; private static final String CLAIM_TYPE_NS = "http://schemas.xmlsoap.org/ws/2005/05/identity"; private static final Logger LOG = LogUtils.getL7dLogger(RESTSecurityTokenServiceImpl.class); static { DEFAULT_CLAIM_TYPE_MAP = new HashMap<>(); DEFAULT_CLAIM_TYPE_MAP.put("emailaddress", CLAIM_TYPE_NS + "/claims/emailaddress"); DEFAULT_CLAIM_TYPE_MAP.put("role", CLAIM_TYPE_NS + "/claims/role"); DEFAULT_CLAIM_TYPE_MAP.put("surname", CLAIM_TYPE_NS + "/claims/surname"); DEFAULT_CLAIM_TYPE_MAP.put("givenname", CLAIM_TYPE_NS + "/claims/givenname"); DEFAULT_CLAIM_TYPE_MAP.put("name", CLAIM_TYPE_NS + "/claims/name"); DEFAULT_CLAIM_TYPE_MAP.put("upn", CLAIM_TYPE_NS + "/claims/upn"); DEFAULT_CLAIM_TYPE_MAP.put("nameidentifier", CLAIM_TYPE_NS + "/claims/nameidentifier"); DEFAULT_TOKEN_TYPE_MAP = new HashMap<>(); DEFAULT_TOKEN_TYPE_MAP.put("saml", WSConstants.WSS_SAML2_TOKEN_TYPE); DEFAULT_TOKEN_TYPE_MAP.put("saml2.0", WSConstants.WSS_SAML2_TOKEN_TYPE); DEFAULT_TOKEN_TYPE_MAP.put("saml1.1", WSConstants.WSS_SAML_TOKEN_TYPE); DEFAULT_TOKEN_TYPE_MAP.put("jwt", JWTTokenProvider.JWT_TOKEN_TYPE); DEFAULT_TOKEN_TYPE_MAP.put("sct", STSUtils.TOKEN_TYPE_SCT_05_12); DEFAULT_KEY_TYPE_MAP.put("SymmetricKey", STSConstants.SYMMETRIC_KEY_KEYTYPE); DEFAULT_KEY_TYPE_MAP.put("PublicKey", STSConstants.PUBLIC_KEY_KEYTYPE); DEFAULT_KEY_TYPE_MAP.put("Bearer", STSConstants.BEARER_KEY_KEYTYPE); } @Context private MessageContext messageContext; @Context private javax.ws.rs.core.SecurityContext securityContext; private Map<String, String> claimTypeMap = DEFAULT_CLAIM_TYPE_MAP; private Map<String, String> tokenTypeMap = DEFAULT_TOKEN_TYPE_MAP; private String defaultKeyType = STSConstants.BEARER_KEY_KEYTYPE; private List<String> defaultClaims; private boolean requestClaimsOptional = true; private boolean useDeflateEncoding = true; @Override public Response getXMLToken(String tokenType, String keyType, List<String> requestedClaims, String appliesTo, boolean wstrustResponse) { RequestSecurityTokenResponseType response = issueToken(tokenType, keyType, requestedClaims, appliesTo); if (wstrustResponse) { JAXBElement<RequestSecurityTokenResponseType> jaxbResponse = QNameConstants.WS_TRUST_FACTORY.createRequestSecurityTokenResponse(response); return Response.ok(jaxbResponse).build(); } RequestedSecurityTokenType requestedToken = getRequestedSecurityToken(response); return Response.ok(requestedToken.getAny()).build(); } @Override public Response getJSONToken(String tokenType, String keyType, List<String> requestedClaims, String appliesTo) { if (!"jwt".equals(tokenType)) { return Response.status(Response.Status.BAD_REQUEST).build(); } RequestSecurityTokenResponseType response = issueToken(tokenType, keyType, requestedClaims, appliesTo); RequestedSecurityTokenType requestedToken = getRequestedSecurityToken(response); // Discard the XML Wrapper + create a new JSON Wrapper String token = ((Element)requestedToken.getAny()).getTextContent(); return Response.ok(new JSONWrapper(token)).build(); } @Override public Response getPlainToken(String tokenType, String keyType, List<String> requestedClaims, String appliesTo) { RequestSecurityTokenResponseType response = issueToken(tokenType, keyType, requestedClaims, appliesTo); RequestedSecurityTokenType requestedToken = getRequestedSecurityToken(response); if ("jwt".equals(tokenType)) { // Discard the wrapper here return Response.ok(((Element)requestedToken.getAny()).getTextContent()).build(); } else { // Base-64 encode the token + return it try { String encodedToken = encodeToken(DOM2Writer.nodeToString((Element)requestedToken.getAny())); return Response.ok(encodedToken).build(); } catch (Exception ex) { LOG.warning(ex.getMessage()); return Response.status(Response.Status.INTERNAL_SERVER_ERROR).build(); } } } private RequestedSecurityTokenType getRequestedSecurityToken(RequestSecurityTokenResponseType response) { for (Object obj : response.getAny()) { if (obj instanceof JAXBElement<?>) { JAXBElement<?> jaxbElement = (JAXBElement<?>)obj; if ("RequestedSecurityToken".equals(jaxbElement.getName().getLocalPart())) { return (RequestedSecurityTokenType)jaxbElement.getValue(); } } } return null; } private RequestSecurityTokenResponseType issueToken( String tokenType, String keyType, List<String> requestedClaims, String appliesTo ) { String tokenTypeToUse = tokenType; if (tokenTypeMap != null && tokenTypeMap.containsKey(tokenTypeToUse)) { tokenTypeToUse = tokenTypeMap.get(tokenTypeToUse); } String keyTypeToUse = keyType; if (DEFAULT_KEY_TYPE_MAP.containsKey(keyTypeToUse)) { keyTypeToUse = DEFAULT_KEY_TYPE_MAP.get(keyTypeToUse); } ObjectFactory of = new ObjectFactory(); RequestSecurityTokenType request = of.createRequestSecurityTokenType(); request.getAny().add(of.createTokenType(tokenTypeToUse)); request.getAny().add(of.createRequestType("http://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue")); String desiredKeyType = keyTypeToUse != null ? keyTypeToUse : defaultKeyType; request.getAny().add(of.createKeyType(desiredKeyType)); // Add the TLS client Certificate as the UseKey Element if the KeyType is PublicKey if (STSConstants.PUBLIC_KEY_KEYTYPE.equals(desiredKeyType)) { X509Certificate clientCert = getTLSClientCertificate(); if (clientCert != null) { Document doc = DOMUtils.createDocument(); Element keyInfoElement = doc.createElementNS("http://www.w3.org/2000/09/xmldsig#", "KeyInfo"); try { X509Data certElem = new X509Data(doc); certElem.addCertificate(clientCert); keyInfoElement.appendChild(certElem.getElement()); UseKeyType useKeyType = of.createUseKeyType(); useKeyType.setAny(keyInfoElement); JAXBElement<UseKeyType> useKey = of.createUseKey(useKeyType); request.getAny().add(useKey); } catch (XMLSecurityException ex) { LOG.warning(ex.getMessage()); } } } // Claims if (requestedClaims == null || requestedClaims.isEmpty()) { requestedClaims = defaultClaims; } if (requestedClaims != null && !requestedClaims.isEmpty()) { ClaimsType claimsType = of.createClaimsType(); claimsType.setDialect(CLAIM_TYPE_NS); JAXBElement<ClaimsType> claims = of.createClaims(claimsType); for (String claim : requestedClaims) { if (claimTypeMap != null && claimTypeMap.containsKey(claim)) { claim = claimTypeMap.get(claim); } Document doc = DOMUtils.createDocument(); Element claimElement = doc.createElementNS(CLAIM_TYPE_NS, CLAIM_TYPE); claimElement.setAttributeNS(null, "Uri", claim); claimElement.setAttributeNS(null, "Optional", Boolean.toString(requestClaimsOptional)); claimsType.getAny().add(claimElement); } request.getAny().add(claims); } if (appliesTo != null) { String wspNamespace = "http://www.w3.org/ns/ws-policy"; Document doc = DOMUtils.createDocument(); Element appliesToElement = doc.createElementNS(wspNamespace, "AppliesTo"); String addressingNamespace = "http://www.w3.org/2005/08/addressing"; Element eprElement = doc.createElementNS(addressingNamespace, "EndpointReference"); Element addressElement = doc.createElementNS(addressingNamespace, "Address"); addressElement.setTextContent(appliesTo); eprElement.appendChild(addressElement); appliesToElement.appendChild(eprElement); request.getAny().add(appliesToElement); } // OnBehalfOf // User Authentication done with JWT or SAML? //if (securityContext != null && securityContext.getUserPrincipal() != null) { //TODO // if (onBehalfOfToken != null) { // OnBehalfOfType onBehalfOfType = of.createOnBehalfOfType(); // onBehalfOfType.setAny(onBehalfOfToken); // JAXBElement<OnBehalfOfType> onBehalfOfElement = of.createOnBehalfOf(onBehalfOfType); // request.getAny().add(onBehalfOfElement); // } // } // request.setContext(null); return processRequest(Action.issue, request); } @Override public Response getToken(Action action, RequestSecurityTokenType request) { RequestSecurityTokenResponseType response = processRequest(action, request); JAXBElement<RequestSecurityTokenResponseType> jaxbResponse = QNameConstants.WS_TRUST_FACTORY.createRequestSecurityTokenResponse(response); return Response.ok(jaxbResponse).build(); } private RequestSecurityTokenResponseType processRequest(Action action, RequestSecurityTokenType request) { switch (action) { case validate: return validate(request); case renew: return renew(request); case cancel: return cancel(request); case issue: default: return issueSingle(request); } } @Override public Response removeToken(RequestSecurityTokenType request) { RequestSecurityTokenResponseType response = cancel(request); return Response.ok(response).build(); } @Override public Response getKeyExchangeToken(RequestSecurityTokenType request) { RequestSecurityTokenResponseType response = keyExchangeToken(request); return Response.ok(response).build(); } public Map<String, String> getTokenTypeMap() { return tokenTypeMap; } public void setTokenTypeMap(Map<String, String> tokenTypeMap) { this.tokenTypeMap = tokenTypeMap; } public String getDefaultKeyType() { return defaultKeyType; } public void setDefaultKeyType(String defaultKeyType) { this.defaultKeyType = defaultKeyType; } public boolean isRequestClaimsOptional() { return requestClaimsOptional; } public void setRequestClaimsOptional(boolean requestClaimsOptional) { this.requestClaimsOptional = requestClaimsOptional; } public Map<String, String> getClaimTypeMap() { return claimTypeMap; } public void setClaimTypeMap(Map<String, String> claimTypeMap) { this.claimTypeMap = claimTypeMap; } @Override protected Principal getPrincipal() { // Try JAX-RS SecurityContext first if (securityContext != null && securityContext.getUserPrincipal() != null) { return securityContext.getUserPrincipal(); } // Then try the CXF SecurityContext SecurityContext sc = (SecurityContext)messageContext.get(SecurityContext.class); if (sc != null && sc.getUserPrincipal() != null) { return sc.getUserPrincipal(); } // Get the TLS client principal if no security context is set up X509Certificate clientCert = getTLSClientCertificate(); if (clientCert != null) { return clientCert.getSubjectX500Principal(); } return null; } private X509Certificate getTLSClientCertificate() { TLSSessionInfo tlsInfo = (TLSSessionInfo)PhaseInterceptorChain.getCurrentMessage().get(TLSSessionInfo.class); if (tlsInfo != null && tlsInfo.getPeerCertificates() != null && tlsInfo.getPeerCertificates().length > 0 && (tlsInfo.getPeerCertificates()[0] instanceof X509Certificate) ) { return (X509Certificate)tlsInfo.getPeerCertificates()[0]; } return null; } @Override protected Map<String, Object> getMessageContext() { return PhaseInterceptorChain.getCurrentMessage(); } public void setUseDeflateEncoding(boolean deflate) { useDeflateEncoding = deflate; } protected String encodeToken(String assertion) throws Base64Exception { byte[] tokenBytes = assertion.getBytes(StandardCharsets.UTF_8); if (useDeflateEncoding) { tokenBytes = CompressionUtils.deflate(tokenBytes, getDeflateLevel(), true); } StringWriter writer = new StringWriter(); Base64Utility.encode(tokenBytes, 0, tokenBytes.length, writer); return writer.toString(); } private static int getDeflateLevel() { Integer level = null; Message m = PhaseInterceptorChain.getCurrentMessage(); if (m != null) { level = PropertyUtils.getInteger(m, "deflate.level"); } if (level == null) { level = Deflater.DEFLATED; } return level; } private static class JSONWrapper { private String token; JSONWrapper(String token) { this.token = token; } @SuppressWarnings("unused") public String getToken() { return token; } } }