package org.pac4j.saml.context;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import net.shibboleth.utilities.java.support.resolver.ResolverException;
import org.opensaml.core.criterion.EntityIdCriterion;
import org.opensaml.messaging.context.BaseContext;
import org.opensaml.messaging.context.MessageContext;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.saml.common.messaging.context.SAMLMetadataContext;
import org.opensaml.saml.common.messaging.context.SAMLPeerEntityContext;
import org.opensaml.saml.common.messaging.context.SAMLSelfEntityContext;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.metadata.resolver.MetadataResolver;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.IDPSSODescriptor;
import org.opensaml.saml.saml2.metadata.RoleDescriptor;
import org.opensaml.saml.saml2.metadata.SPSSODescriptor;
import org.pac4j.core.context.WebContext;
import org.pac4j.core.util.CommonHelper;
import org.pac4j.saml.exceptions.SAMLException;
import org.pac4j.saml.metadata.SAML2MetadataResolver;
import org.pac4j.saml.storage.SAMLMessageStorageFactory;
import org.pac4j.saml.transport.DefaultPac4jSAMLResponse;
import org.pac4j.saml.transport.Pac4jSAMLResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nullable;
import javax.xml.namespace.QName;
import java.util.List;
/**
* Responsible for building a {@link SAML2MessageContext} from given SAML2 properties (idpEntityId and metadata
* manager) and current {@link WebContext}.
*
* @author Michael Remond
* @author Misagh Moayyed
* @since 1.7
*/
@SuppressWarnings("rawtypes")
public class SAML2ContextProvider implements SAMLContextProvider {
private static final String SAML2_WEBSSO_PROFILE_URI = "urn:oasis:names:tc:SAML:2.0:profiles:SSO:browser";
protected final static Logger logger = LoggerFactory.getLogger(SAML2ContextProvider.class);
protected final MetadataResolver metadata;
protected final SAML2MetadataResolver idpEntityId;
protected final SAML2MetadataResolver spEntityId;
protected final SAMLMessageStorageFactory samlMessageStorageFactory;
public SAML2ContextProvider(final MetadataResolver metadata,
final SAML2MetadataResolver idpEntityId,
final SAML2MetadataResolver spEntityId,
@Nullable final SAMLMessageStorageFactory samlMessageStorageFactory) {
this.metadata = metadata;
this.idpEntityId = idpEntityId;
this.spEntityId = spEntityId;
this.samlMessageStorageFactory = samlMessageStorageFactory;
}
@Override
public final SAML2MessageContext buildServiceProviderContext(final WebContext webContext) {
final SAML2MessageContext context = new SAML2MessageContext();
addTransportContext(webContext, context);
addSPContext(context);
return context;
}
@Override
public SAML2MessageContext buildContext(final WebContext webContext) {
final SAML2MessageContext context = buildServiceProviderContext(webContext);
addIDPContext(context);
context.setWebContext(webContext);
return context;
}
protected final void addTransportContext(final WebContext webContext, final SAML2MessageContext context) {
final ProfileRequestContext profile = context.getProfileRequestContext();
profile.setOutboundMessageContext(prepareOutboundMessageContext(webContext));
context.getSAMLProtocolContext().setProtocol(SAMLConstants.SAML20P_NS);
final ProfileRequestContext request = context.getProfileRequestContext();
request.setProfileId(SAML2_WEBSSO_PROFILE_URI);
if (this.samlMessageStorageFactory != null) {
logger.debug("Creating message storage by {}", this.samlMessageStorageFactory.getClass().getName());
context.setSAMLMessageStorage(this.samlMessageStorageFactory.getMessageStorage(webContext));
}
}
protected MessageContext<Pac4jSAMLResponse> prepareOutboundMessageContext(final WebContext webContext) {
final Pac4jSAMLResponse outTransport = new DefaultPac4jSAMLResponse(webContext);
final MessageContext<Pac4jSAMLResponse> outCtx = new MessageContext<>();
outCtx.setMessage(outTransport);
return outCtx;
}
protected final void addSPContext(final SAML2MessageContext context) {
final SAMLSelfEntityContext selfContext = context.getSAMLSelfEntityContext();
selfContext.setEntityId(this.spEntityId.getEntityId());
selfContext.setRole(SPSSODescriptor.DEFAULT_ELEMENT_NAME);
addContext(this.spEntityId, selfContext, SPSSODescriptor.DEFAULT_ELEMENT_NAME);
}
protected final void addIDPContext(final SAML2MessageContext context) {
final SAMLPeerEntityContext peerContext = context.getSAMLPeerEntityContext();
peerContext.setEntityId(this.idpEntityId.getEntityId());
peerContext.setRole(IDPSSODescriptor.DEFAULT_ELEMENT_NAME);
addContext(this.idpEntityId, peerContext, IDPSSODescriptor.DEFAULT_ELEMENT_NAME);
}
protected final void addContext(final SAML2MetadataResolver entityId, final BaseContext parentContext,
final QName elementName) {
final EntityDescriptor entityDescriptor;
final RoleDescriptor roleDescriptor;
try {
final CriteriaSet set = new CriteriaSet();
set.add(new EntityIdCriterion(entityId.getEntityId()));
entityDescriptor = this.metadata.resolveSingle(set);
if (entityDescriptor == null) {
throw new SAMLException("Cannot find entity " + entityId + " in metadata provider");
}
final List<RoleDescriptor> list = entityDescriptor.getRoleDescriptors(elementName,
SAMLConstants.SAML20P_NS);
roleDescriptor = CommonHelper.isNotEmpty(list) ? list.get(0) : null;
if (roleDescriptor == null) {
throw new SAMLException("Cannot find entity " + entityId + " or role "
+ elementName + " in metadata provider");
}
} catch (final ResolverException e) {
throw new SAMLException("An error occured while getting IDP descriptors", e);
}
final SAMLMetadataContext mdCtx = parentContext.getSubcontext(SAMLMetadataContext.class, true);
mdCtx.setEntityDescriptor(entityDescriptor);
mdCtx.setRoleDescriptor(roleDescriptor);
}
}