package org.apereo.cas.authentication; import com.google.common.base.Splitter; import org.apereo.cas.authentication.principal.Principal; import org.apereo.cas.services.MultifactorAuthenticationProvider; import org.apereo.cas.services.RegisteredService; import org.apereo.cas.services.RegisteredServiceMultifactorPolicy; import org.apereo.cas.util.CollectionUtils; import org.springframework.util.StringUtils; import javax.servlet.http.HttpServletRequest; import java.util.Collection; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.Predicate; import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.StreamSupport; /** * Default MFA Trigger selection strategy. This strategy looks for valid triggers in the following order: request * parameter, RegisteredService policy, principal attribute. * * @author Daniel Frett * @since 5.0.0 */ public class DefaultMultifactorTriggerSelectionStrategy implements MultifactorTriggerSelectionStrategy { private static final Splitter ATTR_NAMES = Splitter.on(',').trimResults().omitEmptyStrings(); private final String requestParameter; private final String globalPrincipalAttributeNameTriggers; public DefaultMultifactorTriggerSelectionStrategy(final String attributeNameTriggers, final String requestParameter) { this.globalPrincipalAttributeNameTriggers = attributeNameTriggers; this.requestParameter = requestParameter; } @Override public Optional<String> resolve(final Collection<MultifactorAuthenticationProvider> providers, final HttpServletRequest request, final RegisteredService service, final Principal principal) { Optional<String> provider = Optional.empty(); // short-circuit if we don't have any available MFA providers if (providers == null || providers.isEmpty()) { return provider; } final Set<String> validProviderIds = providers.stream() .map(MultifactorAuthenticationProvider::getId) .collect(Collectors.toSet()); // check for an opt-in provider id parameter trigger, we only care about the first value if (!provider.isPresent() && request != null) { provider = Optional.ofNullable(request.getParameter(requestParameter)) .filter(validProviderIds::contains); } // check for a RegisteredService configured trigger if (!provider.isPresent() && service != null) { final RegisteredServiceMultifactorPolicy policy = service.getMultifactorPolicy(); if (shouldApplyRegisteredServiceMultifactorPolicy(policy, principal)) { provider = policy.getMultifactorAuthenticationProviders().stream() .filter(validProviderIds::contains) .findFirst(); } } // check for principal attribute trigger if (!provider.isPresent() && principal != null && StringUtils.hasText(globalPrincipalAttributeNameTriggers)) { provider = StreamSupport.stream(ATTR_NAMES.split(globalPrincipalAttributeNameTriggers).spliterator(), false) // principal.getAttribute(name).values .map(principal.getAttributes()::get).filter(Objects::nonNull) .map(CollectionUtils::toCollection).flatMap(Set::stream) // validProviderIds.contains((String) value) .filter(String.class::isInstance).map(String.class::cast).filter(validProviderIds::contains) .findFirst(); } // return the resolved trigger return provider; } private static boolean shouldApplyRegisteredServiceMultifactorPolicy(final RegisteredServiceMultifactorPolicy policy, final Principal principal) { final String attrName = policy.getPrincipalAttributeNameTrigger(); final String attrValue = policy.getPrincipalAttributeValueToMatch(); // Principal attribute name and/or value is not defined if (!StringUtils.hasText(attrName) || !StringUtils.hasText(attrValue)) { return true; } // no Principal, we should enforce policy if (principal == null) { return true; } // check to see if any of the specified attributes match the attrValue pattern final Predicate<String> attrValuePredicate = Pattern.compile(attrValue).asPredicate(); return StreamSupport.stream(ATTR_NAMES.split(attrName).spliterator(), false) .map(principal.getAttributes()::get) .filter(Objects::nonNull) .map(CollectionUtils::toCollection) .flatMap(Set::stream) .filter(String.class::isInstance) .map(String.class::cast) .anyMatch(attrValuePredicate); } }