/**
* Copyright (c) Codice Foundation
* <p>
* This is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
* General Public License as published by the Free Software Foundation, either version 3 of the
* License, or any later version.
* <p>
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details. A copy of the GNU Lesser General Public License
* is distributed along with this program and can be found at
* <http://www.gnu.org/licenses/lgpl.html>.
*/
package ddf.security.samlp.impl;
import static java.util.Objects.nonNull;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import javax.annotation.concurrent.Immutable;
import org.apache.commons.lang.StringUtils;
import org.opensaml.core.xml.schema.XSBase64Binary;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.metadata.AssertionConsumerService;
import org.opensaml.saml.saml2.metadata.Endpoint;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.KeyDescriptor;
import org.opensaml.saml.saml2.metadata.SPSSODescriptor;
import org.opensaml.security.credential.UsageType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.collect.ImmutableSet;
import ddf.security.samlp.SamlProtocol;
import ddf.security.samlp.SamlProtocol.Binding;
@Immutable
public class EntityInformation {
private static final Logger LOGGER = LoggerFactory.getLogger(EntityInformation.class);
private final String signingCertificate;
private final String encryptionCertificate;
private final ServiceInfo defaultAssertionConsumerService;
private final Map<Binding, ServiceInfo> assertionConsumerServices;
private final Map<Binding, ServiceInfo> logoutServices;
private final Set<Binding> supportedBindings;
protected static final Binding PREFERRED_BINDING = Binding.HTTP_REDIRECT;
private EntityInformation(Builder builder) {
signingCertificate = builder.signingCertificate;
encryptionCertificate = builder.encryptionCertificate;
defaultAssertionConsumerService = builder.defaultAssertionConsumerService;
assertionConsumerServices = builder.assertionConsumerServices;
logoutServices = builder.logoutServices;
supportedBindings = builder.supportedBindings;
}
public String getSigningCertificate() {
return signingCertificate;
}
public String getEncryptionCertificate() {
return encryptionCertificate;
}
public ServiceInfo getLogoutService() {
return getLogoutService(null);
}
public ServiceInfo getLogoutService(Binding preferred) {
Binding binding = getBinding(null, preferred);
ServiceInfo logoutServiceInfo = logoutServices.get(binding);
if (logoutServiceInfo == null) {
logoutServiceInfo = logoutServices.values()
.stream()
.findFirst()
.orElse(null);
}
return logoutServiceInfo;
}
Binding getBinding(AuthnRequest request, Binding preferred) {
if (request != null && request.getProtocolBinding() != null && supportedBindings.contains(
Binding.from(request.getProtocolBinding()))) {
return Binding.from(request.getProtocolBinding());
}
return preferred != null ? preferred : PREFERRED_BINDING;
}
public ServiceInfo getAssertionConsumerService(AuthnRequest request, Binding preferred) {
ServiceInfo si;
if (request != null && request.getProtocolBinding() != null && supportedBindings.contains(
Binding.from(request.getProtocolBinding()))) {
si = assertionConsumerServices.get(Binding.from(request.getProtocolBinding()));
if (si != null) {
return si;
}
}
Binding binding = preferred != null ? preferred : PREFERRED_BINDING;
si = assertionConsumerServices.get(binding);
if (si != null) {
return si;
}
if (defaultAssertionConsumerService != null) {
return defaultAssertionConsumerService;
}
return assertionConsumerServices.values()
.stream()
.findFirst()
.orElse(null);
}
public static class Builder {
private static final ImmutableSet<UsageType> SIGNING_TYPES =
ImmutableSet.of(UsageType.UNSPECIFIED, UsageType.SIGNING);
private final SPSSODescriptor spssoDescriptor;
private final Set<Binding> supportedBindings;
private String signingCertificate;
private String encryptionCertificate;
private ServiceInfo defaultAssertionConsumerService;
private final Map<Binding, ServiceInfo> assertionConsumerServices = new HashMap<>();
private final Map<Binding, ServiceInfo> logoutServices = new HashMap<>();
public Builder(EntityDescriptor ed, Set<Binding> supportedBindings) {
spssoDescriptor = getSpssoDescriptor(ed);
this.supportedBindings = supportedBindings;
}
public EntityInformation build() {
if (spssoDescriptor == null) {
LOGGER.debug("Unable to build EntityInformation without a descriptor");
return null;
}
return new EntityInformation(parseSigningCertificate().parseEncryptionCertificate()
.parseAssertionConsumerServiceInfo()
.parseLogoutServices());
}
SPSSODescriptor getSpssoDescriptor(EntityDescriptor ed) {
SPSSODescriptor spssoDescriptor =
ed.getSPSSODescriptor(SamlProtocol.SUPPORTED_PROTOCOL);
if (spssoDescriptor == null) {
LOGGER.debug("Unable to find supported protocol in EntityDescriptor {}",
ed.getEntityID());
}
return spssoDescriptor;
}
Builder parseSigningCertificate() {
signingCertificate = extractCertificate(spssoDescriptor,
kd -> SIGNING_TYPES.contains(kd.getUse()));
return this;
}
Builder parseEncryptionCertificate() {
encryptionCertificate = extractCertificate(spssoDescriptor,
kd -> UsageType.ENCRYPTION.equals(kd.getUse()));
return this;
}
Builder parseAssertionConsumerServiceInfo() {
AssertionConsumerService defaultACS =
spssoDescriptor.getDefaultAssertionConsumerService();
//see if the default service uses our supported bindings, and then use that
//as we add more bindings, we'll need to update this
if (defaultACS != null
&& supportedBindings.contains(Binding.from(defaultACS.getBinding()))) {
LOGGER.debug(
"Using AssertionConsumerServiceURL from default assertion consumer service: {}",
defaultACS.getLocation());
defaultAssertionConsumerService = new ServiceInfo(defaultACS.getLocation(),
Binding.from(defaultACS.getBinding()));
}
putAllSupported(assertionConsumerServices,
spssoDescriptor.getAssertionConsumerServices());
return this;
}
Builder parseLogoutServices() {
putAllSupported(logoutServices, spssoDescriptor.getSingleLogoutServices());
return this;
}
void putAllSupported(Map<Binding, ServiceInfo> target, List<? extends Endpoint> services) {
for (Binding binding : supportedBindings) {
ServiceInfo serviceInfo = parseServiceInfo(services,
e -> binding.isEqual(e.getBinding()));
if (serviceInfo.url != null) {
target.put(binding, serviceInfo);
}
}
}
ServiceInfo parseServiceInfo(List<? extends Endpoint> services,
Predicate<Endpoint> bindingFilter) {
return services.stream()
.filter(bindingFilter)
.findFirst()
.map(si -> new ServiceInfo(si.getLocation(), Binding.from(si.getBinding())))
.orElse(new ServiceInfo(null, null));
}
String extractCertificate(SPSSODescriptor spssoDescriptor,
Predicate<KeyDescriptor> usageTypePredicate) {
return spssoDescriptor.getKeyDescriptors()
.stream()
.filter(Objects::nonNull)
.filter(kd -> nonNull(kd.getUse()))
.filter(usageTypePredicate)
.filter(kd -> nonNull(extractCertificateFromKeyDescriptor(kd)))
.reduce((acc, val) -> val.getUse()
.equals(UsageType.SIGNING) || acc == null ? val : acc)
.map(this::extractCertificateFromKeyDescriptor)
.orElse(null);
}
String extractCertificateFromKeyDescriptor(KeyDescriptor kd) {
return kd.getKeyInfo()
.getX509Datas()
.stream()
.flatMap(datas -> datas.getX509Certificates()
.stream())
.map(XSBase64Binary::getValue)
.filter(StringUtils::isNotBlank)
.findFirst()
.orElse(null);
}
}
public static class ServiceInfo {
private final String url;
private final Binding binding;
ServiceInfo(String url, Binding binding) {
this.url = url;
this.binding = binding;
}
public String getUrl() {
return url;
}
public Binding getBinding() {
return binding;
}
}
}