package com.haskforce.utils;
import com.google.common.collect.Lists;
import com.haskforce.index.HaskellModuleIndex;
import com.haskforce.psi.*;
import com.intellij.lang.ASTNode;
import com.intellij.openapi.project.Project;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiElementResolveResult;
import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiNamedElement;
import com.intellij.psi.search.GlobalSearchScope;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.ArrayUtil;
import com.intellij.util.containers.ContainerUtil;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.*;
/**
* General util class. Provides methods for finding named nodes in the Psi tree.
*/
public class HaskellUtil {
/**
* Finds name definition across all Haskell files in the project. All
* definitions are found when name is null.
*/
@NotNull
public static List<FoundDefinition> findDefinitionNode(@NotNull Project project, @Nullable String name, @NotNull PsiNamedElement e) {
// Guess where the name could be defined by lookup up potential modules.
// TODO This removing duplicates, for example importing the same module twice. Fair enough
final List<HaskellPsiUtil.Import> potentialModules =
getPotentialDefinitionModuleNames(e, HaskellPsiUtil.parseImports(e.getContainingFile()));
final Set<String> potentialModuleNames = new HashSet<String>();
for (HaskellPsiUtil.Import i : potentialModules) {
potentialModuleNames.add(i.module);
}
List<FoundDefinition> results = ContainerUtil.newArrayList();
final String qPrefix = getQualifiedPrefix(e);
final PsiFile psiFile = e.getContainingFile().getOriginalFile();
if (psiFile instanceof HaskellFile) {
List<PsiNamedElement> result = ContainerUtil.newArrayList();
findDefinitionNode((HaskellFile)psiFile, name, e, result);
addFoundDefinition(result, null, results);
}
for (HaskellPsiUtil.Import potentialModule : potentialModules) {
List<PsiNamedElement> result = ContainerUtil.newArrayList();
List<HaskellFile> files = HaskellModuleIndex.getFilesByModuleName(project, potentialModule.module, GlobalSearchScope.allScope(project));
for (HaskellFile f : files) {
final boolean returnAllReferences = name == null;
final boolean inLocalModule = f != null && qPrefix == null && f.equals(psiFile);
final boolean inImportedModule = f != null && potentialModuleNames.contains(f.getModuleName());
if (returnAllReferences || inLocalModule || inImportedModule) {
findDefinitionNode(f, name, e, result);
findDefinitionNodeInExport(project, f, name, e, result);
}
}
addFoundDefinition(result, potentialModule, results);
}
return results;
}
/**
* Find definitions that have been re-exported.
*
* <code>
* module Foo (module Bar, foo) where
* import Bar
* import Baz (foo)
* </code>
*/
private static void findDefinitionNodeInExport(@NotNull Project project, HaskellFile f, @Nullable String name,
@Nullable PsiNamedElement e, List<PsiNamedElement> result) {
List<HaskellPsiUtil.Import> imports = HaskellPsiUtil.parseImports(f);
for (HaskellExport export : PsiTreeUtil.findChildrenOfType(f, HaskellExport.class)) {
boolean exportFn = export.getQvar() != null && export.getQvar().getQvarid() != null
&& export.getQvar().getQvarid().getVarid().getName().equals(name);
String moduleName = exportFn
? getModule(export.getQvar().getQvarid().getConidList())
: export.getModuletoken() != null && export.getQconid() != null ? export.getQconid().getText() : null;
if (!exportFn && moduleName == null) continue;
for (HaskellPsiUtil.Import imprt : imports) {
if (moduleName != null && !moduleName.equals(imprt.module) && !moduleName.equals(imprt.alias)) continue;
boolean hidden = imprt.getHidingNames() != null && ArrayUtil.contains(name, imprt.getHidingNames());
boolean notImported = imprt.getImportedNames() != null && !ArrayUtil.contains(name, imprt.getImportedNames());
if (hidden || notImported) continue;
for (HaskellFile f2 : HaskellModuleIndex.getFilesByModuleName(project, imprt.module, GlobalSearchScope.allScope(project))) {
findDefinitionNode(f2, name, e, result);
findDefinitionNodeInExport(project, f2, name, e, result);
}
}
}
}
/**
* Finds a name definition inside a Haskell file. All definitions are found when name
* is null.
*/
public static void findDefinitionNode(@Nullable HaskellFile file, @Nullable String name, @Nullable PsiNamedElement e, @NotNull List<PsiNamedElement> result) {
if (file == null) return;
// We only want to look for classes that match the element we are resolving (e.g. varid, conid, etc.)
final Class<? extends PsiNamedElement> elementClass;
if (e instanceof HaskellVarid) {
elementClass = HaskellVarid.class;
} else if (e instanceof HaskellConid) {
elementClass = HaskellConid.class;
} else {
elementClass = PsiNamedElement.class;
}
final boolean isType = PsiTreeUtil.getParentOfType(e, HaskellGendecl.class) != null;
Collection<PsiNamedElement> namedElements = PsiTreeUtil.findChildrenOfType(file, elementClass);
for (PsiNamedElement namedElement : namedElements) {
if ((name == null || name.equals(namedElement.getName())) && definitionNode(namedElement)) {
result.add(namedElement);
} else if (isType && name != null && name.equals(namedElement.getName()) && typeNode(name, namedElement)) {
result.add(namedElement);
}
}
}
private static boolean typeNode(@NotNull String name, @NotNull PsiNamedElement e) {
HaskellDatadecl datadecl = PsiTreeUtil.getParentOfType(e, HaskellDatadecl.class);
if (datadecl != null) {
return datadecl.getTypeeList().get(0).getAtypeList().get(0).getText().equals(name);
}
HaskellNewtypedecl newtypedecl = PsiTreeUtil.getParentOfType(e, HaskellNewtypedecl.class);
if (newtypedecl != null && newtypedecl.getTycon() != null) {
return name.equals(newtypedecl.getTycon().getConid().getName());
}
HaskellTypedecl typedecl = PsiTreeUtil.getParentOfType(e, HaskellTypedecl.class);
if (typedecl != null) {
return name.equals(typedecl.getTypeeList().get(0).getAtypeList().get(0).getText());
}
HaskellClassdecl classdecl = PsiTreeUtil.getParentOfType(e, HaskellClassdecl.class);
if (classdecl != null && classdecl.getCtype() != null) {
HaskellCtype ctype = classdecl.getCtype();
while (ctype.getCtype() != null) {
ctype = ctype.getCtype();
}
if (ctype.getTypee() == null) return false;
HaskellAtype haskellAtype = ctype.getTypee().getAtypeList().get(0);
return haskellAtype.getOqtycon() != null && haskellAtype.getOqtycon().getQtycon() != null &&
name.equals(haskellAtype.getOqtycon().getQtycon().getTycon().getConid().getName());
}
return false;
}
/**
* Finds a name definition inside a Haskell file. All definitions are found when name
* is null.
*/
@NotNull
public static List<PsiNamedElement> findDefinitionNodes(@Nullable HaskellFile haskellFile, @Nullable String name) {
List<PsiNamedElement> ret = ContainerUtil.newArrayList();
findDefinitionNode(haskellFile, name, null, ret);
return ret;
}
/**
* Finds name definitions that are within the scope of a file, including imports (to some degree).
*/
@NotNull
public static List<PsiNamedElement> findDefinitionNodes(@NotNull HaskellFile psiFile) {
return findDefinitionNodes(psiFile, null);
}
/**
* Tells whether a named node is a definition node based on its context.
*
* Precondition: Element is in a Haskell file.
*/
public static boolean definitionNode(@NotNull PsiNamedElement e) {
if (e instanceof HaskellVarid) return definitionNode((HaskellVarid)e);
if (e instanceof HaskellConid) return definitionNode((HaskellConid)e);
return false;
}
public static boolean definitionNode(@NotNull HaskellConid e) {
final HaskellConstr constr = PsiTreeUtil.getParentOfType(e, HaskellConstr.class);
final HaskellCon con;
if (constr != null) {
con = constr.getCon();
} else {
final HaskellNewconstr newconstr = PsiTreeUtil.getParentOfType(e, HaskellNewconstr.class);
con = newconstr == null ? null : newconstr.getCon();
}
final HaskellConid conid = con == null ? null : con.getConid();
return e.equals(conid);
}
public static boolean definitionNode(@NotNull HaskellVarid e) {
final PsiElement parent = e.getParent();
if (parent == null) return false;
// If we are in a variable declaration (which has a type signature), return true.
if (HaskellPsiUtil.isType(parent, HaskellTypes.VARS)) return true;
// Now we have to figure out if the current varid, e, is the first top-level declaration in the file.
// Check each top-level declaration. When we find the first one that matches our element's name we'll return
// true if the elements are equal, false otherwise.
final String name = e.getName();
final PsiFile file = e.getContainingFile();
if (!(file instanceof HaskellFile)) return false;
final HaskellBody body = ((HaskellFile)file).getBody();
if (body == null) return false;
for (PsiElement child : body.getChildren()) {
// If we hit a declaration with a type signature, this shouldn't match our element's name.
if (child instanceof HaskellGendecl) {
final HaskellVars vars = ((HaskellGendecl)child).getVars();
if (vars == null) continue;
// If it matches our elements name, return false.
for (HaskellVarid varid : vars.getVaridList()) {
if (name.equals(varid.getName())) return false;
}
} else if (child instanceof HaskellFunorpatdecl) {
final HaskellFunorpatdecl f = (HaskellFunorpatdecl)child;
final HaskellVarop varop = f.getVarop();
// Check if the function is defined as infix.
if (varop != null) {
final HaskellVarid varid = varop.getVarid();
if (varid != null && name.equals(varid.getName())) {
return e.equals(varid);
}
} else {
// If there is a pat in the declaration then there should only be one since the only case of having
// more than one is when using a varop, which was already accounted for above.
List<HaskellPat> pats = f.getPatList();
if (pats.size() == 1 && pats.get(0).getVaridList().contains(e)) return true;
// There can be multiple varids in a declaration, so we'll need to grab the first one.
List<HaskellVarid> varids = f.getVaridList();
if (varids.size() > 0) {
final HaskellVarid varid = varids.get(0);
if (name.equals(varid.getName())) {
return e.equals(varid);
}
}
}
}
}
return false;
}
/**
* Tells whether a node is a definition node based on its context.
*/
public static boolean definitionNode(@NotNull ASTNode node) {
final PsiElement element = node.getPsi();
return element instanceof PsiNamedElement && definitionNode((PsiNamedElement)element);
}
@Nullable
public static String getQualifiedPrefix(@NotNull PsiElement e) {
final PsiElement q = getParentOfType(e, HaskellQcon.class, HaskellQvar.class);
if (q == null) { return null; }
final String qText = q.getText();
final int lastDotPos = qText.lastIndexOf('.');
if (lastDotPos == -1) { return null; }
return qText.substring(0, lastDotPos);
}
/**
* Helper method to avoid the compiler warning.
* See https://youtrack.jetbrains.com/issue/IDEA-157225
*/
@SafeVarargs
@Nullable
public static <T extends PsiElement> T getParentOfType(@Nullable final PsiElement element,
@NotNull final Class<? extends T>... classes) {
return PsiTreeUtil.getParentOfType(element, classes);
}
@NotNull
public static List<HaskellPsiUtil.Import> getPotentialDefinitionModuleNames(@NotNull PsiElement e, @NotNull List<HaskellPsiUtil.Import> imports) {
final String qPrefix = getQualifiedPrefix(e);
if (qPrefix == null) { return imports; }
List<HaskellPsiUtil.Import> result = new ArrayList<HaskellPsiUtil.Import>(2);
for (HaskellPsiUtil.Import anImport : imports) {
if (qPrefix.equals(anImport.module) || qPrefix.equals(anImport.alias)) {
result.add(anImport);
}
}
return result;
}
public static @Nullable PsiElement lookForFunOrPatDeclWithCorrectName(
@NotNull PsiElement element,
@NotNull String matcher){
/**
* A FunOrPatDecl with as parent haskellbody is one of the 'leftmost' function declarations.
* Those should not be taken into account, the definition will already be found from the stub.
* It will cause problems if we also start taking those into account over here.
*/
if (element instanceof HaskellFunorpatdecl &&
! (element.getParent() instanceof HaskellBody)) {
PsiElement[] children = element.getChildren();
for (PsiElement child : children) {
if (child instanceof HaskellVarid) {
PsiElement psiElement = checkForMatchingVariable(child,matcher);
if (psiElement != null){
return psiElement;
}
}
if (child instanceof HaskellPat){
HaskellPat pat = (HaskellPat)child;
List<HaskellVarid> varIds = extractAllHaskellVarids(pat);
for (HaskellVarid varId : varIds) {
if (varId.getName().matches(matcher)){
return varId;
};
}
}
}
}
return null;
}
public static List<HaskellVarid> extractAllHaskellVarids(HaskellPat pat) {
List<HaskellVarid> varidList = pat.getVaridList();
List<HaskellPat> patList = pat.getPatList();
for (HaskellPat haskellPat : patList) {
varidList.addAll(haskellPat.getVaridList());
}
return varidList;
}
private static PsiElement checkForMatchingVariable(PsiElement child, String matcher) {
HaskellVarid haskellVarid = (HaskellVarid) child;
if (haskellVarid.getName().matches(matcher)) {
return child;
} else {
return null;
}
}
public static boolean isInsideBody(@NotNull PsiElement position) {
HaskellGendecl haskellGendecl = PsiTreeUtil.getParentOfType(position, HaskellGendecl.class);
return haskellGendecl != null;
}
public static @NotNull List<PsiElement> matchWhereClausesInScope(
@NotNull PsiNamedElement myElement,
String name) {
return checkWhereClausesInScopeForVariableDeclaration(myElement, name);
}
public static @NotNull List<PsiElement> getAllDefinitionsInWhereClausesInScope(
@NotNull PsiElement myElement) {
return checkWhereClausesInScopeForVariableDeclaration(myElement, ".+");
}
private static @NotNull List<PsiElement> checkWhereClausesInScopeForVariableDeclaration(
@NotNull PsiElement myElement,
String matcher) {
List<PsiElement> results = Lists.newArrayList();
PsiElement parent = myElement.getParent();
do {
if (parent instanceof HaskellRhs) {
HaskellRhs rhs = (HaskellRhs) parent;
PsiElement where = rhs.getWhere();
if (where == null) {
parent = parent.getParent();
continue;
} else {
PsiElement psiElement = checkWhereClause(where, matcher);
if (psiElement != null) {
results.add(psiElement);
}
}
}
parent = parent.getParent();
} while (! (parent instanceof HaskellBody) && ! (parent == null));
return results;
}
private static @Nullable PsiElement checkWhereClause(@NotNull PsiElement where, String matcher) {
PsiElement nextSibling = where.getNextSibling();
while(nextSibling != null){
if(nextSibling instanceof HaskellFunorpatdecl) {
PsiElement psiElement = HaskellUtil.lookForFunOrPatDeclWithCorrectName(nextSibling, matcher);
if (psiElement != null){
return psiElement;
}
}
nextSibling = nextSibling.getNextSibling();
}
return null;
}
public static @NotNull List<PsiElement> matchLocalDefinitionsInScope(PsiElement element, String name){
return checkLocalDefinitionsForVariableDeclarations(element,name);
}
public static @NotNull List<PsiElement> getAllDefinitionsInScope(PsiElement element){
return checkLocalDefinitionsForVariableDeclarations(element,".+");
}
private static @NotNull List<PsiElement> checkLocalDefinitionsForVariableDeclarations(PsiElement element, String matcher){
List<PsiElement> results = Lists.newArrayList();
PsiElement parent = element;
do {
/**
* This whole function needs to be re-evaluated, it's getting too much if,if,if. The logic
* is getting extremely unclear. There should be tests for all (identified) cases so the refactor
* should be feasible.
*/
if (parent instanceof HaskellNewtypedecl){
HaskellNewtypedecl haskellNewtypedecl = (HaskellNewtypedecl) parent;
List<HaskellTyvar> tyvarList = haskellNewtypedecl.getTyvarList();
for (HaskellTyvar haskellTyvar : tyvarList) {
HaskellVarid varId = haskellTyvar.getVarid();
if (varId.getName().matches(matcher)){
results.add(varId);
}
}
}
PsiElement prevSibling = parent.getPrevSibling();
while (prevSibling != null) {
PsiElement possibleMatch = HaskellUtil.lookForFunOrPatDeclWithCorrectName(prevSibling, matcher);
if (possibleMatch != null) {
results.add(possibleMatch);
}
if (prevSibling instanceof HaskellPat && parent instanceof HaskellExp) {
List<HaskellVarid> varIds = HaskellUtil.extractAllHaskellVarids((HaskellPat) prevSibling);
for (HaskellVarid varId : varIds) {
if (varId.getName().matches(matcher)) {
results.add(varId);
}
}
}
if (prevSibling instanceof HaskellVarid){
HaskellVarid varId = (HaskellVarid) prevSibling;
if (varId.getName().matches(matcher)){
results.add(varId);
}
}
prevSibling = prevSibling.getPrevSibling();
}
parent = parent.getParent();
} while(! (parent instanceof PsiFile));
return results;
}
public static List<PsiElement> matchGlobalNamesUnqualified(List<FoundDefinition> namedElements) {
List<PsiElement> results = Lists.newArrayList();
for (FoundDefinition possibleReferences : namedElements) {
if (possibleReferences.imprt == null || !possibleReferences.imprt.isQualified) {
//noinspection ObjectAllocationInLoop
results.add(possibleReferences.element);
}
}
return results;
}
public static List<PsiElementResolveResult> matchGlobalNamesQualified(
List<FoundDefinition> namedElements, String qualifiedCallName){
List<PsiElementResolveResult> results = Lists.newArrayList();
for (FoundDefinition possibleReference : namedElements) {
if(possibleReference.imprt != null && possibleReference.imprt.alias != null &&
possibleReference.imprt.alias.equals(qualifiedCallName)){
results.add(new PsiElementResolveResult(possibleReference.element));
}
}
return results;
}
public static @NotNull String getModuleName(@NotNull PsiElement element) {
HaskellFile containingFile = (HaskellFile)element.getContainingFile();
if (containingFile == null){
return "";
}
String moduleName = containingFile.getModuleName();
if(moduleName != null){
return moduleName;
} else {
return "";
}
}
private static void addFoundDefinition(List<PsiNamedElement> result, HaskellPsiUtil.Import imprt, List<FoundDefinition> results) {
for (PsiNamedElement element : result) {
results.add(new FoundDefinition(element, imprt));
}
}
/**
* Returns the textual representation of a qualified module.
*
* eg. From {@code A.B.C.d} return {@code A.B.C}
*/
@Nullable
private static String getModule(@NotNull List<HaskellConid> conids) {
if (conids.isEmpty()) return null;
StringBuilder b = new StringBuilder();
for (HaskellConid cid : conids) {
b.append(cid.getName());
b.append(".");
}
b.setLength(b.length() - 1);
return b.toString();
}
public static class FoundDefinition {
@NotNull
public PsiNamedElement element;
@Nullable
public HaskellPsiUtil.Import imprt;
public FoundDefinition(@NotNull PsiNamedElement element, @Nullable HaskellPsiUtil.Import imprt) {
this.element = element;
this.imprt = imprt;
}
}
}