package com.netflix.governator.lifecycle;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Opcodes;
import javax.tools.*;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.net.URI;
import java.net.URL;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.jar.JarEntry;
import java.util.jar.JarOutputStream;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class JavaClasspath
{
Map<String, byte[]> classpath = new HashMap<>();
Map<String, File> jarByClass = new HashMap<>();
Path temp;
ClassLoader classLoader = byteClassLoader();
public JavaClasspath()
{
try
{
temp = Files.createTempDirectory("classes");
}
catch (IOException e)
{
throw new IllegalStateException(e);
}
}
public void cleanup()
{
temp.toFile().delete();
}
private ClassLoader byteClassLoader()
{
return new ClassLoader()
{
@Override protected Class<?> findClass(String name) throws ClassNotFoundException
{
byte[] b = classpath.get(name);
if(b == null)
throw new ClassNotFoundException(name);
return defineClass(name, b, 0, b.length);
}
@Override
protected Enumeration<URL> findResources(String name) throws IOException {
Set<URL> matchingJars = new HashSet<>();
for (Map.Entry<String, File> classToJar : jarByClass.entrySet()) {
if(classToJar.getKey().startsWith(name))
matchingJars.add(classToJar.getValue().toURI().toURL());
}
return Collections.enumeration(matchingJars);
}
};
}
private static String classNameFromSource(CharSequence source)
{
Matcher m = Pattern.compile("(class|interface)\\s+(\\w+)").matcher(source);
return m.find() ? m.group(2) : null;
}
public <T> Class<T> loadClass(String className)
{
try
{
return (Class<T>) classLoader.loadClass(className);
}
catch(ClassNotFoundException e)
{
throw new IllegalStateException(e);
}
}
@SuppressWarnings("unchecked")
public <T> T newInstance(String className, Object... args) throws ClassNotFoundException, NoSuchMethodException, InvocationTargetException, InstantiationException
{
Class<T> clazz = (Class<T>) loadClass(className);
nextConstructor: for (Constructor<?> constructor : clazz.getDeclaredConstructors())
{
if(args.length != constructor.getParameterTypes().length)
continue;
int i = 0;
for (Class type : constructor.getParameterTypes())
if(!type.isAssignableFrom(args[i++].getClass()))
continue nextConstructor;
constructor.setAccessible(true);
try
{
return (T) constructor.newInstance(args);
}
catch (IllegalAccessException e)
{
throw new RuntimeException(e); // should never happen
}
}
throw new IllegalArgumentException("No matching constructor for class " + className + " found for provided args");
}
private static class InMemoryJavaFileObject extends SimpleJavaFileObject
{
private String contents = null;
public InMemoryJavaFileObject(String contents)
{
super(URI.create("string:///" + classNameFromSource(contents) + ".java"), Kind.SOURCE);
this.contents = contents;
}
public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException
{
return contents;
}
}
private static DiagnosticListener<JavaFileObject> diagnosticListener = new DiagnosticListener<JavaFileObject>()
{
@Override
public void report(Diagnostic<? extends JavaFileObject> diagnostic)
{
System.out.println("ERROR compiling " + diagnostic.getSource().getName());
System.out.println("Line " + diagnostic.getLineNumber() + ": " + diagnostic.getMessage(Locale.ENGLISH));
}
};
/**
* Compiles all java sources and adds them to the rule's classpath
* @param javaSources
* @return the set of fully qualified class names compiled just in this invocation
*/
public Collection<String> compile(String... javaSources)
{
return compile(Arrays.asList(javaSources));
}
/**
* Compiles all java sources and adds them to the rule's classpath
* @param javaSources
* @return the set of fully qualified class names compiled just in this invocation
*/
public Collection<String> compile(Collection<String> javaSources)
{
Collection<InMemoryJavaFileObject> files = new ArrayList<>();
for (String javaSource : javaSources)
files.add(new InMemoryJavaFileObject(javaSource));
return compileInternal(files);
}
private Collection<String> compileInternal(Collection<InMemoryJavaFileObject> files)
{
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
StandardJavaFileManager fileManager = compiler.getStandardFileManager(diagnosticListener, Locale.ENGLISH, null);
try
{
JavaCompiler.CompilationTask task = compiler.getTask(null, fileManager, diagnosticListener,
Arrays.asList("-d", temp.toFile().getAbsolutePath(), "-cp", temp.toFile().getAbsolutePath(), "-g"), null, files);
task.call();
Map<String, byte[]> classes = new HashMap<>();
for (Path p : recurseListFiles(temp))
{
try
{
byte[] bytes = Files.readAllBytes(p);
classes.put(fullyQualifiedName(bytes), bytes);
}
catch (IOException e)
{
throw new IllegalStateException(e);
}
}
classpath.putAll(classes);
Set<String> classNames = new HashSet<>();
classIter: for (String c : classes.keySet())
{
String[] classNameParts = c.split("\\.");
for (InMemoryJavaFileObject file : files)
{
if(classNameFromSource(file.contents).equals(classNameParts[classNameParts.length - 1]))
{
classNames.add(c);
continue classIter;
}
}
}
return classNames;
}
catch (IOException e)
{
throw new RuntimeException(e);
}
}
private List<Path> recurseListFiles(Path path) throws IOException
{
List<Path> files = new ArrayList<>();
try (DirectoryStream<Path> stream = Files.newDirectoryStream(path))
{
for (Path entry : stream)
{
if (Files.isDirectory(entry))
files.addAll(recurseListFiles(entry));
else
files.add(entry);
}
}
return files;
}
public String fullyQualifiedName(byte[] classBytes)
{
final StringBuffer className = new StringBuffer();
ClassReader cr = new ClassReader(classBytes);
cr.accept(new ClassVisitor(Opcodes.ASM5) {
@Override
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces)
{
className.append(name);
}
}, 0);
return className.toString().replace("/", ".");
}
public File jar(File f, String... classSources)
{
f.getParentFile().mkdirs();
try
{
FileOutputStream fos = new FileOutputStream(f);
JarOutputStream jos = new JarOutputStream(fos);
for (String clazz : compile(classSources))
{
jos.putNextEntry(new JarEntry(clazz.replace(".", "/") + ".class"));
jos.write(classpath.get(clazz));
jarByClass.put(clazz.replace('.', '/'), f);
}
jos.close();
fos.close();
return f;
}
catch (IOException e)
{
throw new RuntimeException(e);
}
}
public byte[] classBytes(String className)
{
return classpath.get(className);
}
public ClassLoader getClassLoader()
{
return classLoader;
}
}