package mujina.idp;
import mujina.api.IdpConfiguration;
import mujina.saml.ProxiedSAMLContextProviderLB;
import mujina.saml.SAMLPrincipal;
import org.joda.time.DateTime;
import org.opensaml.common.SAMLObject;
import org.opensaml.common.binding.BasicSAMLMessageContext;
import org.opensaml.common.binding.decoding.SAMLMessageDecoder;
import org.opensaml.common.binding.encoding.SAMLMessageEncoder;
import org.opensaml.common.xml.SAMLConstants;
import org.opensaml.saml2.core.Assertion;
import org.opensaml.saml2.core.AuthnRequest;
import org.opensaml.saml2.core.Issuer;
import org.opensaml.saml2.core.Response;
import org.opensaml.saml2.core.Status;
import org.opensaml.saml2.core.StatusCode;
import org.opensaml.saml2.metadata.Endpoint;
import org.opensaml.saml2.metadata.SingleSignOnService;
import org.opensaml.saml2.metadata.provider.MetadataProviderException;
import org.opensaml.ws.message.decoder.MessageDecodingException;
import org.opensaml.ws.message.encoder.MessageEncodingException;
import org.opensaml.ws.security.SecurityPolicyResolver;
import org.opensaml.ws.transport.http.HttpServletRequestAdapter;
import org.opensaml.ws.transport.http.HttpServletResponseAdapter;
import org.opensaml.xml.io.MarshallingException;
import org.opensaml.xml.security.CriteriaSet;
import org.opensaml.xml.security.SecurityException;
import org.opensaml.xml.security.credential.Credential;
import org.opensaml.xml.security.criteria.EntityIDCriteria;
import org.opensaml.xml.signature.SignatureException;
import org.opensaml.xml.validation.ValidationException;
import org.opensaml.xml.validation.ValidatorSuite;
import org.springframework.security.saml.context.SAMLMessageContext;
import org.springframework.security.saml.key.KeyManager;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collection;
import java.util.List;
import java.util.UUID;
import static java.util.Arrays.asList;
import static mujina.saml.SAMLBuilder.buildAssertion;
import static mujina.saml.SAMLBuilder.buildIssuer;
import static mujina.saml.SAMLBuilder.buildSAMLObject;
import static mujina.saml.SAMLBuilder.buildStatus;
import static mujina.saml.SAMLBuilder.signAssertion;
import static org.opensaml.xml.Configuration.getValidatorSuite;
public class SAMLMessageHandler {
private final KeyManager keyManager;
private final Collection<SAMLMessageDecoder> decoders;
private final SAMLMessageEncoder encoder;
private final SecurityPolicyResolver resolver;
private final IdpConfiguration idpConfiguration;
private final List<ValidatorSuite> validatorSuites;
private final ProxiedSAMLContextProviderLB proxiedSAMLContextProviderLB;
public SAMLMessageHandler(KeyManager keyManager, Collection<SAMLMessageDecoder> decoders,
SAMLMessageEncoder encoder, SecurityPolicyResolver securityPolicyResolver,
IdpConfiguration idpConfiguration, String idpBaseUrl) throws URISyntaxException {
this.keyManager = keyManager;
this.encoder = encoder;
this.decoders = decoders;
this.resolver = securityPolicyResolver;
this.idpConfiguration = idpConfiguration;
this.validatorSuites = asList(
getValidatorSuite("saml2-core-schema-validator"),
getValidatorSuite("saml2-core-spec-validator"));
this.proxiedSAMLContextProviderLB = new ProxiedSAMLContextProviderLB(new URI(idpBaseUrl));
}
public SAMLMessageContext extractSAMLMessageContext(HttpServletRequest request, HttpServletResponse response, boolean postRequest) throws ValidationException, SecurityException, MessageDecodingException, MetadataProviderException {
SAMLMessageContext messageContext = new SAMLMessageContext();
proxiedSAMLContextProviderLB.populateGenericContext(request, response, messageContext);
messageContext.setSecurityPolicyResolver(resolver);
SAMLMessageDecoder samlMessageDecoder = samlMessageDecoder(postRequest);
samlMessageDecoder.decode(messageContext);
SAMLObject inboundSAMLMessage = messageContext.getInboundSAMLMessage();
AuthnRequest authnRequest = (AuthnRequest) inboundSAMLMessage;
//lambda is poor with Exceptions
for (ValidatorSuite validatorSuite : validatorSuites) {
validatorSuite.validate(authnRequest);
}
return messageContext;
}
private SAMLMessageDecoder samlMessageDecoder(boolean postRequest) {
return decoders.stream().filter(samlMessageDecoder -> postRequest ?
samlMessageDecoder.getBindingURI().equals(SAMLConstants.SAML2_POST_BINDING_URI) :
samlMessageDecoder.getBindingURI().equals(SAMLConstants.SAML2_REDIRECT_BINDING_URI))
.findAny()
.orElseThrow(() -> new RuntimeException(String.format("Only %s and %s are supported",
SAMLConstants.SAML2_REDIRECT_BINDING_URI,
SAMLConstants.SAML2_POST_BINDING_URI)));
}
public void sendAuthnResponse(SAMLPrincipal principal, HttpServletResponse response) throws MarshallingException, SignatureException, MessageEncodingException {
Status status = buildStatus(StatusCode.SUCCESS_URI);
String entityId = idpConfiguration.getEntityId();
Credential signingCredential = resolveCredential(entityId);
Response authResponse = buildSAMLObject(Response.class, Response.DEFAULT_ELEMENT_NAME);
Issuer issuer = buildIssuer(entityId);
authResponse.setIssuer(issuer);
authResponse.setID(UUID.randomUUID().toString());
authResponse.setIssueInstant(new DateTime());
authResponse.setInResponseTo(principal.getRequestID());
Assertion assertion = buildAssertion(principal, status, entityId);
signAssertion(assertion, signingCredential);
authResponse.getAssertions().add(assertion);
authResponse.setDestination(principal.getAssertionConsumerServiceURL());
authResponse.setStatus(status);
Endpoint endpoint = buildSAMLObject(Endpoint.class, SingleSignOnService.DEFAULT_ELEMENT_NAME);
endpoint.setLocation(principal.getAssertionConsumerServiceURL());
HttpServletResponseAdapter outTransport = new HttpServletResponseAdapter(response, false);
BasicSAMLMessageContext messageContext = new BasicSAMLMessageContext();
messageContext.setOutboundMessageTransport(outTransport);
messageContext.setPeerEntityEndpoint(endpoint);
messageContext.setOutboundSAMLMessage(authResponse);
messageContext.setOutboundSAMLMessageSigningCredential(signingCredential);
messageContext.setOutboundMessageIssuer(entityId);
messageContext.setRelayState(principal.getRelayState());
encoder.encode(messageContext);
}
private Credential resolveCredential(String entityId) {
try {
return keyManager.resolveSingle(new CriteriaSet(new EntityIDCriteria(entityId)));
} catch (SecurityException e) {
throw new RuntimeException(e);
}
}
}