From: Stephen Kelly Date: Wed, 27 Jan 2021 22:03:23 +0000 (+0000) Subject: [ASTMatchers] Fix traversal below range-for elements X-Git-Tag: llvmorg-14-init~16515 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=79125085f16540579d27c7e4987f63eef9c4aa23;p=platform%2Fupstream%2Fllvm.git [ASTMatchers] Fix traversal below range-for elements Differential Revision: https://reviews.llvm.org/D95562 --- diff --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp index 5034203..89e83ee 100644 --- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp +++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp @@ -243,10 +243,14 @@ public: return true; ScopedIncrement ScopedDepth(&CurrentDepth); if (auto *Init = Node->getInit()) - if (!match(*Init)) + if (!traverse(*Init)) return false; - if (!match(*Node->getLoopVariable()) || !match(*Node->getRangeInit()) || - !match(*Node->getBody())) + if (!match(*Node->getLoopVariable())) + return false; + if (match(*Node->getRangeInit())) + if (!VisitorBase::TraverseStmt(Node->getRangeInit())) + return false; + if (!match(*Node->getBody())) return false; return VisitorBase::TraverseStmt(Node->getBody()); } @@ -488,15 +492,21 @@ public: bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) { if (auto *RF = dyn_cast(S)) { - for (auto *SubStmt : RF->children()) { - if (SubStmt == RF->getInit() || SubStmt == RF->getLoopVarStmt() || - SubStmt == RF->getRangeInit() || SubStmt == RF->getBody()) { - TraverseStmt(SubStmt, Queue); - } else { - ASTNodeNotSpelledInSourceScope RAII(this, true); - TraverseStmt(SubStmt, Queue); + { + ASTNodeNotAsIsSourceScope RAII(this, true); + TraverseStmt(RF->getInit()); + // Don't traverse under the loop variable + match(*RF->getLoopVariable()); + TraverseStmt(RF->getRangeInit()); + } + { + ASTNodeNotSpelledInSourceScope RAII(this, true); + for (auto *SubStmt : RF->children()) { + if (SubStmt != RF->getBody()) + TraverseStmt(SubStmt); } } + TraverseStmt(RF->getBody()); return true; } else if (auto *RBO = dyn_cast(S)) { { diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp index a3a09c4..cbea274 100644 --- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -2821,6 +2821,36 @@ struct CtorInitsNonTrivial : NonTrivial } Code = R"cpp( + struct Range { + int* begin() const; + int* end() const; + }; + Range getRange(int); + + void rangeFor() + { + for (auto i : getRange(42)) + { + } + } + )cpp"; + { + auto M = integerLiteral(equals(42)); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = callExpr(hasDescendant(integerLiteral(equals(42)))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = compoundStmt(hasDescendant(integerLiteral(equals(42)))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + + Code = R"cpp( void rangeFor() { int arr[2]; @@ -2891,6 +2921,40 @@ struct CtorInitsNonTrivial : NonTrivial matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), true, {"-std=c++20"})); } + + Code = R"cpp( + struct Range { + int* begin() const; + int* end() const; + }; + Range getRange(int); + + int getNum(int); + + void rangeFor() + { + for (auto j = getNum(42); auto i : getRange(j)) + { + } + } + )cpp"; + { + auto M = integerLiteral(equals(42)); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"})); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), + true, {"-std=c++20"})); + } + { + auto M = compoundStmt(hasDescendant(integerLiteral(equals(42)))); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"})); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), + true, {"-std=c++20"})); + } + Code = R"cpp( void hasDefaultArg(int i, int j = 0) {