[ASTMatchers] Fix child traversal over range-for loops
authorStephen Kelly <steveire@gmail.com>
Sat, 26 Dec 2020 21:07:14 +0000 (21:07 +0000)
committerStephen Kelly <steveire@gmail.com>
Tue, 5 Jan 2021 21:29:37 +0000 (21:29 +0000)
Differential Revision: https://reviews.llvm.org/D94031

clang/lib/ASTMatchers/ASTMatchFinder.cpp
clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp

index 99d95838af61555dcaa268414cbd36aef8bf6bd1..39bdb94e62c67d232f9f06b753b9d2e65026eeca 100644 (file)
@@ -236,6 +236,20 @@ public:
     ScopedIncrement ScopedDepth(&CurrentDepth);
     return traverse(TAL);
   }
+  bool TraverseCXXForRangeStmt(CXXForRangeStmt *Node) {
+    if (!Finder->isTraversalIgnoringImplicitNodes())
+      return VisitorBase::TraverseCXXForRangeStmt(Node);
+    if (!Node)
+      return true;
+    ScopedIncrement ScopedDepth(&CurrentDepth);
+    if (auto *Init = Node->getInit())
+      if (!match(*Init))
+        return false;
+    if (!match(*Node->getLoopVariable()) || !match(*Node->getRangeInit()) ||
+        !match(*Node->getBody()))
+      return false;
+    return VisitorBase::TraverseStmt(Node->getBody());
+  }
   bool TraverseLambdaExpr(LambdaExpr *Node) {
     if (!Finder->isTraversalIgnoringImplicitNodes())
       return VisitorBase::TraverseLambdaExpr(Node);
@@ -575,8 +589,6 @@ public:
 
     if (isTraversalIgnoringImplicitNodes()) {
       IgnoreImplicitChildren = true;
-      if (Node.get<CXXForRangeStmt>())
-        ScopedTraversal = true;
     }
 
     ASTNodeNotSpelledInSourceScope RAII(this, ScopedTraversal);
index e706ea4b2a54136aada5d1b8c662429e4e4e5e3d..19ab6187d96003b52e3ddc63ad37327237b72d81 100644 (file)
@@ -2553,7 +2553,9 @@ struct CtorInitsNonTrivial : NonTrivial
     int arr[2];
     for (auto i : arr)
     {
-
+      if (true)
+      {
+      }
     }
   }
   )cpp";
@@ -2596,6 +2598,33 @@ struct CtorInitsNonTrivial : NonTrivial
     EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
     EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
   }
+  {
+    auto M = cxxForRangeStmt(hasDescendant(ifStmt()));
+    EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+  {
+    EXPECT_TRUE(matches(
+        Code, traverse(TK_AsIs, cxxForRangeStmt(has(declStmt(
+                                    hasSingleDecl(varDecl(hasName("i")))))))));
+    EXPECT_TRUE(
+        matches(Code, traverse(TK_IgnoreUnlessSpelledInSource,
+                               cxxForRangeStmt(has(varDecl(hasName("i")))))));
+  }
+  {
+    EXPECT_TRUE(matches(
+        Code, traverse(TK_AsIs, cxxForRangeStmt(has(declStmt(hasSingleDecl(
+                                    varDecl(hasInitializer(declRefExpr(
+                                        to(varDecl(hasName("arr")))))))))))));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource,
+                                       cxxForRangeStmt(has(declRefExpr(
+                                           to(varDecl(hasName("arr")))))))));
+  }
+  {
+    auto M = cxxForRangeStmt(has(compoundStmt()));
+    EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
   {
     auto M = binaryOperator(hasOperatorName("!="));
     EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
@@ -2659,7 +2688,8 @@ struct CtorInitsNonTrivial : NonTrivial
                              true, {"-std=c++20"}));
   }
   {
-    auto M = cxxForRangeStmt(has(declStmt()));
+    auto M =
+        cxxForRangeStmt(has(declStmt(hasSingleDecl(varDecl(hasName("i"))))));
     EXPECT_TRUE(
         matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
     EXPECT_FALSE(
@@ -2679,6 +2709,19 @@ struct CtorInitsNonTrivial : NonTrivial
         matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
                              true, {"-std=c++20"}));
   }
+  {
+    auto M = cxxForRangeStmt(
+        has(declStmt(hasSingleDecl(varDecl(
+            hasName("a"),
+            hasInitializer(declRefExpr(to(varDecl(hasName("arr"))))))))),
+        hasLoopVariable(varDecl(hasName("i"))),
+        hasRangeInit(declRefExpr(to(varDecl(hasName("a"))))));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
+                             true, {"-std=c++20"}));
+  }
   Code = R"cpp(
 void hasDefaultArg(int i, int j = 0)
 {