/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.saml.processing.core.util;
import org.keycloak.saml.common.PicketLinkLogger;
import org.keycloak.saml.common.PicketLinkLoggerFactory;
import org.keycloak.saml.common.constants.GeneralConstants;
import org.xml.sax.ErrorHandler;
import org.xml.sax.SAXException;
import org.xml.sax.SAXParseException;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import javax.xml.transform.Source;
import javax.xml.transform.stream.StreamSource;
import javax.xml.validation.Schema;
import javax.xml.validation.SchemaFactory;
import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
/**
* Utility to obtain JAXB2 marshaller/unmarshaller etc
*
* @author Anil.Saldhana@redhat.com
* @since May 26, 2009
*/
public class JAXBUtil {
private static final PicketLinkLogger logger = PicketLinkLoggerFactory.getLogger();
public static final String W3C_XML_SCHEMA_NS_URI = "http://www.w3.org/2001/XMLSchema";
private static HashMap<String, JAXBContext> jaxbContextHash = new HashMap<String, JAXBContext>();
static {
// Useful on Sun VMs. Harmless on other VMs.
SecurityActions.setSystemProperty("com.sun.xml.bind.v2.runtime.JAXBContextImpl.fastBoot", "true");
}
/**
* Get the JAXB Marshaller
*
* @param pkgName The package name for the jaxb context
* @param schemaLocation location of the schema to validate against
*
* @return Marshaller
*
* @throws JAXBException
* @throws SAXException
*/
public static Marshaller getValidatingMarshaller(String pkgName, String schemaLocation) throws JAXBException, SAXException {
Marshaller marshaller = getMarshaller(pkgName);
// Validate against schema
Schema schema = getJAXPSchemaInstance(schemaLocation);
marshaller.setSchema(schema);
return marshaller;
}
/**
* Get the JAXB Marshaller
*
* @param pkgName The package name for the jaxb context
*
* @return Marshaller
*
* @throws JAXBException
*/
public static Marshaller getMarshaller(String pkgName) throws JAXBException {
if (pkgName == null)
throw logger.nullArgumentError("pkgName");
JAXBContext jc = getJAXBContext(pkgName);
Marshaller marshaller = jc.createMarshaller();
marshaller.setProperty(Marshaller.JAXB_ENCODING, GeneralConstants.SAML_CHARSET_NAME);
marshaller.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, Boolean.FALSE); // Breaks signatures
return marshaller;
}
/**
* Get the JAXB Unmarshaller
*
* @param pkgName The package name for the jaxb context
*
* @return unmarshaller
*
* @throws JAXBException
*/
public static Unmarshaller getUnmarshaller(String pkgName) throws JAXBException {
if (pkgName == null)
throw logger.nullArgumentError("pkgName");
JAXBContext jc = getJAXBContext(pkgName);
return jc.createUnmarshaller();
}
/**
* Get the JAXB Unmarshaller for a selected set of package names
*
* @param pkgNames
*
* @return
*
* @throws JAXBException
*/
public static Unmarshaller getUnmarshaller(String... pkgNames) throws JAXBException {
if (pkgNames == null)
throw logger.nullArgumentError("pkgName");
int len = pkgNames.length;
if (len == 0)
return getUnmarshaller(pkgNames[0]);
JAXBContext jc = getJAXBContext(pkgNames);
return jc.createUnmarshaller();
}
/**
* Get the JAXB Unmarshaller
*
* @param pkgName The package name for the jaxb context
* @param schemaLocation location of the schema to validate against
*
* @return unmarshaller
*
* @throws JAXBException
* @throws SAXException
*/
public static Unmarshaller getValidatingUnmarshaller(String pkgName, String schemaLocation) throws JAXBException,
SAXException {
Unmarshaller unmarshaller = getUnmarshaller(pkgName);
Schema schema = getJAXPSchemaInstance(schemaLocation);
unmarshaller.setSchema(schema);
return unmarshaller;
}
public static Unmarshaller getValidatingUnmarshaller(String[] pkgNames, String[] schemaLocations) throws JAXBException,
SAXException, IOException {
StringBuilder builder = new StringBuilder();
int len = pkgNames.length;
if (len == 0)
throw logger.nullValueError("Packages are empty");
for (String pkg : pkgNames) {
builder.append(pkg);
builder.append(":");
}
Unmarshaller unmarshaller = getUnmarshaller(builder.toString());
SchemaFactory schemaFactory = getSchemaFactory();
// Get the sources
Source[] schemaSources = new Source[schemaLocations.length];
int i = 0;
for (String schemaLocation : schemaLocations) {
URL schemaURL = SecurityActions.loadResource(JAXBUtil.class, schemaLocation);
if (schemaURL == null)
throw logger.nullValueError("Schema URL :" + schemaLocation);
schemaSources[i++] = new StreamSource(schemaURL.openStream());
}
Schema schema = schemaFactory.newSchema(schemaSources);
unmarshaller.setSchema(schema);
return unmarshaller;
}
private static Schema getJAXPSchemaInstance(String schemaLocation) throws SAXException {
URL schemaURL = SecurityActions.loadResource(JAXBUtil.class, schemaLocation);
if (schemaURL == null)
throw logger.nullValueError("Schema URL :" + schemaLocation);
SchemaFactory scFact = getSchemaFactory();
Schema schema = scFact.newSchema(schemaURL);
return schema;
}
private static SchemaFactory getSchemaFactory() {
SchemaFactory scFact = SchemaFactory.newInstance(W3C_XML_SCHEMA_NS_URI);
// Always install the resolver unless the system property is set
if (SecurityActions.getSystemProperty("org.picketlink.identity.federation.jaxb.ls", null) == null)
scFact.setResourceResolver(new IDFedLSInputResolver());
scFact.setErrorHandler(new ErrorHandler() {
public void error(SAXParseException exception) throws SAXException {
StringBuilder builder = new StringBuilder();
builder.append("Line Number=").append(exception.getLineNumber());
builder.append(" Col Number=").append(exception.getColumnNumber());
builder.append(" Public ID=").append(exception.getPublicId());
builder.append(" System ID=").append(exception.getSystemId());
builder.append(" exc=").append(exception.getLocalizedMessage());
logger.trace("SAX Error:" + builder.toString());
}
public void fatalError(SAXParseException exception) throws SAXException {
StringBuilder builder = new StringBuilder();
builder.append("Line Number=").append(exception.getLineNumber());
builder.append(" Col Number=").append(exception.getColumnNumber());
builder.append(" Public ID=").append(exception.getPublicId());
builder.append(" System ID=").append(exception.getSystemId());
builder.append(" exc=").append(exception.getLocalizedMessage());
logger.error("SAX Fatal Error:" + builder.toString());
}
public void warning(SAXParseException exception) throws SAXException {
StringBuilder builder = new StringBuilder();
builder.append("Line Number=").append(exception.getLineNumber());
builder.append(" Col Number=").append(exception.getColumnNumber());
builder.append(" Public ID=").append(exception.getPublicId());
builder.append(" System ID=").append(exception.getSystemId());
builder.append(" exc=").append(exception.getLocalizedMessage());
logger.trace("SAX Warn:" + builder.toString());
}
});
return scFact;
}
public static JAXBContext getJAXBContext(String path) throws JAXBException {
JAXBContext jx = jaxbContextHash.get(path);
if (jx == null) {
jx = JAXBContext.newInstance(path);
jaxbContextHash.put(path, jx);
}
return jx;
}
public static JAXBContext getJAXBContext(String... paths) throws JAXBException {
int len = paths.length;
if (len == 0)
return getJAXBContext(paths[0]);
StringBuilder builder = new StringBuilder();
for (String path : paths) {
builder.append(path).append(":");
}
String finalPath = builder.toString();
JAXBContext jx = jaxbContextHash.get(finalPath);
if (jx == null) {
jx = JAXBContext.newInstance(finalPath);
jaxbContextHash.put(finalPath, jx);
}
return jx;
}
public static JAXBContext getJAXBContext(Class<?> clazz) throws JAXBException {
String clazzName = clazz.getName();
JAXBContext jx = jaxbContextHash.get(clazzName);
if (jx == null) {
jx = JAXBContext.newInstance(clazz);
jaxbContextHash.put(clazzName, jx);
}
return jx;
}
}