package org.apereo.cas.support.saml.services;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.NameValuePair;
import org.apache.http.client.utils.URIBuilder;
import org.apereo.cas.CasProtocolConstants;
import org.apereo.cas.authentication.principal.Principal;
import org.apereo.cas.services.RegisteredService;
import org.apereo.cas.services.ReturnAllowedAttributeReleasePolicy;
import org.apereo.cas.support.saml.SamlProtocolConstants;
import org.apereo.cas.support.saml.services.idp.metadata.SamlRegisteredServiceServiceProviderMetadataFacade;
import org.apereo.cas.support.saml.services.idp.metadata.cache.SamlRegisteredServiceCachingMetadataResolver;
import org.apereo.cas.util.spring.ApplicationContextProvider;
import org.apereo.cas.web.support.WebUtils;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;
import javax.servlet.http.HttpServletRequest;
import java.util.Map;
import java.util.Optional;
/**
* This is {@link BaseSamlRegisteredServiceAttributeReleasePolicy}.
*
* @author Misagh Moayyed
* @since 5.1.0
*/
public abstract class BaseSamlRegisteredServiceAttributeReleasePolicy extends ReturnAllowedAttributeReleasePolicy {
private static final long serialVersionUID = -3301632236702329694L;
private static final Logger LOGGER = LoggerFactory.getLogger(BaseSamlRegisteredServiceAttributeReleasePolicy.class);
@Override
protected Map<String, Object> getAttributesInternal(final Principal principal,
final Map<String, Object> attrs, final RegisteredService service) {
if (service instanceof SamlRegisteredService) {
final SamlRegisteredService saml = (SamlRegisteredService) service;
final HttpServletRequest request = WebUtils.getHttpServletRequestFromRequestAttributes();
if (request == null) {
LOGGER.warn("Could not locate the request context to process attributes");
return super.getAttributesInternal(principal, attrs, service);
}
String entityId = request.getParameter(SamlProtocolConstants.PARAMETER_ENTITY_ID);
if (StringUtils.isBlank(entityId)) {
final String svcParam = request.getParameter(CasProtocolConstants.PARAMETER_SERVICE);
if (StringUtils.isNotBlank(svcParam)) {
try {
final URIBuilder builder = new URIBuilder(svcParam);
entityId = builder.getQueryParams().stream()
.filter(p -> p.getName().equals(SamlProtocolConstants.PARAMETER_ENTITY_ID))
.map(NameValuePair::getValue)
.findFirst()
.orElse(StringUtils.EMPTY);
} catch (final Exception e) {
LOGGER.error(e.getMessage());
}
}
}
final ApplicationContext ctx = ApplicationContextProvider.getApplicationContext();
if (ctx == null) {
LOGGER.warn("Could not locate the application context to process attributes");
return super.getAttributesInternal(principal, attrs, service);
}
final SamlRegisteredServiceCachingMetadataResolver resolver =
ctx.getBean("defaultSamlRegisteredServiceCachingMetadataResolver",
SamlRegisteredServiceCachingMetadataResolver.class);
final Optional<SamlRegisteredServiceServiceProviderMetadataFacade> facade =
SamlRegisteredServiceServiceProviderMetadataFacade.get(resolver, saml, entityId);
if (facade == null || !facade.isPresent()) {
LOGGER.warn("Could not locate metadata for [{}] to process attributes", entityId);
return super.getAttributesInternal(principal, attrs, service);
}
final EntityDescriptor input = facade.get().getEntityDescriptor();
if (input == null) {
LOGGER.warn("Could not locate entity descriptor for [{}] to process attributes", entityId);
return super.getAttributesInternal(principal, attrs, service);
}
return getAttributesForSamlRegisteredService(attrs, saml, ctx, resolver, facade.get(), input);
}
return super.getAttributesInternal(principal, attrs, service);
}
/**
* Gets attributes for saml registered service.
*
* @param attrs the attrs
* @param service the service
* @param applicationContext the application context
* @param resolver the resolver
* @param facade the facade
* @param entityDescriptor the entity descriptor
* @return the attributes for saml registered service
*/
protected abstract Map<String, Object> getAttributesForSamlRegisteredService(Map<String, Object> attrs,
SamlRegisteredService service,
ApplicationContext applicationContext,
SamlRegisteredServiceCachingMetadataResolver resolver,
SamlRegisteredServiceServiceProviderMetadataFacade facade,
EntityDescriptor entityDescriptor);
}