[clangd] Implement cross reference request for #include lines.
authorViktoriia Bakalova <bakalova@google.com>
Fri, 31 Mar 2023 14:28:57 +0000 (14:28 +0000)
committerViktoriia Bakalova <bakalova@google.com>
Thu, 20 Apr 2023 07:11:48 +0000 (07:11 +0000)
Differential Revision: https://reviews.llvm.org/D147044

clang-tools-extra/clangd/Hover.cpp
clang-tools-extra/clangd/IncludeCleaner.cpp
clang-tools-extra/clangd/IncludeCleaner.h
clang-tools-extra/clangd/XRefs.cpp
clang-tools-extra/clangd/unittests/HoverTests.cpp
clang-tools-extra/clangd/unittests/IncludeCleanerTests.cpp
clang-tools-extra/clangd/unittests/XRefsTests.cpp

index 6a56553..3eb0900 100644 (file)
@@ -1172,20 +1172,12 @@ void maybeAddUsedSymbols(ParsedAST &AST, HoverInfo &HI, const Inclusion &Inc) {
             UsedSymbols.contains(Ref.Target))
           return;
 
-        for (const include_cleaner::Header &H : Providers) {
-          auto MatchingIncludes = ConvertedMainFileIncludes.match(H);
-          // No match for this provider in the main file.
-          if (MatchingIncludes.empty())
-            continue;
-
-          // Check if the hovered include matches this provider.
-          if (!HoveredInclude.match(H).empty())
-            UsedSymbols.insert(Ref.Target);
-
-          // Don't look for rest of the providers once we've found a match
-          // in the main file.
-          break;
-        }
+        auto Provider =
+            firstMatchedProvider(ConvertedMainFileIncludes, Providers);
+        if (!Provider || HoveredInclude.match(*Provider).empty())
+          return;
+
+        UsedSymbols.insert(Ref.Target);
       });
 
   for (const auto &UsedSymbolDecl : UsedSymbols)
index 168471a..d15dd70 100644 (file)
@@ -444,5 +444,15 @@ std::vector<Diag> issueIncludeCleanerDiagnostics(ParsedAST &AST,
   return Result;
 }
 
+std::optional<include_cleaner::Header>
+firstMatchedProvider(const include_cleaner::Includes &Includes,
+                     llvm::ArrayRef<include_cleaner::Header> Providers) {
+  for (const auto &H : Providers) {
+    if (!Includes.match(H).empty())
+      return H;
+  }
+  // No match for this provider in the includes list.
+  return std::nullopt;
+}
 } // namespace clangd
 } // namespace clang
index 035142c..675c05a 100644 (file)
@@ -81,6 +81,11 @@ std::string spellHeader(ParsedAST &AST, const FileEntry *MainFile,
 
 std::vector<include_cleaner::SymbolReference>
 collectMacroReferences(ParsedAST &AST);
+
+/// Find the first provider in the list that is matched by the includes.
+std::optional<include_cleaner::Header>
+firstMatchedProvider(const include_cleaner::Includes &Includes,
+                     llvm::ArrayRef<include_cleaner::Header> Providers);
 } // namespace clangd
 } // namespace clang
 
index 23dd72d..51a3ef8 100644 (file)
@@ -9,13 +9,17 @@
 #include "AST.h"
 #include "FindSymbols.h"
 #include "FindTarget.h"
+#include "Headers.h"
 #include "HeuristicResolver.h"
+#include "IncludeCleaner.h"
 #include "ParsedAST.h"
 #include "Protocol.h"
 #include "Quality.h"
 #include "Selection.h"
 #include "SourceCode.h"
 #include "URI.h"
+#include "clang-include-cleaner/Analysis.h"
+#include "clang-include-cleaner/Types.h"
 #include "index/Index.h"
 #include "index/Merge.h"
 #include "index/Relation.h"
@@ -48,6 +52,7 @@
 #include "clang/Index/IndexingAction.h"
 #include "clang/Index/IndexingOptions.h"
 #include "clang/Index/USRGeneration.h"
+#include "clang/Lex/Lexer.h"
 #include "clang/Tooling/Syntax/Tokens.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
@@ -61,6 +66,7 @@
 #include "llvm/Support/Path.h"
 #include "llvm/Support/raw_ostream.h"
 #include <optional>
