package scotch.compiler;
import static java.lang.management.ManagementFactory.getRuntimeMXBean;
import static java.util.regex.Pattern.compile;
import static java.util.stream.Collectors.toList;
import static scotch.symbol.Symbol.getPackageName;
import static scotch.symbol.Symbol.getPackagePath;
import static scotch.symbol.Symbol.toJavaName;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import com.google.common.collect.ImmutableSet;
import scotch.compiler.ModuleScanner.ScanResult;
import scotch.compiler.output.GeneratedClass;
import scotch.symbol.Symbol;
import scotch.symbol.Symbol.QualifiedSymbol;
import scotch.symbol.Symbol.SymbolVisitor;
import scotch.symbol.Symbol.UnqualifiedSymbol;
import scotch.symbol.SymbolEntry;
import scotch.symbol.SymbolResolver;
import scotch.symbol.descriptor.TypeInstanceDescriptor;
import scotch.symbol.descriptor.TypeParameterDescriptor;
import scotch.symbol.exception.SymbolResolutionError;
import scotch.symbol.type.ConstructorTypeDescriptor;
import scotch.symbol.type.FunctionTypeDescriptor;
import scotch.symbol.type.InstanceTypeDescriptor;
import scotch.symbol.type.SumTypeDescriptor;
import scotch.symbol.type.TypeDescriptor;
import scotch.symbol.type.VariableTypeDescriptor;
public class ClassLoaderResolver extends URLClassLoader implements SymbolResolver {
public static ClassLoaderResolver resolver(Optional<File> optionalOutputPath) {
return new ClassLoaderResolver(optionalOutputPath, ClassLoaderResolver.class.getClassLoader());
}
private final Optional<File> optionalOutputPath;
private final Map<Symbol, SymbolEntry> namedSymbols;
private final Set<String> searchedClasses;
private final Set<URL> searchedUrls;
private final Map<Symbol, Map<List<TypeParameterDescriptor>, Set<TypeInstanceDescriptor>>> typeInstances;
private final Map<Symbol, Set<TypeInstanceDescriptor>> typeInstancesByClass;
private final Map<List<TypeParameterDescriptor>, Set<TypeInstanceDescriptor>> typeInstancesByArguments;
private final Map<String, Set<TypeInstanceDescriptor>> typeInstancesByModule;
private final Map<String, Set<Class<?>>> definedClasses;
private final ReExportMap reExports;
public ClassLoaderResolver(Optional<File> optionalOutputPath, ClassLoader parent) {
this(optionalOutputPath, new URL[0], parent);
}
public ClassLoaderResolver(Optional<File> optionalOutputPath, URL[] urls, ClassLoader parent) {
super(urls, parent);
this.optionalOutputPath = optionalOutputPath;
this.namedSymbols = new HashMap<>();
this.searchedClasses = new HashSet<>();
this.searchedUrls = new HashSet<>();
this.typeInstances = new HashMap<>();
this.typeInstancesByClass = new HashMap<>();
this.typeInstancesByArguments = new HashMap<>();
this.typeInstancesByModule = new HashMap<>();
this.definedClasses = new HashMap<>();
this.reExports = new ReExportMap();
}
public Class<?> define(GeneratedClass generatedClass) {
writeClass(generatedClass);
return define_(generatedClass);
}
private void writeClass(GeneratedClass generatedClass) {
optionalOutputPath.ifPresent(outputPath -> writeClass(generatedClass, generatedClass.getBytes(), outputPath));
}
private Class<?> define_(GeneratedClass generatedClass) {
byte[] bytes = generatedClass.getBytes();
Class<?> clazz = defineClass(generatedClass.getClassName(), bytes, 0, bytes.length);
definedClasses
.computeIfAbsent(clazz.getName().replace(Pattern.quote("." + clazz.getSimpleName()) + "$", ""), k -> new HashSet<>())
.add(clazz);
return clazz;
}
public List<Class<?>> defineAll(List<GeneratedClass> generatedClasses) {
generatedClasses.forEach(this::writeClass);
return generatedClasses.stream()
.map(this::define_)
.collect(toList());
}
@Override
public Optional<SymbolEntry> getEntry(Symbol symbol) {
search(symbol);
if (namedSymbols.containsKey(symbol)) {
return Optional.ofNullable(namedSymbols.get(symbol));
} else {
return symbol.accept(new SymbolVisitor<Optional<SymbolEntry>>() {
@Override
public Optional<SymbolEntry> visit(QualifiedSymbol symbol) {
return reExports.qualify(symbol).flatMap(ClassLoaderResolver.this::getEntry);
}
@Override
public Optional<SymbolEntry> visit(UnqualifiedSymbol symbol) {
return Optional.empty();
}
});
}
}
@Override
public Set<TypeInstanceDescriptor> getTypeInstances(Symbol symbol, List<TypeDescriptor> types) {
search(symbol);
search(types);
return Optional.ofNullable(typeInstances.get(symbol))
.flatMap(instances -> instances.keySet().stream()
.filter(parameters -> parametersMatch(parameters, types))
.map(instances::get)
.findFirst())
.orElse(ImmutableSet.of());
}
@Override
public Set<TypeInstanceDescriptor> getTypeInstancesByModule(String moduleName) {
search(moduleName);
return typeInstancesByModule.getOrDefault(moduleName, ImmutableSet.of());
}
private String baseName(File file) {
String fileName = file.getName();
return fileName.substring(0, fileName.lastIndexOf('.'));
}
private File[] classFiles(File directory) {
File[] files = directory.listFiles(pathName -> pathName.isFile() && pathName.getName().endsWith(".class"));
return files == null ? new File[0] : files;
}
private boolean parametersMatch(List<TypeParameterDescriptor> parameters, List<TypeDescriptor> types) {
if (parameters.size() == types.size()) {
for (int i = 0; i < parameters.size(); i++) {
if (!(types.get(i) instanceof SumTypeDescriptor) || !parameters.get(i).matches(types.get(i))) {
return false;
}
}
return true;
}
return false;
}
private void processScan(String moduleName, ScanResult scan) {
reExports.addReExports(moduleName, scan.getReExports());
scan.getEntries().forEach(entry -> namedSymbols.put(entry.getSymbol(), entry));
scan.getInstances().forEach(typeInstance -> {
typeInstances
.computeIfAbsent(typeInstance.getTypeClass(), k -> new HashMap<>())
.computeIfAbsent(typeInstance.getParameters(), k -> new HashSet<>())
.add(typeInstance);
typeInstancesByClass.computeIfAbsent(typeInstance.getTypeClass(), k -> new HashSet<>()).add(typeInstance);
typeInstancesByArguments.computeIfAbsent(typeInstance.getParameters(), k -> new HashSet<>()).add(typeInstance);
typeInstancesByModule.computeIfAbsent(typeInstance.getModuleName(), k -> new HashSet<>()).add(typeInstance);
});
}
private Optional<Class<?>> resolveClass(String className) {
try {
return Optional.of(loadClass(className));
} catch (ClassNotFoundException exception) {
return Optional.empty();
}
}
private List<Class<?>> resolveClasses(URL resource, String packagePath) {
List<Class<?>> classes = new ArrayList<>();
try (ZipInputStream zipStream = new ZipInputStream(resource.openStream())) {
ZipEntry entry;
while (null != (entry = zipStream.getNextEntry())) {
try {
if (!entry.isDirectory()) {
String name = entry.getName();
Pattern pattern = compile("(" + packagePath + "/[^\\./]+)\\.class");
Matcher matcher = pattern.matcher(name);
if (matcher.find()) {
resolveClass(matcher.group(1).replace('/', '.')).ifPresent(classes::add);
}
}
} finally {
zipStream.closeEntry();
}
}
} catch (IOException exception) {
throw new SymbolResolutionError(exception);
}
return classes;
}
private List<Class<?>> resolveClasses(File directory, String packageName) {
List<Class<?>> classes = new ArrayList<>();
if (directory.exists()) {
for (File file : classFiles(directory)) {
resolveClass(packageName + '.' + baseName(file)).ifPresent(classes::add);
}
}
return classes;
}
private void search(List<TypeDescriptor> parameters) {
parameters.forEach(parameter -> parameter.accept(new TypeDescriptor.Visitor<Void>() {
@Override
public Void visit(ConstructorTypeDescriptor type) {
return null;
}
@Override
public Void visit(FunctionTypeDescriptor type) {
type.getArgument().accept(this);
type.getResult().accept(this);
return null;
}
@Override
public Void visit(InstanceTypeDescriptor type) {
type.getBinding().accept(this);
return null;
}
@Override
public Void visit(SumTypeDescriptor type) {
search(type.getSymbol());
type.getParameters().forEach(parameter -> parameter.accept(this));
return null;
}
@Override
public Void visit(VariableTypeDescriptor type) {
type.getContext().forEach(ClassLoaderResolver.this::search);
return null;
}
}));
}
private void search(Symbol symbol) {
symbol.accept(new SymbolVisitor<Void>() {
@Override
public Void visit(QualifiedSymbol symbol) {
search(symbol.getModuleName());
return null;
}
@Override
public Void visit(UnqualifiedSymbol symbol) {
return null;
}
});
}
private void search(String moduleName) {
List<Class<?>> classes = new ArrayList<>();
try {
Enumeration<URL> resources = getResources(getPackagePath(moduleName));
while (resources.hasMoreElements()) {
URL resource = resources.nextElement();
if (!searchedUrls.contains(resource)) {
searchedUrls.add(resource);
if (resource.getFile().contains("!")) {
String path = new File(resource.getFile()).getPath();
classes.addAll(resolveClasses(new URL(path.substring(0, path.indexOf('!'))), getPackagePath(moduleName)));
} else {
classes.addAll(resolveClasses(new File(resource.getFile()), getPackageName(moduleName)));
}
}
}
Optional
.ofNullable(definedClasses.get(toJavaName(moduleName)))
.ifPresent(cs -> cs.forEach(classes::add));
} catch (IOException exception) {
throw new SymbolResolutionError(exception);
}
classes.removeIf(c -> searchedClasses.contains(c.getName()));
classes.stream().map(Class::getName).forEach(searchedClasses::add);
processScan(moduleName, new ModuleScanner(moduleName, classes).scan());
}
private void writeClass(GeneratedClass generatedClass, byte[] bytes, File outputPath) {
File file = new File(outputPath, generatedClass.getClassName().replace('.', '/') + ".class");
if (!file.getParentFile().mkdirs() && !file.getParentFile().exists()) {
throw new RuntimeException("Can't define " + generatedClass.getClassName()
+ ", directory " + file.getParentFile() + " could not be created");
}
try (OutputStream classFile = new FileOutputStream(file)) {
classFile.write(bytes);
classFile.flush();
if (isDebug()) {
System.out.println("Class file written to: " + file.getAbsolutePath());
}
} catch (IOException exception) {
throw new RuntimeException(exception);
}
}
private boolean isDebug() {
return getRuntimeMXBean().getInputArguments().toString().contains("jdwp");
}
}