[Syntax] Add iterators over children of syntax trees.
authorSam McCall <sam.mccall@gmail.com>
Fri, 23 Oct 2020 10:23:29 +0000 (12:23 +0200)
committerSam McCall <sam.mccall@gmail.com>
Wed, 28 Oct 2020 11:37:57 +0000 (12:37 +0100)
This gives us slightly nicer syntax (foreach) for idioms currently expressed
as a loop, and the option to use range algorithms where it makes sense
(e.g. llvm::all_of et al encapsulate the needed flow control in a useful way).

It's also a building block for iteration over filtered views (e.g. iterate over
all Stmt children, with the right type):
for (const Statement &S : filter<Statement>(N.children()))
  ...

I realize the recent direction has been mostly towards strongly-typed
node-specific facilities, but I think it's important we have convenient
generic facilities too.

Differential Revision: https://reviews.llvm.org/D90023

clang/include/clang/Tooling/Syntax/Tree.h
clang/lib/Tooling/Syntax/Tree.cpp
clang/unittests/Tooling/Syntax/TreeTest.cpp
clang/unittests/Tooling/Syntax/TreeTestBase.h

index e1fd3a2..23e6081 100644 (file)
 #include "clang/Tooling/Syntax/Tokens.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/iterator.h"
 #include "llvm/Support/Allocator.h"
 #include <cstdint>
