/* * Copyright 2016 Red Hat, Inc. and/or its affiliates * and other contributors as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.keycloak.saml.processing.core.saml.v2.util; import org.keycloak.dom.xmlsec.w3.xmldsig.DSAKeyValueType; import org.keycloak.dom.xmlsec.w3.xmldsig.KeyInfoType; import org.keycloak.dom.xmlsec.w3.xmldsig.KeyValueType; import org.keycloak.dom.xmlsec.w3.xmldsig.RSAKeyValueType; import org.keycloak.dom.xmlsec.w3.xmldsig.X509CertificateType; import org.keycloak.dom.xmlsec.w3.xmldsig.X509DataType; import org.keycloak.saml.common.ErrorCodes; import org.keycloak.saml.common.PicketLinkLogger; import org.keycloak.saml.common.PicketLinkLoggerFactory; import org.keycloak.saml.common.constants.GeneralConstants; import org.keycloak.saml.common.constants.WSTrustConstants; import org.keycloak.saml.common.exceptions.ProcessingException; import org.keycloak.saml.common.util.StaxUtil; import org.w3c.dom.Element; import javax.xml.stream.XMLStreamWriter; /** * Utility methods for stax writing * * @author anil saldhana * @since Jan 28, 2013 */ public class StaxWriterUtil { private static final PicketLinkLogger logger = PicketLinkLoggerFactory.getLogger(); /** * Write the {@link org.keycloak.dom.xmlsec.w3.xmldsig.KeyInfoType} * * @param writer * @param keyInfo * * @throws org.keycloak.saml.common.exceptions.ProcessingException */ public static void writeKeyInfo(XMLStreamWriter writer, KeyInfoType keyInfo) throws ProcessingException { if (keyInfo.getContent() == null || keyInfo.getContent().size() == 0) throw logger.writerInvalidKeyInfoNullContentError(); StaxUtil.writeStartElement(writer, WSTrustConstants.XMLDSig.DSIG_PREFIX, WSTrustConstants.XMLDSig.KEYINFO, WSTrustConstants.XMLDSig.DSIG_NS); StaxUtil.writeNameSpace(writer, WSTrustConstants.XMLDSig.DSIG_PREFIX, WSTrustConstants.XMLDSig.DSIG_NS); // write the keyInfo content. Object content = keyInfo.getContent().get(0); if (content instanceof Element) { Element element = (Element) keyInfo.getContent().get(0); StaxUtil.writeDOMNode(writer, element); } else if (content instanceof X509DataType) { X509DataType type = (X509DataType) content; if (type.getDataObjects().size() == 0) throw logger.writerNullValueError("X509Data"); StaxUtil.writeStartElement(writer, WSTrustConstants.XMLDSig.DSIG_PREFIX, WSTrustConstants.XMLDSig.X509DATA, WSTrustConstants.XMLDSig.DSIG_NS); Object obj = type.getDataObjects().get(0); if (obj instanceof Element) { Element element = (Element) obj; StaxUtil.writeDOMElement(writer, element); } else if (obj instanceof X509CertificateType) { X509CertificateType cert = (X509CertificateType) obj; StaxUtil.writeStartElement(writer, WSTrustConstants.XMLDSig.DSIG_PREFIX, WSTrustConstants.XMLDSig.X509CERT, WSTrustConstants.XMLDSig.DSIG_NS); StaxUtil.writeCharacters(writer, new String(cert.getEncodedCertificate(), GeneralConstants.SAML_CHARSET)); StaxUtil.writeEndElement(writer); } StaxUtil.writeEndElement(writer); } else if (content instanceof KeyValueType) { KeyValueType keyvalueType = (KeyValueType) content; StaxUtil.writeStartElement(writer, WSTrustConstants.XMLDSig.DSIG_PREFIX, WSTrustConstants.XMLDSig.KEYVALUE, WSTrustConstants.XMLDSig.DSIG_NS); if (keyvalueType instanceof DSAKeyValueType) { writeDSAKeyValueType(writer, (DSAKeyValueType) keyvalueType); } if (keyvalueType instanceof RSAKeyValueType) { writeRSAKeyValueType(writer, (RSAKeyValueType) keyvalueType); } StaxUtil.writeEndElement(writer); } else throw new ProcessingException(ErrorCodes.UNSUPPORTED_TYPE + content); StaxUtil.writeEndElement(writer); } public static void writeRSAKeyValueType(XMLStreamWriter writer, RSAKeyValueType type) throws ProcessingException { String prefix = WSTrustConstants.XMLDSig.DSIG_PREFIX; StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.RSA_KEYVALUE, WSTrustConstants.DSIG_NS); // write the rsa key modulus. byte[] modulus = type.getModulus(); StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.MODULUS, WSTrustConstants.DSIG_NS); StaxUtil.writeCharacters(writer, new String(modulus, GeneralConstants.SAML_CHARSET)); StaxUtil.writeEndElement(writer); // write the rsa key exponent. byte[] exponent = type.getExponent(); StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.EXPONENT, WSTrustConstants.DSIG_NS); StaxUtil.writeCharacters(writer, new String(exponent, GeneralConstants.SAML_CHARSET)); StaxUtil.writeEndElement(writer); StaxUtil.writeEndElement(writer); } public static void writeDSAKeyValueType(XMLStreamWriter writer, DSAKeyValueType type) throws ProcessingException { String prefix = WSTrustConstants.XMLDSig.DSIG_PREFIX; StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.DSA_KEYVALUE, WSTrustConstants.DSIG_NS); byte[] p = type.getP(); if (p != null) { StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.P, WSTrustConstants.DSIG_NS); StaxUtil.writeCharacters(writer, new String(p, GeneralConstants.SAML_CHARSET)); StaxUtil.writeEndElement(writer); } byte[] q = type.getQ(); if (q != null) { StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.Q, WSTrustConstants.DSIG_NS); StaxUtil.writeCharacters(writer, new String(q, GeneralConstants.SAML_CHARSET)); StaxUtil.writeEndElement(writer); } byte[] g = type.getG(); if (g != null) { StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.G, WSTrustConstants.DSIG_NS); StaxUtil.writeCharacters(writer, new String(g, GeneralConstants.SAML_CHARSET)); StaxUtil.writeEndElement(writer); } byte[] y = type.getY(); if (y != null) { StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.Y, WSTrustConstants.DSIG_NS); StaxUtil.writeCharacters(writer, new String(y, GeneralConstants.SAML_CHARSET)); StaxUtil.writeEndElement(writer); } byte[] seed = type.getSeed(); if (seed != null) { StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.SEED, WSTrustConstants.DSIG_NS); StaxUtil.writeCharacters(writer, new String(seed, GeneralConstants.SAML_CHARSET)); StaxUtil.writeEndElement(writer); } byte[] pgen = type.getPgenCounter(); if (pgen != null) { StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.PGEN_COUNTER, WSTrustConstants.DSIG_NS); StaxUtil.writeCharacters(writer, new String(pgen, GeneralConstants.SAML_CHARSET)); StaxUtil.writeEndElement(writer); } StaxUtil.writeEndElement(writer); } }