/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.cxf.sts.token.renewer; import java.security.Principal; import java.security.cert.Certificate; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import javax.security.auth.callback.CallbackHandler; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.apache.cxf.common.logging.LogUtils; import org.apache.cxf.helpers.CastUtils; import org.apache.cxf.helpers.DOMUtils; import org.apache.cxf.security.transport.TLSSessionInfo; import org.apache.cxf.sts.STSConstants; import org.apache.cxf.sts.STSPropertiesMBean; import org.apache.cxf.sts.cache.CacheUtils; import org.apache.cxf.sts.request.ReceivedToken; import org.apache.cxf.sts.request.ReceivedToken.STATE; import org.apache.cxf.sts.token.provider.AbstractSAMLTokenProvider; import org.apache.cxf.sts.token.provider.ConditionsProvider; import org.apache.cxf.sts.token.provider.DefaultConditionsProvider; import org.apache.cxf.sts.token.provider.TokenProviderParameters; import org.apache.cxf.sts.token.realm.RealmProperties; import org.apache.cxf.ws.security.sts.provider.STSException; import org.apache.cxf.ws.security.tokenstore.SecurityToken; import org.apache.cxf.ws.security.tokenstore.TokenStore; import org.apache.wss4j.common.crypto.Crypto; import org.apache.wss4j.common.ext.WSSecurityException; import org.apache.wss4j.common.saml.SAMLKeyInfo; import org.apache.wss4j.common.saml.SamlAssertionWrapper; import org.apache.wss4j.common.saml.bean.ConditionsBean; import org.apache.wss4j.common.saml.builder.SAML1ComponentBuilder; import org.apache.wss4j.common.saml.builder.SAML2ComponentBuilder; import org.apache.wss4j.dom.WSConstants; import org.apache.wss4j.dom.WSDocInfo; import org.apache.wss4j.dom.engine.WSSConfig; import org.apache.wss4j.dom.engine.WSSecurityEngineResult; import org.apache.wss4j.dom.handler.RequestData; import org.apache.wss4j.dom.handler.WSHandlerConstants; import org.apache.wss4j.dom.handler.WSHandlerResult; import org.apache.wss4j.dom.saml.DOMSAMLUtil; import org.apache.wss4j.dom.saml.WSSSAMLKeyInfoProcessor; import org.apache.xml.security.stax.impl.util.IDGenerator; import org.joda.time.DateTime; import org.opensaml.saml.common.SAMLVersion; import org.opensaml.saml.saml1.core.Audience; import org.opensaml.saml.saml1.core.AudienceRestrictionCondition; import org.opensaml.saml.saml2.core.AudienceRestriction; /** * A TokenRenewer implementation that renews a (valid or expired) SAML Token. */ public class SAMLTokenRenewer extends AbstractSAMLTokenProvider implements TokenRenewer { // The default maximum expired time a token is allowed to be is 30 minutes public static final long DEFAULT_MAX_EXPIRY = 60L * 30L; private static final Logger LOG = LogUtils.getL7dLogger(SAMLTokenRenewer.class); private boolean signToken = true; private ConditionsProvider conditionsProvider = new DefaultConditionsProvider(); private Map<String, RealmProperties> realmMap = new HashMap<>(); private long maxExpiry = DEFAULT_MAX_EXPIRY; // boolean to enable/disable the check of proof of possession private boolean verifyProofOfPossession = true; private boolean allowRenewalAfterExpiry; /** * Return true if this TokenRenewer implementation is able to renew a token. */ public boolean canHandleToken(ReceivedToken renewTarget) { return canHandleToken(renewTarget, null); } /** * Return true if this TokenRenewer implementation is able to renew a token in the given realm. */ public boolean canHandleToken(ReceivedToken renewTarget, String realm) { if (realm != null && !realmMap.containsKey(realm)) { return false; } Object token = renewTarget.getToken(); if (token instanceof Element) { Element tokenElement = (Element)token; String namespace = tokenElement.getNamespaceURI(); String localname = tokenElement.getLocalName(); if ((WSConstants.SAML_NS.equals(namespace) || WSConstants.SAML2_NS.equals(namespace)) && "Assertion".equals(localname)) { return true; } } return false; } /** * Set whether proof of possession is required or not to renew a token */ public void setVerifyProofOfPossession(boolean verifyProofOfPossession) { this.verifyProofOfPossession = verifyProofOfPossession; } /** * Get whether we allow renewal after expiry. The default is false. */ public boolean isAllowRenewalAfterExpiry() { return allowRenewalAfterExpiry; } /** * Set whether we allow renewal after expiry. The default is false. */ public void setAllowRenewalAfterExpiry(boolean allowRenewalAfterExpiry) { this.allowRenewalAfterExpiry = allowRenewalAfterExpiry; } /** * Set a new value (in seconds) for how long a token is allowed to be expired for before renewal. * The default is 30 minutes. */ public void setMaxExpiry(long newExpiry) { maxExpiry = newExpiry; } /** * Get how long a token is allowed to be expired for before renewal (in seconds). The default is * 30 minutes. */ public long getMaxExpiry() { return maxExpiry; } /** * Renew a token given a TokenRenewerParameters */ public TokenRenewerResponse renewToken(TokenRenewerParameters tokenParameters) { TokenRenewerResponse response = new TokenRenewerResponse(); ReceivedToken tokenToRenew = tokenParameters.getToken(); if (tokenToRenew == null || tokenToRenew.getToken() == null || (tokenToRenew.getState() != STATE.EXPIRED && tokenToRenew.getState() != STATE.VALID)) { LOG.log(Level.WARNING, "The token to renew is null or invalid"); throw new STSException( "The token to renew is null or invalid", STSException.INVALID_REQUEST ); } TokenStore tokenStore = tokenParameters.getTokenStore(); if (tokenStore == null) { LOG.log(Level.FINE, "A cache must be configured to use the SAMLTokenRenewer"); throw new STSException("Can't renew SAML assertion", STSException.REQUEST_FAILED); } try { SamlAssertionWrapper assertion = new SamlAssertionWrapper((Element)tokenToRenew.getToken()); byte[] oldSignature = assertion.getSignatureValue(); int hash = Arrays.hashCode(oldSignature); SecurityToken cachedToken = tokenStore.getToken(Integer.toString(hash)); if (cachedToken == null) { LOG.log(Level.FINE, "The token to be renewed must be stored in the cache"); throw new STSException("Can't renew SAML assertion", STSException.REQUEST_FAILED); } // Validate the Assertion validateAssertion(assertion, tokenToRenew, cachedToken, tokenParameters); SamlAssertionWrapper renewedAssertion = new SamlAssertionWrapper(assertion.getSamlObject()); String oldId = createNewId(renewedAssertion); // Remove the previous token (now expired) from the cache tokenStore.remove(oldId); tokenStore.remove(Integer.toString(hash)); // Create new Conditions & sign the Assertion createNewConditions(renewedAssertion, tokenParameters); signAssertion(renewedAssertion, tokenParameters); Document doc = DOMUtils.createDocument(); Element token = renewedAssertion.toDOM(doc); if (renewedAssertion.getSaml1() != null) { token.setIdAttributeNS(null, "AssertionID", true); } else { token.setIdAttributeNS(null, "ID", true); } doc.appendChild(token); // Cache the token storeTokenInCache( tokenStore, renewedAssertion, tokenParameters.getPrincipal(), tokenParameters ); response.setToken(token); response.setTokenId(renewedAssertion.getId()); DateTime validFrom = null; DateTime validTill = null; if (renewedAssertion.getSamlVersion().equals(SAMLVersion.VERSION_20)) { validFrom = renewedAssertion.getSaml2().getConditions().getNotBefore(); validTill = renewedAssertion.getSaml2().getConditions().getNotOnOrAfter(); } else { validFrom = renewedAssertion.getSaml1().getConditions().getNotBefore(); validTill = renewedAssertion.getSaml1().getConditions().getNotOnOrAfter(); } response.setCreated(validFrom.toDate().toInstant()); response.setExpires(validTill.toDate().toInstant()); LOG.fine("SAML Token successfully renewed"); return response; } catch (Exception ex) { LOG.log(Level.WARNING, "", ex); throw new STSException("Can't renew SAML assertion", ex, STSException.REQUEST_FAILED); } } /** * Set the ConditionsProvider */ public void setConditionsProvider(ConditionsProvider conditionsProvider) { this.conditionsProvider = conditionsProvider; } /** * Get the ConditionsProvider */ public ConditionsProvider getConditionsProvider() { return conditionsProvider; } /** * Return whether the provided token will be signed or not. Default is true. */ public boolean isSignToken() { return signToken; } /** * Set whether the provided token will be signed or not. Default is true. */ public void setSignToken(boolean signToken) { this.signToken = signToken; } /** * Set the map of realm->RealmProperties for this token provider * @param realms the map of realm->RealmProperties for this token provider */ public void setRealmMap(Map<String, ? extends RealmProperties> realms) { this.realmMap.clear(); this.realmMap.putAll(realms); } /** * Get the map of realm->RealmProperties for this token provider * @return the map of realm->RealmProperties for this token provider */ public Map<String, RealmProperties> getRealmMap() { return Collections.unmodifiableMap(realmMap); } private void validateAssertion( SamlAssertionWrapper assertion, ReceivedToken tokenToRenew, SecurityToken token, TokenRenewerParameters tokenParameters ) throws WSSecurityException { // Check the cached renewal properties Map<String, Object> props = token.getProperties(); if (props == null) { LOG.log(Level.WARNING, "Error in getting properties from cached token"); throw new STSException( "Error in getting properties from cached token", STSException.REQUEST_FAILED ); } String isAllowRenewal = (String)props.get(STSConstants.TOKEN_RENEWING_ALLOW); String isAllowRenewalAfterExpiry = (String)props.get(STSConstants.TOKEN_RENEWING_ALLOW_AFTER_EXPIRY); if (isAllowRenewal == null || !Boolean.valueOf(isAllowRenewal)) { LOG.log(Level.WARNING, "The token is not allowed to be renewed"); throw new STSException("The token is not allowed to be renewed", STSException.REQUEST_FAILED); } // Check to see whether the token has expired greater than the configured max expiry time if (tokenToRenew.getState() == STATE.EXPIRED) { if (!allowRenewalAfterExpiry || isAllowRenewalAfterExpiry == null || !Boolean.valueOf(isAllowRenewalAfterExpiry)) { LOG.log(Level.WARNING, "Renewal after expiry is not allowed"); throw new STSException( "Renewal after expiry is not allowed", STSException.REQUEST_FAILED ); } DateTime expiryDate = getExpiryDate(assertion); DateTime currentDate = new DateTime(); if ((currentDate.getMillis() - expiryDate.getMillis()) > (maxExpiry * 1000L)) { LOG.log(Level.WARNING, "The token expired too long ago to be renewed"); throw new STSException( "The token expired too long ago to be renewed", STSException.REQUEST_FAILED ); } } // Verify Proof of Possession ProofOfPossessionValidator popValidator = new ProofOfPossessionValidator(); if (verifyProofOfPossession) { STSPropertiesMBean stsProperties = tokenParameters.getStsProperties(); Crypto sigCrypto = stsProperties.getSignatureCrypto(); CallbackHandler callbackHandler = stsProperties.getCallbackHandler(); RequestData requestData = new RequestData(); requestData.setSigVerCrypto(sigCrypto); WSSConfig wssConfig = WSSConfig.getNewInstance(); requestData.setWssConfig(wssConfig); WSDocInfo docInfo = new WSDocInfo(((Element)tokenToRenew.getToken()).getOwnerDocument()); requestData.setWsDocInfo(docInfo); // Parse the HOK subject if it exists assertion.parseSubject( new WSSSAMLKeyInfoProcessor(requestData), sigCrypto, callbackHandler ); SAMLKeyInfo keyInfo = assertion.getSubjectKeyInfo(); if (keyInfo == null) { keyInfo = new SAMLKeyInfo((byte[])null); } if (!popValidator.checkProofOfPossession(tokenParameters, keyInfo)) { throw new STSException( "Failed to verify the proof of possession of the key associated with the " + "saml token. No matching key found in the request.", STSException.INVALID_REQUEST ); } } // Check the AppliesTo address String appliesToAddress = tokenParameters.getAppliesToAddress(); if (appliesToAddress != null) { if (assertion.getSaml1() != null) { List<AudienceRestrictionCondition> restrConditions = assertion.getSaml1().getConditions().getAudienceRestrictionConditions(); if (!matchSaml1AudienceRestriction(appliesToAddress, restrConditions)) { LOG.log(Level.WARNING, "The AppliesTo address does not match the Audience Restriction"); throw new STSException( "The AppliesTo address does not match the Audience Restriction", STSException.INVALID_REQUEST ); } } else { List<AudienceRestriction> audienceRestrs = assertion.getSaml2().getConditions().getAudienceRestrictions(); if (!matchSaml2AudienceRestriction(appliesToAddress, audienceRestrs)) { LOG.log(Level.WARNING, "The AppliesTo address does not match the Audience Restriction"); throw new STSException( "The AppliesTo address does not match the Audience Restriction", STSException.INVALID_REQUEST ); } } } } private boolean matchSaml1AudienceRestriction( String appliesTo, List<AudienceRestrictionCondition> restrConditions ) { boolean found = false; if (restrConditions != null && !restrConditions.isEmpty()) { for (AudienceRestrictionCondition restrCondition : restrConditions) { if (restrCondition.getAudiences() != null) { for (Audience audience : restrCondition.getAudiences()) { if (appliesTo.equals(audience.getUri())) { return true; } } } } } return found; } private boolean matchSaml2AudienceRestriction( String appliesTo, List<AudienceRestriction> audienceRestrictions ) { boolean found = false; if (audienceRestrictions != null && !audienceRestrictions.isEmpty()) { for (AudienceRestriction audienceRestriction : audienceRestrictions) { if (audienceRestriction.getAudiences() != null) { for (org.opensaml.saml.saml2.core.Audience audience : audienceRestriction.getAudiences()) { if (appliesTo.equals(audience.getAudienceURI())) { return true; } } } } } return found; } private void signAssertion( SamlAssertionWrapper assertion, TokenRenewerParameters tokenParameters ) throws Exception { if (signToken) { STSPropertiesMBean stsProperties = tokenParameters.getStsProperties(); String realm = tokenParameters.getRealm(); RealmProperties samlRealm = null; if (realm != null && realmMap.containsKey(realm)) { samlRealm = realmMap.get(realm); } signToken(assertion, samlRealm, stsProperties, tokenParameters.getKeyRequirements()); } else { if (assertion.getSaml1().getSignature() != null) { assertion.getSaml1().setSignature(null); } else if (assertion.getSaml2().getSignature() != null) { assertion.getSaml2().setSignature(null); } } } private void createNewConditions(SamlAssertionWrapper assertion, TokenRenewerParameters tokenParameters) { ConditionsBean conditions = conditionsProvider.getConditions(convertToProviderParameters(tokenParameters)); if (assertion.getSaml1() != null) { org.opensaml.saml.saml1.core.Assertion saml1Assertion = assertion.getSaml1(); saml1Assertion.setIssueInstant(new DateTime()); org.opensaml.saml.saml1.core.Conditions saml1Conditions = SAML1ComponentBuilder.createSamlv1Conditions(conditions); saml1Assertion.setConditions(saml1Conditions); } else { org.opensaml.saml.saml2.core.Assertion saml2Assertion = assertion.getSaml2(); saml2Assertion.setIssueInstant(new DateTime()); org.opensaml.saml.saml2.core.Conditions saml2Conditions = SAML2ComponentBuilder.createConditions(conditions); saml2Assertion.setConditions(saml2Conditions); } } private TokenProviderParameters convertToProviderParameters( TokenRenewerParameters renewerParameters ) { TokenProviderParameters providerParameters = new TokenProviderParameters(); providerParameters.setAppliesToAddress(renewerParameters.getAppliesToAddress()); providerParameters.setEncryptionProperties(renewerParameters.getEncryptionProperties()); providerParameters.setKeyRequirements(renewerParameters.getKeyRequirements()); providerParameters.setPrincipal(renewerParameters.getPrincipal()); providerParameters.setRealm(renewerParameters.getRealm()); providerParameters.setStsProperties(renewerParameters.getStsProperties()); providerParameters.setTokenRequirements(renewerParameters.getTokenRequirements()); providerParameters.setTokenStore(renewerParameters.getTokenStore()); providerParameters.setMessageContext(renewerParameters.getMessageContext()); // Store token to renew in the additional properties in case you want to base some // Conditions on the token Map<String, Object> additionalProperties = renewerParameters.getAdditionalProperties(); if (additionalProperties == null) { additionalProperties = new HashMap<>(1); } additionalProperties.put(ReceivedToken.class.getName(), renewerParameters.getToken()); providerParameters.setAdditionalProperties(additionalProperties); return providerParameters; } private String createNewId(SamlAssertionWrapper assertion) { if (assertion.getSaml1() != null) { org.opensaml.saml.saml1.core.Assertion saml1Assertion = assertion.getSaml1(); String oldId = saml1Assertion.getID(); saml1Assertion.setID(IDGenerator.generateID("_")); return oldId; } else { org.opensaml.saml.saml2.core.Assertion saml2Assertion = assertion.getSaml2(); String oldId = saml2Assertion.getID(); saml2Assertion.setID(IDGenerator.generateID("_")); return oldId; } } private void storeTokenInCache( TokenStore tokenStore, SamlAssertionWrapper assertion, Principal principal, TokenRenewerParameters tokenParameters ) throws WSSecurityException { // Store the successfully renewed token in the cache byte[] signatureValue = assertion.getSignatureValue(); if (tokenStore != null && signatureValue != null && signatureValue.length > 0) { SecurityToken securityToken = CacheUtils.createSecurityTokenForStorage(assertion.getElement(), assertion.getId(), assertion.getNotOnOrAfter(), tokenParameters.getPrincipal(), tokenParameters.getRealm(), tokenParameters.getTokenRequirements().getRenewing()); CacheUtils.storeTokenInCache( securityToken, tokenParameters.getTokenStore(), signatureValue); } } private DateTime getExpiryDate(SamlAssertionWrapper assertion) { if (assertion.getSamlVersion().equals(SAMLVersion.VERSION_20)) { return assertion.getSaml2().getConditions().getNotOnOrAfter(); } else { return assertion.getSaml1().getConditions().getNotOnOrAfter(); } } private static class ProofOfPossessionValidator { public boolean checkProofOfPossession( TokenRenewerParameters tokenParameters, SAMLKeyInfo subjectKeyInfo ) { Map<String, Object> messageContext = tokenParameters.getMessageContext(); final List<WSHandlerResult> handlerResults = CastUtils.cast((List<?>) messageContext.get(WSHandlerConstants.RECV_RESULTS)); List<WSSecurityEngineResult> signedResults = new ArrayList<>(); if (handlerResults != null && !handlerResults.isEmpty()) { WSHandlerResult handlerResult = handlerResults.get(0); if (handlerResult.getActionResults().containsKey(WSConstants.SIGN)) { signedResults.addAll(handlerResult.getActionResults().get(WSConstants.SIGN)); } if (handlerResult.getActionResults().containsKey(WSConstants.UT_SIGN)) { signedResults.addAll(handlerResult.getActionResults().get(WSConstants.UT_SIGN)); } } TLSSessionInfo tlsInfo = (TLSSessionInfo)messageContext.get(TLSSessionInfo.class.getName()); Certificate[] tlsCerts = null; if (tlsInfo != null) { tlsCerts = tlsInfo.getPeerCertificates(); } return DOMSAMLUtil.compareCredentials(subjectKeyInfo, signedResults, tlsCerts); } } }