package com.taobao.yugong.common.utils.compile;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.tools.DiagnosticCollector;
import javax.tools.FileObject;
import javax.tools.ForwardingJavaFileManager;
import javax.tools.JavaCompiler;
import javax.tools.JavaCompiler.CompilationTask;
import javax.tools.JavaFileManager;
import javax.tools.JavaFileObject;
import javax.tools.JavaFileObject.Kind;
import javax.tools.SimpleJavaFileObject;
import javax.tools.StandardJavaFileManager;
import javax.tools.StandardLocation;
import javax.tools.ToolProvider;
import com.taobao.yugong.exception.YuGongException;
/**
* @author agapple 2014年2月25日 下午11:38:06
* @since 1.0.0
*/
public class JdkCompiler implements JavaSourceCompiler {
private List<String> options = new ArrayList<String>();
public JdkCompiler(){
}
public Class compile(String sourceString) {
return compile(new JavaSource(sourceString));
}
private Class compile(JavaSource javaSource) {
try {
JdkCompileTask compileTask = new JdkCompileTask(new JdkCompilerClassLoader(this.getClass().getClassLoader()),
options);
return compileTask.compile(javaSource.getPackageName(), javaSource.getClassName(), javaSource.getSource());
} catch (JdkCompileException ex) {
DiagnosticCollector<JavaFileObject> diagnostics = ex.getDiagnostics();
throw new YuGongException("source:" + javaSource + ", " + diagnostics.getDiagnostics(), ex);
} catch (Throwable ex) {
throw new YuGongException("source:" + javaSource, ex);
}
}
public static URI toURI(String name) {
try {
return new URI(name);
} catch (URISyntaxException e) {
throw new YuGongException(e);
}
}
public static class JdkCompileTask<T> {
public static final String EXTENSION = ".java";
public static final JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
private List<String> options;
private JdkCompilerClassLoader classLoader;
public JdkCompileTask(JdkCompilerClassLoader classLoader, List<String> options){
if (compiler == null) {
throw new YuGongException("Can't find java compiler , pls check tools.jar");
}
this.classLoader = classLoader;
this.options = options;
}
public synchronized Class compile(String packageName, String className, final CharSequence javaSource)
throws JdkCompileException,
ClassCastException {
DiagnosticCollector<JavaFileObject> diagnostics = new DiagnosticCollector<JavaFileObject>();
JavaFileManagerImpl javaFileManager = buildFileManager(classLoader, classLoader.getParent(), diagnostics);
JavaFileObjectImpl source = new JavaFileObjectImpl(className, javaSource);
javaFileManager.putFileForInput(StandardLocation.SOURCE_PATH, packageName, className + EXTENSION, source);
CompilationTask task = compiler.getTask(null,
javaFileManager,
diagnostics,
options,
null,
Arrays.asList(source));
final Boolean result = task.call();
if (result == null || !result.booleanValue()) {
throw new JdkCompileException("Compilation failed.", diagnostics);
}
try {
return (Class<T>) classLoader.loadClass(packageName + "." + className);
} catch (Throwable e) {
throw new JdkCompileException(e, diagnostics);
}
}
private JavaFileManagerImpl buildFileManager(JdkCompilerClassLoader classLoader, ClassLoader loader,
DiagnosticCollector<JavaFileObject> diagnostics) {
StandardJavaFileManager fileManager = compiler.getStandardFileManager(diagnostics, null, null);
if (loader instanceof URLClassLoader
&& (!"sun.misc.Launcher$AppClassLoader".equalsIgnoreCase(loader.getClass().getName()))) {
try {
URLClassLoader urlClassLoader = (URLClassLoader) loader;
List<File> paths = new ArrayList<File>();
for (URL url : urlClassLoader.getURLs()) {
File file = new File(url.getFile());
paths.add(file);
}
fileManager.setLocation(StandardLocation.CLASS_PATH, paths);
} catch (Throwable e) {
throw new YuGongException(e);
}
}
return new JavaFileManagerImpl(fileManager, classLoader);
}
}
public static class JavaFileManagerImpl extends ForwardingJavaFileManager<JavaFileManager> {
private final JdkCompilerClassLoader classLoader;
private final Map<URI, JavaFileObject> fileDatas = new HashMap<URI, JavaFileObject>();
public JavaFileManagerImpl(JavaFileManager fileManager, JdkCompilerClassLoader classLoader){
super(fileManager);
this.classLoader = classLoader;
}
public void putFileForInput(StandardLocation location, String packageName, String relativeName,
JavaFileObject file) {
fileDatas.put(clasURI(location, packageName, relativeName), file);
}
@Override
public FileObject getFileForInput(Location location, String packageName, String relativeName)
throws IOException {
FileObject o = fileDatas.get(clasURI(location, packageName, relativeName));
if (o == null) {
return super.getFileForInput(location, packageName, relativeName);
} else {
return o;
}
}
@Override
public String inferBinaryName(Location loc, JavaFileObject file) {
// 自定义实现
if (file instanceof JavaFileObjectImpl) {
return file.getName();
} else {
return super.inferBinaryName(loc, file);
}
}
@Override
public JavaFileObject getJavaFileForOutput(Location location, String qualifiedName, Kind kind,
FileObject outputFile) throws IOException {
JavaFileObject file = new JavaFileObjectImpl(qualifiedName, kind);
classLoader.add(qualifiedName, file);
return file;
}
@Override
public ClassLoader getClassLoader(JavaFileManager.Location location) {
return classLoader;
}
@Override
public Iterable<JavaFileObject> list(Location location, String packageName, Set<Kind> kinds, boolean recurse)
throws IOException {
Iterable<JavaFileObject> files = super.list(location, packageName, kinds, recurse);
List<JavaFileObject> result = new ArrayList<JavaFileObject>();
for (JavaFileObject file : files) {
result.add(file);
}
if (StandardLocation.CLASS_PATH == location && kinds.contains(JavaFileObject.Kind.CLASS)) {
for (JavaFileObject file : fileDatas.values()) {
if (Kind.CLASS == file.getKind() && file.getName().startsWith(packageName)) {
result.add(file);
}
}
result.addAll(classLoader.getAllFiles());
}
if (StandardLocation.SOURCE_PATH == location && kinds.contains(JavaFileObject.Kind.SOURCE)) {
for (JavaFileObject file : fileDatas.values()) {
if (Kind.SOURCE == file.getKind() && file.getName().startsWith(packageName)) {
result.add(file);
}
}
}
return result;
}
private URI clasURI(Location location, String packageName, String relativeName) {
return toURI(location.getName() + '/' + packageName + '/' + relativeName);
}
}
public static class JavaFileObjectImpl extends SimpleJavaFileObject {
private ByteArrayOutputStream byteCode = new ByteArrayOutputStream();
private CharSequence source;
public JavaFileObjectImpl(final String baseName, final CharSequence source){
super(toURI(baseName + JdkCompileTask.EXTENSION), Kind.SOURCE);
this.source = source;
}
public JavaFileObjectImpl(final String name, final Kind kind){
super(toURI(name), kind);
source = null;
}
public JavaFileObjectImpl(URI uri, Kind kind){
super(uri, kind);
source = null;
}
@Override
public CharSequence getCharContent(final boolean ignoreEncodingErrors) throws UnsupportedOperationException {
if (source == null) {
throw new UnsupportedOperationException();
} else {
return source;
}
}
@Override
public InputStream openInputStream() {
return new ByteArrayInputStream(byteCode.toByteArray());
}
@Override
public OutputStream openOutputStream() {
return byteCode;
}
public byte[] getBytes() {
return byteCode.toByteArray();
}
}
public final class JdkCompilerClassLoader extends ClassLoader {
private final Map<String, JavaFileObject> classes = new HashMap<String, JavaFileObject>();
public JdkCompilerClassLoader(ClassLoader parentClassLoader){
super(parentClassLoader);
}
public Collection<JavaFileObject> getAllFiles() {
return classes.values();
}
protected synchronized Class<?> findClass(String qualifiedClassName) throws ClassNotFoundException {
JavaFileObject file = classes.get(qualifiedClassName);
if (file != null) {
byte[] bytes = ((JavaFileObjectImpl) file).getBytes();
return defineClass(qualifiedClassName, bytes, 0, bytes.length);
}
try {
return Class.forName(qualifiedClassName);
} catch (ClassNotFoundException nf) {
// Ignore
}
try {
return Thread.currentThread().getContextClassLoader().loadClass(qualifiedClassName);
} catch (ClassNotFoundException nf) {
// Ignore
}
return super.findClass(qualifiedClassName);
}
public void add(String qualifiedClassName, final JavaFileObject javaFile) {
classes.put(qualifiedClassName, javaFile);
}
protected synchronized Class<?> loadClass(final String name, final boolean resolve)
throws ClassNotFoundException {
try {
Class c = findClass(name);
if (c != null) {
if (resolve) {
resolveClass(c);
}
return c;
}
} catch (ClassNotFoundException e) {
// Ignore and fall through
}
return super.loadClass(name, resolve);
}
public InputStream getResourceAsStream(final String name) {
if (name.endsWith(".class")) {
String qualifiedClassName = name.substring(0, name.length() - ".class".length()).replace('/', '.');
JavaFileObjectImpl file = (JavaFileObjectImpl) classes.get(qualifiedClassName);
if (file != null) {
return new ByteArrayInputStream(file.getBytes());
}
}
return super.getResourceAsStream(name);
}
public void clearCache() {
this.classes.clear();
}
public JavaFileObject getJavaFileObject(String qualifiedClassName) {
return classes.get(qualifiedClassName);
}
}
}