package com.taobao.yugong.common.utils.compile; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.URI; import java.net.URISyntaxException; import java.net.URL; import java.net.URLClassLoader; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import javax.tools.DiagnosticCollector; import javax.tools.FileObject; import javax.tools.ForwardingJavaFileManager; import javax.tools.JavaCompiler; import javax.tools.JavaCompiler.CompilationTask; import javax.tools.JavaFileManager; import javax.tools.JavaFileObject; import javax.tools.JavaFileObject.Kind; import javax.tools.SimpleJavaFileObject; import javax.tools.StandardJavaFileManager; import javax.tools.StandardLocation; import javax.tools.ToolProvider; import com.taobao.yugong.exception.YuGongException; /** * @author agapple 2014年2月25日 下午11:38:06 * @since 1.0.0 */ public class JdkCompiler implements JavaSourceCompiler { private List<String> options = new ArrayList<String>(); public JdkCompiler(){ } public Class compile(String sourceString) { return compile(new JavaSource(sourceString)); } private Class compile(JavaSource javaSource) { try { JdkCompileTask compileTask = new JdkCompileTask(new JdkCompilerClassLoader(this.getClass().getClassLoader()), options); return compileTask.compile(javaSource.getPackageName(), javaSource.getClassName(), javaSource.getSource()); } catch (JdkCompileException ex) { DiagnosticCollector<JavaFileObject> diagnostics = ex.getDiagnostics(); throw new YuGongException("source:" + javaSource + ", " + diagnostics.getDiagnostics(), ex); } catch (Throwable ex) { throw new YuGongException("source:" + javaSource, ex); } } public static URI toURI(String name) { try { return new URI(name); } catch (URISyntaxException e) { throw new YuGongException(e); } } public static class JdkCompileTask<T> { public static final String EXTENSION = ".java"; public static final JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); private List<String> options; private JdkCompilerClassLoader classLoader; public JdkCompileTask(JdkCompilerClassLoader classLoader, List<String> options){ if (compiler == null) { throw new YuGongException("Can't find java compiler , pls check tools.jar"); } this.classLoader = classLoader; this.options = options; } public synchronized Class compile(String packageName, String className, final CharSequence javaSource) throws JdkCompileException, ClassCastException { DiagnosticCollector<JavaFileObject> diagnostics = new DiagnosticCollector<JavaFileObject>(); JavaFileManagerImpl javaFileManager = buildFileManager(classLoader, classLoader.getParent(), diagnostics); JavaFileObjectImpl source = new JavaFileObjectImpl(className, javaSource); javaFileManager.putFileForInput(StandardLocation.SOURCE_PATH, packageName, className + EXTENSION, source); CompilationTask task = compiler.getTask(null, javaFileManager, diagnostics, options, null, Arrays.asList(source)); final Boolean result = task.call(); if (result == null || !result.booleanValue()) { throw new JdkCompileException("Compilation failed.", diagnostics); } try { return (Class<T>) classLoader.loadClass(packageName + "." + className); } catch (Throwable e) { throw new JdkCompileException(e, diagnostics); } } private JavaFileManagerImpl buildFileManager(JdkCompilerClassLoader classLoader, ClassLoader loader, DiagnosticCollector<JavaFileObject> diagnostics) { StandardJavaFileManager fileManager = compiler.getStandardFileManager(diagnostics, null, null); if (loader instanceof URLClassLoader && (!"sun.misc.Launcher$AppClassLoader".equalsIgnoreCase(loader.getClass().getName()))) { try { URLClassLoader urlClassLoader = (URLClassLoader) loader; List<File> paths = new ArrayList<File>(); for (URL url : urlClassLoader.getURLs()) { File file = new File(url.getFile()); paths.add(file); } fileManager.setLocation(StandardLocation.CLASS_PATH, paths); } catch (Throwable e) { throw new YuGongException(e); } } return new JavaFileManagerImpl(fileManager, classLoader); } } public static class JavaFileManagerImpl extends ForwardingJavaFileManager<JavaFileManager> { private final JdkCompilerClassLoader classLoader; private final Map<URI, JavaFileObject> fileDatas = new HashMap<URI, JavaFileObject>(); public JavaFileManagerImpl(JavaFileManager fileManager, JdkCompilerClassLoader classLoader){ super(fileManager); this.classLoader = classLoader; } public void putFileForInput(StandardLocation location, String packageName, String relativeName, JavaFileObject file) { fileDatas.put(clasURI(location, packageName, relativeName), file); } @Override public FileObject getFileForInput(Location location, String packageName, String relativeName) throws IOException { FileObject o = fileDatas.get(clasURI(location, packageName, relativeName)); if (o == null) { return super.getFileForInput(location, packageName, relativeName); } else { return o; } } @Override public String inferBinaryName(Location loc, JavaFileObject file) { // 自定义实现 if (file instanceof JavaFileObjectImpl) { return file.getName(); } else { return super.inferBinaryName(loc, file); } } @Override public JavaFileObject getJavaFileForOutput(Location location, String qualifiedName, Kind kind, FileObject outputFile) throws IOException { JavaFileObject file = new JavaFileObjectImpl(qualifiedName, kind); classLoader.add(qualifiedName, file); return file; } @Override public ClassLoader getClassLoader(JavaFileManager.Location location) { return classLoader; } @Override public Iterable<JavaFileObject> list(Location location, String packageName, Set<Kind> kinds, boolean recurse) throws IOException { Iterable<JavaFileObject> files = super.list(location, packageName, kinds, recurse); List<JavaFileObject> result = new ArrayList<JavaFileObject>(); for (JavaFileObject file : files) { result.add(file); } if (StandardLocation.CLASS_PATH == location && kinds.contains(JavaFileObject.Kind.CLASS)) { for (JavaFileObject file : fileDatas.values()) { if (Kind.CLASS == file.getKind() && file.getName().startsWith(packageName)) { result.add(file); } } result.addAll(classLoader.getAllFiles()); } if (StandardLocation.SOURCE_PATH == location && kinds.contains(JavaFileObject.Kind.SOURCE)) { for (JavaFileObject file : fileDatas.values()) { if (Kind.SOURCE == file.getKind() && file.getName().startsWith(packageName)) { result.add(file); } } } return result; } private URI clasURI(Location location, String packageName, String relativeName) { return toURI(location.getName() + '/' + packageName + '/' + relativeName); } } public static class JavaFileObjectImpl extends SimpleJavaFileObject { private ByteArrayOutputStream byteCode = new ByteArrayOutputStream(); private CharSequence source; public JavaFileObjectImpl(final String baseName, final CharSequence source){ super(toURI(baseName + JdkCompileTask.EXTENSION), Kind.SOURCE); this.source = source; } public JavaFileObjectImpl(final String name, final Kind kind){ super(toURI(name), kind); source = null; } public JavaFileObjectImpl(URI uri, Kind kind){ super(uri, kind); source = null; } @Override public CharSequence getCharContent(final boolean ignoreEncodingErrors) throws UnsupportedOperationException { if (source == null) { throw new UnsupportedOperationException(); } else { return source; } } @Override public InputStream openInputStream() { return new ByteArrayInputStream(byteCode.toByteArray()); } @Override public OutputStream openOutputStream() { return byteCode; } public byte[] getBytes() { return byteCode.toByteArray(); } } public final class JdkCompilerClassLoader extends ClassLoader { private final Map<String, JavaFileObject> classes = new HashMap<String, JavaFileObject>(); public JdkCompilerClassLoader(ClassLoader parentClassLoader){ super(parentClassLoader); } public Collection<JavaFileObject> getAllFiles() { return classes.values(); } protected synchronized Class<?> findClass(String qualifiedClassName) throws ClassNotFoundException { JavaFileObject file = classes.get(qualifiedClassName); if (file != null) { byte[] bytes = ((JavaFileObjectImpl) file).getBytes(); return defineClass(qualifiedClassName, bytes, 0, bytes.length); } try { return Class.forName(qualifiedClassName); } catch (ClassNotFoundException nf) { // Ignore } try { return Thread.currentThread().getContextClassLoader().loadClass(qualifiedClassName); } catch (ClassNotFoundException nf) { // Ignore } return super.findClass(qualifiedClassName); } public void add(String qualifiedClassName, final JavaFileObject javaFile) { classes.put(qualifiedClassName, javaFile); } protected synchronized Class<?> loadClass(final String name, final boolean resolve) throws ClassNotFoundException { try { Class c = findClass(name); if (c != null) { if (resolve) { resolveClass(c); } return c; } } catch (ClassNotFoundException e) { // Ignore and fall through } return super.loadClass(name, resolve); } public InputStream getResourceAsStream(final String name) { if (name.endsWith(".class")) { String qualifiedClassName = name.substring(0, name.length() - ".class".length()).replace('/', '.'); JavaFileObjectImpl file = (JavaFileObjectImpl) classes.get(qualifiedClassName); if (file != null) { return new ByteArrayInputStream(file.getBytes()); } } return super.getResourceAsStream(name); } public void clearCache() { this.classes.clear(); } public JavaFileObject getJavaFileObject(String qualifiedClassName) { return classes.get(qualifiedClassName); } } }