/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.rs.security.oauth2.utils;
import java.lang.reflect.Method;
import java.security.MessageDigest;
import java.security.Principal;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import javax.security.auth.x500.X500Principal;
import javax.servlet.http.HttpSession;
import javax.ws.rs.core.MultivaluedMap;
import org.apache.cxf.common.util.Base64UrlUtility;
import org.apache.cxf.common.util.Base64Utility;
import org.apache.cxf.common.util.StringUtils;
import org.apache.cxf.jaxrs.ext.MessageContext;
import org.apache.cxf.jaxrs.impl.MetadataMap;
import org.apache.cxf.jaxrs.model.URITemplate;
import org.apache.cxf.jaxrs.utils.JAXRSUtils;
import org.apache.cxf.message.Message;
import org.apache.cxf.rs.security.jose.common.JoseConstants;
import org.apache.cxf.rs.security.jose.jwa.AlgorithmUtils;
import org.apache.cxf.rs.security.jose.jwa.ContentAlgorithm;
import org.apache.cxf.rs.security.jose.jwa.SignatureAlgorithm;
import org.apache.cxf.rs.security.jose.jwe.JweDecryptionProvider;
import org.apache.cxf.rs.security.jose.jwe.JweEncryptionProvider;
import org.apache.cxf.rs.security.jose.jwe.JweUtils;
import org.apache.cxf.rs.security.jose.jws.JwsSignatureProvider;
import org.apache.cxf.rs.security.jose.jws.JwsSignatureVerifier;
import org.apache.cxf.rs.security.jose.jws.JwsUtils;
import org.apache.cxf.rs.security.oauth2.common.AuthenticationMethod;
import org.apache.cxf.rs.security.oauth2.common.Client;
import org.apache.cxf.rs.security.oauth2.common.ClientAccessToken;
import org.apache.cxf.rs.security.oauth2.common.OAuthPermission;
import org.apache.cxf.rs.security.oauth2.common.ServerAccessToken;
import org.apache.cxf.rs.security.oauth2.common.UserSubject;
import org.apache.cxf.rs.security.oauth2.provider.OAuthServiceException;
import org.apache.cxf.rt.security.crypto.CryptoUtils;
import org.apache.cxf.rt.security.crypto.MessageDigestUtils;
import org.apache.cxf.security.LoginSecurityContext;
import org.apache.cxf.security.SecurityContext;
import org.apache.cxf.security.transport.TLSSessionInfo;
/**
* Various utility methods
*/
public final class OAuthUtils {
private OAuthUtils() {
}
public static byte[] createCertificateThumbprint(X509Certificate cert) throws Exception {
return MessageDigestUtils.createDigest(cert.getEncoded(), MessageDigestUtils.ALGO_SHA_256);
}
public static void setCertificateThumbprintConfirmation(MessageContext mc, X509Certificate cert) {
try {
byte[] thumbprint = createCertificateThumbprint(cert);
String encodedThumbprint = Base64UrlUtility.encode(thumbprint);
mc.put(JoseConstants.HEADER_X509_THUMBPRINT_SHA256, encodedThumbprint);
} catch (Exception ex) {
throw new OAuthServiceException(ex);
}
}
public static boolean compareCertificateThumbprints(X509Certificate cert, String encodedThumbprint) {
try {
byte[] thumbprint = createCertificateThumbprint(cert);
byte[] currentThumbprint = Base64UrlUtility.decode(encodedThumbprint);
return MessageDigest.isEqual(thumbprint, currentThumbprint);
} catch (Exception ex) {
return false;
}
}
public static boolean compareTlsCertificates(TLSSessionInfo tlsInfo,
List<String> base64EncodedCerts) {
Certificate[] clientCerts = tlsInfo.getPeerCertificates();
if (clientCerts.length == base64EncodedCerts.size()) {
try {
for (int i = 0; i < clientCerts.length; i++) {
X509Certificate x509Cert = (X509Certificate)clientCerts[i];
byte[] encodedKey = x509Cert.getEncoded();
byte[] clientKey = Base64Utility.decode(base64EncodedCerts.get(i));
if (!Arrays.equals(encodedKey, clientKey)) {
return false;
}
}
return true;
} catch (Exception ex) {
// throw exception later
}
}
return false;
}
public static boolean isMutualTls(javax.ws.rs.core.SecurityContext sc, TLSSessionInfo tlsSessionInfo) {
// Pure 2-way TLS authentication
return tlsSessionInfo != null
&& StringUtils.isEmpty(sc.getAuthenticationScheme())
&& getRootTLSCertificate(tlsSessionInfo) != null;
}
public static String getSubjectDnFromTLSCertificates(X509Certificate cert) {
X500Principal x509Principal = cert.getSubjectX500Principal();
return x509Principal.getName();
}
public static String getIssuerDnFromTLSCertificates(X509Certificate cert) {
X500Principal x509Principal = cert.getIssuerX500Principal();
return x509Principal.getName();
}
public static X509Certificate getRootTLSCertificate(TLSSessionInfo tlsInfo) {
Certificate[] clientCerts = tlsInfo.getPeerCertificates();
if (clientCerts != null && clientCerts.length > 0) {
return (X509Certificate)clientCerts[0];
}
return null;
}
public static void injectContextIntoOAuthProvider(MessageContext context, Object provider) {
Method dataProviderContextMethod = null;
try {
dataProviderContextMethod = provider.getClass().getMethod("setMessageContext",
new Class[]{MessageContext.class});
} catch (Throwable t) {
// ignore
}
if (dataProviderContextMethod != null) {
try {
dataProviderContextMethod.invoke(provider, new Object[]{context});
} catch (Throwable t) {
throw new RuntimeException(t);
}
}
}
public static String setSessionToken(MessageContext mc) {
return setSessionToken(mc, 0);
}
public static String setSessionToken(MessageContext mc, int maxInactiveInterval) {
return setSessionToken(mc, generateRandomTokenKey());
}
public static String setSessionToken(MessageContext mc, String sessionToken) {
return setSessionToken(mc, sessionToken, 0);
}
public static String setSessionToken(MessageContext mc, String sessionToken, int maxInactiveInterval) {
return setSessionToken(mc, sessionToken, null, 0);
}
public static String setSessionToken(MessageContext mc, String sessionToken,
String attribute, int maxInactiveInterval) {
HttpSession session = mc.getHttpServletRequest().getSession();
if (maxInactiveInterval > 0) {
session.setMaxInactiveInterval(maxInactiveInterval);
}
String theAttribute = attribute == null ? OAuthConstants.SESSION_AUTHENTICITY_TOKEN : attribute;
session.setAttribute(theAttribute, sessionToken);
return sessionToken;
}
public static String getSessionToken(MessageContext mc) {
return getSessionToken(mc, null);
}
public static String getSessionToken(MessageContext mc, String attribute) {
return getSessionToken(mc, attribute, true);
}
public static String getSessionToken(MessageContext mc, String attribute, boolean remove) {
HttpSession session = mc.getHttpServletRequest().getSession();
String theAttribute = attribute == null ? OAuthConstants.SESSION_AUTHENTICITY_TOKEN : attribute;
String sessionToken = (String)session.getAttribute(theAttribute);
if (sessionToken != null && remove) {
session.removeAttribute(theAttribute);
}
return sessionToken;
}
public static UserSubject createSubject(MessageContext mc, SecurityContext sc) {
UserSubject subject = mc.getContent(UserSubject.class);
if (subject != null) {
return subject;
} else {
return OAuthUtils.createSubject(sc);
}
}
public static UserSubject createSubject(SecurityContext securityContext) {
List<String> roleNames = Collections.emptyList();
if (securityContext instanceof LoginSecurityContext) {
roleNames = new ArrayList<>();
Set<Principal> roles = ((LoginSecurityContext)securityContext).getUserRoles();
for (Principal p : roles) {
roleNames.add(p.getName());
}
}
UserSubject subject = new UserSubject(securityContext.getUserPrincipal().getName(), roleNames);
Message m = JAXRSUtils.getCurrentMessage();
if (m != null && m.get(AuthenticationMethod.class) != null) {
subject.setAuthenticationMethod(m.get(AuthenticationMethod.class));
}
return subject;
}
public static String convertPermissionsToScope(List<OAuthPermission> perms) {
StringBuilder sb = new StringBuilder();
for (OAuthPermission perm : perms) {
if (perm.isInvisibleToClient() || perm.getPermission() == null) {
continue;
}
if (sb.length() > 0) {
sb.append(" ");
}
sb.append(perm.getPermission());
}
return sb.toString();
}
public static List<String> convertPermissionsToScopeList(List<OAuthPermission> perms) {
List<String> list = new LinkedList<String>();
for (OAuthPermission perm : perms) {
list.add(perm.getPermission());
}
return list;
}
public static boolean isGrantSupportedForClient(Client client,
boolean canSupportPublicClients,
String grantType) {
if (grantType == null || !client.isConfidential() && !canSupportPublicClients) {
return false;
}
List<String> allowedGrants = client.getAllowedGrantTypes();
return allowedGrants.isEmpty() || allowedGrants.contains(grantType);
}
public static List<String> parseScope(String requestedScope) {
List<String> list = new LinkedList<String>();
if (requestedScope != null) {
String[] scopeValues = requestedScope.split(" ");
for (String scope : scopeValues) {
if (!StringUtils.isEmpty(scope)) {
list.add(scope);
}
}
}
return list;
}
public static String generateRandomTokenKey() throws OAuthServiceException {
return generateRandomTokenKey(16);
}
public static String generateRandomTokenKey(int byteSize) {
if (byteSize < 16) {
throw new OAuthServiceException();
}
return StringUtils.toHexString(CryptoUtils.generateSecureRandomBytes(byteSize));
}
public static long getIssuedAt() {
return System.currentTimeMillis() / 1000L;
}
public static boolean isExpired(Long issuedAt, Long lifetime) {
// At some point -1 was used to indicate an unlimited lifetime
// with 0 being introduced instead at a later stage.
// In theory there still could be a code around initializing the tokens with -1.
// Treating -1 and 0 the same way is reasonable and it also makes it easier to
// deal with the token introspection responses with no issuedAt time reported
return lifetime == null
|| lifetime < -1
|| lifetime > 0L && issuedAt + lifetime < System.currentTimeMillis() / 1000L;
}
public static boolean validateAudience(String providedAudience,
List<String> allowedAudiences) {
return providedAudience == null
|| validateAudiences(Collections.singletonList(providedAudience), allowedAudiences);
}
public static boolean validateAudiences(List<String> providedAudiences,
List<String> allowedAudiences) {
return StringUtils.isEmpty(providedAudiences)
&& StringUtils.isEmpty(allowedAudiences)
|| allowedAudiences.containsAll(providedAudiences);
}
public static boolean checkRequestURI(String servletPath, String uri) {
boolean wildcard = uri.endsWith("*");
String theURI = wildcard ? uri.substring(0, uri.length() - 1) : uri;
try {
URITemplate template = new URITemplate(theURI);
MultivaluedMap<String, String> map = new MetadataMap<String, String>();
if (template.match(servletPath, map)) {
String finalGroup = map.getFirst(URITemplate.FINAL_MATCH_GROUP);
if (wildcard || StringUtils.isEmpty(finalGroup) || "/".equals(finalGroup)) {
return true;
}
}
} catch (Exception ex) {
// ignore
}
return false;
}
public static List<String> getRequestedScopes(Client client,
String scopeParameter,
boolean useAllClientScopes,
boolean partialMatchScopeValidation) {
List<String> requestScopes = parseScope(scopeParameter);
List<String> registeredScopes = client.getRegisteredScopes();
if (requestScopes.isEmpty()) {
requestScopes.addAll(registeredScopes);
return requestScopes;
}
if (!validateScopes(requestScopes, registeredScopes, partialMatchScopeValidation)) {
throw new OAuthServiceException("Unexpected scope");
}
if (useAllClientScopes) {
for (String registeredScope : registeredScopes) {
if (!requestScopes.contains(registeredScope)) {
requestScopes.add(registeredScope);
}
}
}
return requestScopes;
}
public static boolean validateScopes(List<String> requestScopes, List<String> registeredScopes,
boolean partialMatchScopeValidation) {
if (!registeredScopes.isEmpty()) {
// if it is a strict validation then pre-registered scopes have to contains all
// the current request scopes
if (!partialMatchScopeValidation) {
return registeredScopes.containsAll(requestScopes);
} else {
for (String requestScope : requestScopes) {
boolean match = false;
for (String registeredScope : registeredScopes) {
if (requestScope.startsWith(registeredScope)) {
match = true;
break;
}
}
if (!match) {
return false;
}
}
}
}
return true;
}
public static ClientAccessToken toClientAccessToken(ServerAccessToken serverToken, boolean supportOptionalParams) {
ClientAccessToken clientToken = new ClientAccessToken(serverToken.getTokenType(),
serverToken.getTokenKey());
clientToken.setRefreshToken(serverToken.getRefreshToken());
if (supportOptionalParams) {
clientToken.setExpiresIn(serverToken.getExpiresIn());
List<OAuthPermission> perms = serverToken.getScopes();
String scopeString = OAuthUtils.convertPermissionsToScope(perms);
if (!StringUtils.isEmpty(scopeString)) {
clientToken.setApprovedScope(scopeString);
}
clientToken.setParameters(new HashMap<String, String>(serverToken.getParameters()));
}
return clientToken;
}
public static JwsSignatureProvider getClientSecretSignatureProvider(String clientSecret) {
Properties sigProps = JwsUtils.loadSignatureOutProperties(false);
return JwsUtils.getHmacSignatureProvider(clientSecret,
getClientSecretSignatureAlgorithm(sigProps));
}
public static JwsSignatureVerifier getClientSecretSignatureVerifier(String clientSecret) {
Properties sigProps = JwsUtils.loadSignatureOutProperties(false);
return JwsUtils.getHmacSignatureVerifier(clientSecret,
getClientSecretSignatureAlgorithm(sigProps));
}
public static JweDecryptionProvider getClientSecretDecryptionProvider(String clientSecret) {
Properties props = JweUtils.loadEncryptionInProperties(false);
byte[] key = StringUtils.toBytesUTF8(clientSecret);
return JweUtils.getDirectKeyJweDecryption(key, getClientSecretContentAlgorithm(props));
}
public static JweEncryptionProvider getClientSecretEncryptionProvider(String clientSecret) {
Properties props = JweUtils.loadEncryptionInProperties(false);
byte[] key = StringUtils.toBytesUTF8(clientSecret);
return JweUtils.getDirectKeyJweEncryption(key, getClientSecretContentAlgorithm(props));
}
private static ContentAlgorithm getClientSecretContentAlgorithm(Properties props) {
String ctAlgoProp = props.getProperty(OAuthConstants.CLIENT_SECRET_CONTENT_ENCRYPTION_ALGORITHM);
if (ctAlgoProp == null) {
ctAlgoProp = props.getProperty(JoseConstants.RSSEC_ENCRYPTION_CONTENT_ALGORITHM);
}
ContentAlgorithm ctAlgo = ContentAlgorithm.getAlgorithm(ctAlgoProp);
ctAlgo = ctAlgo != null ? ctAlgo : ContentAlgorithm.A128GCM;
return ctAlgo;
}
public static SignatureAlgorithm getClientSecretSignatureAlgorithm(Properties sigProps) {
String clientSecretSigProp = sigProps.getProperty(OAuthConstants.CLIENT_SECRET_SIGNATURE_ALGORITHM);
if (clientSecretSigProp == null) {
String sigProp = sigProps.getProperty(JoseConstants.RSSEC_SIGNATURE_ALGORITHM);
if (AlgorithmUtils.isHmacSign(sigProp)) {
clientSecretSigProp = sigProp;
}
}
SignatureAlgorithm sigAlgo = SignatureAlgorithm.getAlgorithm(clientSecretSigProp);
sigAlgo = sigAlgo != null ? sigAlgo : SignatureAlgorithm.HS256;
if (!AlgorithmUtils.isHmacSign(sigAlgo)) {
// Must be HS-based for the symmetric signature
throw new OAuthServiceException(OAuthConstants.SERVER_ERROR);
} else {
return sigAlgo;
}
}
public static String convertListOfScopesToString(List<String> registeredScopes) {
StringBuilder sb = new StringBuilder();
for (String s : registeredScopes) {
if (sb.length() > 0) {
sb.append(", ");
}
sb.append(s);
}
return sb.toString();
}
}