/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.rs.security.xml;
import java.io.IOException;
import java.io.InputStream;
import java.security.Key;
import java.security.PublicKey;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ReaderInterceptor;
import javax.ws.rs.ext.ReaderInterceptorContext;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamReader;
import org.apache.cxf.common.logging.LogUtils;
import org.apache.cxf.interceptor.Fault;
import org.apache.cxf.interceptor.StaxInInterceptor;
import org.apache.cxf.jaxrs.impl.ReaderInterceptorContextImpl;
import org.apache.cxf.jaxrs.utils.ExceptionUtils;
import org.apache.cxf.jaxrs.utils.JAXRSUtils;
import org.apache.cxf.message.Message;
import org.apache.cxf.message.MessageUtils;
import org.apache.cxf.phase.AbstractPhaseInterceptor;
import org.apache.cxf.phase.Phase;
import org.apache.cxf.rs.security.common.CryptoLoader;
import org.apache.cxf.rs.security.common.RSSecurityUtils;
import org.apache.cxf.rs.security.common.TrustValidator;
import org.apache.cxf.rt.security.SecurityConstants;
import org.apache.cxf.rt.security.utils.SecurityUtils;
import org.apache.cxf.staxutils.StaxUtils;
import org.apache.wss4j.common.crypto.Crypto;
import org.apache.wss4j.common.crypto.CryptoType;
import org.apache.wss4j.common.ext.WSPasswordCallback;
import org.apache.wss4j.common.ext.WSSecurityException;
import org.apache.xml.security.exceptions.XMLSecurityException;
import org.apache.xml.security.stax.ext.InboundXMLSec;
import org.apache.xml.security.stax.ext.XMLSec;
import org.apache.xml.security.stax.ext.XMLSecurityConstants;
import org.apache.xml.security.stax.ext.XMLSecurityProperties;
import org.apache.xml.security.stax.impl.securityToken.KeyNameSecurityToken;
import org.apache.xml.security.stax.securityEvent.AlgorithmSuiteSecurityEvent;
import org.apache.xml.security.stax.securityEvent.SecurityEvent;
import org.apache.xml.security.stax.securityEvent.SecurityEventConstants;
import org.apache.xml.security.stax.securityEvent.SecurityEventConstants.Event;
import org.apache.xml.security.stax.securityEvent.SecurityEventListener;
import org.apache.xml.security.stax.securityEvent.TokenSecurityEvent;
import org.apache.xml.security.stax.securityToken.SecurityToken;
/**
* A new StAX-based interceptor for processing messages with XML Signature + Encryption content.
*/
public class XmlSecInInterceptor extends AbstractPhaseInterceptor<Message> implements ReaderInterceptor {
private static final Logger LOG = LogUtils.getL7dLogger(XmlSecInInterceptor.class);
private EncryptionProperties encryptionProperties;
private SignatureProperties sigProps;
private String decryptionAlias;
private String signatureVerificationAlias;
private boolean persistSignature = true;
private boolean requireSignature;
private boolean requireEncryption;
/**
* a collection of compiled regular expression patterns for the subject DN
*/
private Collection<Pattern> subjectDNPatterns = new ArrayList<>();
public XmlSecInInterceptor() {
super(Phase.POST_STREAM);
getAfter().add(StaxInInterceptor.class.getName());
}
public void handleMessage(Message message) throws Fault {
if (!canDocumentBeRead(message)) {
return;
}
prepareMessage(message);
message.getInterceptorChain().add(
new StaxActionInInterceptor(requireSignature, requireEncryption));
}
private void prepareMessage(Message inMsg) throws Fault {
XMLStreamReader originalXmlStreamReader = inMsg.getContent(XMLStreamReader.class);
if (originalXmlStreamReader == null) {
InputStream is = inMsg.getContent(InputStream.class);
if (is != null) {
originalXmlStreamReader = StaxUtils.createXMLStreamReader(is);
}
}
try {
XMLSecurityProperties properties = new XMLSecurityProperties();
configureDecryptionKeys(inMsg, properties);
Crypto signatureCrypto = getSignatureCrypto(inMsg);
configureSignatureKeys(signatureCrypto, inMsg, properties);
SecurityEventListener securityEventListener =
configureSecurityEventListener(signatureCrypto, inMsg, properties);
InboundXMLSec inboundXMLSec = XMLSec.getInboundWSSec(properties);
XMLStreamReader newXmlStreamReader =
inboundXMLSec.processInMessage(originalXmlStreamReader, null, securityEventListener);
inMsg.setContent(XMLStreamReader.class, newXmlStreamReader);
} catch (XMLStreamException e) {
throwFault(e.getMessage(), e);
} catch (XMLSecurityException e) {
throwFault(e.getMessage(), e);
} catch (IOException e) {
throwFault(e.getMessage(), e);
} catch (UnsupportedCallbackException e) {
throwFault(e.getMessage(), e);
}
}
private boolean canDocumentBeRead(Message message) {
if (isServerGet(message)) {
return false;
} else {
Integer responseCode = (Integer)message.get(Message.RESPONSE_CODE);
if (responseCode != null && responseCode != 200) {
return false;
}
}
return true;
}
private boolean isServerGet(Message message) {
String method = (String)message.get(Message.HTTP_REQUEST_METHOD);
return "GET".equals(method) && !MessageUtils.isRequestor(message);
}
private void configureDecryptionKeys(Message message, XMLSecurityProperties properties)
throws IOException,
UnsupportedCallbackException, WSSecurityException {
String cryptoKey = null;
String propKey = null;
if (RSSecurityUtils.isSignedAndEncryptedTwoWay(message)) {
cryptoKey = SecurityConstants.SIGNATURE_CRYPTO;
propKey = SecurityConstants.SIGNATURE_PROPERTIES;
} else {
cryptoKey = SecurityConstants.ENCRYPT_CRYPTO;
propKey = SecurityConstants.ENCRYPT_PROPERTIES;
}
Crypto crypto = null;
try {
crypto = new CryptoLoader().getCrypto(message, cryptoKey, propKey);
} catch (Exception ex) {
throwFault("Crypto can not be loaded", ex);
}
if (crypto != null) {
String alias = decryptionAlias;
if (alias == null) {
alias = crypto.getDefaultX509Identifier();
}
if (alias != null) {
CallbackHandler callback = RSSecurityUtils.getCallbackHandler(message, this.getClass());
WSPasswordCallback passwordCallback =
new WSPasswordCallback(alias, WSPasswordCallback.DECRYPT);
callback.handle(new Callback[] {passwordCallback});
Key privateKey = crypto.getPrivateKey(alias, passwordCallback.getPassword());
properties.setDecryptionKey(privateKey);
}
}
}
private Crypto getSignatureCrypto(Message message) {
String cryptoKey = null;
String propKey = null;
if (RSSecurityUtils.isSignedAndEncryptedTwoWay(message)) {
cryptoKey = SecurityConstants.ENCRYPT_CRYPTO;
propKey = SecurityConstants.ENCRYPT_PROPERTIES;
} else {
cryptoKey = SecurityConstants.SIGNATURE_CRYPTO;
propKey = SecurityConstants.SIGNATURE_PROPERTIES;
}
try {
return new CryptoLoader().getCrypto(message, cryptoKey, propKey);
} catch (Exception ex) {
throwFault("Crypto can not be loaded", ex);
return null;
}
}
private void configureSignatureKeys(
Crypto sigCrypto, Message message, XMLSecurityProperties properties
) throws IOException,
UnsupportedCallbackException, WSSecurityException {
if (sigCrypto != null && signatureVerificationAlias != null) {
CryptoType cryptoType = new CryptoType(CryptoType.TYPE.ALIAS);
cryptoType.setAlias(signatureVerificationAlias);
X509Certificate[] certs = sigCrypto.getX509Certificates(cryptoType);
if (certs != null && certs.length > 0) {
properties.setSignatureVerificationKey(certs[0].getPublicKey());
}
} else if (sigCrypto != null && sigProps != null && sigProps.getKeyNameAliasMap() != null) {
Map<String, String> keyNameAliasMap = sigProps.getKeyNameAliasMap();
for (Map.Entry<String, String> mapping: keyNameAliasMap.entrySet()) {
CryptoType cryptoType = new CryptoType(CryptoType.TYPE.ALIAS);
cryptoType.setAlias(mapping.getValue());
X509Certificate[] certs = sigCrypto.getX509Certificates(cryptoType);
if (certs != null && certs.length > 0) {
properties.addKeyNameMapping(mapping.getKey(), certs[0].getPublicKey());
}
}
}
}
protected SecurityEventListener configureSecurityEventListener(
final Crypto sigCrypto, final Message msg, XMLSecurityProperties securityProperties
) {
final List<SecurityEvent> incomingSecurityEventList = new LinkedList<>();
SecurityEventListener securityEventListener = new SecurityEventListener() {
@Override
public void registerSecurityEvent(SecurityEvent securityEvent) throws XMLSecurityException {
if (securityEvent.getSecurityEventType() == SecurityEventConstants.AlgorithmSuite) {
if (encryptionProperties != null) {
checkEncryptionAlgorithms((AlgorithmSuiteSecurityEvent)securityEvent);
}
if (sigProps != null) {
checkSignatureAlgorithms((AlgorithmSuiteSecurityEvent)securityEvent);
}
} else if (securityEvent.getSecurityEventType() != SecurityEventConstants.EncryptedKeyToken
&& securityEvent instanceof TokenSecurityEvent<?>) {
checkSignatureTrust(sigCrypto, msg, (TokenSecurityEvent<?>)securityEvent);
}
incomingSecurityEventList.add(securityEvent);
}
};
msg.getExchange().put(SecurityEvent.class.getName() + ".in", incomingSecurityEventList);
msg.put(SecurityEvent.class.getName() + ".in", incomingSecurityEventList);
return securityEventListener;
}
private void checkEncryptionAlgorithms(AlgorithmSuiteSecurityEvent event)
throws XMLSecurityException {
if (XMLSecurityConstants.Enc.equals(event.getAlgorithmUsage())
&& encryptionProperties.getEncryptionSymmetricKeyAlgo() != null
&& !encryptionProperties.getEncryptionSymmetricKeyAlgo().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", new Object[] {"The symmetric encryption algorithm "
+ event.getAlgorithmURI() + " is not allowed"});
} else if ((XMLSecurityConstants.Sym_Key_Wrap.equals(event.getAlgorithmUsage())
|| XMLSecurityConstants.Asym_Key_Wrap.equals(event.getAlgorithmUsage()))
&& encryptionProperties.getEncryptionKeyTransportAlgo() != null
&& !encryptionProperties.getEncryptionKeyTransportAlgo().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", new Object[] {"The key transport algorithm "
+ event.getAlgorithmURI() + " is not allowed"});
} else if (XMLSecurityConstants.EncDig.equals(event.getAlgorithmUsage())
&& encryptionProperties.getEncryptionDigestAlgo() != null
&& !encryptionProperties.getEncryptionDigestAlgo().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", new Object[] {"The encryption digest algorithm "
+ event.getAlgorithmURI() + " is not allowed"});
}
}
private void checkSignatureAlgorithms(AlgorithmSuiteSecurityEvent event)
throws XMLSecurityException {
if ((XMLSecurityConstants.Asym_Sig.equals(event.getAlgorithmUsage())
|| XMLSecurityConstants.Sym_Sig.equals(event.getAlgorithmUsage()))
&& sigProps.getSignatureAlgo() != null
&& !sigProps.getSignatureAlgo().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", new Object[] {"The signature algorithm "
+ event.getAlgorithmURI() + " is not allowed"});
} else if (XMLSecurityConstants.SigDig.equals(event.getAlgorithmUsage())
&& sigProps.getSignatureDigestAlgo() != null
&& !sigProps.getSignatureDigestAlgo().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", new Object[] {"The signature digest algorithm "
+ event.getAlgorithmURI() + " is not allowed"});
} else if (XMLSecurityConstants.SigC14n.equals(event.getAlgorithmUsage())
&& sigProps.getSignatureC14nMethod() != null
&& !sigProps.getSignatureC14nMethod().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", new Object[] {"The signature c14n algorithm "
+ event.getAlgorithmURI() + " is not allowed"});
} else if (XMLSecurityConstants.SigTransform.equals(event.getAlgorithmUsage())
&& !XMLSecurityConstants.NS_XMLDSIG_ENVELOPED_SIGNATURE.equals(event.getAlgorithmURI())
&& sigProps.getSignatureC14nTransform() != null
&& !sigProps.getSignatureC14nTransform().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", new Object[] {"The signature transformation algorithm "
+ event.getAlgorithmURI() + " is not allowed"});
}
}
private void checkSignatureTrust(
Crypto sigCrypto, Message msg, TokenSecurityEvent<?> event
) throws XMLSecurityException {
SecurityToken token = event.getSecurityToken();
if (token != null) {
X509Certificate[] certs = token.getX509Certificates();
if (certs == null && token.getPublicKey() == null && token instanceof KeyNameSecurityToken) {
certs = getX509CertificatesForKeyName(sigCrypto, msg, (KeyNameSecurityToken)token);
}
PublicKey publicKey = token.getPublicKey();
X509Certificate cert = null;
if (certs != null && certs.length > 0) {
cert = certs[0];
}
// validate trust
try {
new TrustValidator().validateTrust(sigCrypto, cert, publicKey,
getSubjectContraints(msg));
} catch (WSSecurityException e) {
String error = "Signature validation failed";
throw new XMLSecurityException("empty", new Object[] {error});
}
if (persistSignature) {
msg.setContent(X509Certificate.class, cert);
}
}
}
private X509Certificate[] getX509CertificatesForKeyName(Crypto sigCrypto, Message msg, KeyNameSecurityToken token)
throws XMLSecurityException {
X509Certificate[] certs;
KeyNameSecurityToken keyNameSecurityToken = token;
String keyName = keyNameSecurityToken.getKeyName();
String alias = null;
if (sigProps != null && sigProps.getKeyNameAliasMap() != null) {
alias = sigProps.getKeyNameAliasMap().get(keyName);
}
try {
certs = RSSecurityUtils.getCertificates(sigCrypto, alias);
} catch (Exception e) {
throw new XMLSecurityException("empty", new Object[] {"Error during Signature Trust "
+ "validation"});
}
return certs;
}
protected void throwFault(String error, Exception ex) {
LOG.warning(error);
Response response = JAXRSUtils.toResponseBuilder(400).entity(error).type("text/plain").build();
throw ExceptionUtils.toBadRequestException(null, response);
}
public void setEncryptionProperties(EncryptionProperties properties) {
this.encryptionProperties = properties;
}
public void setSignatureProperties(SignatureProperties properties) {
this.sigProps = properties;
}
public String getDecryptionAlias() {
return decryptionAlias;
}
public void setDecryptionAlias(String decryptionAlias) {
this.decryptionAlias = decryptionAlias;
}
public String getSignatureVerificationAlias() {
return signatureVerificationAlias;
}
public void setSignatureVerificationAlias(String signatureVerificationAlias) {
this.signatureVerificationAlias = signatureVerificationAlias;
}
public void setPersistSignature(boolean persist) {
this.persistSignature = persist;
}
public boolean isRequireSignature() {
return requireSignature;
}
public void setRequireSignature(boolean requireSignature) {
this.requireSignature = requireSignature;
}
public boolean isRequireEncryption() {
return requireEncryption;
}
public void setRequireEncryption(boolean requireEncryption) {
this.requireEncryption = requireEncryption;
}
/**
* Set a list of Strings corresponding to regular expression constraints on the subject DN
* of a certificate
*/
public void setSubjectConstraints(List<String> constraints) {
if (constraints != null) {
subjectDNPatterns = new ArrayList<>();
for (String constraint : constraints) {
try {
subjectDNPatterns.add(Pattern.compile(constraint.trim()));
} catch (PatternSyntaxException ex) {
throw ex;
}
}
}
}
private Collection<Pattern> getSubjectContraints(Message msg) throws PatternSyntaxException {
String certConstraints =
(String)SecurityUtils.getSecurityPropertyValue(SecurityConstants.SUBJECT_CERT_CONSTRAINTS, msg);
// Check the message property first. If this is not null then use it. Otherwise pick up
// the constraints set as a property
if (certConstraints != null) {
String[] certConstraintsList = certConstraints.split(",");
if (certConstraintsList != null) {
subjectDNPatterns.clear();
for (String certConstraint : certConstraintsList) {
subjectDNPatterns.add(Pattern.compile(certConstraint.trim()));
}
}
}
return subjectDNPatterns;
}
@Override
public Object aroundReadFrom(ReaderInterceptorContext ctx) throws IOException, WebApplicationException {
Message message = ((ReaderInterceptorContextImpl)ctx).getMessage();
if (!canDocumentBeRead(message)) {
return ctx.proceed();
} else {
prepareMessage(message);
Object object = ctx.proceed();
new StaxActionInInterceptor(requireSignature,
requireEncryption).handleMessage(message);
return object;
}
}
/**
* This interceptor handles parsing the StaX results (events) + checks to see whether the
* required (if any) Actions (signature or encryption) were fulfilled.
*/
private static class StaxActionInInterceptor extends AbstractPhaseInterceptor<Message> {
private static final Logger LOG =
LogUtils.getL7dLogger(StaxActionInInterceptor.class);
private final boolean signatureRequired;
private final boolean encryptionRequired;
StaxActionInInterceptor(boolean signatureRequired, boolean encryptionRequired) {
super(Phase.PRE_LOGICAL);
this.signatureRequired = signatureRequired;
this.encryptionRequired = encryptionRequired;
}
@Override
public void handleMessage(Message message) throws Fault {
if (!(signatureRequired || encryptionRequired)) {
return;
}
@SuppressWarnings("unchecked")
final List<SecurityEvent> incomingSecurityEventList =
(List<SecurityEvent>)message.get(SecurityEvent.class.getName() + ".in");
if (incomingSecurityEventList == null) {
LOG.warning("Security processing failed (actions mismatch)");
XMLSecurityException ex =
new XMLSecurityException("empty", new Object[] {"The request was not signed or encrypted"});
throwFault(ex.getMessage(), ex);
}
if (signatureRequired) {
Event requiredEvent = SecurityEventConstants.SignatureValue;
if (!isEventInResults(requiredEvent, incomingSecurityEventList)) {
LOG.warning("The request was not signed");
XMLSecurityException ex =
new XMLSecurityException("empty", new Object[] {"The request was not signed"});
throwFault(ex.getMessage(), ex);
}
}
if (encryptionRequired) {
boolean foundEncryptionPart =
isEventInResults(SecurityEventConstants.EncryptedElement, incomingSecurityEventList);
if (!foundEncryptionPart) {
LOG.warning("The request was not encrypted");
XMLSecurityException ex =
new XMLSecurityException("empty", new Object[] {"The request was not encrypted"});
throwFault(ex.getMessage(), ex);
}
}
}
private boolean isEventInResults(Event event, List<SecurityEvent> incomingSecurityEventList) {
for (SecurityEvent incomingEvent : incomingSecurityEventList) {
if (event == incomingEvent.getSecurityEventType()) {
return true;
}
}
return false;
}
protected void throwFault(String error, Exception ex) {
LOG.warning(error);
Response response = JAXRSUtils.toResponseBuilder(400).entity(error).build();
throw ExceptionUtils.toBadRequestException(null, response);
}
}
}