[AST] Include the TranslationUnitDecl when traversing with TraversalScope
authorSam McCall <sam.mccall@gmail.com>
Thu, 10 Jun 2021 22:16:14 +0000 (00:16 +0200)
committerSam McCall <sam.mccall@gmail.com>
Fri, 11 Jun 2021 12:29:45 +0000 (14:29 +0200)
Given `int foo, bar;`, TraverseAST reveals this tree:
  TranslationUnitDecl
   - foo
   - bar

Before this patch, with the TraversalScope set to {foo}, TraverseAST yields:
  foo

After this patch it yields:
  TranslationUnitDecl
  - foo

Also, TraverseDecl(TranslationUnitDecl) now respects the traversal scope.

---

The main effect of this today is that clang-tidy checks that match the
translationUnitDecl(), either in order to traverse it or check
parentage, should work.

Differential Revision: https://reviews.llvm.org/D104071

clang-tools-extra/clangd/DumpAST.cpp
clang-tools-extra/clangd/refactor/tweaks/AddUsing.cpp
clang-tools-extra/clangd/unittests/DiagnosticsTests.cpp
clang-tools-extra/clangd/unittests/TestTU.cpp
clang/include/clang/AST/ASTContext.h
clang/include/clang/AST/RecursiveASTVisitor.h
clang/unittests/AST/ASTContextParentMapTest.cpp
clang/unittests/Tooling/RecursiveASTVisitorTests/TraversalScope.cpp

