/* * Copyright (c) 2012, WSO2 Inc. (http://www.wso2.org) All Rights Reserved. * * WSO2 Inc. licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file except * in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.wso2.carbon.identity.oauth2.authcontext; import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSSigner; import com.nimbusds.jose.crypto.RSASSASigner; import com.nimbusds.jose.util.Base64URL; import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.PlainJWT; import com.nimbusds.jwt.SignedJWT; import org.apache.commons.codec.binary.Base64; import org.apache.commons.io.Charsets; import org.apache.commons.lang.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.wso2.carbon.base.MultitenantConstants; import org.wso2.carbon.core.util.KeyStoreManager; import org.wso2.carbon.identity.oauth.dao.OAuthAppDO; import org.wso2.carbon.identity.core.util.IdentityCoreConstants; import org.wso2.carbon.identity.core.util.IdentityUtil; import org.wso2.carbon.identity.oauth.common.exception.InvalidOAuthClientException; import org.wso2.carbon.identity.oauth.config.OAuthServerConfiguration; import org.wso2.carbon.identity.oauth.dao.OAuthAppDAO; import org.wso2.carbon.identity.oauth.internal.OAuthComponentServiceHolder; import org.wso2.carbon.identity.oauth.util.ClaimCache; import org.wso2.carbon.identity.oauth.util.ClaimCacheKey; import org.wso2.carbon.identity.oauth.util.UserClaims; import org.wso2.carbon.identity.oauth2.IdentityOAuth2Exception; import org.wso2.carbon.identity.oauth2.dto.OAuth2TokenValidationResponseDTO; import org.wso2.carbon.identity.oauth2.model.AccessTokenDO; import org.wso2.carbon.identity.oauth2.util.OAuth2Util; import org.wso2.carbon.identity.oauth2.validators.OAuth2TokenValidationMessageContext; import org.wso2.carbon.user.api.RealmConfiguration; import org.wso2.carbon.user.api.UserRealm; import org.wso2.carbon.user.api.UserStoreException; import org.wso2.carbon.user.core.UserStoreManager; import org.wso2.carbon.user.core.service.RealmService; import org.wso2.carbon.utils.multitenancy.MultitenantUtils; import java.security.Key; import java.security.KeyStore; import java.security.MessageDigest; import java.security.cert.Certificate; import java.security.interfaces.RSAPrivateKey; import java.util.ArrayList; import java.util.Calendar; import java.util.Date; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.SortedMap; import java.util.StringTokenizer; import java.util.TreeSet; import java.util.concurrent.ConcurrentHashMap; /** * This class represents the JSON Web Token generator. * By default the following properties are encoded to each authenticated API request: * subscriber, applicationName, apiContext, version, tier, and endUserName * Additional properties can be encoded by engaging the ClaimsRetrieverImplClass callback-handler. * The JWT header and body are base64 encoded separately and concatenated with a dot. * Finally the token is signed using SHA256 with RSA algorithm. */ public class JWTTokenGenerator implements AuthorizationContextTokenGenerator { private static final Log log = LogFactory.getLog(JWTTokenGenerator.class); private static final String API_GATEWAY_ID = "http://wso2.org/gateway"; private static final String NONE = "NONE"; private static final Base64 base64Url = new Base64(0, null, true); private static volatile long ttl = -1L; private ClaimsRetriever claimsRetriever; private JWSAlgorithm signatureAlgorithm = new JWSAlgorithm(JWSAlgorithm.RS256.getName()); private boolean includeClaims = true; private boolean enableSigning = true; private static Map<Integer, Key> privateKeys = new ConcurrentHashMap<Integer, Key>(); private static Map<Integer, Certificate> publicCerts = new ConcurrentHashMap<Integer, Certificate>(); private ClaimCache claimsLocalCache; public JWTTokenGenerator() { claimsLocalCache = ClaimCache.getInstance(); } private String userAttributeSeparator = IdentityCoreConstants.MULTI_ATTRIBUTE_SEPARATOR_DEFAULT; //constructor for testing purposes public JWTTokenGenerator(boolean includeClaims, boolean enableSigning) { this.includeClaims = includeClaims; this.enableSigning = enableSigning; signatureAlgorithm = new JWSAlgorithm(JWSAlgorithm.NONE.getName()); } /** * Reads the ClaimsRetrieverImplClass from identity.xml -> * OAuth -> TokenGeneration -> ClaimsRetrieverImplClass. * * @throws IdentityOAuth2Exception */ @Override public void init() throws IdentityOAuth2Exception { if (includeClaims && enableSigning) { String claimsRetrieverImplClass = OAuthServerConfiguration.getInstance().getClaimsRetrieverImplClass(); String sigAlg = OAuthServerConfiguration.getInstance().getSignatureAlgorithm(); if(sigAlg != null && !sigAlg.trim().isEmpty()){ signatureAlgorithm = mapSignatureAlgorithm(sigAlg); } if(claimsRetrieverImplClass != null){ try{ claimsRetriever = (ClaimsRetriever)Class.forName(claimsRetrieverImplClass).newInstance(); claimsRetriever.init(); } catch (ClassNotFoundException e){ log.error("Cannot find class: " + claimsRetrieverImplClass, e); } catch (InstantiationException e) { log.error("Error instantiating " + claimsRetrieverImplClass, e); } catch (IllegalAccessException e) { log.error("Illegal access to " + claimsRetrieverImplClass, e); } catch (IdentityOAuth2Exception e){ log.error("Error while initializing " + claimsRetrieverImplClass, e); } } } } /** * Method that generates the JWT. * * @return signed JWT token * @throws IdentityOAuth2Exception */ @Override public void generateToken(OAuth2TokenValidationMessageContext messageContext) throws IdentityOAuth2Exception { String clientId = ((AccessTokenDO)messageContext.getProperty("AccessTokenDO")).getConsumerKey(); long issuedTime = ((AccessTokenDO)messageContext.getProperty("AccessTokenDO")).getIssuedTime().getTime(); String authzUser = messageContext.getResponseDTO().getAuthorizedUser(); int tenantID = ((AccessTokenDO)messageContext.getProperty("AccessTokenDO")).getTenantID(); String tenantDomain = OAuth2Util.getTenantDomain(tenantID); boolean isExistingUser = false; RealmService realmService = OAuthComponentServiceHolder.getRealmService(); // TODO : Need to handle situation where federated user name is similar to a one we have in our user store if (realmService != null && tenantID != MultitenantConstants.INVALID_TENANT_ID ) { try { UserRealm userRealm = realmService.getTenantUserRealm(tenantID); if (userRealm != null) { UserStoreManager userStoreManager = (UserStoreManager) userRealm.getUserStoreManager(); isExistingUser = userStoreManager.isExistingUser(MultitenantUtils.getTenantAwareUsername (authzUser)); } } catch (UserStoreException e) { log.error("Error occurred while loading the realm service", e); } } OAuthAppDAO appDAO = new OAuthAppDAO(); OAuthAppDO appDO; try { appDO = appDAO.getAppInformation(clientId); // Adding the OAuthAppDO as a context property for further use messageContext.addProperty("OAuthAppDO", appDO); } catch (IdentityOAuth2Exception e) { log.debug(e.getMessage(), e); throw new IdentityOAuth2Exception(e.getMessage()); } catch (InvalidOAuthClientException e) { log.debug(e.getMessage(), e); throw new IdentityOAuth2Exception(e.getMessage()); } String subscriber = appDO.getUser().toString(); String applicationName = appDO.getApplicationName(); //generating expiring timestamp long currentTime = Calendar.getInstance().getTimeInMillis(); long expireIn = currentTime + 1000 * 60 * getTTL(); // Prepare JWT with claims set JWTClaimsSet claimsSet = new JWTClaimsSet(); claimsSet.setIssuer(API_GATEWAY_ID); claimsSet.setSubject(authzUser); claimsSet.setIssueTime(new Date(issuedTime)); claimsSet.setExpirationTime(new Date(expireIn)); claimsSet.setClaim(API_GATEWAY_ID+"/subscriber",subscriber); claimsSet.setClaim(API_GATEWAY_ID+"/applicationname",applicationName); claimsSet.setClaim(API_GATEWAY_ID+"/enduser",authzUser); if(claimsRetriever != null){ //check in local cache String[] requestedClaims = messageContext.getRequestDTO().getRequiredClaimURIs(); if(requestedClaims == null && isExistingUser) { // if no claims were requested, return all requestedClaims = claimsRetriever.getDefaultClaims(authzUser); } ClaimCacheKey cacheKey = null; UserClaims result = null; if(requestedClaims != null) { cacheKey = new ClaimCacheKey(authzUser, requestedClaims); result = claimsLocalCache.getValueFromCache(cacheKey); } SortedMap<String,String> claimValues = null; if (result != null) { claimValues = result.getClaimValues(); } else if (isExistingUser) { claimValues = claimsRetriever.getClaims(authzUser, requestedClaims); UserClaims userClaims = new UserClaims(claimValues); claimsLocalCache.addToCache(cacheKey, userClaims); } if(isExistingUser) { String claimSeparator = getMultiAttributeSeparator(authzUser, tenantID); if (StringUtils.isBlank(claimSeparator)) { userAttributeSeparator = claimSeparator; } } if(claimValues != null) { Iterator<String> it = new TreeSet(claimValues.keySet()).iterator(); while (it.hasNext()) { String claimURI = it.next(); String claimVal = claimValues.get(claimURI); List<String> claimList = new ArrayList<String>(); if (userAttributeSeparator != null && claimVal.contains(userAttributeSeparator)) { StringTokenizer st = new StringTokenizer(claimVal, userAttributeSeparator); while (st.hasMoreElements()) { String attValue = st.nextElement().toString(); if (StringUtils.isNotBlank(attValue)) { claimList.add(attValue); } } claimsSet.setClaim(claimURI, claimList.toArray(new String[claimList.size()])); } else { claimsSet.setClaim(claimURI, claimVal); } } } } JWT jwt = null; if(!JWSAlgorithm.NONE.equals(signatureAlgorithm)){ JWSHeader header = new JWSHeader(JWSAlgorithm.RS256); header.setX509CertThumbprint(new Base64URL(getThumbPrint(tenantDomain, tenantID))); jwt = new SignedJWT(header, claimsSet); jwt = signJWT((SignedJWT)jwt, tenantDomain, tenantID); } else { jwt = new PlainJWT(claimsSet); } if (log.isDebugEnabled()) { log.debug("JWT Assertion Value : " + jwt.serialize()); } OAuth2TokenValidationResponseDTO.AuthorizationContextToken token; token = messageContext.getResponseDTO().new AuthorizationContextToken("JWT", jwt.serialize()); messageContext.getResponseDTO().setAuthorizationContextToken(token); } /** * Sign with given RSA Algorithm * * @param signedJWT * @param jwsAlgorithm * @param tenantDomain * @param tenantId * @return * @throws IdentityOAuth2Exception */ protected SignedJWT signJWTWithRSA(SignedJWT signedJWT, JWSAlgorithm jwsAlgorithm, String tenantDomain, int tenantId) throws IdentityOAuth2Exception { try { Key privateKey = getPrivateKey(tenantDomain, tenantId); JWSSigner signer = new RSASSASigner((RSAPrivateKey) privateKey); signedJWT.sign(signer); return signedJWT; } catch (JOSEException e) { log.error("Error in obtaining tenant's keystore", e); throw new IdentityOAuth2Exception("Error in obtaining tenant's keystore", e); } catch (Exception e) { log.error("Error in obtaining tenant's keystore", e); throw new IdentityOAuth2Exception("Error in obtaining tenant's keystore", e); } } /** * Generic Signing function * * @param signedJWT * @param tenantDomain * @param tenantId * @return * @throws IdentityOAuth2Exception */ protected JWT signJWT(SignedJWT signedJWT, String tenantDomain, int tenantId) throws IdentityOAuth2Exception { if (JWSAlgorithm.RS256.equals(signatureAlgorithm) || JWSAlgorithm.RS384.equals(signatureAlgorithm) || JWSAlgorithm.RS512.equals(signatureAlgorithm)) { return signJWTWithRSA(signedJWT, signatureAlgorithm, tenantDomain, tenantId); } else if (JWSAlgorithm.HS256.equals(signatureAlgorithm) || JWSAlgorithm.HS384.equals(signatureAlgorithm) || JWSAlgorithm.HS512.equals(signatureAlgorithm)) { // return signWithHMAC(payLoad,jwsAlgorithm,tenantDomain,tenantId); implementation // need to be done } else if (JWSAlgorithm.ES256.equals(signatureAlgorithm) || JWSAlgorithm.ES384.equals(signatureAlgorithm) || JWSAlgorithm.ES512.equals(signatureAlgorithm)) { // return signWithEC(payLoad,jwsAlgorithm,tenantDomain,tenantId); implementation // need to be done } log.error("UnSupported Signature Algorithm"); throw new IdentityOAuth2Exception("UnSupported Signature Algorithm"); } /** * This method map signature algorithm define in identity.xml to nimbus * signature algorithm * format, Strings are defined inline hence there are not being used any * where * * @param signatureAlgorithm * @return * @throws IdentityOAuth2Exception */ protected JWSAlgorithm mapSignatureAlgorithm(String signatureAlgorithm) throws IdentityOAuth2Exception { if ("SHA256withRSA".equals(signatureAlgorithm)) { return JWSAlgorithm.RS256; } else if ("SHA384withRSA".equals(signatureAlgorithm)) { return JWSAlgorithm.RS384; } else if ("SHA512withRSA".equals(signatureAlgorithm)) { return JWSAlgorithm.RS512; } else if ("SHA256withHMAC".equals(signatureAlgorithm)) { return JWSAlgorithm.HS256; } else if ("SHA384withHMAC".equals(signatureAlgorithm)) { return JWSAlgorithm.HS384; } else if ("SHA512withHMAC".equals(signatureAlgorithm)) { return JWSAlgorithm.HS512; } else if ("SHA256withEC".equals(signatureAlgorithm)) { return JWSAlgorithm.ES256; } else if ("SHA384withEC".equals(signatureAlgorithm)) { return JWSAlgorithm.ES384; } else if ("SHA512withEC".equals(signatureAlgorithm)) { return JWSAlgorithm.ES512; } else if(NONE.equals(signatureAlgorithm)){ return new JWSAlgorithm(JWSAlgorithm.NONE.getName()); } log.error("Unsupported Signature Algorithm in identity.xml"); throw new IdentityOAuth2Exception("Unsupported Signature Algorithm in identity.xml"); } private long getTTL() { if (ttl != -1) { return ttl; } synchronized (JWTTokenGenerator.class) { if (ttl != -1) { return ttl; } String ttlValue = OAuthServerConfiguration.getInstance().getAuthorizationContextTTL(); if (ttlValue != null) { ttl = Long.parseLong(ttlValue); } else { ttl = 15L; } return ttl; } } /** * Helper method to add public certificate to JWT_HEADER to signature verification. * * @param tenantDomain * @param tenantId * @throws IdentityOAuth2Exception */ private String getThumbPrint(String tenantDomain, int tenantId) throws IdentityOAuth2Exception { try { Certificate certificate = getCertificate(tenantDomain, tenantId); // TODO: maintain a hashmap with tenants' pubkey thumbprints after first initialization //generate the SHA-1 thumbprint of the certificate MessageDigest digestValue = MessageDigest.getInstance("SHA-1"); byte[] der = certificate.getEncoded(); digestValue.update(der); byte[] digestInBytes = digestValue.digest(); String publicCertThumbprint = hexify(digestInBytes); String base64EncodedThumbPrint = new String(base64Url.encode(publicCertThumbprint.getBytes(Charsets.UTF_8)), Charsets.UTF_8); return base64EncodedThumbPrint; } catch (Exception e) { String error = "Error in obtaining certificate for tenant " + tenantDomain; throw new IdentityOAuth2Exception(error, e); } } private Key getPrivateKey(String tenantDomain, int tenantId) throws IdentityOAuth2Exception { if (tenantDomain == null) { tenantDomain = MultitenantConstants.SUPER_TENANT_DOMAIN_NAME; } if (tenantId == 0) { tenantId = OAuth2Util.getTenantId(tenantDomain); } Key privateKey = null; if (!(privateKeys.containsKey(tenantId))) { // get tenant's key store manager KeyStoreManager tenantKSM = KeyStoreManager.getInstance(tenantId); if (!tenantDomain.equals(MultitenantConstants.SUPER_TENANT_DOMAIN_NAME)) { // derive key store name String ksName = tenantDomain.trim().replace(".", "-"); String jksName = ksName + ".jks"; // obtain private key privateKey = tenantKSM.getPrivateKey(jksName, tenantDomain); } else { try { privateKey = tenantKSM.getDefaultPrivateKey(); } catch (Exception e) { log.error("Error while obtaining private key for super tenant", e); } } if (privateKey != null) { privateKeys.put(tenantId, privateKey); } } else { privateKey = privateKeys.get(tenantId); } return privateKey; } private Certificate getCertificate(String tenantDomain, int tenantId) throws Exception { if (tenantDomain == null) { tenantDomain = MultitenantConstants.SUPER_TENANT_DOMAIN_NAME; } if (tenantId == 0) { tenantId = OAuth2Util.getTenantId(tenantDomain); } Certificate publicCert = null; if (!(publicCerts.containsKey(tenantId))) { // get tenant's key store manager KeyStoreManager tenantKSM = KeyStoreManager.getInstance(tenantId); KeyStore keyStore = null; if (!tenantDomain.equals(MultitenantConstants.SUPER_TENANT_DOMAIN_NAME)) { // derive key store name String ksName = tenantDomain.trim().replace(".", "-"); String jksName = ksName + ".jks"; keyStore = tenantKSM.getKeyStore(jksName); publicCert = keyStore.getCertificate(tenantDomain); } else { publicCert = tenantKSM.getDefaultPrimaryCertificate(); } if (publicCert != null) { publicCerts.put(tenantId, publicCert); } } else { publicCert = publicCerts.get(tenantId); } return publicCert; } /** * Helper method to hexify a byte array. * TODO:need to verify the logic * * @param bytes * @return hexadecimal representation */ private String hexify(byte bytes[]) { char[] hexDigits = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; StringBuilder buf = new StringBuilder(bytes.length * 2); for (int i = 0; i < bytes.length; ++i) { buf.append(hexDigits[(bytes[i] & 0xf0) >> 4]); buf.append(hexDigits[bytes[i] & 0x0f]); } return buf.toString(); } private String getMultiAttributeSeparator(String authenticatedUser, int tenantId) { String claimSeparator = null; String userDomain = IdentityUtil.extractDomainFromName(authenticatedUser); try { RealmConfiguration realmConfiguration = null; RealmService realmService = OAuthComponentServiceHolder.getRealmService(); if (realmService != null && tenantId != MultitenantConstants.INVALID_TENANT_ID) { UserStoreManager userStoreManager = (UserStoreManager) realmService.getTenantUserRealm(tenantId) .getUserStoreManager(); realmConfiguration = userStoreManager.getSecondaryUserStoreManager(userDomain).getRealmConfiguration(); } if (realmConfiguration != null) { claimSeparator = realmConfiguration.getUserStoreProperty(IdentityCoreConstants.MULTI_ATTRIBUTE_SEPARATOR); if (claimSeparator != null && !claimSeparator.trim().isEmpty()) { return claimSeparator; } } } catch (UserStoreException e) { log.error("Error occurred while getting the realm configuration, User store properties might not be " + "returned", e); } return null; } }