Add depth first iterator for trees
authorVictor Lomuller <victor@codeplay.com>
Mon, 4 Dec 2017 14:36:05 +0000 (14:36 +0000)
committerDiego Novillo <dnovillo@google.com>
Thu, 7 Dec 2017 15:07:56 +0000 (10:07 -0500)
 - Add generic depth first iterator
 - Update the dominator tree to use this iterator instead of "randomly"
   iterate over the nodes

source/opt/CMakeLists.txt
source/opt/dominator_tree.cpp
source/opt/dominator_tree.h
source/opt/tree_iterator.h [new file with mode: 0644]
test/opt/dominator_tree/generated.cpp

index 0759029..278adae 100644 (file)
@@ -60,6 +60,7 @@ add_library(SPIRV-Tools-opt
   set_spec_constant_default_value_pass.h
   strength_reduction_pass.h
   strip_debug_info_pass.h
+  tree_iterator.h
   type_manager.h
   types.h
   unify_const_pass.h
index 5417dd1..9dfc559 100644 (file)
@@ -368,7 +368,6 @@ void DominatorTree::InitializeTree(const ir::Function* f, const ir::CFG& cfg) {
 void DominatorTree::DumpTreeAsDot(std::ostream& out_stream) const {
   out_stream << "digraph {\n";
   Visit([&out_stream](const DominatorTreeNode* node) {
-
     // Print the node.
     if (node->bb_) {
       out_stream << node->bb_->id() << "[label=\"" << node->bb_->id()
@@ -388,32 +387,5 @@ void DominatorTree::DumpTreeAsDot(std::ostream& out_stream) const {
   out_stream << "}\n";
 }
 
-bool DominatorTree::Visit(DominatorTreeNode* node,
-                          std::function<bool(DominatorTreeNode*)> func) {
-  // Apply the function to the node.
-  if (!func(node)) return false;
-
-  // Apply the function to every child node.
-  for (DominatorTreeNode* child : node->children_) {
-    if (!Visit(child, func)) return false;
-  }
-
-  return true;
-}
-
-bool DominatorTree::Visit(
-    const DominatorTreeNode* node,
-    std::function<bool(const DominatorTreeNode*)> func) const {
-  // Apply the function to the node.
-  if (!func(node)) return false;
-
-  // Apply the function to every child node.
-  for (const DominatorTreeNode* child : node->children_) {
-    if (!Visit(child, func)) return false;
-  }
-
-  return true;
-}
-
 }  // namespace opt
 }  // namespace spvtools
index 2670384..d3cdcdf 100644 (file)
@@ -22,6 +22,7 @@
 
 #include "cfg.h"
 #include "module.h"
+#include "tree_iterator.h"
 
 namespace spvtools {
 namespace opt {
@@ -36,6 +37,16 @@ struct DominatorTreeNode {
         dfs_num_pre_(-1),
         dfs_num_post_(-1) {}
 
+  using iterator = std::vector<DominatorTreeNode*>::iterator;
+  using const_iterator = std::vector<DominatorTreeNode*>::const_iterator;
+
+  iterator begin() { return children_.begin(); }
+  iterator end() { return children_.end(); }
+  const_iterator begin() const { return cbegin(); }
+  const_iterator end() const { return cend(); }
+  const_iterator cbegin() const { return children_.begin(); }
+  const_iterator cend() const { return children_.end(); }
+
   inline uint32_t id() const { return bb_->id(); }
 
   ir::BasicBlock* bb_;
@@ -56,8 +67,8 @@ class DominatorTree {
  public:
   // Map OpLabel ids to dominator tree nodes
   using DominatorTreeNodeMap = std::map<uint32_t, DominatorTreeNode>;
-  using iterator = DominatorTreeNodeMap::iterator;
-  using const_iterator = DominatorTreeNodeMap::const_iterator;
+  using iterator = TreeDFIterator<DominatorTreeNode>;
+  using const_iterator = TreeDFIterator<const DominatorTreeNode>;
 
   // List of DominatorTreeNode to define the list of roots
   using DominatorTreeNodeList = std::vector<DominatorTreeNode*>;
@@ -67,12 +78,14 @@ class DominatorTree {
   DominatorTree() : postdominator_(false) {}
   explicit DominatorTree(bool post) : postdominator_(post) {}
 
-  iterator begin() { return nodes_.begin(); }
-  iterator end() { return nodes_.end(); }
+  // Depth first iterators.
+  // Traverse the dominator tree in a depth first pre-order.
+  iterator begin() { return iterator(GetRoot()); }
+  iterator end() { return iterator(); }
   const_iterator begin() const { return cbegin(); }
   const_iterator end() const { return cend(); }
-  const_iterator cbegin() const { return nodes_.begin(); }
-  const_iterator cend() const { return nodes_.end(); }
+  const_iterator cbegin() const { return const_iterator(GetRoot()); }
+  const_iterator cend() const { return const_iterator(); }
 
   roots_iterator roots_begin() { return roots_.begin(); }
   roots_iterator roots_end() { return roots_.end(); }
@@ -143,31 +156,23 @@ class DominatorTree {
   }
 
   // Applies the std::function |func| to all nodes in the dominator tree.
+  // Tree nodes are visited in a depth first pre-order.
   bool Visit(std::function<bool(DominatorTreeNode*)> func) {
-    for (auto n : roots_) {
-      if (!Visit(n, func)) return false;
+    for (auto n : *this) {
+      if (!func(&n)) return false;
     }
     return true;
   }
 
   // Applies the std::function |func| to all nodes in the dominator tree.
+  // Tree nodes are visited in a depth first pre-order.
   bool Visit(std::function<bool(const DominatorTreeNode*)> func) const {
-    for (auto n : roots_) {
-      if (!Visit(n, func)) return false;
+    for (auto n : *this) {
+      if (!func(&n)) return false;
     }
     return true;
   }
 
-  // Applies the std::function |func| to |node| then applies it to nodes
-  // children.
-  bool Visit(DominatorTreeNode* node,
-             std::function<bool(DominatorTreeNode*)> func);
-
-  // Applies the std::function |func| to |node| then applies it to nodes
-  // children.
-  bool Visit(const DominatorTreeNode* node,
-             std::function<bool(const DominatorTreeNode*)> func) const;
-
  private:
   // Adds the basic block |bb| to the tree structure if it doesn't already
   // exist.
diff --git a/source/opt/tree_iterator.h b/source/opt/tree_iterator.h
new file mode 100644 (file)
index 0000000..4a24a01
--- /dev/null
@@ -0,0 +1,121 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_TREE_ITERATOR_H_
+#define LIBSPIRV_OPT_TREE_ITERATOR_H_
+
+#include <stack>
+#include <type_traits>
+#include <utility>
+
+namespace spvtools {
+namespace opt {
+
+// Helper class to iterate over a tree in a depth first order.
+// The class assumes the data structure is a tree, tree node type implements a
+// forward iterator.
+// At each step, the iterator holds the pointer to the current node and state of
+// the walk.
+// The state is recorded by stacking the iteration position of the node
+// children. To move to the next node, the iterator:
+//  - Looks at the top of the stack;
+//  - Sets the node behind the iterator as the current node;
+//  - Increments the iterator if it has more children to visit, pops otherwise;
+//  - If the current node has children, the children iterator is pushed into the
+//    stack.
+template <typename NodeTy>
+class TreeDFIterator {
+  static_assert(!std::is_pointer<NodeTy>::value &&
+                    !std::is_reference<NodeTy>::value,
+                "NodeTy should be a class");
+  // Type alias to keep track of the const qualifier.
+  using NodeIterator =
+      typename std::conditional<std::is_const<NodeTy>::value,
+                                typename NodeTy::const_iterator,
+                                typename NodeTy::iterator>::type;
+
+  // Type alias to keep track of the const qualifier.
+  using NodePtr = NodeTy*;
+
+ public:
+  // Standard iterator interface.
+  using reference = NodeTy&;
+  using value_type = NodeTy;
+
+  explicit inline TreeDFIterator(NodePtr top_node) : current_(top_node) {
+    if (current_ && current_->begin() != current_->end())
+      parent_iterators_.emplace(make_pair(current_, current_->begin()));
+  }
+
+  // end() iterator.
+  inline TreeDFIterator() : TreeDFIterator(nullptr) {}
+
+  bool operator==(const TreeDFIterator& x) const {
+    return current_ == x.current_;
+  }
+
+  bool operator!=(const TreeDFIterator& x) const { return !(*this == x); }
+
+  reference operator*() const { return *current_; }
+
+  NodePtr operator->() const { return current_; }
+
+  TreeDFIterator& operator++() {
+    MoveToNextNode();
+    return *this;
+  }
+
+  TreeDFIterator operator++(int) {
+    TreeDFIterator tmp = *this;
+    ++*this;
+    return tmp;
+  }
+
+ private:
+  // Moves the iterator to the next node in the tree.
+  // If we are at the end, do nothing, otherwise
+  // if our current node has children, use the children iterator and push the
+  // current node into the stack.
+  // If we reach the end of the local iterator, pop it.
+  inline void MoveToNextNode() {
+    if (!current_) return;
+    if (parent_iterators_.empty()) {
+      current_ = nullptr;
+      return;
+    }
+    std::pair<NodePtr, NodeIterator>& next_it = parent_iterators_.top();
+    // Set the new node.
+    current_ = *next_it.second;
+    // Update the iterator for the next child.
+    ++next_it.second;
+    // If we finished with node, pop it.
+    if (next_it.first->end() == next_it.second) parent_iterators_.pop();
+    // If our current node is not a leaf, store the iteration state for later.
+    if (current_->begin() != current_->end())
+      parent_iterators_.emplace(make_pair(current_, current_->begin()));
+  }
+
+  // The current node of the tree.
+  NodePtr current_;
+  // State of the tree walk: each pair contains the parent node (which has been
+  // already visited) and the iterator of the next children to visit.
+  // When all the children has been visited, we pop the entry, get the next
+  // child and push back the pair if the children iterator is not end().
+  std::stack<std::pair<NodePtr, NodeIterator>> parent_iterators_;
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // LIBSPIRV_OPT_TREE_ITERATOR_H_
index 0ea55db..e9d9827 100644 (file)
@@ -14,6 +14,7 @@
 
 #include <gmock/gmock.h>
 
+#include <array>
 #include <memory>
 #include <set>
 #include <string>
@@ -429,6 +430,33 @@ TEST_F(PassClassTest, DominatorLoopToSelf) {
               spvtest::GetBasicBlock(fn, 10));
     EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)),
               spvtest::GetBasicBlock(fn, 11));
+
+    uint32_t entry_id = cfg.pseudo_entry_block()->id();
+    std::array<uint32_t, 4> node_order = {{entry_id, 10, 11, 12}};
+    {
+      // Test dominator tree iteration order.
+      opt::DominatorTree::iterator node_it = dom_tree.GetDomTree().begin();
+      opt::DominatorTree::iterator node_end = dom_tree.GetDomTree().end();
+      for (uint32_t id : node_order) {
+        EXPECT_NE(node_it, node_end);
+        EXPECT_EQ(node_it->id(), id);
+        node_it++;
+      }
+      EXPECT_EQ(node_it, node_end);
+    }
+    {
+      // Same as above, but with const iterators.
+      opt::DominatorTree::const_iterator node_it =
+          dom_tree.GetDomTree().cbegin();
+      opt::DominatorTree::const_iterator node_end =
+          dom_tree.GetDomTree().cend();
+      for (uint32_t id : node_order) {
+        EXPECT_NE(node_it, node_end);
+        EXPECT_EQ(node_it->id(), id);
+        node_it++;
+      }
+      EXPECT_EQ(node_it, node_end);
+    }
   }
 
   // Check post dominator tree
@@ -459,6 +487,31 @@ TEST_F(PassClassTest, DominatorLoopToSelf) {
 
     EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)),
               cfg.pseudo_exit_block());
+
+    uint32_t entry_id = cfg.pseudo_exit_block()->id();
+    std::array<uint32_t, 4> node_order = {{entry_id, 12, 11, 10}};
+    {
+      // Test dominator tree iteration order.
+      opt::DominatorTree::iterator node_it = tree.begin();
+      opt::DominatorTree::iterator node_end = tree.end();
+      for (uint32_t id : node_order) {
+        EXPECT_NE(node_it, node_end);
+        EXPECT_EQ(node_it->id(), id);
+        node_it++;
+      }
+      EXPECT_EQ(node_it, node_end);
+    }
+    {
+      // Same as above, but with const iterators.
+      opt::DominatorTree::const_iterator node_it = tree.cbegin();
+      opt::DominatorTree::const_iterator node_end = tree.cend();
+      for (uint32_t id : node_order) {
+        EXPECT_NE(node_it, node_end);
+        EXPECT_EQ(node_it->id(), id);
+        node_it++;
+      }
+      EXPECT_EQ(node_it, node_end);
+    }
   }
 }