/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.wss4j.common.token; import java.math.BigInteger; import java.security.cert.CertificateEncodingException; import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.Base64; import javax.xml.namespace.QName; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.w3c.dom.Node; import org.w3c.dom.Text; import org.apache.wss4j.common.WSS4JConstants; import org.apache.wss4j.common.bsp.BSPEnforcer; import org.apache.wss4j.common.bsp.BSPRule; import org.apache.wss4j.common.crypto.Crypto; import org.apache.wss4j.common.crypto.CryptoType; import org.apache.wss4j.common.crypto.Merlin; import org.apache.wss4j.common.ext.WSSecurityException; import org.apache.wss4j.common.util.DOM2Writer; import org.apache.wss4j.common.util.KeyUtils; import org.apache.wss4j.common.util.XMLUtils; /** * Security Token Reference. */ public class SecurityTokenReference { public static final String SECURITY_TOKEN_REFERENCE = "SecurityTokenReference"; public static final QName STR_QNAME = new QName(WSS4JConstants.WSSE_NS, SECURITY_TOKEN_REFERENCE); public static final String SKI_URI = WSS4JConstants.X509TOKEN_NS + "#X509SubjectKeyIdentifier"; public static final String THUMB_URI = WSS4JConstants.SOAPMESSAGE_NS11 + "#" + WSS4JConstants.THUMBPRINT; public static final String ENC_KEY_SHA1_URI = WSS4JConstants.SOAPMESSAGE_NS11 + "#" + WSS4JConstants.ENC_KEY_SHA1_URI; public static final String X509_V3_TYPE = WSS4JConstants.X509TOKEN_NS + "#X509v3"; private static final org.slf4j.Logger LOG = org.slf4j.LoggerFactory.getLogger(SecurityTokenReference.class); private Element element; private DOMX509IssuerSerial issuerSerial; private byte[] skiBytes; private Reference reference; /** * Constructor. * * @param elem A SecurityTokenReference element * @param bspEnforcer a BSPEnforcer instance to enforce BSP rules * @throws WSSecurityException */ public SecurityTokenReference(Element elem, BSPEnforcer bspEnforcer) throws WSSecurityException { element = elem; QName el = new QName(element.getNamespaceURI(), element.getLocalName()); if (!STR_QNAME.equals(el)) { throw new WSSecurityException(WSSecurityException.ErrorCode.FAILURE, "badElement", new Object[] {STR_QNAME, el}); } checkBSPCompliance(bspEnforcer); if (containsReference()) { Node node = element.getFirstChild(); while (node != null) { if (Node.ELEMENT_NODE == node.getNodeType() && WSS4JConstants.WSSE_NS.equals(node.getNamespaceURI()) && "Reference".equals(node.getLocalName())) { reference = new Reference((Element)node); break; } node = node.getNextSibling(); } } } /** * Constructor. * * @param doc The Document */ public SecurityTokenReference(Document doc) { element = doc.createElementNS(WSS4JConstants.WSSE_NS, "wsse:SecurityTokenReference"); } /** * Add the WSSE Namespace to this STR. The namespace is not added by default for * efficiency purposes. */ public void addWSSENamespace() { XMLUtils.setNamespace(element, WSS4JConstants.WSSE_NS, WSS4JConstants.WSSE_PREFIX); } /** * Add the WSU Namespace to this STR. The namespace is not added by default for * efficiency purposes. */ public void addWSUNamespace() { XMLUtils.setNamespace(element, WSS4JConstants.WSU_NS, WSS4JConstants.WSU_PREFIX); } /** * Add a wsse11:TokenType attribute to this SecurityTokenReference * @param tokenType the wsse11:TokenType attribute to add */ public void addTokenType(String tokenType) { if (tokenType != null) { XMLUtils.setNamespace(element, WSS4JConstants.WSSE11_NS, WSS4JConstants.WSSE11_PREFIX); element.setAttributeNS( WSS4JConstants.WSSE11_NS, WSS4JConstants.WSSE11_PREFIX + ":" + WSS4JConstants.TOKEN_TYPE, tokenType ); } } /** * Get the wsse11:TokenType attribute of this SecurityTokenReference * @return the value of the wsse11:TokenType attribute */ public String getTokenType() { return element.getAttributeNS( WSS4JConstants.WSSE11_NS, WSS4JConstants.TOKEN_TYPE ); } /** * set the reference. * * @param ref */ public void setReference(Reference ref) { Element elem = getFirstElement(); if (elem != null) { element.replaceChild(ref.getElement(), elem); } else { element.appendChild(ref.getElement()); } this.reference = ref; } /** * Gets the Reference. * * @return the <code>Reference</code> element contained in this * SecurityTokenReference * @throws WSSecurityException */ public Reference getReference() throws WSSecurityException { return reference; } /** * Sets the KeyIdentifier Element as a X509 certificate. * Takes a X509 certificate, converts its data into base 64 and inserts * it into a <code>wsse:KeyIdentifier</code> element, which is placed * in the <code>wsse:SecurityTokenReference</code> element. * * @param cert is the X509 certificate to be inserted as key identifier */ public void setKeyIdentifier(X509Certificate cert) throws WSSecurityException { Document doc = element.getOwnerDocument(); byte[] data = null; try { data = cert.getEncoded(); } catch (CertificateEncodingException e) { throw new WSSecurityException( WSSecurityException.ErrorCode.SECURITY_TOKEN_UNAVAILABLE, e, "encodeError" ); } Text text = doc.createTextNode(Base64.getMimeEncoder().encodeToString(data)); createKeyIdentifier(doc, X509_V3_TYPE, text, true); } /** * Sets the KeyIdentifier Element as a X509 Subject-Key-Identifier (SKI). * Takes a X509 certificate, gets the SKI data, converts it into base 64 and * inserts it into a <code>wsse:KeyIdentifier</code> element, which is placed * in the <code>wsse:SecurityTokenReference</code> element. * * @param cert is the X509 certificate to get the SKI * @param crypto is the Crypto implementation. Used to read SKI info bytes from certificate */ public void setKeyIdentifierSKI(X509Certificate cert, Crypto crypto) throws WSSecurityException { // // As per the 1.1 specification, SKI can only be used for a V3 certificate // if (cert.getVersion() != 3) { throw new WSSecurityException( WSSecurityException.ErrorCode.UNSUPPORTED_SECURITY_TOKEN, "invalidCertForSKI", new Object[] {cert.getVersion()}); } Document doc = element.getOwnerDocument(); // Fall back to Merlin if crypto parameter is null Crypto skiCrypto = crypto; if (skiCrypto == null) { skiCrypto = new Merlin(); } byte[] data = skiCrypto.getSKIBytesFromCert(cert); Text text = doc.createTextNode(Base64.getMimeEncoder().encodeToString(data)); createKeyIdentifier(doc, SKI_URI, text, true); } /** * Sets the KeyIdentifier Element as a Thumbprint. * * Takes a X509 certificate, computes its thumbprint using SHA-1, converts * into base 64 and inserts it into a <code>wsse:KeyIdentifier</code> * element, which is placed in the <code>wsse:SecurityTokenReference</code> * element. * * @param cert is the X509 certificate to get the thumbprint */ public void setKeyIdentifierThumb(X509Certificate cert) throws WSSecurityException { Document doc = element.getOwnerDocument(); byte[] encodedCert = null; try { encodedCert = cert.getEncoded(); } catch (CertificateEncodingException e1) { throw new WSSecurityException( WSSecurityException.ErrorCode.SECURITY_TOKEN_UNAVAILABLE, e1, "encodeError" ); } try { byte[] encodedBytes = KeyUtils.generateDigest(encodedCert); Text text = doc.createTextNode(Base64.getMimeEncoder().encodeToString(encodedBytes)); createKeyIdentifier(doc, THUMB_URI, text, true); } catch (WSSecurityException e1) { throw new WSSecurityException( WSSecurityException.ErrorCode.FAILURE, e1, "decoding.general" ); } } public void setKeyIdentifierEncKeySHA1(String value) throws WSSecurityException { Document doc = element.getOwnerDocument(); Text text = doc.createTextNode(value); createKeyIdentifier(doc, ENC_KEY_SHA1_URI, text, true); } public void setKeyIdentifier(String valueType, String keyIdVal) throws WSSecurityException { setKeyIdentifier(valueType, keyIdVal, false); } public void setKeyIdentifier(String valueType, String keyIdVal, boolean base64) throws WSSecurityException { Document doc = element.getOwnerDocument(); createKeyIdentifier(doc, valueType, doc.createTextNode(keyIdVal), base64); } private void createKeyIdentifier(Document doc, String uri, Node node, boolean base64) { Element keyId = doc.createElementNS(WSS4JConstants.WSSE_NS, "wsse:KeyIdentifier"); keyId.setAttributeNS(null, "ValueType", uri); if (base64) { keyId.setAttributeNS(null, "EncodingType", WSS4JConstants.BASE64_ENCODING); } keyId.appendChild(node); Element elem = getFirstElement(); if (elem != null) { element.replaceChild(keyId, elem); } else { element.appendChild(keyId); } } /** * get the first child element. * * @return the first <code>Element</code> child node */ public Element getFirstElement() { for (Node currentChild = element.getFirstChild(); currentChild != null; currentChild = currentChild.getNextSibling() ) { if (Node.ELEMENT_NODE == currentChild.getNodeType()) { return (Element) currentChild; } } return null; } /** * Gets the KeyIdentifier. * * @return the the X509 certificate or zero if a unknown key identifier * type was detected. */ public X509Certificate[] getKeyIdentifier(Crypto crypto) throws WSSecurityException { if (crypto == null) { return null; } Element elem = getFirstElement(); String value = elem.getAttributeNS(null, "ValueType"); if (X509_V3_TYPE.equals(value)) { X509Security token = new X509Security(elem, new BSPEnforcer(true)); X509Certificate cert = token.getX509Certificate(crypto); return new X509Certificate[]{cert}; } else if (SKI_URI.equals(value)) { X509Certificate cert = getX509SKIAlias(crypto); if (cert != null) { return new X509Certificate[]{cert}; } } else if (THUMB_URI.equals(value)) { String text = XMLUtils.getElementText(getFirstElement()); if (text != null) { byte[] thumb = Base64.getMimeDecoder().decode(text); CryptoType cryptoType = new CryptoType(CryptoType.TYPE.THUMBPRINT_SHA1); cryptoType.setBytes(thumb); X509Certificate[] certs = crypto.getX509Certificates(cryptoType); if (certs != null) { return new X509Certificate[]{certs[0]}; } } } return null; } public String getKeyIdentifierValue() { if (containsKeyIdentifier()) { return XMLUtils.getElementText(getFirstElement()); } return null; } public String getKeyIdentifierValueType() { if (containsKeyIdentifier()) { Element elem = getFirstElement(); return elem.getAttributeNS(null, "ValueType"); } return null; } public String getKeyIdentifierEncodingType() { if (containsKeyIdentifier()) { Element elem = getFirstElement(); return elem.getAttributeNS(null, "EncodingType"); } return null; } public X509Certificate getX509SKIAlias(Crypto crypto) throws WSSecurityException { if (crypto == null) { return null; } if (skiBytes == null) { skiBytes = getSKIBytes(); if (skiBytes == null) { return null; } } CryptoType cryptoType = new CryptoType(CryptoType.TYPE.SKI_BYTES); cryptoType.setBytes(skiBytes); X509Certificate[] certs = crypto.getX509Certificates(cryptoType); if (certs != null) { return certs[0]; } return null; } public byte[] getSKIBytes() { if (skiBytes != null) { return skiBytes; } String text = XMLUtils.getElementText(getFirstElement()); if (text != null) { skiBytes = Base64.getMimeDecoder().decode(text); } return skiBytes; } /** * Set an unknown element. * * @param unknownElement the org.w3c.dom.Element to put into this * SecurityTokenReference */ public void setUnknownElement(Element unknownElement) { Element elem = getFirstElement(); if (elem != null) { element.replaceChild(unknownElement, elem); } else { element.appendChild(unknownElement); } } /** * Gets the certificate identified with X509 issuerSerial data. * * @return a certificate array or null if nothing found */ public X509Certificate[] getX509IssuerSerial(Crypto crypto) throws WSSecurityException { if (crypto == null) { return null; } if (issuerSerial == null) { issuerSerial = getIssuerSerial(); if (issuerSerial == null) { return null; } } CryptoType cryptoType = new CryptoType(CryptoType.TYPE.ISSUER_SERIAL); cryptoType.setIssuerSerial(issuerSerial.getIssuer(), issuerSerial.getSerialNumber()); return crypto.getX509Certificates(cryptoType); } private DOMX509IssuerSerial getIssuerSerial() throws WSSecurityException { if (issuerSerial != null) { return issuerSerial; } Element elem = getFirstElement(); if (elem == null) { return null; } if (WSS4JConstants.X509_DATA_LN.equals(elem.getLocalName())) { elem = XMLUtils.findElement( elem, WSS4JConstants.X509_ISSUER_SERIAL_LN, WSS4JConstants.SIG_NS ); } issuerSerial = new DOMX509IssuerSerial(elem); return issuerSerial; } /** * Method containsReference * * @return true if the <code>SecurityTokenReference</code> contains * a <code>wsse:Reference</code> element */ public boolean containsReference() { return containsElement(WSS4JConstants.WSSE_NS, "Reference"); } /** * Method containsX509IssuerSerial * * @return true if the <code>SecurityTokenReference</code> contains * a <code>ds:IssuerSerial</code> element */ public boolean containsX509IssuerSerial() { return containsElement(WSS4JConstants.SIG_NS, WSS4JConstants.X509_ISSUER_SERIAL_LN); } /** * Method containsX509Data * * @return true if the <code>SecurityTokenReference</code> contains * a <code>ds:X509Data</code> element */ public boolean containsX509Data() { return containsElement(WSS4JConstants.SIG_NS, WSS4JConstants.X509_DATA_LN); } /** * Method containsKeyIdentifier. * * @return true if the <code>SecurityTokenReference</code> contains * a <code>wsse:KeyIdentifier</code> element */ public boolean containsKeyIdentifier() { return containsElement(WSS4JConstants.WSSE_NS, "KeyIdentifier"); } private boolean containsElement(String namespace, String localname) { Node node = element.getFirstChild(); while (node != null) { if (Node.ELEMENT_NODE == node.getNodeType()) { String ns = node.getNamespaceURI(); String name = node.getLocalName(); if ((namespace != null && namespace.equals(ns) || namespace == null && ns == null) && localname.equals(name) ) { return true; } } node = node.getNextSibling(); } return false; } /** * Get the DOM element. * * @return the DOM element */ public Element getElement() { return element; } /** * set the id. * * @param id */ public void setID(String id) { element.setAttributeNS(WSS4JConstants.WSU_NS, WSS4JConstants.WSU_PREFIX + ":Id", id); } /** * Get the id * @return the wsu ID of the element */ public String getID() { return element.getAttributeNS(WSS4JConstants.WSU_NS, "Id"); } /** * return the string representation. * * @return a representation of this SecurityTokenReference element as a String */ public String toString() { return DOM2Writer.nodeToString(element); } /** * A method to check that the SecurityTokenReference is compliant with the BSP spec. * @throws WSSecurityException */ private void checkBSPCompliance(BSPEnforcer bspEnforcer) throws WSSecurityException { // We can only have one token reference int result = 0; Node node = element.getFirstChild(); Element child = null; while (node != null) { if (Node.ELEMENT_NODE == node.getNodeType()) { result++; child = (Element)node; } node = node.getNextSibling(); } if (result != 1) { bspEnforcer.handleBSPRule(BSPRule.R3061); } if ("KeyIdentifier".equals(child.getLocalName()) && WSS4JConstants.WSSE_NS.equals(child.getNamespaceURI())) { String valueType = getKeyIdentifierValueType(); // ValueType cannot be null if (valueType == null || "".equals(valueType)) { bspEnforcer.handleBSPRule(BSPRule.R3054); } String encodingType = getFirstElement().getAttributeNS(null, "EncodingType"); // Encoding Type must be equal to Base64Binary if it's specified if (!"".equals(encodingType) && !WSS4JConstants.BASE64_ENCODING.equals(encodingType)) { bspEnforcer.handleBSPRule(BSPRule.R3071); } // Encoding type must be specified other than for a SAML Assertion if (!WSS4JConstants.WSS_SAML_KI_VALUE_TYPE.equals(valueType) && !WSS4JConstants.WSS_SAML2_KI_VALUE_TYPE.equals(valueType) && "".equals(encodingType)) { bspEnforcer.handleBSPRule(BSPRule.R3070); } } else if ("Embedded".equals(child.getLocalName())) { result = 0; node = child.getFirstChild(); while (node != null) { if (Node.ELEMENT_NODE == node.getNodeType()) { result++; // We cannot have a SecurityTokenReference child element if ("SecurityTokenReference".equals(node.getLocalName()) && WSS4JConstants.WSSE_NS.equals(node.getNamespaceURI())) { bspEnforcer.handleBSPRule(BSPRule.R3056); } } node = node.getNextSibling(); } // We can only have one embedded child if (result != 1) { bspEnforcer.handleBSPRule(BSPRule.R3060); } } } @Override public int hashCode() { int result = 17; try { Reference reference = getReference(); if (reference != null) { result = 31 * result + reference.hashCode(); } } catch (WSSecurityException e) { LOG.error(e.getMessage(), e); } String keyIdentifierEncodingType = getKeyIdentifierEncodingType(); if (keyIdentifierEncodingType != null) { result = 31 * result + keyIdentifierEncodingType.hashCode(); } String keyIdentifierValueType = getKeyIdentifierValueType(); if (keyIdentifierValueType != null) { result = 31 * result + keyIdentifierValueType.hashCode(); } String keyIdentifierValue = getKeyIdentifierValue(); if (keyIdentifierValue != null) { result = 31 * result + keyIdentifierValue.hashCode(); } String tokenType = getTokenType(); if (tokenType != null) { result = 31 * result + tokenType.hashCode(); } byte[] skiBytes = getSKIBytes(); if (skiBytes != null) { result = 31 * result + Arrays.hashCode(skiBytes); } String issuer = null; BigInteger serialNumber = null; try { issuer = getIssuerSerial().getIssuer(); serialNumber = getIssuerSerial().getSerialNumber(); } catch (WSSecurityException e) { LOG.error(e.getMessage(), e); } if (issuer != null) { result = 31 * result + issuer.hashCode(); } if (serialNumber != null) { result = 31 * result + serialNumber.hashCode(); } return result; } @Override public boolean equals(Object object) { if (!(object instanceof SecurityTokenReference)) { return false; } SecurityTokenReference tokenReference = (SecurityTokenReference)object; try { if (!getReference().equals(tokenReference.getReference())) { return false; } } catch (WSSecurityException e) { LOG.error(e.getMessage(), e); return false; } if (!compare(getKeyIdentifierEncodingType(), tokenReference.getKeyIdentifierEncodingType())) { return false; } if (!compare(getKeyIdentifierValueType(), tokenReference.getKeyIdentifierValueType())) { return false; } if (!compare(getKeyIdentifierValue(), tokenReference.getKeyIdentifierValue())) { return false; } if (!compare(getTokenType(), tokenReference.getTokenType())) { return false; } if (!Arrays.equals(getSKIBytes(), tokenReference.getSKIBytes())) { return false; } try { if (getIssuerSerial() != null && tokenReference.getIssuerSerial() != null) { if (!compare(getIssuerSerial().getIssuer(), tokenReference.getIssuerSerial().getIssuer())) { return false; } if (!compare(getIssuerSerial().getSerialNumber(), tokenReference.getIssuerSerial().getSerialNumber())) { return false; } } } catch (WSSecurityException e) { LOG.error(e.getMessage(), e); return false; } return true; } private boolean compare(String item1, String item2) { if (item1 == null && item2 != null) { return false; } else if (item1 != null && !item1.equals(item2)) { return false; } return true; } private boolean compare(BigInteger item1, BigInteger item2) { if (item1 == null && item2 != null) { return false; } else if (item1 != null && !item1.equals(item2)) { return false; } return true; } }