package xapi.javac.dev.util;
import com.sun.source.tree.CompilationUnitTree;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.MethodTree;
import com.sun.source.tree.ReturnTree;
import com.sun.source.tree.StatementTree;
import com.sun.source.tree.Tree.Kind;
import com.sun.source.tree.VariableTree;
import com.sun.source.util.TreeScanner;
import com.sun.tools.javac.tree.JCTree.JCCompilationUnit;
import com.sun.tools.javac.tree.JCTree.JCIdent;
import com.sun.tools.javac.tree.JCTree.JCMethodInvocation;
import xapi.javac.dev.api.CompilerService;
import xapi.javac.dev.model.HasClassLiteralReference;
import xapi.log.X_Log;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class ClassLiteralResolver extends TreeScanner<Void, Void>{
private final List<HasClassLiteralReference> variables;
private final List<HasClassLiteralReference> invocations;
private final CompilerService classWorld;
private String currentUnit;
int depth = 0;
public ClassLiteralResolver(List<? extends HasClassLiteralReference> items, CompilerService classWorld) {
this.classWorld = classWorld;
variables = new ArrayList<>();
invocations = new ArrayList<>();
items.forEach(this::add);
}
private void add(HasClassLiteralReference r) {
if (!r.isResolved()) {
Kind kind = r.getNodeKind();
switch(kind) {
case IDENTIFIER:
variables.add(r);
break;
case MEMBER_SELECT:
break;
case METHOD_INVOCATION:
invocations.add(r);
break;
default:
X_Log.warn(getClass(), "Unhandled type found while trying to resolve class literal: "+kind);
}
}
}
@Override
public Void visitCompilationUnit(CompilationUnitTree node, Void p) {
currentUnit = NameUtil.getName((JCCompilationUnit)node);
super.visitCompilationUnit(node, p);
final Set<String> pending = new HashSet<>();
boolean[] done = new boolean[]{true};
variables.removeIf(HasClassLiteralReference::isResolved);
invocations.removeIf(HasClassLiteralReference::isResolved);
List<HasClassLiteralReference> all = new ArrayList<>();
all.addAll(variables);
all.addAll(invocations);
variables.clear();
invocations.clear();
all.forEach(r -> {
JCIdent ident;
switch (r.getNodeKind()) {
case MEMBER_SELECT:
return;
case IDENTIFIER:
ident = (JCIdent) r.getSource();
break;
case METHOD_INVOCATION:
JCMethodInvocation method = (JCMethodInvocation)r.getSource();
ident = (JCIdent) method.getMethodSelect();
break;
default:
X_Log.warn(getClass(), "Unhandled node type ",r.getNodeKind(),"in",r);
return;
}
add(r);
String owner = ident.sym.owner.name.toString();
if (owner.equals(currentUnit)) {
done[0] = false;
} else {
if (pending.add(owner)) {
classWorld.onCompilationUnitFinished(owner, jcu -> scan(jcu, null));
}
}
});
if (!done[0]) {
if (depth ++ > 50) {
throw new IllegalStateException("Max recursion depth of 50 has been exceeded");
}
visitCompilationUnit(node, p);
}
return null;
}
@Override
public Void visitMethod(MethodTree node, Void p) {
if (!invocations.isEmpty()) {
invocations
.stream()
.filter(r -> {
if (r.getNodeKind() != Kind.METHOD_INVOCATION) { return false; }
JCMethodInvocation method = (JCMethodInvocation)r.getSource();
JCIdent select = (JCIdent) method.getMethodSelect();
String owner = select.sym.owner.name.toString();
return owner.equals(currentUnit) && NameUtil.equals(r.getNodeName(), node.getName());
}
)// TODO: also check parameter types
.forEach(r -> {
List<? extends StatementTree> body = new ArrayList<>(node.getBody().getStatements());
body.removeIf(t -> !(t instanceof ReturnTree));
assert body.size() > 0;
if (body.size() == 1) {
// Single return statement, much easier to deal with
ReturnTree tree = (ReturnTree) body.get(0);
r.setSource(tree.getExpression());
} else {
// Multiple return statements... things are about to get ugly
throw new IllegalArgumentException("Method "+node+" is used to supply class literals, "
+ "but this method does not have exactly one return statement: "+body);
}
});
}
return super.visitMethod(node, p);
}
@Override
public Void visitVariable(VariableTree node, Void p) {
if (!variables.isEmpty()) {
variables
.stream()
.filter(r -> {
if (r.getNodeKind() != Kind.IDENTIFIER) { return false; }
JCIdent select = (JCIdent) r.getSource();
String owner = select.sym.owner.name.toString();
return owner.equals(currentUnit) && NameUtil.equals(r.getNodeName(),node.getName());
}
)
.forEach(r -> {
ExpressionTree init = node.getInitializer();
r.setSource(init);
});
}
return super.visitVariable(node, p);
}
}