/*
* Copyright 2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.glowroot.common.repo.util;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.security.CodeSource;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.tools.Diagnostic;
import javax.tools.DiagnosticCollector;
import javax.tools.FileObject;
import javax.tools.ForwardingJavaFileManager;
import javax.tools.JavaCompiler;
import javax.tools.JavaFileManager;
import javax.tools.JavaFileObject;
import javax.tools.SimpleJavaFileObject;
import javax.tools.StandardJavaFileManager;
import javax.tools.StandardLocation;
import javax.tools.ToolProvider;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Charsets;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
public class Compilations {
private static final Pattern CLASS_NAME_PATTERN =
Pattern.compile("\\bpublic\\s+class\\s+([^\\s{]+)[\\s{]");
private static final Pattern PACKAGE_NAME_PATTERN =
Pattern.compile("^\\s*\\bpackage\\s+([^;]+)\\s*;");
public static Class<?> compile(String source) throws Exception {
JavaCompiler javaCompiler = ToolProvider.getSystemJavaCompiler();
DiagnosticCollector<JavaFileObject> diagnosticCollector =
new DiagnosticCollector<JavaFileObject>();
IsolatedClassLoader isolatedClassLoader = new IsolatedClassLoader();
StandardJavaFileManager standardFileManager = javaCompiler
.getStandardFileManager(diagnosticCollector, Locale.ENGLISH, Charsets.UTF_8);
standardFileManager.setLocation(StandardLocation.CLASS_PATH, getCompilationClassPath());
JavaFileManager fileManager =
new IsolatedJavaFileManager(standardFileManager, isolatedClassLoader);
try {
List<JavaFileObject> compilationUnits = Lists.newArrayList();
String className = getPublicClassName(source);
int index = className.lastIndexOf('.');
String simpleName;
if (index == -1) {
simpleName = className;
} else {
simpleName = className.substring(index + 1);
}
compilationUnits.add(new SourceJavaFileObject(simpleName, source));
JavaCompiler.CompilationTask task =
javaCompiler.getTask(null, fileManager, diagnosticCollector, null, null,
compilationUnits);
task.call();
List<Diagnostic<? extends JavaFileObject>> diagnostics =
diagnosticCollector.getDiagnostics();
if (!diagnostics.isEmpty()) {
List<String> compilationErrors = Lists.newArrayList();
for (Diagnostic<? extends JavaFileObject> diagnostic : diagnostics) {
compilationErrors.add(diagnostic.toString());
}
throw new CompilationException(compilationErrors);
}
if (className.equals("")) {
throw new CompilationException(ImmutableList.of("Class must be public"));
}
return isolatedClassLoader.loadClass(className);
} finally {
fileManager.close();
}
}
private static List<File> getCompilationClassPath() throws Exception {
// selenium-api, selenium-support and guava are needed for compilation
// cannot use default system classpath when running in a servlet container
return ImmutableList.of(getJarFile("org.openqa.selenium.WebDriver"),
getJarFile("org.openqa.selenium.support.ui.ExpectedConditions"),
getJarFile("com.google.common.base.Function"));
}
private static File getJarFile(String className) throws Exception {
CodeSource codeSource = Class.forName(className).getProtectionDomain().getCodeSource();
if (codeSource == null) {
throw new IllegalStateException("Code source is null for class: " + className);
}
URL location = codeSource.getLocation();
return new File(location.toURI());
}
@VisibleForTesting
static String getPublicClassName(String source) {
Matcher matcher = CLASS_NAME_PATTERN.matcher(source);
if (matcher.find()) {
return getPackagePrefix(source) + matcher.group(1);
} else {
return "";
}
}
private static String getPackagePrefix(String source) {
Matcher matcher = PACKAGE_NAME_PATTERN.matcher(source);
if (matcher.find()) {
return matcher.group(1) + ".";
} else {
return "";
}
}
@SuppressWarnings("serial")
public static class CompilationException extends Exception {
private final List<String> compilationErrors;
public CompilationException(List<String> compilationErrors) {
this.compilationErrors = compilationErrors;
}
public List<String> getCompilationErrors() {
return compilationErrors;
}
}
private static class IsolatedJavaFileManager
extends ForwardingJavaFileManager<JavaFileManager> {
private final IsolatedClassLoader loader;
private IsolatedJavaFileManager(JavaFileManager fileManager, IsolatedClassLoader loader) {
super(fileManager);
this.loader = loader;
}
@Override
public ClassLoader getClassLoader(Location location) {
return loader;
}
@Override
public JavaFileObject getJavaFileForOutput(JavaFileManager.Location location,
String className, JavaFileObject.Kind kind, FileObject sibling) throws IOException {
CompiledJavaFileObject javaFileObject = new CompiledJavaFileObject();
loader.compiledJavaFileObjects.put(className, javaFileObject);
return javaFileObject;
}
}
private static class IsolatedClassLoader extends ClassLoader {
private final Map<String, CompiledJavaFileObject> compiledJavaFileObjects =
Maps.newHashMap();
private IsolatedClassLoader() {
super(IsolatedClassLoader.class.getClassLoader());
}
@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
CompiledJavaFileObject compiledJavaFileObject = compiledJavaFileObjects.get(name);
if (compiledJavaFileObject == null) {
return super.findClass(name);
}
byte[] byteCode = compiledJavaFileObject.baos.toByteArray();
return defineClass(name, byteCode, 0, byteCode.length);
}
}
private static class SourceJavaFileObject extends SimpleJavaFileObject {
private final String source;
private SourceJavaFileObject(String simpleClassName, String source)
throws URISyntaxException {
super(URI.create(simpleClassName + JavaFileObject.Kind.SOURCE.extension),
JavaFileObject.Kind.SOURCE);
this.source = source;
}
@Override
public CharSequence getCharContent(boolean ignoreEncodingErrors) {
return source;
}
}
private static class CompiledJavaFileObject extends SimpleJavaFileObject {
private final ByteArrayOutputStream baos = new ByteArrayOutputStream();
protected CompiledJavaFileObject() {
super(URI.create(""), Kind.CLASS);
}
@Override
public OutputStream openOutputStream() throws IOException {
return baos;
}
}
}