[clangd] Extend find-refs to include overrides.
authorHaojian Wu <hokein.wu@gmail.com>
Fri, 8 Jan 2021 10:56:30 +0000 (11:56 +0100)
committerHaojian Wu <hokein.wu@gmail.com>
Wed, 20 Jan 2021 12:23:20 +0000 (13:23 +0100)
find-references on `virtual void meth^od() = 0` will include override references.

Differential Revision: https://reviews.llvm.org/D94390

clang-tools-extra/clangd/XRefs.cpp
clang-tools-extra/clangd/unittests/XRefsTests.cpp

index 8027e05..d4dc621 100644 (file)
@@ -879,11 +879,8 @@ public:
   };
 
   ReferenceFinder(const ParsedAST &AST,
-                  const std::vector<const NamedDecl *> &TargetDecls)
-      : AST(AST) {
-    for (const NamedDecl *D : TargetDecls)
-      CanonicalTargets.insert(D->getCanonicalDecl());
-  }
+                  const llvm::DenseSet<SymbolID> &TargetIDs)
+      : AST(AST), TargetIDs(TargetIDs) {}
 
   std::vector<Reference> take() && {
     llvm::sort(References, [](const Reference &L, const Reference &R) {
@@ -908,9 +905,9 @@ public:
                        llvm::ArrayRef<index::SymbolRelation> Relations,
                        SourceLocation Loc,
                        index::IndexDataConsumer::ASTNodeInfo ASTNode) override {
-    assert(D->isCanonicalDecl() && "expect D to be a canonical declaration");
     const SourceManager &SM = AST.getSourceManager();
-    if (!CanonicalTargets.count(D) || !isInsideMainFile(Loc, SM))
+    if (!isInsideMainFile(Loc, SM) ||
+        TargetIDs.find(getSymbolID(D)) == TargetIDs.end())
       return true;
     const auto &TB = AST.getTokens();
     Loc = SM.getFileLoc(Loc);
@@ -920,14 +917,14 @@ public:
   }
 
 private:
-  llvm::SmallSet<const Decl *, 4> CanonicalTargets;
   std::vector<Reference> References;
   const ParsedAST &AST;
+  const llvm::DenseSet<SymbolID> &TargetIDs;
 };
 
 std::vector<ReferenceFinder::Reference>
-findRefs(const std::vector<const NamedDecl *> &Decls, ParsedAST &AST) {
-  ReferenceFinder RefFinder(AST, Decls);
+findRefs(const llvm::DenseSet<SymbolID> &IDs, ParsedAST &AST) {
+  ReferenceFinder RefFinder(AST, IDs);
   index::IndexingOptions IndexOpts;
   IndexOpts.SystemSymbolFilter =
       index::IndexingOptions::SystemSymbolFilterKind::All;
@@ -1217,7 +1214,11 @@ std::vector<DocumentHighlight> findDocumentHighlights(ParsedAST &AST,
       if (!Decls.empty()) {
         // FIXME: we may get multiple DocumentHighlights with the same location
         // and different kinds, deduplicate them.
-        for (const auto &Ref : findRefs({Decls.begin(), Decls.end()}, AST))
+        llvm::DenseSet<SymbolID> Targets;
+        for (const NamedDecl *ND : Decls)
+          if (auto ID = getSymbolID(ND))
+            Targets.insert(ID);
+        for (const auto &Ref : findRefs(Targets, AST))
           Result.push_back(toHighlight(Ref, SM));
         return true;
       }
@@ -1295,13 +1296,14 @@ ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit,
     llvm::consumeError(CurLoc.takeError());
     return {};
   }
-  llvm::Optional<DefinedMacro> Macro;
-  if (const auto *IdentifierAtCursor =
-          syntax::spelledIdentifierTouching(*CurLoc, AST.getTokens())) {
-    Macro = locateMacroAt(*IdentifierAtCursor, AST.getPreprocessor());
-  }
 
   RefsRequest Req;
+
+  const auto *IdentifierAtCursor =
+      syntax::spelledIdentifierTouching(*CurLoc, AST.getTokens());
+  llvm::Optional<DefinedMacro> Macro;
+  if (IdentifierAtCursor)
+    Macro = locateMacroAt(*IdentifierAtCursor, AST.getPreprocessor());
   if (Macro) {
     // Handle references to macro.
     if (auto MacroSID = getSymbolID(Macro->Name, Macro->Info, SM)) {
@@ -1325,9 +1327,35 @@ ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit,
         DeclRelation::TemplatePattern | DeclRelation::Alias;
     std::vector<const NamedDecl *> Decls =
         getDeclAtPosition(AST, *CurLoc, Relations);
+    llvm::DenseSet<SymbolID> Targets;
+    for (const NamedDecl *D : Decls)
+      if (auto ID = getSymbolID(D))
+        Targets.insert(ID);
+
+    llvm::DenseSet<SymbolID> Overrides;
+    if (Index) {
+      RelationsRequest FindOverrides;
+      FindOverrides.Predicate = RelationKind::OverriddenBy;
+      for (const NamedDecl *ND : Decls) {
+        // Special case: virtual void meth^od() = 0 includes refs of overrides.
+        if (const auto *CMD = llvm::dyn_cast<CXXMethodDecl>(ND)) {
+          if (CMD->isPure())
+            if (IdentifierAtCursor && SM.getSpellingLoc(CMD->getLocation()) ==
+                                          IdentifierAtCursor->location())
+              if (auto ID = getSymbolID(CMD))
+                FindOverrides.Subjects.insert(ID);
+        }
+      }
+      if (!FindOverrides.Subjects.empty())
+        Index->relations(FindOverrides,
+                         [&](const SymbolID &Subject, const Symbol &Object) {
+                           Overrides.insert(Object.ID);
+                         });
+      Targets.insert(Overrides.begin(), Overrides.end());
+    }
 
     // We traverse the AST to find references in the main file.
-    auto MainFileRefs = findRefs(Decls, AST);
+    auto MainFileRefs = findRefs(Targets, AST);
     // We may get multiple refs with the same location and different Roles, as
     // cross-reference is only interested in locations, we deduplicate them
     // by the location to avoid emitting duplicated locations.
@@ -1345,6 +1373,7 @@ ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit,
       Results.References.push_back(std::move(Result));
     }
     if (Index && Results.References.size() <= Limit) {
+      Req.IDs = std::move(Overrides);
       for (const Decl *D : Decls) {
         // Not all symbols can be referenced from outside (e.g.
         // function-locals).
index 178e630..449dbe3 100644 (file)
@@ -1845,6 +1845,31 @@ TEST(FindReferences, WithinAST) {
   }
 }
 
+TEST(FindReferences, IncludeOverrides) {
+  llvm::StringRef Test =
+      R"cpp(
+        class Base {
+        public:
+          virtual void [[f^unc]]() = 0;
+        };
+        class Derived : public Base {
+        public:
+          void [[func]]() override;
+        };
+        void test(Derived* D) {
+          D->[[func]]();
+        })cpp";
+  Annotations T(Test);
+  auto TU = TestTU::withCode(T.code());
+  auto AST = TU.build();
+  std::vector<Matcher<Location>> ExpectedLocations;
+  for (const auto &R : T.ranges())
+    ExpectedLocations.push_back(RangeIs(R));
+  EXPECT_THAT(findReferences(AST, T.point(), 0, TU.index().get()).References,
+              ElementsAreArray(ExpectedLocations))
+      << Test;
+}
+
 TEST(FindReferences, MainFileReferencesOnly) {
   llvm::StringRef Test =
       R"cpp(