package org.apereo.cas.oidc.token;
import org.apache.commons.codec.digest.MessageDigestAlgorithms;
import org.apereo.cas.authentication.Authentication;
import org.apereo.cas.authentication.AuthenticationHandler;
import org.apereo.cas.authentication.principal.Principal;
import org.apereo.cas.configuration.CasConfigurationProperties;
import org.apereo.cas.oidc.OidcConstants;
import org.apereo.cas.services.OidcRegisteredService;
import org.apereo.cas.support.oauth.OAuth20Constants;
import org.apereo.cas.support.oauth.OAuth20ResponseTypes;
import org.apereo.cas.support.oauth.services.OAuthRegisteredService;
import org.apereo.cas.ticket.accesstoken.AccessToken;
import org.apereo.cas.util.CollectionUtils;
import org.apereo.cas.util.DigestUtils;
import org.apereo.cas.util.EncodingUtils;
import org.apereo.cas.web.support.WebUtils;
import org.jose4j.jws.AlgorithmIdentifiers;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.NumericDate;
import org.pac4j.core.context.J2EContext;
import org.pac4j.core.profile.ProfileManager;
import org.pac4j.core.profile.UserProfile;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Arrays;
import java.util.Collection;
import java.util.Optional;
import java.util.UUID;
/**
* This is {@link OidcIdTokenGeneratorService}.
*
* @author Misagh Moayyed
* @since 5.0.0
*/
public class OidcIdTokenGeneratorService {
private static final Logger LOGGER = LoggerFactory.getLogger(OidcIdTokenGeneratorService.class);
@Autowired
private CasConfigurationProperties casProperties;
private final String issuer;
private final int skew;
private final OidcIdTokenSigningAndEncryptionService signingService;
public OidcIdTokenGeneratorService(final String issuer,
final int skew,
final OidcIdTokenSigningAndEncryptionService signingService) {
this.signingService = signingService;
this.issuer = issuer;
this.skew = skew;
}
/**
* Generate string.
*
* @param request the request
* @param response the response
* @param accessTokenId the access token id
* @param timeout the timeout
* @param responseType the response type
* @param registeredService the registered service
* @return the string
* @throws Exception the exception
*/
public String generate(final HttpServletRequest request,
final HttpServletResponse response,
final AccessToken accessTokenId,
final long timeout,
final OAuth20ResponseTypes responseType,
final OAuthRegisteredService registeredService) throws Exception {
final OidcRegisteredService oidcRegisteredService = (OidcRegisteredService) registeredService;
final J2EContext context = WebUtils.getPac4jJ2EContext(request, response);
final ProfileManager manager = WebUtils.getPac4jProfileManager(request, response);
final Optional<UserProfile> profile = manager.get(true);
LOGGER.debug("Attempting to produce claims for the id token [{}]", accessTokenId);
final JwtClaims claims = produceIdTokenClaims(request, accessTokenId, timeout,
oidcRegisteredService, profile.get(), context, responseType);
LOGGER.debug("Produce claims for the id token [{}] as [{}]", accessTokenId, claims);
return this.signingService.encode(oidcRegisteredService, claims);
}
/**
* Produce id token claims jwt claims.
*
* @param request the request
* @param accessTokenId the access token id
* @param timeout the timeout
* @param service the service
* @param profile the user profile
* @param context the context
* @param responseType the response type
* @return the jwt claims
*/
protected JwtClaims produceIdTokenClaims(final HttpServletRequest request,
final AccessToken accessTokenId,
final long timeout,
final OidcRegisteredService service,
final UserProfile profile,
final J2EContext context,
final OAuth20ResponseTypes responseType) {
final Authentication authentication = accessTokenId.getAuthentication();
final Principal principal = authentication.getPrincipal();
final JwtClaims claims = new JwtClaims();
claims.setJwtId(UUID.randomUUID().toString());
claims.setIssuer(this.issuer);
claims.setAudience(service.getClientId());
final NumericDate expirationDate = NumericDate.now();
expirationDate.addSeconds(timeout);
claims.setExpirationTime(expirationDate);
claims.setIssuedAtToNow();
claims.setNotBeforeMinutesInThePast(this.skew);
claims.setSubject(principal.getId());
if (authentication.getAttributes().containsKey(casProperties.getAuthn().getMfa().getAuthenticationContextAttribute())) {
final Collection<Object> val = CollectionUtils.toCollection(
authentication.getAttributes().get(casProperties.getAuthn().getMfa().getAuthenticationContextAttribute()));
claims.setStringClaim(OidcConstants.ACR, val.iterator().next().toString());
}
if (authentication.getAttributes().containsKey(AuthenticationHandler.SUCCESSFUL_AUTHENTICATION_HANDLERS)) {
final Collection<Object> val = CollectionUtils.toCollection(
authentication.getAttributes().get(AuthenticationHandler.SUCCESSFUL_AUTHENTICATION_HANDLERS));
claims.setStringListClaim(OidcConstants.AMR, val.toArray(new String[]{}));
}
claims.setClaim(OAuth20Constants.STATE, authentication.getAttributes().get(OAuth20Constants.STATE));
claims.setClaim(OAuth20Constants.NONCE, authentication.getAttributes().get(OAuth20Constants.NONCE));
claims.setClaim(OidcConstants.CLAIM_AT_HASH, generateAccessTokenHash(accessTokenId, service));
principal.getAttributes().entrySet().stream()
.filter(entry -> casProperties.getAuthn().getOidc().getClaims().contains(entry.getKey()))
.forEach(entry -> claims.setClaim(entry.getKey(), entry.getValue()));
if (!claims.hasClaim(OidcConstants.CLAIM_PREFERRED_USERNAME)) {
claims.setClaim(OidcConstants.CLAIM_PREFERRED_USERNAME, profile.getId());
}
return claims;
}
private String generateAccessTokenHash(final AccessToken accessTokenId,
final OidcRegisteredService service) {
final byte[] tokenBytes = accessTokenId.getId().getBytes();
final String hashAlg;
switch (signingService.getJsonWebKeySigningAlgorithm()) {
case AlgorithmIdentifiers.RSA_USING_SHA512:
hashAlg = MessageDigestAlgorithms.SHA_512;
break;
case AlgorithmIdentifiers.RSA_USING_SHA256:
default:
hashAlg = MessageDigestAlgorithms.SHA_256;
}
LOGGER.debug("Digesting access token hash via algorithm [{}]", hashAlg);
final byte[] digested = DigestUtils.rawDigest(hashAlg, tokenBytes);
final byte[] hashBytesLeftHalf = Arrays.copyOf(digested, digested.length / 2);
return EncodingUtils.encodeBase64(hashBytesLeftHalf);
}
}