/*******************************************************************************
* Cloud Foundry
* Copyright (c) [2009-2015] Pivotal Software, Inc. All Rights Reserved.
*
* This product is licensed to you under the Apache License, Version 2.0 (the "License").
* You may not use this product except in compliance with the License.
*
* This product includes a number of subcomponents with
* separate copyright notices and license terms. Your use of these
* subcomponents is subject to the terms and conditions of the
* subcomponent's license, as noted in the LICENSE file.
*******************************************************************************/
package org.cloudfoundry.identity.uaa.util;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.cloudfoundry.identity.uaa.oauth.client.ClientConstants;
import org.cloudfoundry.identity.uaa.provider.IdentityProvider;
import org.springframework.security.oauth2.provider.ClientDetails;
import org.springframework.util.StringUtils;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static java.util.Collections.EMPTY_LIST;
import static org.cloudfoundry.identity.uaa.constants.OriginKeys.UAA;
public class DomainFilter {
private static Log logger = LogFactory.getLog(DomainFilter.class);
public static List<IdentityProvider> filter(List<IdentityProvider> activeProviders, ClientDetails client, String email) {
if (!StringUtils.hasText(email)) {
return EMPTY_LIST;
}
if (activeProviders!=null && activeProviders.size()>0) {
//filter client providers
List<String> clientFilter = getProvidersForClient(client);
if (clientFilter!=null) {
activeProviders =
activeProviders.stream().filter(
p -> clientFilter.contains(p.getOriginKey())
).collect(Collectors.toList());
}
//filter for email domain
if (email!=null && email.contains("@")) {
final String domain = email.substring(email.indexOf('@') + 1);
List<IdentityProvider> explicitlyMatched =
activeProviders.stream().filter(
p -> doesEmailDomainMatchProvider(p, domain, true)
).collect(Collectors.toList());
if (explicitlyMatched.size()>0) {
return explicitlyMatched;
}
activeProviders =
activeProviders.stream().filter(
p -> doesEmailDomainMatchProvider(p, domain, false)
).collect(Collectors.toList());
}
}
return activeProviders != null ? activeProviders : EMPTY_LIST;
}
protected static List<String> getProvidersForClient(ClientDetails client) {
if (client==null) {
return null;
} else {
return (List<String>) client.getAdditionalInformation().get(ClientConstants.ALLOWED_PROVIDERS);
}
}
protected static List<String> getEmailDomain(IdentityProvider provider) {
if (provider.getConfig()!=null) {
return provider.getConfig().getEmailDomain();
}
return null;
}
protected static boolean doesEmailDomainMatchProvider(IdentityProvider provider, String domain, boolean explicit) {
List<String> domainList = getEmailDomain(provider);
List<String> wildcardList;
wildcardList = domainList;
if (!explicit) {
if (UAA.equals(provider.getOriginKey())) {
wildcardList = domainList == null ? Arrays.asList("*.*", "*.*.*", "*.*.*.*") : domainList;
}
}
if (wildcardList==null) {
return false;
} else {
Set<Pattern> patterns = UaaStringUtils.constructWildcards(wildcardList);
return UaaStringUtils.matches(patterns, domain);
}
}
}