package com.bagri.xquery.saxon; import static com.bagri.xquery.saxon.SaxonUtils.*; import java.io.Reader; import java.io.StringReader; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.Set; import javax.xml.transform.ErrorListener; import javax.xml.transform.TransformerException; import javax.xml.transform.stream.StreamSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import net.sf.saxon.Configuration; import net.sf.saxon.expr.instruct.UserFunction; import net.sf.saxon.expr.instruct.UserFunctionParameter; import net.sf.saxon.functions.ExecutableFunctionLibrary; import net.sf.saxon.functions.FunctionLibrary; import net.sf.saxon.functions.FunctionLibraryList; import net.sf.saxon.lib.ExtensionFunctionDefinition; import net.sf.saxon.lib.ModuleURIResolver; import net.sf.saxon.lib.UnfailingErrorListener; import net.sf.saxon.lib.Validation; import net.sf.saxon.om.StructuredQName; import net.sf.saxon.query.Annotation; import net.sf.saxon.query.StaticQueryContext; import net.sf.saxon.query.XQueryExpression; import net.sf.saxon.query.XQueryFunction; import net.sf.saxon.query.XQueryFunctionLibrary; import net.sf.saxon.trans.XPathException; import net.sf.saxon.value.AtomicValue; import com.bagri.core.api.BagriException; import com.bagri.core.system.DataType; import com.bagri.core.system.Function; import com.bagri.core.system.Library; import com.bagri.core.system.Module; import com.bagri.core.system.Parameter; import com.bagri.core.system.XQueryTrigger; import com.bagri.core.xquery.api.XQCompiler; import com.bagri.xquery.saxon.ext.doc.GetDocumentContent; import com.bagri.xquery.saxon.ext.doc.QueryDocumentUris; import com.bagri.xquery.saxon.ext.doc.RemoveCollectionDocuments; import com.bagri.xquery.saxon.ext.doc.RemoveDocument; import com.bagri.xquery.saxon.ext.doc.StoreDocument; import com.bagri.xquery.saxon.ext.http.HttpGet; import com.bagri.xquery.saxon.ext.tx.BeginTransaction; import com.bagri.xquery.saxon.ext.tx.CommitTransaction; import com.bagri.xquery.saxon.ext.tx.RollbackTransaction; import com.bagri.xquery.saxon.ext.util.GetUuid; import com.bagri.xquery.saxon.ext.util.LogOutput; import com.bagri.xquery.saxon.ext.util.StaticFunctionExtension; public class XQCompilerImpl implements XQCompiler { private static final Logger logger = LoggerFactory.getLogger(XQCompilerImpl.class); private Properties props = new Properties(); private Configuration config; private List<Library> libraries = new ArrayList<>(); public XQCompilerImpl() { initializeConfig(); } @Override public Properties getProperties() { return props; } @Override public void setProperty(String name, Object value) { props.setProperty(name, value.toString()); } private String getError(XPathException ex, StaticQueryContext sqc) { StringBuffer buff = new StringBuffer(); if (sqc.getErrorListener() instanceof LocalErrorListener) { List<TransformerException> errors = ((LocalErrorListener) sqc.getErrorListener()).getErrors(); for (TransformerException tex: errors) { buff.append(tex.getMessageAndLocation()).append("\n"); } } else { Throwable err = ex; while (err != null) { buff.append(err.getMessage()).append("\n"); err = err.getCause(); } } return buff.toString(); } @Override public void compileQuery(String query) throws BagriException { long stamp = System.currentTimeMillis(); logger.trace("compileQuery.enter; got query: {}", query); StaticQueryContext sqc = null; try { sqc = prepareStaticContext(null); sqc.compileQuery(query); } catch (XPathException ex) { String error = getError(ex, sqc); logger.info("compileQuery.error; message: {}", error); throw new BagriException(error, BagriException.ecQueryCompile); } stamp = System.currentTimeMillis() - stamp; logger.trace("compileQuery.exit; time taken: {}", stamp); } @Override public void compileModule(Module module) throws BagriException { long stamp = System.currentTimeMillis(); logger.trace("compileModule.enter; got module: {}", module); getModuleExpression(module); stamp = System.currentTimeMillis() - stamp; logger.trace("compileModule.exit; time taken: {}", stamp); } @Override public String compileTrigger(Module module, XQueryTrigger trigger) throws BagriException { long stamp = System.currentTimeMillis(); logger.trace("compileTrigger.enter; got trigger: {}", trigger); String query = "import module namespace " + module.getPrefix() + "=\"" + module.getNamespace() + "\" at \"" + module.getName() + "\";\n" + "declare variable $doc external;\n\n" + trigger.getFunction() + "($doc)\n"; StaticQueryContext sqc = prepareStaticContext(module.getBody()); logger.trace("getModuleExpression; compiling query: {}", query); try { sqc.compileQuery(query); } catch (XPathException ex) { String error = getError(ex, sqc); //logger.error("compileQuery.error", ex); logger.info("compileTrigger.error; message: {}", error); throw new BagriException(error, BagriException.ecQueryCompile); } stamp = System.currentTimeMillis() - stamp; logger.trace("compileTrigger.exit; time taken: {}", stamp); return query; } @Override public List<String> getModuleFunctions(Module module) throws BagriException { long stamp = System.currentTimeMillis(); logger.trace("getModuleFunctions.enter; got module: {}", module); XQueryExpression exp = getModuleExpression(module); List<String> result = lookupFunctions(exp.getExecutable().getFunctionLibrary(), new FunctionExtractor<String>() { @Override public String extractFunction(UserFunction fn) { String decl = getFunctionDeclaration(fn); List<Annotation> atns = fn.getAnnotations(); logger.trace("lookupFunctions; fn annotations: {}", atns); StringBuffer buff = new StringBuffer(); for (Annotation atn: atns) { if (Annotation.PRIVATE.equals(atn.getAnnotationQName())) { // do not expose private functions return null; } buff.append(atn.getAnnotationQName().getDisplayName()); if (atn.getAnnotationParameters() != null) { buff.append("("); int cnt = 0; for (AtomicValue av: atn.getAnnotationParameters()) { if (cnt > 0) { buff.append(", "); } buff.append("\"").append(av.getStringValue()).append("\""); cnt++; } buff.append(")"); } buff.append("\n"); } decl = buff.toString() + decl; return decl; } }); stamp = System.currentTimeMillis() - stamp; logger.trace("getModuleFunctions.exit; time taken: {}; returning: {}", stamp, result); return result; } private String getFunctionDeclaration(UserFunction function) { //declare function hw:helloworld($name as xs:string) logger.trace("getFunctionDeclaration.enter; function: {}", function); StringBuffer buff = new StringBuffer("function "); buff.append(function.getFunctionName()); buff.append("("); int idx =0; for (UserFunctionParameter ufp: function.getParameterDefinitions()) { if (idx > 0) { buff.append(", "); } buff.append("$"); buff.append(ufp.getVariableQName()); buff.append(" as "); buff.append(ufp.getRequiredType().toString()); idx++; } buff.append(") as "); // TODO: get rid of Q{} notation.. buff.append(function.getDeclaredResultType().toString()); String result = buff.toString(); logger.trace("getFunctionDeclaration.exit; returning: {}", result); return result; } @Override public boolean getModuleState(Module module) { try { String query = "import module namespace test=\"" + module.getNamespace() + "\" at \"" + module.getName() + "\";\n\n"; query += "1213"; StaticQueryContext sqc = prepareStaticContext(module.getBody()); logger.trace("getModuleExpression; compiling query: {}", query); sqc.compileQuery(query); return true; } catch (XPathException ex) { return false; } } @Override public void setLibraries(Collection<Library> libraries) { this.libraries.clear(); this.libraries.addAll(libraries); //config.registerExtensionFunction(function); initializeConfig(); } private void initializeConfig() { logger.trace("initializeConfig.enter; current config: {}", config); config = Configuration.newConfiguration(); //config.setHostLanguage(Configuration.XQUERY); config.setSchemaValidationMode(Validation.STRIP); //config.setConfigurationProperty(FeatureKeys.ALLOW_EXTERNAL_FUNCTIONS, Boolean.TRUE); config.registerExtensionFunction(new GetUuid()); config.registerExtensionFunction(new LogOutput()); config.registerExtensionFunction(new HttpGet()); config.registerExtensionFunction(new GetDocumentContent(null)); config.registerExtensionFunction(new RemoveDocument(null)); config.registerExtensionFunction(new StoreDocument(null)); config.registerExtensionFunction(new RemoveCollectionDocuments(null)); config.registerExtensionFunction(new QueryDocumentUris(null)); config.registerExtensionFunction(new BeginTransaction(null)); config.registerExtensionFunction(new CommitTransaction(null)); config.registerExtensionFunction(new RollbackTransaction(null)); if (libraries != null) { registerExtensions(config, libraries); } logger.trace("initializeConfig.exit; new config: {}", config); } static void registerExtensions(Configuration config, Collection<Library> libraries) { for (Library lib: libraries) { for (Function func: lib.getFunctions()) { try { ExtensionFunctionDefinition efd = new StaticFunctionExtension(func, config); logger.trace("registerExtensions; funtion {} registered as {}", func.toString(), efd.getFunctionQName()); config.registerExtensionFunction(efd); } catch (Exception ex) { logger.warn("registerExtensions; error registering function {}: {}; skipped", func.toString(), ex.getMessage()); } } } } private StaticQueryContext prepareStaticContext(String body) { StaticQueryContext sqc = config.newStaticQueryContext(); sqc.setErrorListener(new LocalErrorListener()); //sqc.setSchemaAware(true); - requires Saxon-EE sqc.setLanguageVersion(saxon_xquery_version); if (body != null) { sqc.setModuleURIResolver(new LocalModuleURIResolver(body)); } return sqc; } private XQueryExpression getModuleExpression(Module module) throws BagriException { //logger.trace("getModuleExpression.enter; got namespace: {}, name: {}, body: {}", namespace, name, body); String query = "import module namespace test=\"" + module.getNamespace() + "\" at \"" + module.getName() + "\";\n\n1213"; StaticQueryContext sqc = null; try { //sqc.compileLibrary(query); - works in Saxon-EE only sqc = prepareStaticContext(module.getBody()); logger.trace("getModuleExpression; compiling query: {}", query); //logger.trace("getModuleExpression.exit; time taken: {}", stamp); return sqc.compileQuery(query); //sqc.getCompiledLibrary("test")... } catch (XPathException ex) { String error = getError(ex, sqc); logger.error("getModuleExpression.error; " + error, ex); //logger.info("getModuleExpression.error; message: {}", error); throw new BagriException(error, BagriException.ecQueryCompile); } } private <R> List<R> lookupFunctions(FunctionLibraryList fll, FunctionExtractor<R> ext) { List<R> fl = new ArrayList<>(); for (FunctionLibrary lib: fll.getLibraryList()) { logger.trace("lookupFunctions; function library: {}; class: {}", lib.toString(), lib.getClass().getName()); if (lib instanceof FunctionLibraryList) { fl.addAll(lookupFunctions((FunctionLibraryList) lib, ext)); //} else if (lib instanceof ExecutableFunctionLibrary) { // ExecutableFunctionLibrary efl = (ExecutableFunctionLibrary) lib; // Iterator<UserFunction> itr = efl.iterateFunctions(); // while (itr.hasNext()) { // fl.add(getFunctionDeclaration(itr.next())); // } } else if (lib instanceof XQueryFunctionLibrary) { XQueryFunctionLibrary xqfl = (XQueryFunctionLibrary) lib; Iterator<XQueryFunction> itr = xqfl.getFunctionDefinitions(); while (itr.hasNext()) { XQueryFunction fn = itr.next(); logger.trace("lookupFunctions; fn: {}", fn.getDisplayName()); R result = ext.extractFunction(fn.getUserFunction()); if (result != null) { fl.add(result); } } } } return fl; } @Override public List<Function> getRestFunctions(Module module) throws BagriException { long stamp = System.currentTimeMillis(); logger.trace("getRestFunctions.enter; got module: {}", module); XQueryExpression exp = getModuleExpression(module); List<Function> result = lookupFunctions(exp.getExecutable().getFunctionLibrary(), new FunctionExtractor<Function>() { @Override public Function extractFunction(UserFunction fn) { logger.trace("extractFunction.enter; function: {}", fn); List<Annotation> atns = fn.getAnnotations(); if (!hasRestAnnotations(atns)) { logger.debug("extractFunction; no REST annotations found for function {}, skipping it", fn.getFunctionName().getDisplayName()); return null; } DataType type = new DataType(getTypeName(fn.getResultType().getPrimaryType()), getCardinality(fn.getResultType().getCardinality())); Function result = new Function(null, fn.getFunctionName().getLocalPart(), type, null, fn.getFunctionName().getPrefix()); for (UserFunctionParameter ufp: fn.getParameterDefinitions()) { Parameter param = new Parameter(ufp.getVariableQName().getLocalPart(), getTypeName(ufp.getRequiredType().getPrimaryType()), getCardinality(ufp.getRequiredType().getCardinality())); result.getParameters().add(param); } for (Annotation atn: atns) { String aName = atn.getAnnotationQName().getDisplayName(); if (aName.startsWith("rest:")) { result.addAnnotation(aName, null); if (atn.getAnnotationParameters() != null) { for (AtomicValue av: atn.getAnnotationParameters()) { result.addAnnotation(aName, av.getStringValue()); } } } } logger.trace("extractFunction.exit; returning: {}", result); return result; } }); stamp = System.currentTimeMillis() - stamp; logger.trace("getRestFunctions.exit; time taken: {}; returning: {}", stamp, result); return result; } private boolean hasRestAnnotations(List<Annotation> annotations) { for (Annotation atn: annotations) { if ("rest".equalsIgnoreCase(atn.getAnnotationQName().getPrefix())) { return true; } } return false; } private interface FunctionExtractor<R> { R extractFunction(UserFunction fn); } private class LocalErrorListener implements UnfailingErrorListener { private List<TransformerException> errors = new ArrayList<>(); public List<TransformerException> getErrors() { return errors; } @Override public void error(TransformerException txEx) { errors.add(txEx); } @Override public void fatalError(TransformerException txEx) { errors.add(txEx); } @Override public void warning(TransformerException txEx) { errors.add(txEx); } } private class LocalModuleURIResolver implements ModuleURIResolver { private String body; LocalModuleURIResolver(String body) { this.body = body; } @Override public StreamSource[] resolve(String moduleURI, String baseURI, String[] locations) throws XPathException { logger.trace("resolve.enter; got moduleURI: {}, baseURI: {}, locations: {}, body: {}", moduleURI, baseURI, locations, body); Reader reader = new StringReader(body); return new StreamSource[] {new StreamSource(reader)}; } } }