package org.apereo.cas.support.saml;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.http.NameValuePair;
import org.apache.http.client.utils.URIBuilder;
import org.apereo.cas.authentication.AuthenticationServiceSelectionStrategy;
import org.apereo.cas.authentication.principal.Service;
import org.apereo.cas.authentication.principal.ServiceFactory;
import org.apereo.cas.web.support.WebUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.Ordered;
import javax.servlet.http.HttpServletRequest;
import java.util.Arrays;
import java.util.Optional;
/**
* This is {@link ShibbolethIdPEntityIdAuthenticationServiceSelectionStrategy}.
*
* @author Misagh Moayyed
* @since 5.0.0
*/
public class ShibbolethIdPEntityIdAuthenticationServiceSelectionStrategy implements AuthenticationServiceSelectionStrategy {
private static final long serialVersionUID = -2059445756475980894L;
private static final Logger LOGGER = LoggerFactory.getLogger(ShibbolethIdPEntityIdAuthenticationServiceSelectionStrategy.class);
private int order = Ordered.HIGHEST_PRECEDENCE;
private final ServiceFactory webApplicationServiceFactory;
private final String idpServerPrefix;
public ShibbolethIdPEntityIdAuthenticationServiceSelectionStrategy(final ServiceFactory webApplicationServiceFactory,
final String idpServerPrefix) {
this.webApplicationServiceFactory = webApplicationServiceFactory;
this.idpServerPrefix = idpServerPrefix;
}
@Override
public Service resolveServiceFrom(final Service service) {
final String entityId = getEntityIdAsParameter(service).get();
LOGGER.debug("Located entity id [{}] from service authentication request at [{}]", entityId, service.getId());
return this.webApplicationServiceFactory.createService(entityId);
}
@Override
public boolean supports(final Service service) {
final String casPattern = "^".concat(idpServerPrefix).concat(".*");
return service != null && service.getId().matches(casPattern)
&& getEntityIdAsParameter(service).isPresent();
}
/**
* Gets entity id as parameter.
*
* @param service the service
* @return the entity id as parameter
*/
protected static Optional<String> getEntityIdAsParameter(final Service service) {
try {
final URIBuilder builder = new URIBuilder(service.getId());
final Optional<NameValuePair> param = builder.getQueryParams()
.stream()
.filter(p -> p.getName().equals(SamlProtocolConstants.PARAMETER_ENTITY_ID))
.findFirst();
if (param.isPresent()) {
return Optional.of(param.get().getValue());
}
final HttpServletRequest request = WebUtils.getHttpServletRequest();
final String[] query = request.getQueryString().split("&");
final Optional<String> paramRequest = Arrays.stream(query)
.map(p -> {
final String[] params = p.split("=");
return Pair.of(params[0], params[1]);
})
.filter(p -> p.getKey().equals(SamlProtocolConstants.PARAMETER_ENTITY_ID))
.map(Pair::getValue)
.findFirst();
return paramRequest;
} catch (final Exception e) {
LOGGER.error(e.getMessage(), e);
}
return Optional.empty();
}
@Override
public int getOrder() {
return this.order;
}
}