package org.enumerable.lambda.weaving;
import org.enumerable.lambda.exception.LambdaWeavingNotEnabledException;
import org.enumerable.lambda.weaving.tree.LambdaTreeTransformer;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.IllegalClassFormatException;
import java.lang.instrument.Instrumentation;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.security.ProtectionDomain;
import static java.lang.System.*;
import static java.lang.Thread.currentThread;
import static org.enumerable.lambda.exception.UncheckedException.uncheck;
import static org.enumerable.lambda.weaving.ClassFilter.createClassFilter;
import static org.enumerable.lambda.weaving.Debug.debug;
import static org.enumerable.lambda.weaving.Version.getVersionString;
public class LambdaLoader extends ClassLoader implements ClassFileTransformer {
private static boolean isEnabled;
private static boolean transformationFailed;
static String weavingNotEnabledMessage = "Please start the JVM with -javaagent:enumerable-java-"
+ Version.getVersion() + ".jar";
private ClassFilter filter;
static {
isEnabled = LambdaLoader.class.getClassLoader().getResource(LambdaCompiler.AOT_COMPILED_MARKER) != null;
}
public LambdaLoader() {
this(createClassFilter());
}
public LambdaLoader(ClassFilter filter) {
this.filter = filter;
}
/**
* Allows you to query the Lambda weaver at runtime to see if it's enabled.
*/
public static boolean isEnabled() {
return isEnabled && !transformationFailed;
}
/**
* This method can be used as a guard clause in your code, potentially
* throwing a {@link LambdaWeavingNotEnabledException}.
*/
@SuppressWarnings("unused")
public static void ensureIsEnabled() {
if (!isEnabled())
throw new LambdaWeavingNotEnabledException();
}
/**
* This method can be used as a guard clause in your code, exiting the VM if
* weaving isn't enabled.
*/
@SuppressWarnings("unused")
public static void ensureIsEnabledOrExit() {
if (!isEnabled()) {
err.println(LambdaLoader.getNotEnabledMessage());
System.exit(1);
}
}
/**
* This method can be used early in a main method to allow it to reload the
* caller in the same process with lambda weaving enabled if it is currently
* disabled. Control will not normally be returned to the caller as the VM
* will be exited after the reloaded main has finished.
* <p>
* If waving is already enabled, this method just returns.
* <p>
* This method is mainly intended to be used as a convenience in smaller
* applications.
*/
public static void bootstrapMainIfNotEnabledAndExitUponItsReturn(String[] args) {
if (!isEnabled()) {
StackTraceElement caller = currentThread().getStackTrace()[2];
if ("main".equals(caller.getMethodName())) {
try {
String className = caller.getClassName();
out.println(getNotEnabledMessage());
out.println("Will try to reload " + className + " in the same process:");
launchApplication(className, args);
exit(0);
} catch (Exception e) {
throw uncheck(e);
}
}
throw new IllegalStateException("Must be called from a main method.");
}
}
/**
* Loads a class in a new class loader with lambda weaving enabled and
* invokes its main method.
*/
public static Object launchApplication(String className, String[] args) throws ClassNotFoundException,
NoSuchMethodException, IllegalAccessException, InvocationTargetException {
debug("[main] " + getVersionString());
isEnabled = true;
Class<?> c = new LambdaLoader().loadClass(className);
Method m = c.getMethod("main", String[].class);
return m.invoke(null, new Object[] { args });
}
public static String getNotEnabledMessage() {
return weavingNotEnabledMessage;
}
LambdaTreeTransformer transformer = new LambdaTreeTransformer();
protected synchronized Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
InputStream in = getResourceAsStream(name.replace('.', '/') + ".class");
try {
byte[] b = transformClass(this, name, in);
if (b == null)
return super.loadClass(name, resolve);
return defineClass(name, b, 0, b.length);
} catch (Exception e) {
throw uncheck(e);
} finally {
try {
if (in != null)
in.close();
} catch (IOException silent) {
}
}
}
public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined,
ProtectionDomain protectionDomain, byte[] classfileBuffer) throws IllegalClassFormatException {
try {
return transformClass(loader != null ? loader : ClassLoader.getSystemClassLoader(), className.replace('/', '.'), new ByteArrayInputStream(classfileBuffer));
} catch (Throwable t) {
t.printStackTrace();
return null;
}
}
public byte[] transformClass(ClassLoader loader, String name, InputStream in) {
try {
if (!filter.isToBeInstrumented(name) || transformationFailed)
return null;
return transformer.transform(loader, filter, name, in);
} catch (Throwable t) {
transformationFailed = true;
weavingNotEnabledMessage = t.getMessage();
err.println(getVersionString());
err.println("caught throwable while transforming " + name + ", transformation is disabled from here on");
throw uncheck(t);
}
}
public static void premain(String agentArgs, Instrumentation instrumentation) {
debug("[premain] " + getVersionString());
isEnabled = true;
instrumentation.addTransformer(new LambdaLoader());
}
@SuppressWarnings({"UnusedDeclaration"})
public static void agentmain(String agentArgs, Instrumentation instrumentation) {
debug("[agentmain] " + getVersionString());
isEnabled = true;
instrumentation.addTransformer(new LambdaLoader());
}
public static void main(String[] args) throws Throwable {
try {
if (args.length == 0) {
System.out.println("[launcher] " + getVersionString());
out.println("Usage: class [ARGS]...");
return;
} else
debug("[launcher] " + getVersionString());
String[] argsCopy = new String[args.length - 1];
arraycopy(args, 1, argsCopy, 0, args.length - 1);
launchApplication(args[0], argsCopy);
} catch (InvocationTargetException e) {
throw e.getCause();
}
}
}