package restx.classloader;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.net.URLConnection;
import java.util.Collection;
/**
* The {@link ClassLoader} for hot reloading.
*
* @author higa
* @since 1.0.0
*/
public class HotReloadingClassLoader extends ClassLoader {
/**
* The root package name.
*/
private final String rootPackageName;
private final ImmutableMap<String, Class> coldClasses;
/**
* Constructor
*
* @param parentClassLoader
* the parent class loader.
* @param rootPackageName
* the root package name
* @param coldClasses
* @throws NullPointerException
* if the rootPackageName parameter is null or if the
* coolPackageName parameter is null
*/
public HotReloadingClassLoader(ClassLoader parentClassLoader, String rootPackageName, ImmutableSet<Class> coldClasses)
throws NullPointerException {
super(parentClassLoader);
if (rootPackageName == null) {
throw new NullPointerException(
"The rootPackageName parameter is null.");
}
this.rootPackageName = rootPackageName;
ImmutableMap.Builder<String, Class> builder = ImmutableMap.builder();
for (Class coldClass : coldClasses) {
builder.put(coldClass.getName(), coldClass);
}
this.coldClasses = builder.build();
}
@Override
public Class<?> loadClass(String className, boolean resolve)
throws ClassNotFoundException {
if (isTarget(className)) {
Class<?> clazz = findLoadedClass(className);
if (clazz != null) {
return clazz;
}
clazz = coldClasses.get(className);
if (clazz != null) {
return clazz;
}
int index = className.lastIndexOf('.');
if (index >= 0) {
String packageName = className.substring(0, index);
if (getPackage(packageName) == null) {
try {
definePackage(
packageName,
null,
null,
null,
null,
null,
null,
null);
} catch (IllegalArgumentException ignore) {
}
}
}
clazz = defineClass(className, resolve);
if (clazz != null) {
return clazz;
}
}
return super.loadClass(className, resolve);
}
/**
* Defines the class.
*
* @param className
* the class name
* @param resolve
* whether the class is resolved
* @return the class
*/
protected Class<?> defineClass(String className, boolean resolve) {
Class<?> clazz;
String path = className.replace('.', '/') + ".class";
InputStream is = getInputStream(path);
if (is != null) {
clazz = defineClass(className, is);
if (resolve) {
resolveClass(clazz);
}
return clazz;
}
return null;
}
/**
* Defines the class.
*
* @param className
* the class name
* @param is
* the input stream
* @return the class
*/
protected Class<?> defineClass(String className, InputStream is) {
return defineClass(className, getBytes(is));
}
/**
* Defines the class.
*
* @param className
* the class name
* @param bytes
* the array of bytes.
* @return the class
*/
protected Class<?> defineClass(String className, byte[] bytes) {
return defineClass(className, bytes, 0, bytes.length);
}
/**
* Returns the input stream.
*
* @param path
* the path
* @return the input stream
*/
protected InputStream getInputStream(String path) {
try {
URL url = getResource(path);
if (url == null) {
return null;
}
URLConnection connection = url.openConnection();
connection.setUseCaches(false);
return connection.getInputStream();
} catch (IOException ignore) {
return null;
}
}
/**
* Returns input stream data as the array of bytes.
*
* @param is
* the input stream
* @return the array of bytes
* @throws RuntimeException
* if {@link IOException} is encountered
*/
protected byte[] getBytes(InputStream is) throws RuntimeException {
byte[] bytes = null;
byte[] buf = new byte[8192];
try {
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
int n = 0;
while ((n = is.read(buf, 0, buf.length)) != -1) {
baos.write(buf, 0, n);
}
bytes = baos.toByteArray();
} finally {
if (is != null) {
is.close();
}
}
} catch (IOException e) {
throw new RuntimeException(e);
}
return bytes;
}
/**
* Determines if the class is the target of hot deployment.
*
* @param className
* the class name
* @return whether the class is the target of hot deployment
*/
protected boolean isTarget(String className) {
if (!className.startsWith(rootPackageName + ".")) {
return false;
}
return true;
}
}