package org.apereo.cas.oidc.jwks;
import com.google.common.cache.CacheLoader;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.jose4j.jwk.JsonWebKeySet;
import org.jose4j.jwk.RsaJsonWebKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.Resource;
import java.nio.charset.StandardCharsets;
import java.util.Optional;
/**
* This is {@link OidcDefaultJsonWebKeystoreCacheLoader}.
* Only attempts to cache the default CAS keystore.
*
* @author Misagh Moayyed
* @since 5.1.0
*/
public class OidcDefaultJsonWebKeystoreCacheLoader extends CacheLoader<String, Optional<RsaJsonWebKey>> {
private static final Logger LOGGER = LoggerFactory.getLogger(OidcDefaultJsonWebKeystoreCacheLoader.class);
private final Resource jwksFile;
public OidcDefaultJsonWebKeystoreCacheLoader(final Resource jwksFile) {
this.jwksFile = jwksFile;
}
@Override
public Optional<RsaJsonWebKey> load(final String issuer) throws Exception {
final Optional<JsonWebKeySet> jwks = buildJsonWebKeySet();
if (!jwks.isPresent() || jwks.get().getJsonWebKeys().isEmpty()) {
return Optional.empty();
}
final RsaJsonWebKey key = getJsonSigningWebKeyFromJwks(jwks.get());
if (key == null) {
return Optional.empty();
}
return Optional.of(key);
}
/**
* Build json web key set.
*
* @return the json web key set
* @throws Exception the exception
*/
private Optional<JsonWebKeySet> buildJsonWebKeySet() throws Exception {
try {
LOGGER.debug("Loading default JSON web key from [{}]", this.jwksFile);
if (this.jwksFile != null) {
LOGGER.debug("Retrieving default JSON web key from [{}]", this.jwksFile);
final JsonWebKeySet jsonWebKeySet = buildJsonWebKeySet(this.jwksFile);
if (jsonWebKeySet == null || jsonWebKeySet.getJsonWebKeys().isEmpty()) {
LOGGER.warn("No JSON web keys could be found");
return Optional.empty();
}
final long badKeysCount = jsonWebKeySet.getJsonWebKeys().stream().filter(k ->
StringUtils.isBlank(k.getAlgorithm())
&& StringUtils.isBlank(k.getKeyId())
&& StringUtils.isBlank(k.getKeyType())).count();
if (badKeysCount == jsonWebKeySet.getJsonWebKeys().size()) {
LOGGER.warn("No valid JSON web keys could be found");
return Optional.empty();
}
final RsaJsonWebKey webKey = getJsonSigningWebKeyFromJwks(jsonWebKeySet);
if (webKey.getPrivateKey() == null) {
LOGGER.warn("JSON web key retrieved [{}] has no associated private key", webKey.getKeyId());
return Optional.empty();
}
return Optional.of(jsonWebKeySet);
}
} catch (final Exception e) {
LOGGER.debug(e.getMessage(), e);
}
return Optional.empty();
}
private static RsaJsonWebKey getJsonSigningWebKeyFromJwks(final JsonWebKeySet jwks) {
if (jwks.getJsonWebKeys().isEmpty()) {
LOGGER.warn("No JSON web keys are available in the keystore");
return null;
}
final RsaJsonWebKey key = (RsaJsonWebKey) jwks.getJsonWebKeys().get(0);
if (StringUtils.isBlank(key.getAlgorithm())) {
LOGGER.warn("Located JSON web key [{}] has no algorithm defined", key);
}
if (StringUtils.isBlank(key.getKeyId())) {
LOGGER.warn("Located JSON web key [{}] has no key id defined", key);
}
if (key.getPrivateKey() == null) {
LOGGER.warn("Located JSON web key [{}] has no private key", key);
return null;
}
return key;
}
private static JsonWebKeySet buildJsonWebKeySet(final Resource resource) throws Exception {
final String json = IOUtils.toString(resource.getInputStream(), StandardCharsets.UTF_8);
LOGGER.debug("Retrieved JSON web key from [{}] as [{}]", resource, json);
return buildJsonWebKeySet(json);
}
private static JsonWebKeySet buildJsonWebKeySet(final String json) throws Exception {
final JsonWebKeySet jsonWebKeySet = new JsonWebKeySet(json);
final RsaJsonWebKey webKey = getJsonSigningWebKeyFromJwks(jsonWebKeySet);
if (webKey == null || webKey.getPrivateKey() == null) {
LOGGER.warn("JSON web key retrieved [{}] is not found or has no associated private key", webKey);
return null;
}
return jsonWebKeySet;
}
}