package org.apereo.cas.oidc.claims;
import com.fasterxml.jackson.annotation.JsonIgnore;
import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.tuple.Pair;
import org.apereo.cas.authentication.principal.Principal;
import org.apereo.cas.oidc.claims.mapping.OidcAttributeToScopeClaimMapper;
import org.apereo.cas.services.AbstractRegisteredServiceAttributeReleasePolicy;
import org.apereo.cas.services.RegisteredService;
import org.apereo.cas.util.spring.ApplicationContextProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
/**
* This is {@link BaseOidcScopeAttributeReleasePolicy}.
*
* @author Misagh Moayyed
* @since 5.1.0
*/
public abstract class BaseOidcScopeAttributeReleasePolicy extends AbstractRegisteredServiceAttributeReleasePolicy {
private static final long serialVersionUID = -7302163334687300920L;
private static final Logger LOGGER = LoggerFactory.getLogger(BaseOidcScopeAttributeReleasePolicy.class);
private List<String> allowedAttributes;
@JsonIgnore
private String scopeName;
public BaseOidcScopeAttributeReleasePolicy(final String scopeName) {
this.scopeName = scopeName;
}
public String getScopeName() {
return scopeName;
}
public void setAllowedAttributes(final List<String> allowed) {
this.allowedAttributes = allowed;
}
public List<String> getAllowedAttributes() {
return this.allowedAttributes;
}
@Override
public boolean equals(final Object obj) {
if (obj == null) {
return false;
}
if (obj == this) {
return true;
}
if (obj.getClass() != getClass()) {
return false;
}
final BaseOidcScopeAttributeReleasePolicy rhs = (BaseOidcScopeAttributeReleasePolicy) obj;
return new EqualsBuilder()
.appendSuper(super.equals(obj))
.append(getAllowedAttributes(), rhs.getAllowedAttributes())
.append(getScopeName(), rhs.getScopeName())
.isEquals();
}
@Override
public int hashCode() {
return new HashCodeBuilder(13, 133)
.appendSuper(super.hashCode())
.append(getAllowedAttributes())
.append(getScopeName())
.toHashCode();
}
@Override
public String toString() {
return new ToStringBuilder(this)
.appendSuper(super.toString())
.append("allowedAttributes", getAllowedAttributes())
.append("scopeName", scopeName)
.toString();
}
@Override
protected Map<String, Object> getAttributesInternal(final Principal principal,
final Map<String, Object> attributes,
final RegisteredService service) {
final Map<String, Object> resolvedAttributes = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
resolvedAttributes.putAll(attributes);
final Map<String, Object> attributesToRelease = new HashMap<>(resolvedAttributes.size());
getAllowedAttributes()
.stream()
.map(claim -> mapClaimToAttribute(claim, resolvedAttributes))
.filter(p -> p.getValue() != null)
.forEach(p -> attributesToRelease.put(p.getKey(), p.getValue()));
return attributesToRelease;
}
private Pair<String, Object> mapClaimToAttribute(final String claim, final Map<String, Object> resolvedAttributes) {
final ApplicationContext applicationContext = ApplicationContextProvider.getApplicationContext();
final OidcAttributeToScopeClaimMapper attributeToScopeClaimMapper =
applicationContext.getBean("oidcAttributeToScopeClaimMapper", OidcAttributeToScopeClaimMapper.class);
LOGGER.debug("Attempting to process claim [{}]", claim);
if (attributeToScopeClaimMapper.containsMappedAttribute(claim)) {
final String mappedAttr = attributeToScopeClaimMapper.getMappedAttribute(claim);
final Object value = resolvedAttributes.get(mappedAttr);
LOGGER.debug("Found mapped attribute [{}] with value [{}] for claim [{}]", mappedAttr, value, claim);
return Pair.of(claim, value);
}
final Object value = resolvedAttributes.get(claim);
LOGGER.debug("No mapped attribute is defined for claim [{}]; Used [{}] to locate value [{}]", claim, value);
return Pair.of(claim, value);
}
}