/*
* Copyright, Aspect Security, Inc.
*
* This file is part of JavaSnoop.
*
* JavaSnoop is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JavaSnoop is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with JavaSnoop. If not, see <http://www.gnu.org/licenses/>.
*/
package com.aspect.snoop.agent.manager;
import com.aspect.snoop.agent.AgentLogger;
import com.aspect.snoop.util.ReflectionUtil;
import java.io.IOException;
import java.lang.instrument.ClassDefinition;
import java.lang.instrument.Instrumentation;
import java.lang.instrument.UnmodifiableClassException;
import java.lang.reflect.AccessibleObject;
import java.lang.reflect.Method;
import java.net.URL;
import java.security.CodeSource;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import javassist.CannotCompileException;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.NotFoundException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Map;
import javassist.ByteArrayClassPath;
import javassist.CtBehavior;
import javassist.CtMethod;
import javassist.CtConstructor;
import javassist.expr.ExprEditor;
import javassist.expr.MethodCall;
import javassist.LoaderClassPath;
public class InstrumentationManager {
private HashMap<Integer,ClassHistory> modifiedClasses;
private Instrumentation inst;
private List<ClassLoader> classloaders;
HashMap<URL, SmartURLClassPath> urlSources;
public List<String> getLoadedClassesAsStrings() {
List<String> classes = new ArrayList<String>();
for ( Class c : inst.getAllLoadedClasses() ) {
if ( ! c.isArray() && ! c.isPrimitive() && ! c.isSynthetic() ) {
classes.add( c.getName() );
}
}
return classes;
}
public List<Class> getLoadedClasses() {
List<Class> classes = new ArrayList<Class>();
for ( Class c : inst.getAllLoadedClasses() ) {
if ( ! c.isArray() && ! c.isPrimitive() && ! c.isSynthetic() ) {
classes.add( c );
}
}
return classes;
}
public InstrumentationManager(Instrumentation inst) {
this.inst = inst;
this.modifiedClasses = new HashMap<Integer,ClassHistory>();
this.classloaders = new ArrayList<ClassLoader>();
this.urlSources = new HashMap<URL, SmartURLClassPath>();
updateClassPool();
}
public List<URL> getCodeSourceURLs() {
List<URL> urls = new ArrayList<URL>();
for(URL u : urlSources.keySet()) {
urls.add(u);
}
return urls;
}
public final void updateClassPool() {
ClassPool classPool = ClassPool.getDefault();
for ( Class c : inst.getAllLoadedClasses() ) {
CodeSource cs = c.getProtectionDomain().getCodeSource();
if ( cs != null && cs.getLocation() != null ) {
URL url = cs.getLocation();
SmartURLClassPath cp = urlSources.get(url);
if ( cp == null ) {
cp = new SmartURLClassPath(url);
urlSources.put(url, cp);
classPool.appendClassPath( cp );
AgentLogger.debug("Adding " + url.toExternalForm() + " to classpath lookup");
}
cp.addClass(c.getName());
}
ClassLoader cl = c.getClassLoader();
if ( cl != null && ! classloaders.contains(cl)) {
classloaders.add(cl);
classPool.insertClassPath(new LoaderClassPath(cl));
}
}
}
public boolean hasClassBeenModified(String clazz)
throws ClassNotFoundException {
return hasClassBeenModified(Class.forName(clazz));
}
public boolean hasClassBeenModified(Class c) {
return modifiedClasses.get(c.hashCode()) != null;
}
public void resetClass(Class clazz)
throws ClassNotFoundException, UnmodifiableClassException {
ClassHistory history = modifiedClasses.get(clazz.hashCode());
if ( history != null ) {
// re-instrument original code back in
ClassDefinition def = new ClassDefinition(clazz, history.getOriginalClass());
inst.redefineClasses(def);
modifiedClasses.remove(clazz.hashCode());
}
}
public void ensureClassIsLoaded(String clazz, ClassLoader loader)
throws ClassNotFoundException {
Class.forName(clazz, true, loader);
}
public void deinstrument(Class clazz)
throws InstrumentationException {
ClassHistory history = modifiedClasses.get(clazz.hashCode());
try {
if ( history == null ) {
throw new InstrumentationException("Class to deinstrument '" + clazz.getName() + "' not found in history");
}
ClassDefinition definition = new ClassDefinition(clazz, history.getOriginalClass());
inst.redefineClasses(definition);
AgentLogger.debug("Just de-instrumented " + clazz.getName());
} catch (ClassNotFoundException cnfe) {
throw new InstrumentationException(cnfe);
} catch (UnmodifiableClassException cnfe) {
throw new InstrumentationException(cnfe);
}
}
public void instrument(Class clazz,MethodChanges[] methodChanges) throws InstrumentationException {
// step #1: get original class
try {
ClassPool classPool = ClassPool.getDefault();
CtClass cls = classPool.get(clazz.getName());
// get the original bytecode so we can change our mind later
ClassHistory ch = modifiedClasses.get(clazz.hashCode());
byte[] originalByteCode = null;
byte[] lastVersionByteCode = null;
if ( ch != null ) {
// we've instrumented this class before. we've got to
// be tricky here.
originalByteCode = ch.getOriginalClass();
AgentLogger.trace("Restoring saved bytes for " + clazz.getName() + " (" + md5(originalByteCode) + ")");
ClassPool cp = new ClassPool(classPool);
cp.childFirstLookup = true;
cp.insertClassPath(new ByteArrayClassPath(clazz.getName(),originalByteCode));
cls = cp.get(clazz.getName());
cp.childFirstLookup = false;
AgentLogger.trace("Retrieved bytes after save: " + md5(cls.toBytecode()));
lastVersionByteCode = ch.getCurrentClass();
} else {
originalByteCode = cls.toBytecode();
AgentLogger.trace("Instrumenting new class " + clazz.getName() + " (" + md5(originalByteCode) + ")");
lastVersionByteCode = originalByteCode;
}
// unfreeze the class so we can modify it
cls.defrost();
for ( MethodChanges change : methodChanges ) {
AccessibleObject methodToChange = change.getMethod();
// get the parameters in order so we can get the method to instrument
Class[] parameterTypes = ReflectionUtil.getParameterTypes(methodToChange);
CtClass[] classes = new CtClass[parameterTypes.length];
//System.out.println(clazz.getName() + ": " + change.getUniqueMethod().getName() + "(" + parameterTypes.length);
for(int i=0;i<parameterTypes.length;i++)
classes[i] = classPool.get(parameterTypes[i].getName());
// get the method to instrument
String methodName = null;
if ( methodToChange instanceof Method )
methodName = ((Method)methodToChange).getName();
else
methodName = "<init>";
CtMethod method = null;
if ( "<init>".equals(methodName)) {
CtConstructor myConstructor = new CtConstructor(classes, cls);
myConstructor = cls.getDeclaredConstructor(classes);
method = myConstructor.toMethod("<init>", cls);
} else {
method = cls.getDeclaredMethod(methodName, classes);
}
// instrument the method, adding any necessary vars first
LocalVariable[] newVars = change.getNewLocalVariables();
for(int i=0;i<newVars.length;i++) {
LocalVariable newVar = newVars[i];
method.addLocalVariable(newVar.getName(), newVar.getType());
}
AgentLogger.trace("Adding to class " + clazz.getName());
if ( change.getNewStartSrc().length() > 0 ) {
AgentLogger.trace("Compiling code at beginnging of function:");
AgentLogger.trace(change.getNewStartSrc());
method.insertBefore( " { " + change.getNewStartSrc() + " } ");
}
if ( change.getNewEndSrc().length() > 0 ) {
AgentLogger.trace("Compiling code in place of function:");
AgentLogger.trace(change.getNewEndSrc());
method.setBody(" { " + change.getNewEndSrc() + " } ");
AgentLogger.debug("Done bytecode modification for " + clazz.getName());
}
}
// save the instrumented version of the class
byte[] newByteCode = cls.toBytecode();
ClassDefinition definition = new ClassDefinition(clazz, newByteCode);
try {
inst.redefineClasses(definition);
} catch (VerifyError error) {
//logger.error(error);
}
// save the original
ClassHistory history = new ClassHistory(clazz,originalByteCode,newByteCode);
history.setLastClass(lastVersionByteCode);
modifiedClasses.put(clazz.hashCode(), history);
} catch (UnmodifiableClassException uce) {
throw new InstrumentationException(uce);
} catch (ClassNotFoundException cnfe) {
throw new InstrumentationException(cnfe);
} catch (IOException ioe) {
throw new InstrumentationException(ioe);
} catch (CannotCompileException cce) {
throw new InstrumentationException(cce);
} catch (NotFoundException nfe) {
throw new InstrumentationException(nfe);
}
}
Map<String,byte[]> classBytes = new HashMap<String,byte[]>();
public byte[] getClassBytes(String clazz) {
try {
byte[] bytes = classBytes.get(clazz);
if ( bytes != null ) {
return bytes;
}
CtClass cls = ClassPool.getDefault().get(clazz);
bytes = cls.toBytecode();
classBytes.put(clazz,bytes);
return bytes;
} catch (IOException ex) {
//logger.error(ex);
} catch (CannotCompileException ex) {
//logger.error(ex);
} catch (NotFoundException ex) {
// this will occasionally with applet-loading related classes (com.sun.deploy, sun.reflect, etc.)
}
return null;
}
public Class getFromAllClasses(String className) throws ClassNotFoundException {
Class[] allClasses = inst.getAllLoadedClasses();
for ( Class c : allClasses ) {
if ( c.getName().equals(className)) {
return c;
}
}
try {
/*
* If the target process was started through "Start & Snoop" rather
* than through "Attach & Snoop" then the class might not have
* been in the initial list (in fact, it probably wasn't).
*
* This means we have to get some other handle to the class. For
* now, we only try a simple Class.forName(). In the future, we
* should try to get a fresh copy of all the loaded classloaders
* (how? I don't know) and try to load the class with each until
* we find one that is responsible for it.
*/
Class cls = Class.forName(className);
return cls;
} catch (Throwable t){ }
throw new ClassNotFoundException(className);
}
public Class getFromAllClasses(int hash) throws ClassNotFoundException {
Class[] allClasses = inst.getAllLoadedClasses();
for ( Class c : allClasses ) {
if ( c.hashCode() == hash) {
return c;
}
}
throw new ClassNotFoundException("For hash: " + hash);
}
public void resetAllClasses() throws InstrumentationException {
for(Integer i: modifiedClasses.keySet()) {
try {
Class c = getFromAllClasses(i.intValue());
deinstrument(c);
} catch (ClassNotFoundException e) {
//logger.error("Couldn't find class from hash " + i);
}
}
}
public static String md5(byte[] bytes) {
String res = "";
try {
MessageDigest algorithm = MessageDigest.getInstance("MD5");
algorithm.reset();
algorithm.update(bytes);
byte[] md5 = algorithm.digest();
String tmp = "";
for (int i = 0; i < md5.length; i++) {
tmp = (Integer.toHexString(0xFF & md5[i]));
if (tmp.length() == 1) {
res += "0" + tmp;
} else {
res += tmp;
}
}
} catch (NoSuchAlgorithmException ex) { }
return res;
}
public List<ClassLoader> getClassLoaders() {
return classloaders;
}
}