package mujina.sp;
import mujina.api.SpConfiguration;
import org.opensaml.common.SAMLException;
import org.opensaml.saml2.metadata.AssertionConsumerService;
import org.opensaml.saml2.metadata.Endpoint;
import org.opensaml.saml2.metadata.SPSSODescriptor;
import org.opensaml.saml2.metadata.provider.MetadataProviderException;
import org.opensaml.ws.message.encoder.MessageEncodingException;
import org.springframework.security.saml.context.SAMLMessageContext;
import org.springframework.security.saml.processor.SAMLBinding;
import org.springframework.security.saml.processor.SAMLProcessorImpl;
import java.util.Collection;
public class ConfigurableSAMLProcessor extends SAMLProcessorImpl {
private final SpConfiguration spConfiguration;
public ConfigurableSAMLProcessor(Collection<SAMLBinding> bindings, SpConfiguration spConfiguration) {
super(bindings);
this.spConfiguration = spConfiguration;
}
@Override
public SAMLMessageContext sendMessage(SAMLMessageContext samlContext, boolean sign)
throws SAMLException, MetadataProviderException, MessageEncodingException {
Endpoint endpoint = samlContext.getPeerEntityEndpoint();
SAMLBinding binding = getBinding(endpoint);
samlContext.setLocalEntityId(spConfiguration.getEntityId());
samlContext.getLocalEntityMetadata().setEntityID(spConfiguration.getEntityId());
samlContext.getPeerEntityEndpoint().setLocation(spConfiguration.getIdpSSOServiceURL());
SPSSODescriptor roleDescriptor = (SPSSODescriptor) samlContext.getLocalEntityMetadata().getRoleDescriptors().get(0);
AssertionConsumerService assertionConsumerService = roleDescriptor.getAssertionConsumerServices().stream().filter(service -> service.isDefault()).findAny().orElseThrow(() -> new RuntimeException("No default ACS"));
assertionConsumerService.setBinding(spConfiguration.getProtocolBinding());
assertionConsumerService.setLocation(spConfiguration.getAssertionConsumerServiceURL());
return super.sendMessage(samlContext, spConfiguration.isNeedsSigning(), binding);
}
}