package org.dcache.gplazma.oidc; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.net.InternetDomainName; import org.codehaus.jackson.JsonNode; import org.dcache.auth.BearerTokenCredential; import org.dcache.auth.EmailAddressPrincipal; import org.dcache.auth.FullNamePrincipal; import org.dcache.auth.OidcSubjectPrincipal; import org.dcache.gplazma.AuthenticationException; import org.dcache.gplazma.oidc.exceptions.OidcException; import org.dcache.gplazma.oidc.helpers.JsonHttpClient; import org.dcache.gplazma.plugins.GPlazmaAuthenticationPlugin; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.security.Principal; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Properties; import java.util.Random; import java.util.Set; import java.util.Map; import java.util.HashSet; import java.util.Base64; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.function.Predicate; import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkArgument; import static org.dcache.gplazma.util.Preconditions.checkAuthentication; public class OidcAuthPlugin implements GPlazmaAuthenticationPlugin { private final static Logger LOG = LoggerFactory.getLogger(OidcAuthPlugin.class); private final static String OIDC_HOSTNAMES = "gplazma.oidc.hostnames"; private final LoadingCache<String, JsonNode> cache; private Set<String> discoveryDocs; private JsonHttpClient jsonHttpClient; private final Random random = new Random(); public OidcAuthPlugin(Properties properties) { this(properties, new JsonHttpClient()); } @VisibleForTesting OidcAuthPlugin(Properties properties, JsonHttpClient client) { this(properties, client, createLoadingCache(client)); } @VisibleForTesting OidcAuthPlugin(Properties properties, JsonHttpClient client, LoadingCache<String, JsonNode> cache) { String oidcHostnamesProperty = properties.getProperty(OIDC_HOSTNAMES); checkArgument(oidcHostnamesProperty != null, OIDC_HOSTNAMES + " not defined"); Map<Boolean, Set<String>> validHosts = Arrays.stream(oidcHostnamesProperty.split("\\s+")) .filter(not(String::isEmpty)) .collect( Collectors.groupingBy(InternetDomainName::isValid, Collectors.toSet()) ); checkArgument(!validHosts.containsKey(Boolean.FALSE), String.format("Invalid hosts in %s: %s", OIDC_HOSTNAMES, Joiner.on(", ").join(nullToEmpty(validHosts.get(Boolean.FALSE))))); checkArgument(validHosts.containsKey(Boolean.TRUE), String.format("No hosts specified in %s", OIDC_HOSTNAMES)); this.discoveryDocs = validHosts.get(Boolean.TRUE); this.jsonHttpClient = client; this.cache = cache; } @Override public void authenticate(Set<Object> publicCredentials, Set<Object> privateCredentials, Set<Principal> identifiedPrincipals) throws AuthenticationException { Set<String> failures = new HashSet<>(); boolean foundBearerToken = false; for (Object credential: privateCredentials) { if (credential instanceof BearerTokenCredential) { BearerTokenCredential token = (BearerTokenCredential) credential; foundBearerToken = true; for (String host : discoveryDocs) { try { identifiedPrincipals.addAll( validateBearerTokenWithOpenIdProvider(token, extractUserInfoEndPoint(cache.get(host)), host)); return; } catch (OidcException oe) { failures.add(oe.getMessage()); } catch (ExecutionException e) { failures.add("(\"" + host + "\", " + e.getMessage() + ")"); } } } } checkAuthentication(foundBearerToken, "No bearer token in the credentials"); if (failures.size() == 1) { throw new AuthenticationException("OpenId Validation Failed: " + failures.iterator().next()); } else { String randomId = randomId(); LOG.warn("OpenId Validation Failure ({}): {}", randomId, buildErrorMessage(failures)); throw new AuthenticationException("OpenId Validation Failed check [log entry #" + randomId + "]"); } } private Set<Principal> validateBearerTokenWithOpenIdProvider (BearerTokenCredential credential, String infoUrl, String host) throws OidcException { try { JsonNode userInfo = getUserInfo(infoUrl, credential.getToken()); if (userInfo != null && userInfo.has("sub")) { LOG.debug("UserInfo from OpenId Provider: {}", userInfo); Set<Principal> principals = new HashSet<>(); addSub(userInfo, principals); addNames(userInfo, principals); addEmail(userInfo, principals); return principals; } else { throw new OidcException(host, "No OpendId \"sub\""); } } catch (IllegalArgumentException iae) { throw new OidcException(host, "Error parsing UserInfo: " + iae.getMessage()); } catch (AuthenticationException e) { throw new OidcException(host, e.getMessage()); } catch (IOException e) { throw new OidcException(host, "Failed to fetch UserInfo: " + e.getMessage()); } } private JsonNode getUserInfo(String url, String token) throws AuthenticationException, IOException { JsonNode userInfo = jsonHttpClient.doGetWithToken(url, token); if (userInfo.has("error")) { String error = userInfo.get("error").asText(); String errorDescription = userInfo.get("error_description").asText(); throw new AuthenticationException("Error: [" + error + ", " + errorDescription + " ]"); } else { return userInfo; } } private String extractUserInfoEndPoint(JsonNode discoveryDoc) { if (discoveryDoc.has("userinfo_endpoint")) { return discoveryDoc.get("userinfo_endpoint").asText(); } else { return null; } } private static LoadingCache<String, JsonNode> createLoadingCache(final JsonHttpClient client) { return CacheBuilder.newBuilder() .maximumSize(100) .expireAfterAccess(1, TimeUnit.HOURS) .build( new CacheLoader<String, JsonNode>() { @Override public JsonNode load(String hostname) throws OidcException, IOException { JsonNode discoveryDoc = client.doGet("https://" + hostname + "/.well-known/openid-configuration"); if ( discoveryDoc != null && discoveryDoc.has("userinfo_endpoint")) { return discoveryDoc; } else { throw new OidcException(hostname, "Discovery Document at " + discoveryDoc + " does not contain userinfo endpoint url"); } } } ); } private void addEmail(JsonNode userInfo, Set<Principal> principals) { if (userInfo.has("email")) { principals.add(new EmailAddressPrincipal(userInfo.get("email").asText())); } } private void addNames(JsonNode userInfo, Set<Principal> principals) { JsonNode givenName = userInfo.get("given_name"); JsonNode familyName = userInfo.get("family_name"); JsonNode fullName = userInfo.get("name"); if (fullName != null && !fullName.asText().isEmpty()) { principals.add(new FullNamePrincipal(fullName.asText())); } else { principals.add(new FullNamePrincipal(givenName == null ? null : givenName.asText(), familyName == null ? null : familyName.asText())); } } private boolean addSub(JsonNode userInfo, Set<Principal> principals) { return principals.add(new OidcSubjectPrincipal(userInfo.get("sub").asText())); } private static <T> Predicate<T> not(Predicate<T> t) { return t.negate(); } private <T> Collection<T> nullToEmpty(final Collection<T> collection) { return collection == null ? Collections.emptySet() : collection; } private String buildErrorMessage(Set<String> errors) { return errors.isEmpty() ? "(unknown)" : errors.stream().collect(Collectors.joining(", ", "[", "]")); } private String randomId() { byte[] rawId = new byte[6]; // a Base64 char represents 6 bits; 4 chars represent 3 bytes. random.nextBytes(rawId); return Base64.getEncoder().withoutPadding().encodeToString(rawId); } }