+#include <iterator>
 
 namespace clang {
 namespace syntax {
@@ -162,6 +164,34 @@ private:
 
 /// A node that has children and represents a syntactic language construct.
 class Tree : public Node {
+  /// Iterator over children (common base for const/non-const).
+  /// Not invalidated by tree mutations (holds a stable node pointer).
+  template <typename DerivedT, typename NodeT>
+  class ChildIteratorBase
+      : public llvm::iterator_facade_base<DerivedT, std::forward_iterator_tag,
+                                          NodeT> {
+  protected:
+    NodeT *N = nullptr;
+    using Base = ChildIteratorBase;
+
+  public:
+    ChildIteratorBase() = default;
+    explicit ChildIteratorBase(NodeT *N) : N(N) {}
+
+    bool operator==(const DerivedT &O) const { return O.N == N; }
+    NodeT &operator*() const { return *N; }
+    DerivedT &operator++() {
+      N = N->getNextSibling();
+      return *static_cast<DerivedT *>(this);
+    }
+
+    /// Truthy if valid (not past-the-end).
+    /// This allows: if (auto It = find_if(N.children(), ...) )
+    explicit operator bool() const { return N != nullptr; }
+    /// The element, or nullptr if past-the-end.
+    NodeT *asPointer() const { return N; }
+  };
+
 public:
   static bool classof(const Node *N);
 
@@ -178,6 +208,23 @@ public:
     return const_cast<Leaf *>(const_cast<const Tree *>(this)->findLastLeaf());
   }
 
+  /// child_iterator is not invalidated by mutations.
+  struct ChildIterator : ChildIteratorBase<ChildIterator, Node> {
+    using Base::ChildIteratorBase;
+  };
+  struct ConstChildIterator
+      : ChildIteratorBase<ConstChildIterator, const Node> {
+    using Base::ChildIteratorBase;
+    ConstChildIterator(const ChildIterator &I) : Base(I.asPointer()) {}
+  };
+
+  llvm::iterator_range<ChildIterator> getChildren() {
+    return {ChildIterator(getFirstChild()), ChildIterator()};
+  }
+  llvm::iterator_range<ConstChildIterator> getChildren() const {
+    return {ConstChildIterator(getFirstChild()), ConstChildIterator()};
+  }
+
   /// Find the first node with a corresponding role.
   const Node *findChild(NodeRole R) const;
   Node *findChild(NodeRole R) {
@@ -209,6 +256,14 @@ private:
   Node *FirstChild = nullptr;
 };
 
+// Provide missing non_const == const overload.
+// iterator_facade_base requires == to be a member, but implicit conversions
+// don't work on the LHS of a member operator.
+inline bool operator==(const Tree::ConstChildIterator &A,
+                       const Tree::ConstChildIterator &B) {
+  return A.operator==(B);
+}
+
 /// A list of Elements separated or terminated by a fixed token.
 ///
 /// This type models the following grammar construct:
index 9904d14..f910365 100644 (file)
@@ -19,8 +19,8 @@ namespace {
 static void traverse(const syntax::Node *N,
                      llvm::function_ref<void(const syntax::Node *)> Visit) {
   if (auto *T = dyn_cast<syntax::Tree>(N)) {
-    for (const auto *C = T->getFirstChild(); C; C = C->getNextSibling())
-      traverse(C, Visit);
+    for (const syntax::Node &C : T->getChildren())
+      traverse(&C, Visit);
   }
   Visit(N);
 }
@@ -194,21 +194,21 @@ static void dumpNode(raw_ostream &OS, const syntax::Node *N,
   DumpExtraInfo(N);
   OS << "\n";
 
-  for (const auto *It = T->getFirstChild(); It; It = It->getNextSibling()) {
+  for (const syntax::Node &It : T->getChildren()) {
     for (bool Filled : IndentMask) {
       if (Filled)
         OS << "| ";
       else
         OS << "  ";
     }
-    if (!It->getNextSibling()) {
+    if (!It.getNextSibling()) {
       OS << "`-";
       IndentMask.push_back(false);
     } else {
       OS << "|-";
       IndentMask.push_back(true);
     }
-    dumpNode(OS, It, SM, IndentMask);
+    dumpNode(OS, &It, SM, IndentMask);
     IndentMask.pop_back();
   }
 }
@@ -243,22 +243,22 @@ void syntax::Node::assertInvariants() const {
   const auto *T = dyn_cast<Tree>(this);
   if (!T)
     return;
-  for (const auto *C = T->getFirstChild(); C; C = C->getNextSibling()) {
+  for (const Node &C : T->getChildren()) {
     if (T->isOriginal())
-      assert(C->isOriginal());
-    assert(!C->isDetached());
-    assert(C->getParent() == T);
+      assert(C.isOriginal());
+    assert(!C.isDetached());
+    assert(C.getParent() == T);
   }
 
   const auto *L = dyn_cast<List>(T);
   if (!L)
     return;
-  for (const auto *C = T->getFirstChild(); C; C = C->getNextSibling()) {
-    assert(C->getRole() == NodeRole::ListElement ||
-           C->getRole() == NodeRole::ListDelimiter);
-    if (C->getRole() == NodeRole::ListDelimiter) {
+  for (const Node &C : T->getChildren()) {
+    assert(C.getRole() == NodeRole::ListElement ||
+           C.getRole() == NodeRole::ListDelimiter);
+    if (C.getRole() == NodeRole::ListDelimiter) {
       assert(isa<Leaf>(C));
-      assert(cast<Leaf>(C)->getToken()->kind() == L->getDelimiterTokenKind());
+      assert(cast<Leaf>(C).getToken()->kind() == L->getDelimiterTokenKind());
     }
   }
 
@@ -272,10 +272,10 @@ void syntax::Node::assertInvariantsRecursive() const {
 }
 
 const syntax::Leaf *syntax::Tree::findFirstLeaf() const {
-  for (const auto *C = getFirstChild(); C; C = C->getNextSibling()) {
-    if (const auto *L = dyn_cast<syntax::Leaf>(C))
+  for (const Node &C : getChildren()) {
+    if (const auto *L = dyn_cast<syntax::Leaf>(&C))
       return L;
-    if (const auto *L = cast<syntax::Tree>(C)->findFirstLeaf())
+    if (const auto *L = cast<syntax::Tree>(C).findFirstLeaf())
       return L;
   }
   return nullptr;
@@ -283,19 +283,19 @@ const syntax::Leaf *syntax::Tree::findFirstLeaf() const {
 
 const syntax::Leaf *syntax::Tree::findLastLeaf() const {
   const syntax::Leaf *Last = nullptr;
-  for (const auto *C = getFirstChild(); C; C = C->getNextSibling()) {
-    if (const auto *L = dyn_cast<syntax::Leaf>(C))
+  for (const Node &C : getChildren()) {
+    if (const auto *L = dyn_cast<syntax::Leaf>(&C))
       Last = L;
-    else if (const auto *L = cast<syntax::Tree>(C)->findLastLeaf())
+    else if (const auto *L = cast<syntax::Tree>(C).findLastLeaf())
       Last = L;
   }
   return Last;
 }
 
 const syntax::Node *syntax::Tree::findChild(NodeRole R) const {
-  for (const auto *C = FirstChild; C; C = C->getNextSibling()) {
-    if (C->getRole() == R)
-      return C;
+  for (const Node &C : getChildren()) {
+    if (C.getRole() == R)
+      return &C;
   }
   return nullptr;
 }
@@ -318,17 +318,17 @@ syntax::List::getElementsAsNodesAndDelimiters() {
 
   std::vector<syntax::List::ElementAndDelimiter<Node>> Children;
   syntax::Node *ElementWithoutDelimiter = nullptr;
-  for (auto *C = getFirstChild(); C; C = C->getNextSibling()) {
-    switch (C->getRole()) {
+  for (Node &C : getChildren()) {
+    switch (C.getRole()) {
     case syntax::NodeRole::ListElement: {
       if (ElementWithoutDelimiter) {
         Children.push_back({ElementWithoutDelimiter, nullptr});
       }
-      ElementWithoutDelimiter = C;
+      ElementWithoutDelimiter = &C;
       break;
     }
     case syntax::NodeRole::ListDelimiter: {
-      Children.push_back({ElementWithoutDelimiter, cast<syntax::Leaf>(C)});
+      Children.push_back({ElementWithoutDelimiter, cast<syntax::Leaf>(&C)});
       ElementWithoutDelimiter = nullptr;
       break;
     }
@@ -363,13 +363,13 @@ std::vector<syntax::Node *> syntax::List::getElementsAsNodes() {
 
   std::vector<syntax::Node *> Children;
   syntax::Node *ElementWithoutDelimiter = nullptr;
-  for (auto *C = getFirstChild(); C; C = C->getNextSibling()) {
-    switch (C->getRole()) {
+  for (Node &C : getChildren()) {
+    switch (C.getRole()) {
     case syntax::NodeRole::ListElement: {
       if (ElementWithoutDelimiter) {
         Children.push_back(ElementWithoutDelimiter);
       }
-      ElementWithoutDelimiter = C;
+      ElementWithoutDelimiter = &C;
       break;
     }
     case syntax::NodeRole::ListDelimiter: {
index fba3164..ed839e2 100644 (file)
@@ -8,6 +8,7 @@
 
 #include "clang/Tooling/Syntax/Tree.h"
 #include "TreeTestBase.h"
+#include "clang/Basic/SourceManager.h"
 #include "clang/Tooling/Syntax/BuildTree.h"
 #include "clang/Tooling/Syntax/Nodes.h"
 #include "llvm/ADT/STLExtras.h"
@@ -17,6 +18,7 @@ using namespace clang;
 using namespace clang::syntax;
 
 namespace {
+using testing::ElementsAre;
 
 class TreeTest : public SyntaxTreeTest {
 private:
@@ -124,6 +126,56 @@ TEST_P(TreeTest, LastLeaf) {
   }
 }
 
+TEST_F(TreeTest, Iterators) {
+  buildTree("", allTestClangConfigs().front());
+  std::vector<Node *> Children = {createLeaf(*Arena, tok::identifier, "a"),
+                                  createLeaf(*Arena, tok::identifier, "b"),
+                                  createLeaf(*Arena, tok::identifier, "c")};
+  auto *Tree = syntax::createTree(*Arena,
+                                  {{Children[0], NodeRole::LeftHandSide},
+                                   {Children[1], NodeRole::OperatorToken},
+                                   {Children[2], NodeRole::RightHandSide}},
+                                  NodeKind::TranslationUnit);
+  const auto *ConstTree = Tree;
+
+  auto Range = Tree->getChildren();
+  EXPECT_THAT(Range, ElementsAre(role(NodeRole::LeftHandSide),
+                                 role(NodeRole::OperatorToken),
+                                 role(NodeRole::RightHandSide)));
+
+  auto ConstRange = ConstTree->getChildren();
+  EXPECT_THAT(ConstRange, ElementsAre(role(NodeRole::LeftHandSide),
+                                      role(NodeRole::OperatorToken),
+                                      role(NodeRole::RightHandSide)));
+
+  // FIXME: mutate and observe no invalidation. Mutations are private for now...
+  auto It = Range.begin();
+  auto CIt = ConstRange.begin();
+  static_assert(std::is_same<decltype(*It), syntax::Node &>::value,
+                "mutable range");
+  static_assert(std::is_same<decltype(*CIt), const syntax::Node &>::value,
+                "const range");
+
+  for (unsigned I = 0; I < 3; ++I) {
+    EXPECT_EQ(It, CIt);
+    EXPECT_TRUE(It);
+    EXPECT_TRUE(CIt);
+    EXPECT_EQ(It.asPointer(), Children[I]);
+    EXPECT_EQ(CIt.asPointer(), Children[I]);
+    EXPECT_EQ(&*It, Children[I]);
+    EXPECT_EQ(&*CIt, Children[I]);
+    ++It;
+    ++CIt;
+  }
+  EXPECT_EQ(It, CIt);
+  EXPECT_EQ(It, Tree::ChildIterator());
+  EXPECT_EQ(CIt, Tree::ConstChildIterator());
+  EXPECT_FALSE(It);
+  EXPECT_FALSE(CIt);
+  EXPECT_EQ(nullptr, It.asPointer());
+  EXPECT_EQ(nullptr, CIt.asPointer());
+}
+
 class ListTest : public SyntaxTreeTest {
 private:
   std::string dumpQuotedTokensOrNull(const Node *N) {
index 8b0ca97..a86b5d7 100644 (file)
@@ -20,7 +20,9 @@
 #include "clang/Tooling/Syntax/Tokens.h"
 #include "clang/Tooling/Syntax/Tree.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Support/ScopedPrinter.h"
 #include "llvm/Testing/Support/Annotations.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
 namespace clang {
@@ -53,6 +55,14 @@ protected:
 };
 
 std::vector<TestClangConfig> allTestClangConfigs();
+
+MATCHER_P(role, R, "") {
+  if (arg.getRole() == R)
+    return true;
+  *result_listener << "role is " << llvm::to_string(arg.getRole());
+  return false;
+}
+
 } // namespace syntax
 } // namespace clang
 #endif // LLVM_CLANG_UNITTESTS_TOOLING_SYNTAX_TREETESTBASE_H