/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.rs.security.saml.sso;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ResourceBundle;
import java.util.UUID;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.zip.DataFormatException;
import javax.annotation.PreDestroy;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.Response;
import org.w3c.dom.Document;
import org.apache.cxf.Bus;
import org.apache.cxf.common.i18n.BundleUtils;
import org.apache.cxf.common.logging.LogUtils;
import org.apache.cxf.common.util.Base64Exception;
import org.apache.cxf.common.util.Base64Utility;
import org.apache.cxf.common.util.StringUtils;
import org.apache.cxf.jaxrs.ext.MessageContext;
import org.apache.cxf.jaxrs.utils.ExceptionUtils;
import org.apache.cxf.jaxrs.utils.JAXRSUtils;
import org.apache.cxf.rs.security.saml.DeflateEncoderDecoder;
import org.apache.cxf.rs.security.saml.sso.state.RequestState;
import org.apache.cxf.rs.security.saml.sso.state.ResponseState;
import org.apache.cxf.staxutils.StaxUtils;
import org.apache.wss4j.common.ext.WSSecurityException;
import org.apache.wss4j.common.saml.OpenSAMLUtil;
import org.apache.wss4j.common.util.DOM2Writer;
import org.opensaml.core.xml.XMLObject;
public abstract class AbstractRequestAssertionConsumerHandler extends AbstractSSOSpHandler {
private static final Logger LOG =
LogUtils.getL7dLogger(AbstractRequestAssertionConsumerHandler.class);
private static final ResourceBundle BUNDLE =
BundleUtils.getBundle(AbstractRequestAssertionConsumerHandler.class);
private boolean supportDeflateEncoding = true;
private boolean supportBase64Encoding = true;
private boolean enforceAssertionsSigned = true;
private boolean enforceKnownIssuer = true;
private boolean keyInfoMustBeAvailable = true;
private boolean enforceResponseSigned;
private TokenReplayCache<String> replayCache;
private MessageContext messageContext;
private String applicationURL;
private boolean parseApplicationURLFromRelayState;
private String assertionConsumerServiceAddress;
@Context
public void setMessageContext(MessageContext mc) {
this.messageContext = mc;
}
public void setSupportDeflateEncoding(boolean deflate) {
supportDeflateEncoding = deflate;
}
public boolean isSupportDeflateEncoding() {
return supportDeflateEncoding;
}
public void setReplayCache(TokenReplayCache<String> replayCache) {
this.replayCache = replayCache;
}
public TokenReplayCache<String> getReplayCache() {
if (replayCache == null) {
Bus bus = (Bus)messageContext.getContextualProperty(Bus.class.getName());
replayCache = new EHCacheTokenReplayCache(bus);
}
return replayCache;
}
/**
* Enforce that Assertions must be signed if the POST binding was used. The default is true.
*/
public void setEnforceAssertionsSigned(boolean enforceAssertionsSigned) {
this.enforceAssertionsSigned = enforceAssertionsSigned;
}
/**
* Enforce that the Issuer of the received Response/Assertion is known to this RACS. The
* default is true.
*/
public void setEnforceKnownIssuer(boolean enforceKnownIssuer) {
this.enforceKnownIssuer = enforceKnownIssuer;
}
public void setSupportBase64Encoding(boolean supportBase64Encoding) {
this.supportBase64Encoding = supportBase64Encoding;
}
public boolean isSupportBase64Encoding() {
return supportBase64Encoding;
}
@PreDestroy
@Override
public void close() {
if (replayCache != null) {
try {
replayCache.close();
} catch (IOException ex) {
LOG.warning("Replay cache can not be closed: " + ex.getMessage());
}
}
super.close();
}
protected Response doProcessSamlResponse(String encodedSamlResponse,
String relayState,
boolean postBinding) {
RequestState requestState = processRelayState(relayState);
String contextCookie = createSecurityContext(requestState,
encodedSamlResponse,
relayState,
postBinding);
// Finally, redirect to the service provider endpoint
URI targetURI = getTargetURI(requestState.getTargetAddress());
return Response.seeOther(targetURI).header("Set-Cookie", contextCookie).build();
}
private URI getTargetURI(String targetAddress) {
if (targetAddress != null) {
try {
return URI.create(targetAddress);
} catch (IllegalArgumentException ex) {
reportError("INVALID_TARGET_URI");
}
} else {
reportError("MISSING_TARGET_URI");
}
throw ExceptionUtils.toBadRequestException(null, null);
}
protected String createSecurityContext(RequestState requestState,
String encodedSamlResponse,
String relayState,
boolean postBinding) {
org.opensaml.saml.saml2.core.Response samlResponse =
readSAMLResponse(postBinding, encodedSamlResponse);
// Validate the Response
validateSamlResponseProtocol(samlResponse);
SSOValidatorResponse validatorResponse =
validateSamlSSOResponse(postBinding, samlResponse, requestState);
// Set the security context
String securityContextKey = UUID.randomUUID().toString();
long currentTime = System.currentTimeMillis();
Instant notOnOrAfter = validatorResponse.getSessionNotOnOrAfter();
long expiresAt = 0;
if (notOnOrAfter != null) {
expiresAt = notOnOrAfter.toEpochMilli();
} else {
expiresAt = currentTime + getStateTimeToLive();
}
ResponseState responseState =
new ResponseState(validatorResponse.getAssertion(),
relayState,
requestState.getWebAppContext(),
requestState.getWebAppDomain(),
currentTime,
expiresAt);
getStateProvider().setResponseState(securityContextKey, responseState);
return createCookie(SSOConstants.SECURITY_CONTEXT_TOKEN,
securityContextKey,
requestState.getWebAppContext(),
requestState.getWebAppDomain());
}
protected RequestState processRelayState(String relayState) {
if (isSupportUnsolicited()) {
String urlToForwardTo = applicationURL;
if (relayState != null && relayState.getBytes().length > 0 && relayState.getBytes().length < 80) {
// First see if we have a valid RequestState
RequestState requestState = getStateProvider().removeRequestState(relayState);
if (requestState != null && !isStateExpired(requestState.getCreatedAt(), 0)) {
return requestState;
}
// Otherwise get the application URL from the RelayState if supported
if (parseApplicationURLFromRelayState) {
urlToForwardTo = relayState;
}
}
// Otherwise create a new one for the IdP initiated case
Instant now = Instant.now();
return new RequestState(urlToForwardTo,
getIdpServiceAddress(),
null,
getIssuerId(JAXRSUtils.getCurrentMessage()),
"/",
null,
now.toEpochMilli());
}
if (relayState == null) {
reportError("MISSING_RELAY_STATE");
throw ExceptionUtils.toBadRequestException(null, null);
}
if (relayState.getBytes().length == 0 || relayState.getBytes().length > 80) {
reportError("INVALID_RELAY_STATE");
throw ExceptionUtils.toBadRequestException(null, null);
}
RequestState requestState = getStateProvider().removeRequestState(relayState);
if (requestState == null) {
reportError("MISSING_REQUEST_STATE");
throw ExceptionUtils.toBadRequestException(null, null);
}
if (isStateExpired(requestState.getCreatedAt(), 0)) {
reportError("EXPIRED_REQUEST_STATE");
throw ExceptionUtils.toBadRequestException(null, null);
}
return requestState;
}
private org.opensaml.saml.saml2.core.Response readSAMLResponse(
boolean postBinding,
String samlResponse
) {
if (StringUtils.isEmpty(samlResponse)) {
reportError("MISSING_SAML_RESPONSE");
throw ExceptionUtils.toBadRequestException(null, null);
}
String samlResponseDecoded = samlResponse;
/*
// URL Decoding only applies for the re-direct binding
if (!postBinding) {
try {
samlResponseDecoded = URLDecoder.decode(samlResponse, StandardCharsets.UTF_8);
} catch (UnsupportedEncodingException e) {
throw ExceptionUtils.toBadRequestException(null, null);
}
}
*/
InputStream tokenStream = null;
if (isSupportBase64Encoding()) {
try {
byte[] deflatedToken = Base64Utility.decode(samlResponseDecoded);
tokenStream = !postBinding && isSupportDeflateEncoding()
? new DeflateEncoderDecoder().inflateToken(deflatedToken)
: new ByteArrayInputStream(deflatedToken);
} catch (Base64Exception ex) {
throw ExceptionUtils.toBadRequestException(ex, null);
} catch (DataFormatException ex) {
throw ExceptionUtils.toBadRequestException(ex, null);
}
} else {
tokenStream = new ByteArrayInputStream(samlResponseDecoded.getBytes(StandardCharsets.UTF_8));
}
Document responseDoc = null;
try {
responseDoc = StaxUtils.read(new InputStreamReader(tokenStream, StandardCharsets.UTF_8));
} catch (Exception ex) {
throw new WebApplicationException(400);
}
if (LOG.isLoggable(Level.FINE)) {
LOG.fine("Received response: " + DOM2Writer.nodeToString(responseDoc.getDocumentElement()));
}
XMLObject responseObject = null;
try {
responseObject = OpenSAMLUtil.fromDom(responseDoc.getDocumentElement());
} catch (WSSecurityException ex) {
throw ExceptionUtils.toBadRequestException(ex, null);
}
if (!(responseObject instanceof org.opensaml.saml.saml2.core.Response)) {
throw ExceptionUtils.toBadRequestException(null, null);
}
return (org.opensaml.saml.saml2.core.Response)responseObject;
}
/**
* Validate the received SAML Response as per the protocol
*/
protected void validateSamlResponseProtocol(
org.opensaml.saml.saml2.core.Response samlResponse
) {
try {
SAMLProtocolResponseValidator protocolValidator = new SAMLProtocolResponseValidator();
protocolValidator.setKeyInfoMustBeAvailable(keyInfoMustBeAvailable);
protocolValidator.validateSamlResponse(samlResponse, getSignatureCrypto(), getCallbackHandler());
} catch (WSSecurityException ex) {
LOG.log(Level.FINE, ex.getMessage(), ex);
reportError("INVALID_SAML_RESPONSE");
throw ExceptionUtils.toBadRequestException(null, null);
}
}
/**
* Validate the received SAML Response as per the Web SSO profile
*/
protected SSOValidatorResponse validateSamlSSOResponse(
boolean postBinding,
org.opensaml.saml.saml2.core.Response samlResponse,
RequestState requestState
) {
try {
SAMLSSOResponseValidator ssoResponseValidator = new SAMLSSOResponseValidator();
String racsAddress = assertionConsumerServiceAddress;
if (racsAddress == null) {
racsAddress = messageContext.getUriInfo().getAbsolutePath().toString();
}
ssoResponseValidator.setAssertionConsumerURL(racsAddress);
ssoResponseValidator.setClientAddress(
messageContext.getHttpServletRequest().getRemoteAddr());
ssoResponseValidator.setIssuerIDP(requestState.getIdpServiceAddress());
ssoResponseValidator.setRequestId(requestState.getSamlRequestId());
ssoResponseValidator.setSpIdentifier(requestState.getIssuerId());
ssoResponseValidator.setEnforceAssertionsSigned(enforceAssertionsSigned);
ssoResponseValidator.setEnforceResponseSigned(enforceResponseSigned);
ssoResponseValidator.setEnforceKnownIssuer(enforceKnownIssuer);
if (postBinding) {
ssoResponseValidator.setReplayCache(getReplayCache());
}
return ssoResponseValidator.validateSamlResponse(samlResponse, postBinding);
} catch (WSSecurityException ex) {
reportError("INVALID_SAML_RESPONSE");
throw ExceptionUtils.toBadRequestException(ex, null);
}
}
protected void reportError(String code) {
org.apache.cxf.common.i18n.Message errorMsg =
new org.apache.cxf.common.i18n.Message(code, BUNDLE);
LOG.warning(errorMsg.toString());
}
public void setKeyInfoMustBeAvailable(boolean keyInfoMustBeAvailable) {
this.keyInfoMustBeAvailable = keyInfoMustBeAvailable;
}
public boolean isEnforceResponseSigned() {
return enforceResponseSigned;
}
/**
* Enforce that a SAML Response must be signed.
*/
public void setEnforceResponseSigned(boolean enforceResponseSigned) {
this.enforceResponseSigned = enforceResponseSigned;
}
public String getApplicationURL() {
return applicationURL;
}
/**
* Set the Application URL to forward to, for the unsolicited IdP case.
* @param applicationURL
*/
public void setApplicationURL(String applicationURL) {
this.applicationURL = applicationURL;
}
public boolean isParseApplicationURLFromRelayState() {
return parseApplicationURLFromRelayState;
}
/**
* Whether to parse the application URL to forward to from the RelayState, for the unsolicted IdP case.
* @param parseApplicationURLFromRelayState
*/
public void setParseApplicationURLFromRelayState(boolean parseApplicationURLFromRelayState) {
this.parseApplicationURLFromRelayState = parseApplicationURLFromRelayState;
}
public String getAssertionConsumerServiceAddress() {
return assertionConsumerServiceAddress;
}
public void setAssertionConsumerServiceAddress(String assertionConsumerServiceAddress) {
this.assertionConsumerServiceAddress = assertionConsumerServiceAddress;
}
}