+#include <string>
 #include <vector>
 
 namespace clang {
@@ -1310,6 +1316,63 @@ stringifyContainerForMainFileRef(const Decl *Container) {
     return printQualifiedName(*ND);
   return {};
 }
+
+std::optional<ReferencesResult>
+maybeFindIncludeReferences(ParsedAST &AST, Position Pos,
+                           URIForFile URIMainFile) {
+  const auto &Includes = AST.getIncludeStructure().MainFileIncludes;
+  auto IncludeOnLine = llvm::find_if(Includes, [&Pos](const Inclusion &Inc) {
+    return Inc.HashLine == Pos.line;
+  });
+  if (IncludeOnLine == Includes.end())
+    return std::nullopt;
+
+  const auto &Inc = *IncludeOnLine;
+  const SourceManager &SM = AST.getSourceManager();
+  ReferencesResult Results;
+  auto ConvertedMainFileIncludes = convertIncludes(SM, Includes);
+  auto ReferencedInclude = convertIncludes(SM, Inc);
+  include_cleaner::walkUsed(
+      AST.getLocalTopLevelDecls(), collectMacroReferences(AST),
+      AST.getPragmaIncludes(), SM,
+      [&](const include_cleaner::SymbolReference &Ref,
+          llvm::ArrayRef<include_cleaner::Header> Providers) {
+        if (Ref.RT != include_cleaner::RefType::Explicit)
+          return;
+
+        auto Provider =
+            firstMatchedProvider(ConvertedMainFileIncludes, Providers);
+        if (!Provider || ReferencedInclude.match(*Provider).empty())
+          return;
+
+        auto Loc = SM.getFileLoc(Ref.RefLocation);
+        // File locations can be outside of the main file if macro is
+        // expanded through an #include.
+        while (SM.getFileID(Loc) != SM.getMainFileID())
+          Loc = SM.getIncludeLoc(SM.getFileID(Loc));
+
+        ReferencesResult::Reference Result;
+        const auto *Token = AST.getTokens().spelledTokenAt(Loc);
+        Result.Loc.range = Range{sourceLocToPosition(SM, Token->location()),
+                                 sourceLocToPosition(SM, Token->endLocation())};
+        Result.Loc.uri = URIMainFile;
+        Results.References.push_back(std::move(Result));
+      });
+  if (Results.References.empty())
+    return std::nullopt;
+
+  // Add the #include line to the references list.
+  auto IncludeLen = std::string{"#include"}.length() + Inc.Written.length() + 1;
+  ReferencesResult::Reference Result;
+  Result.Loc.range = clangd::Range{Position{Inc.HashLine, 0},
+                                   Position{Inc.HashLine, (int)IncludeLen}};
+  Result.Loc.uri = URIMainFile;
+  Results.References.push_back(std::move(Result));
+
+  if (Results.References.empty())
+    return std::nullopt;
+  return Results;
+}
 } // namespace
 
 ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit,
@@ -1324,6 +1387,11 @@ ReferencesResult findReferences(ParsedAST &AST, Position Pos, uint32_t Limit,
     return {};
   }
 
+  const auto IncludeReferences =
+      maybeFindIncludeReferences(AST, Pos, URIMainFile);
+  if (IncludeReferences)
+    return *IncludeReferences;
+
   llvm::DenseSet<SymbolID> IDsToQuery, OverriddenMethods;
 
   const auto *IdentifierAtCursor =
index 36d7348..5ad9d69 100644 (file)
@@ -2999,36 +2999,7 @@ TEST(Hover, UsedSymbols) {
                   #in^clude <vector>
                   std::vector<int> vec;
                 )cpp",
-                [](HoverInfo &HI) { HI.UsedSymbolNames = {"vector"}; }},
-               {R"cpp(
-                  #in^clude "public.h"
-                  #include "private.h"
-                  int fooVar = foo();
-                )cpp",
-                [](HoverInfo &HI) { HI.UsedSymbolNames = {"foo"}; }},
-               {R"cpp(
-                  #include "bar.h"
-                  #include "for^ward.h"
-                  Bar *x;
-                )cpp",
-                [](HoverInfo &HI) {
-                  HI.UsedSymbolNames = {
-                      // No used symbols, since bar.h is a higher-ranked
-                      // provider for Bar.
-                  };
-                }},
-               {R"cpp(
-                  #include "b^ar.h"
-                  #define DEF(X) const Bar *X
-                  DEF(a);
-                )cpp",
-                [](HoverInfo &HI) { HI.UsedSymbolNames = {"Bar"}; }},
-               {R"cpp(
-                  #in^clude "bar.h"
-                  #define BAZ(X) const X x
-                  BAZ(Bar);
-                )cpp",
-                [](HoverInfo &HI) { HI.UsedSymbolNames = {"Bar"}; }}};
+                [](HoverInfo &HI) { HI.UsedSymbolNames = {"vector"}; }}};
   for (const auto &Case : Cases) {
     Annotations Code{Case.Code};
     SCOPED_TRACE(Code.code());
@@ -3042,18 +3013,12 @@ TEST(Hover, UsedSymbols) {
                                           int bar2();
                                           class Bar {};
                                         )cpp");
-    TU.AdditionalFiles["private.h"] = guard(R"cpp(
-                                              // IWYU pragma: private, include "public.h"
-                                              int foo(); 
-                                            )cpp");
-    TU.AdditionalFiles["public.h"] = guard("");
     TU.AdditionalFiles["system/vector"] = guard(R"cpp(
       namespace std {
         template<typename>
         class vector{};
       }
     )cpp");
-    TU.AdditionalFiles["forward.h"] = guard("class Bar;");
     TU.ExtraArgs.push_back("-isystem" + testPath("system"));
 
     auto AST = TU.build();
index d0b99b7..39901fd 100644 (file)
@@ -29,6 +29,7 @@
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 #include <cstddef>
+#include <optional>
 #include <string>
 #include <utility>
 #include <vector>
@@ -435,6 +436,48 @@ TEST(IncludeCleaner, NoCrash) {
       MainCode.range());
 }
 
