/******************************************************************************* * Copyright (c) 2012-2016 Codenvy, S.A. * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html * * Contributors: * Codenvy, S.A. - initial API and implementation *******************************************************************************/ package org.everrest.groovy; import groovy.lang.GroovyClassLoader; import org.codehaus.groovy.ast.ClassNode; import org.codehaus.groovy.ast.ModuleNode; import org.codehaus.groovy.control.CompilationFailedException; import org.codehaus.groovy.control.CompilationUnit; import org.codehaus.groovy.control.CompilerConfiguration; import org.codehaus.groovy.control.Phases; import org.codehaus.groovy.control.SourceUnit; import java.io.InputStream; import java.net.MalformedURLException; import java.net.URL; import java.security.CodeSource; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; /** * @author andrew00x */ @SuppressWarnings({"unchecked"}) public class ExtendedGroovyClassLoader extends GroovyClassLoader { public static final String CODE_BASE = "/groovy/script/jaxrs"; public static class SingleClassCollector extends GroovyClassLoader.ClassCollector { protected final CompilationUnit cunit; protected final SourceUnit sunit; protected Class target; protected SingleClassCollector(ExtendedInnerLoader cl, CompilationUnit cunit, SourceUnit sunit) { super(cl, cunit, sunit); this.cunit = cunit; this.sunit = sunit; } @Override protected Class createClass(byte[] code, ClassNode classNode) { ExtendedInnerLoader cl = (ExtendedInnerLoader)getDefiningClassLoader(); Class clazz = cl.defineClass(classNode.getName(), code, cunit.getAST().getCodeSource()); getLoadedClasses().add(clazz); if (target == null) { ClassNode targetClassNode = null; SourceUnit targetSunit = null; ModuleNode module = classNode.getModule(); if (module != null) { targetClassNode = module.getClasses().get(0); targetSunit = module.getContext(); } if (targetSunit == sunit && targetClassNode == classNode) { target = clazz; } } return clazz; } public Class getTarget() { return target; } } public static class MultipleClassCollector extends GroovyClassLoader.ClassCollector { protected final CompilationUnit cunit; protected final Set<SourceUnit> sunitSet; private final List<Class> compiledClasses; protected MultipleClassCollector(ExtendedInnerLoader cl, CompilationUnit cunit, Set<SourceUnit> sunitSet) { super(cl, cunit, null); this.cunit = cunit; this.sunitSet = sunitSet; this.compiledClasses = new ArrayList<>(); } @Override protected Class createClass(byte[] code, ClassNode classNode) { ExtendedInnerLoader cl = (ExtendedInnerLoader)getDefiningClassLoader(); Class clazz = cl.defineClass(classNode.getName(), code, cunit.getAST().getCodeSource()); getLoadedClasses().add(clazz); ModuleNode module = classNode.getModule(); if (module != null) { SourceUnit currentSunit = module.getContext(); if (sunitSet.contains(currentSunit)) { compiledClasses.add(clazz); } } return clazz; } public List<Class> getCompiledClasses() { return compiledClasses; } } public static class ExtendedInnerLoader extends GroovyClassLoader.InnerLoader { public ExtendedInnerLoader(ExtendedGroovyClassLoader parent) { super(parent); } protected Class defineClass(String name, byte[] code, CodeSource cs) { return super.defineClass(name, code, 0, code.length, cs); } protected void definePackage(String name) throws IllegalArgumentException { Package pkg = getPackage(name); if (pkg == null) { super.definePackage(name, null, null, null, null, null, null, null); } } } public ExtendedGroovyClassLoader(ClassLoader classLoader) { super(classLoader); } public ExtendedGroovyClassLoader(GroovyClassLoader parent) { super(parent); } public Class parseClass(InputStream in, String fileName, SourceFile[] files) throws CompilationFailedException { return doParseClass(in, fileName, files, Phases.CLASS_GENERATION, null, false); } protected Class doParseClass(InputStream in, String fileName, SourceFile[] files, int phase, CompilerConfiguration config, boolean shouldCacheSource) throws CompilationFailedException { synchronized (sourceCache) { Class target = sourceCache.get(fileName); if (target == null) { CodeSource cs = new CodeSource(getCodeSource(), (java.security.cert.Certificate[])null); CompilationUnit cunit = createCompilationUnit(config, cs); SourceUnit targetSunit = cunit.addSource(fileName, in); if (files != null) { for (int i = 0; i < files.length; i++) { cunit.addSource(files[i].getPath()); } } SingleClassCollector collector = createSingleCollector(cunit, targetSunit); cunit.setClassgenCallback(collector); cunit.compile(phase); for (Iterator iter = collector.getLoadedClasses().iterator(); iter.hasNext(); ) { Class clazz = (Class)iter.next(); String classname = clazz.getName(); int i = classname.lastIndexOf('.'); if (i != -1) { String pkgname = classname.substring(0, i); Package pkg = getPackage(pkgname); if (pkg == null) { definePackage(pkgname, null, null, null, null, null, null, null); } } setClassCacheEntry(clazz); } target = collector.getTarget(); if (shouldCacheSource) { sourceCache.put(fileName, target); } } return target; } } public Class[] parseClasses(SourceFile[] files) { return doParseClasses(files, Phases.CLASS_GENERATION, null); } protected Class[] doParseClasses(SourceFile[] sources, int phase, CompilerConfiguration config) { synchronized (classCache) { CodeSource cs = new CodeSource(getCodeSource(), (java.security.cert.Certificate[])null); CompilationUnit cunit = createCompilationUnit(config, cs); Set<SourceUnit> setSunit = new HashSet<>(); for (int i = 0; i < sources.length; i++) { setSunit.add(cunit.addSource(sources[i].getPath())); } MultipleClassCollector collector = createMultipleCollector(cunit, setSunit); cunit.setClassgenCallback(collector); cunit.compile(phase); for (Iterator iter = collector.getLoadedClasses().iterator(); iter.hasNext(); ) { Class clazz = (Class)iter.next(); String classname = clazz.getName(); int i = classname.lastIndexOf('.'); if (i != -1) { String pkgname = classname.substring(0, i); Package pkg = getPackage(pkgname); if (pkg == null) { definePackage(pkgname, null, null, null, null, null, null, null); } } setClassCacheEntry(clazz); } List<Class> compiledClasses = collector.getCompiledClasses(); return compiledClasses.toArray(new Class[compiledClasses.size()]); } } /** * @see groovy.lang.GroovyClassLoader#createCompilationUnit(org.codehaus.groovy.control.CompilerConfiguration, * java.security.CodeSource) */ @Override protected CompilationUnit createCompilationUnit(CompilerConfiguration config, CodeSource cs) { return new CompilationUnit(config, cs, this); } protected SingleClassCollector createSingleCollector(CompilationUnit unit, SourceUnit sunit) { ExtendedInnerLoader loader = new ExtendedInnerLoader(ExtendedGroovyClassLoader.this); return new SingleClassCollector(loader, unit, sunit); } protected MultipleClassCollector createMultipleCollector(CompilationUnit unit, Set<SourceUnit> setSunit) { ExtendedInnerLoader loader = new ExtendedInnerLoader(ExtendedGroovyClassLoader.this); return new MultipleClassCollector(loader, unit, setSunit); } protected URL getCodeSource() { return getCodeSource(CODE_BASE); } private URL getCodeSource(String codeBase) { try { return new URL("file", "", codeBase); } catch (MalformedURLException e) { throw new IllegalArgumentException(String.format("Unable create code source URL from: %s. %s", codeBase, e.getMessage())); } } }