package org.atricore.idbus.capabilities.oauth2.common.util; import com.sun.xml.bind.marshaller.NamespacePrefixMapper; import org.apache.commons.codec.binary.Base64; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.atricore.idbus.capabilities.oauth2.common.OAuth2Constants; import org.atricore.idbus.common.oauth._2_0.protocol.OAuthRequestAbstractType; import org.atricore.idbus.common.oauth._2_0.protocol.OAuthResponseAbstractType; import org.atricore.idbus.kernel.main.databinding.JAXBUtils; import org.xml.sax.SAXException; import javax.xml.bind.JAXBContext; import javax.xml.bind.JAXBElement; import javax.xml.bind.Marshaller; import javax.xml.namespace.QName; import javax.xml.parsers.ParserConfigurationException; import javax.xml.parsers.SAXParserFactory; import javax.xml.stream.XMLInputFactory; import javax.xml.stream.XMLOutputFactory; import javax.xml.stream.XMLStreamException; import javax.xml.stream.XMLStreamWriter; import javax.xml.ws.Holder; import java.io.ByteArrayOutputStream; import java.io.StringWriter; import java.io.Writer; import java.util.HashMap; import java.util.TreeSet; /** * @author <a href=mailto:sgonzalez@atricore.org>Sebastian Gonzalez Oyuela</a> */ public class XmlUtils { private static final Log logger = LogFactory.getLog(XmlUtils.class); private static final TreeSet<String> oauthContextPackages = new TreeSet<String>(); private static final Holder<JAXBUtils.CONSTRUCTION_TYPE> constructionType = new Holder<JAXBUtils.CONSTRUCTION_TYPE>(); private static final XMLInputFactory staxIF = XMLInputFactory.newInstance(); private static final XMLOutputFactory staxOF = XMLOutputFactory.newInstance(); static { oauthContextPackages.add(OAuth2Constants.OAUTH2_PROTOCOL_PKG); javax.xml.parsers.DocumentBuilderFactory dbf = javax.xml.parsers.DocumentBuilderFactory.newInstance(); javax.xml.parsers.SAXParserFactory saxf = SAXParserFactory.newInstance(); try { logger.debug("DocumentBuilder = " + dbf.newDocumentBuilder()); logger.debug("SAXParser = " + saxf.newSAXParser()); logger.debug("XMLEventReader = " + staxIF.createXMLEventReader(new StringSource("<a>Hello</a>"))); logger.debug("XMLEventWriter = " + staxOF.createXMLEventWriter(new ByteArrayOutputStream())); } catch (ParserConfigurationException e) { logger.error(e.getMessage(), e); } catch (SAXException e) { logger.error(e.getMessage(), e); } catch (XMLStreamException e) { logger.error(e.getMessage(), e); } } public static String marshalOAuth2Request(OAuthRequestAbstractType request, boolean encode) throws Exception { String type = request.getClass().getSimpleName(); if (type.endsWith("Type")) type = type.substring(0, type.length() - 4); return marshalOAuth2Request(request, type, encode); } public static String marshalOAuth2Request(OAuthRequestAbstractType request, String requestType, boolean encode) throws Exception { String marshaledRequest = marshalOAuth2( request, OAuth2Constants.OAUTH2_PROTOCOL_NS, requestType ); return encode ? new String(new Base64().encode(marshaledRequest.getBytes())) : marshaledRequest; } public static String marshalOAuth2Response(OAuthResponseAbstractType response, boolean encode) throws Exception { String type = response.getClass().getSimpleName(); if (type.endsWith("Type")) type = type.substring(0, type.length() - 4); return marshalOAuth2Response(response, type, encode); } public static String marshalOAuth2Response(OAuthResponseAbstractType response, String responseType, boolean encode) throws Exception { String marshaledResponse = XmlUtils.marshalOAuth2( response, OAuth2Constants.OAUTH2_PROTOCOL_NS, responseType ); return encode ? new String(new Base64().encode(marshaledResponse.getBytes())) : marshaledResponse; } public static String marshalOAuth2(Object msg, String msgQName, String msgLocalName) throws Exception { //JAXBContext jaxbContext = createJAXBContext(userPackages); JAXBContext jaxbContext = JAXBUtils.getJAXBContext(oauthContextPackages, constructionType, oauthContextPackages.toString(), XmlUtils.class.getClassLoader(), new HashMap<String, Object>()); Marshaller m = JAXBUtils.getJAXBMarshaller(jaxbContext); JAXBElement jaxbRequest = new JAXBElement(new QName(msgQName, msgLocalName), msg.getClass(), msg ); Writer writer = new StringWriter(); XMLStreamWriter xmlStreamWriter = new NamespaceFilterXMLStreamWriter(writer); // Support XMLDsig // TODO : What about non-sun XML Bind stacks! m.setProperty("com.sun.xml.bind.namespacePrefixMapper", new NamespacePrefixMapper() { @Override public String[] getPreDeclaredNamespaceUris() { return new String[] { OAuth2Constants.OAUTH2_PROTOCOL_NS, "http://www.w3.org/2000/09/xmldsig#", "http://www.w3.org/2001/04/xmlenc#", "http://www.w3.org/2001/XMLSchema" }; } @Override public String getPreferredPrefix(String nsUri, String suggestion, boolean requirePrefix) { if (nsUri.equals(OAuth2Constants.OAUTH2_PROTOCOL_NS)) return "oauth2p"; else if (nsUri.equals("http://www.w3.org/2000/09/xmldsig#")) return "ds"; else if (nsUri.equals("http://www.w3.org/2001/04/xmlenc#")) return "enc"; else if (nsUri.equals("http://www.w3.org/2001/XMLSchema")) return "xsd"; return suggestion; } }); m.marshal(jaxbRequest, xmlStreamWriter); xmlStreamWriter.flush(); JAXBUtils.releaseJAXBMarshaller(jaxbContext, m); return writer.toString(); } }