+TEST(IncludeCleaner, FirstMatchedProvider) {
+  struct {
+    const char *Code;
+    const std::vector<include_cleaner::Header> Providers;
+    const std::optional<include_cleaner::Header> ExpectedProvider;
+  } Cases[] = {
+      {R"cpp(
+        #include "bar.h"
+        #include "foo.h"
+      )cpp",
+       {include_cleaner::Header{"bar.h"}, include_cleaner::Header{"foo.h"}},
+       include_cleaner::Header{"bar.h"}},
+      {R"cpp(
+        #include "bar.h"
+        #include "foo.h"
+      )cpp",
+       {include_cleaner::Header{"foo.h"}, include_cleaner::Header{"bar.h"}},
+       include_cleaner::Header{"foo.h"}},
+      {"#include \"bar.h\"",
+       {include_cleaner::Header{"bar.h"}},
+       include_cleaner::Header{"bar.h"}},
+      {"#include \"bar.h\"", {include_cleaner::Header{"foo.h"}}, std::nullopt},
+      {"#include \"bar.h\"", {}, std::nullopt}};
+  for (const auto &Case : Cases) {
+    Annotations Code{Case.Code};
+    SCOPED_TRACE(Code.code());
+
+    TestTU TU;
+    TU.Code = Code.code();
+    TU.AdditionalFiles["bar.h"] = "";
+    TU.AdditionalFiles["foo.h"] = "";
+
+    auto AST = TU.build();
+    std::optional<include_cleaner::Header> MatchedProvider =
+        firstMatchedProvider(
+            convertIncludes(AST.getSourceManager(),
+                            AST.getIncludeStructure().MainFileIncludes),
+            Case.Providers);
+    EXPECT_EQ(MatchedProvider, Case.ExpectedProvider);
+  }
+}
+
 } // namespace
 } // namespace clangd
 } // namespace clang
index 3bbcde2..1424e6a 100644 (file)
@@ -43,6 +43,10 @@ using ::testing::UnorderedElementsAre;
 using ::testing::UnorderedElementsAreArray;
 using ::testing::UnorderedPointwise;
 
+std::string guard(llvm::StringRef Code) {
+  return "#pragma once\n" + Code.str();
+}
+
 MATCHER_P2(FileRange, File, Range, "") {
   return Location{URIForFile::canonicalize(File, testRoot()), Range} == arg;
 }
@@ -2293,6 +2297,50 @@ TEST(FindReferences, ExplicitSymbols) {
     checkFindRefs(Test);
 }
 
+TEST(FindReferences, UsedSymbolsFromInclude) {
+  const char *Tests[] = {
+      R"cpp([[#include ^"bar.h"]]
+        #include <vector>
+        int fstBar = [[bar1]]();
+        int sndBar = [[bar2]]();
+        [[Bar]] bar;
+        int macroBar = [[BAR]];
+        std::vector<int> vec;
+      )cpp",
+
+      R"cpp([[#in^clude <vector>]]
+        std::[[vector]]<int> vec;
+      )cpp"};
+  for (const char *Test : Tests) {
+    Annotations T(Test);
+    auto TU = TestTU::withCode(T.code());
+    TU.ExtraArgs.push_back("-std=c++20");
+    TU.AdditionalFiles["bar.h"] = guard(R"cpp(
+      #define BAR 5
+      int bar1();
+      int bar2();
+      class Bar {};            
+    )cpp");
+    TU.AdditionalFiles["system/vector"] = guard(R"cpp(
+      namespace std {
+        template<typename>
+        class vector{};
+      }
+    )cpp");
+    TU.ExtraArgs.push_back("-isystem" + testPath("system"));
+
+    auto AST = TU.build();
+    std::vector<Matcher<ReferencesResult::Reference>> ExpectedLocations;
+    for (const auto &R : T.ranges())
+      ExpectedLocations.push_back(AllOf(rangeIs(R), attrsAre(0u)));
+    for (const auto &P : T.points()) 
+      EXPECT_THAT(findReferences(AST, P, 0).References,
+                  UnorderedElementsAreArray(ExpectedLocations))
+          << "Failed for Refs at " << P << "\n"
+          << Test;
+  }
+}
+
 TEST(FindReferences, NeedsIndexForSymbols) {
   const char *Header = "int foo();";
   Annotations Main("int main() { [[f^oo]](); }");