index 30b90d2..36d61ac 100644 (file)
@@ -337,12 +337,8 @@ public:
   // Generally, these are nodes with position information (TypeLoc, not Type).
 
   bool TraverseDecl(Decl *D) {
-    return !D || isInjectedClassName(D) || traverseNode("declaration", D, [&] {
-      if (isa<TranslationUnitDecl>(D))
-        Base::TraverseAST(const_cast<ASTContext &>(Ctx));
-      else
-        Base::TraverseDecl(D);
-    });
+    return !D || isInjectedClassName(D) ||
+           traverseNode("declaration", D, [&] { Base::TraverseDecl(D); });
   }
   bool TraverseTypeLoc(TypeLoc TL) {
     return !TL || traverseNode("type", TL, [&] { Base::TraverseTypeLoc(TL); });
index d6a57ef..a8c937a 100644 (file)
@@ -82,7 +82,8 @@ public:
     // There is no need to go deeper into nodes that do not enclose selection,
     // since "using" there will not affect selection, nor would it make a good
     // insertion point.
-    if (Node->getDeclContext()->Encloses(SelectionDeclContext)) {
+    if (!Node->getDeclContext() ||
+        Node->getDeclContext()->Encloses(SelectionDeclContext)) {
       return RecursiveASTVisitor<UsingFinder>::TraverseDecl(Node);
     }
     return true;
index 4aa9cb7..87f3c87 100644 (file)
@@ -248,13 +248,23 @@ TEST(DiagnosticsTest, ClangTidy) {
       return SQUARE($macroarg[[++]]y);
       return $doubled[[sizeof]](sizeof(int));
     }
+
+    // misc-no-recursion uses a custom traversal from the TUDecl
+    void foo();
+    void $bar[[bar]]() {
+      foo();
+    }
+    void $foo[[foo]]() {
+      bar();
+    }
   )cpp");
   auto TU = TestTU::withCode(Test.code());
   TU.HeaderFilename = "assert.h"; // Suppress "not found" error.
   TU.ClangTidyProvider = addTidyChecks("bugprone-sizeof-expression,"
                                        "bugprone-macro-repeated-side-effects,"
                                        "modernize-deprecated-headers,"
-                                       "modernize-use-trailing-return-type");
+                                       "modernize-use-trailing-return-type,"
+                                       "misc-no-recursion");
   EXPECT_THAT(
       *TU.build().getDiagnostics(),
       UnorderedElementsAre(
@@ -283,8 +293,12 @@ TEST(DiagnosticsTest, ClangTidy) {
               DiagSource(Diag::ClangTidy),
               DiagName("modernize-use-trailing-return-type"),
               // Verify that we don't have "[check-name]" suffix in the message.
-              WithFix(FixMessage(
-                  "use a trailing return type for this function")))));
+              WithFix(
+                  FixMessage("use a trailing return type for this function"))),
+          Diag(Test.range("foo"),
+               "function 'foo' is within a recursive call chain"),
+          Diag(Test.range("bar"),
+               "function 'bar' is within a recursive call chain")));
 }
 
 TEST(DiagnosticsTest, ClangTidyEOF) {
index 335993a..cb258a5 100644 (file)
@@ -195,6 +195,19 @@ const Symbol &findSymbol(const SymbolSlab &Slab, llvm::StringRef QName) {
   return *Result;
 }
 
+// RAII scoped class to disable TraversalScope for a ParsedAST.
+class TraverseHeadersToo {
+  ASTContext &Ctx;
+  std::vector<Decl *> ScopeToRestore;
+
+public:
+  TraverseHeadersToo(ParsedAST &AST)
+      : Ctx(AST.getASTContext()), ScopeToRestore(Ctx.getTraversalScope()) {
+    Ctx.setTraversalScope({Ctx.getTranslationUnitDecl()});
+  }
+  ~TraverseHeadersToo() { Ctx.setTraversalScope(std::move(ScopeToRestore)); }
+};
+
 const NamedDecl &findDecl(ParsedAST &AST, llvm::StringRef QName) {
   auto &Ctx = AST.getASTContext();
   auto LookupDecl = [&Ctx](const DeclContext &Scope,
@@ -217,6 +230,7 @@ const NamedDecl &findDecl(ParsedAST &AST, llvm::StringRef QName) {
 
 const NamedDecl &findDecl(ParsedAST &AST,
                           std::function<bool(const NamedDecl &)> Filter) {
+  TraverseHeadersToo Too(AST);
   struct Visitor : RecursiveASTVisitor<Visitor> {
     decltype(Filter) F;
     llvm::SmallVector<const NamedDecl *, 1> Decls;
index f103ec6..5032f31 100644 (file)
@@ -635,11 +635,22 @@ public:
   ParentMapContext &getParentMapContext();
 
   // A traversal scope limits the parts of the AST visible to certain analyses.
-  // RecursiveASTVisitor::TraverseAST will only visit reachable nodes, and
+  // RecursiveASTVisitor only visits specified children of TranslationUnitDecl.
   // getParents() will only observe reachable parent edges.
   //
-  // The scope is defined by a set of "top-level" declarations.
-  // Initially, it is the entire TU: {getTranslationUnitDecl()}.
+  // The scope is defined by a set of "top-level" declarations which will be
+  // visible under the TranslationUnitDecl.
+  // Initially, it is the entire TU, represented by {getTranslationUnitDecl()}.
+  //
+  // After setTraversalScope({foo, bar}), the exposed AST looks like:
+  // TranslationUnitDecl
+  //  - foo
+  //    - ...
+  //  - bar
+  //    - ...
+  // All other siblings of foo and bar are pruned from the tree.
+  // (However they are still accessible via TranslationUnitDecl->decls())
+  //
   // Changing the scope clears the parent cache, which is expensive to rebuild.
   std::vector<Decl *> getTraversalScope() const { return TraversalScope; }
   void setTraversalScope(const std::vector<Decl *> &);
index a29559e..9bfa5b9 100644 (file)
@@ -192,14 +192,12 @@ public:
   /// Return whether this visitor should traverse post-order.
   bool shouldTraversePostOrder() const { return false; }
 
-  /// Recursively visits an entire AST, starting from the top-level Decls
-  /// in the AST traversal scope (by default, the TranslationUnitDecl).
+  /// Recursively visits an entire AST, starting from the TranslationUnitDecl.
   /// \returns false if visitation was terminated early.
   bool TraverseAST(ASTContext &AST) {
-    for (Decl *D : AST.getTraversalScope())
-      if (!getDerived().TraverseDecl(D))
-        return false;
-    return true;
+    // Currently just an alias for TraverseDecl(TUDecl), but kept in case
+    // we change the implementation again.
+    return getDerived().TraverseDecl(AST.getTranslationUnitDecl());
   }
 
   /// Recursively visit a statement or expression, by
@@ -1495,12 +1493,24 @@ DEF_TRAVERSE_DECL(StaticAssertDecl, {
   TRY_TO(TraverseStmt(D->getMessage()));
 })
 
-DEF_TRAVERSE_DECL(
-    TranslationUnitDecl,
-    {// Code in an unnamed namespace shows up automatically in
-     // decls_begin()/decls_end().  Thus we don't need to recurse on
-     // D->getAnonymousNamespace().
-    })
+DEF_TRAVERSE_DECL(TranslationUnitDecl, {
+  // Code in an unnamed namespace shows up automatically in
+  // decls_begin()/decls_end().  Thus we don't need to recurse on
+  // D->getAnonymousNamespace().
+
+  // If the traversal scope is set, then consider them to be the children of
+  // the TUDecl, rather than traversing (and loading?) all top-level decls.
+  auto Scope = D->getASTContext().getTraversalScope();
+  bool HasLimitedScope =
+      Scope.size() != 1 || !isa<TranslationUnitDecl>(Scope.front());
+  if (HasLimitedScope) {
+    ShouldVisitChildren = false; // we'll do that here instead
+    for (auto *Child : Scope) {
+      if (!canIgnoreChildDeclWhileTraversingDeclContext(Child))
+        TRY_TO(TraverseDecl(Child));
+    }
+  }
+})
 
 DEF_TRAVERSE_DECL(PragmaCommentDecl, {})
 
index 855d970..4d11ef0 100644 (file)
@@ -81,27 +81,31 @@ TEST(GetParents, ReturnsMultipleParentsInTemplateInstantiations) {
 }
 
 TEST(GetParents, RespectsTraversalScope) {
-  auto AST =
-      tooling::buildASTFromCode("struct foo { int bar; };", "foo.cpp",
-                                std::make_shared<PCHContainerOperations>());
+  auto AST = tooling::buildASTFromCode(
+      "struct foo { int bar; }; struct baz{};", "foo.cpp",
+      std::make_shared<PCHContainerOperations>());
   auto &Ctx = AST->getASTContext();
   auto &TU = *Ctx.getTranslationUnitDecl();
   auto &Foo = *TU.lookup(&Ctx.Idents.get("foo")).front();
   auto &Bar = *cast<DeclContext>(Foo).lookup(&Ctx.Idents.get("bar")).front();
+  auto &Baz = *TU.lookup(&Ctx.Idents.get("baz")).front();
 
   // Initially, scope is the whole TU.
   EXPECT_THAT(Ctx.getParents(Bar), ElementsAre(DynTypedNode::create(Foo)));
   EXPECT_THAT(Ctx.getParents(Foo), ElementsAre(DynTypedNode::create(TU)));
+  EXPECT_THAT(Ctx.getParents(Baz), ElementsAre(DynTypedNode::create(TU)));
 
   // Restrict the scope, now some parents are gone.
   Ctx.setTraversalScope({&Foo});
   EXPECT_THAT(Ctx.getParents(Bar), ElementsAre(DynTypedNode::create(Foo)));
-  EXPECT_THAT(Ctx.getParents(Foo), ElementsAre());
+  EXPECT_THAT(Ctx.getParents(Foo), ElementsAre(DynTypedNode::create(TU)));
+  EXPECT_THAT(Ctx.getParents(Baz), ElementsAre());
 
   // Reset the scope, we get back the original results.
   Ctx.setTraversalScope({&TU});
   EXPECT_THAT(Ctx.getParents(Bar), ElementsAre(DynTypedNode::create(Foo)));
   EXPECT_THAT(Ctx.getParents(Foo), ElementsAre(DynTypedNode::create(TU)));
+  EXPECT_THAT(Ctx.getParents(Baz), ElementsAre(DynTypedNode::create(TU)));
 }
 
 TEST(GetParents, ImplicitLambdaNodes) {
index c05be7f..9e71f95 100644 (file)
@@ -16,6 +16,12 @@ class Visitor : public ExpectedLocationVisitor<Visitor, clang::TestVisitor> {
 public:
   Visitor(ASTContext *Context) { this->Context = Context; }
 
+  bool VisitTranslationUnitDecl(TranslationUnitDecl *D) {
+    auto &SM = D->getParentASTContext().getSourceManager();
+    Match("TU", SM.getLocForStartOfFile(SM.getMainFileID()));
+    return true;
+  }
+
   bool VisitNamedDecl(NamedDecl *D) {
     if (!D->isImplicit())
       Match(D->getName(), D->getLocation());
@@ -41,6 +47,7 @@ struct foo {
   Ctx.setTraversalScope({&Bar});
 
   Visitor V(&Ctx);
+  V.ExpectMatch("TU", 1, 1);
   V.DisallowMatch("foo", 2, 8);
   V.ExpectMatch("bar", 3, 10);
   V.ExpectMatch("baz", 4, 12);