Refactor depth first traversal to be more generic
authorUmar Arshad <umar@arrayfire.com>
Sat, 4 Jun 2016 01:14:22 +0000 (21:14 -0400)
committerDavid Neto <dneto@google.com>
Fri, 10 Jun 2016 10:39:42 +0000 (06:39 -0400)
Refactor the way the post order vector is created. This new method
will allow for the extraction of backedges and create the post order
vector in one pass.

source/val/BasicBlock.h
source/val/Function.cpp
source/validate_cfg.cpp

index 0612db1..2366103 100644 (file)
@@ -46,18 +46,20 @@ class BasicBlock {
   uint32_t get_id() const { return id_; }
 
   /// Returns the predecessors of the BasicBlock
-  const std::vector<BasicBlock*>& get_predecessors() const {
-    return predecessors_;
+  const std::vector<BasicBlock*>* get_predecessors() const {
+    return &predecessors_;
   }
 
   /// Returns the predecessors of the BasicBlock
-  std::vector<BasicBlock*>& get_predecessors() { return predecessors_; }
+  std::vector<BasicBlock*>* get_predecessors() { return &predecessors_; }
 
   /// Returns the successors of the BasicBlock
-  const std::vector<BasicBlock*>& get_successors() const { return successors_; }
+  const std::vector<BasicBlock*>* get_successors() const {
+    return &successors_;
+  }
 
   /// Returns the successors of the BasicBlock
-  std::vector<BasicBlock*>& get_successors() { return successors_; }
+  std::vector<BasicBlock*>* get_successors() { return &successors_; }
 
   /// Returns true if the  block should be reachable in the CFG
   bool is_reachable() const { return reachable_; }
index dd8b386..3756949 100644 (file)
@@ -43,10 +43,10 @@ namespace {
 
 void printDot(const BasicBlock& other, const ValidationState_t& module) {
   string block_string;
-  if (other.get_successors().empty()) {
+  if (other.get_successors()->empty()) {
     block_string += "end ";
   } else {
-    for (auto& block : other.get_successors()) {
+    for (auto block : *other.get_successors()) {
       block_string += module.getIdOrName(block->get_id()) + " ";
     }
   }
index d6d1389..ef41db0 100644 (file)
@@ -40,6 +40,7 @@
 #include "val/ValidationState.h"
 
 using std::find;
+using std::function;
 using std::get;
 using std::make_pair;
 using std::numeric_limits;
@@ -59,56 +60,97 @@ using bb_ptr = BasicBlock*;
 using cbb_ptr = const BasicBlock*;
 using bb_iter = vector<BasicBlock*>::const_iterator;
 
-/// @brief Sorts the blocks in a CFG given the entry node
+using get_blocks_func = function<const vector<BasicBlock*>*(const BasicBlock*)>;
+
+struct block_info {
+  cbb_ptr block;  ///< pointer to the block
+  bb_iter iter;   ///< Iterator to the current child node being processed
+};
+
+/// Returns true if a block with @p id is found in the @p work_list vector
 ///
-/// Returns a vector of basic block pointers in a Control Flow Graph(CFG) which
-/// are sorted in the order they were accessed in a post order traversal.
+/// @param[in] work_list Set of blocks visited in the the depth first traversal
+///                   of the CFG
+/// @param[in] id The ID of the block being checked
+/// @return true if the edge work_list.back().block->get_id() => id is a
+/// back-edge
+bool FindInWorkList(vector<block_info> work_list, uint32_t id) {
+  for (auto b : work_list) {
+    if (b.block->get_id() == id) return true;
+  }
+  return false;
+}
+
+/// @brief Depth first traversal starting from the \p entry BasicBlock
 ///
-/// @param[in] entry the first block of a CFG
-/// @param[in] depth_hint a hint about the depth of the CFG
+/// This function performs a depth first traversal from the \p entry
+/// BasicBlock and calls the pre/postorder functions when it needs to process
+/// the node in pre order, post order. It also calls the backedge function
+/// when a back edge is encountered
 ///
-/// @return A vector of pointers in the order they were access in a post order
-/// traversal
-vector<const BasicBlock*> PostOrderSort(const BasicBlock& entry, size_t size) {
-  struct block_info {
-    cbb_ptr block;
-    bb_iter iter;
-  };
-
+/// @param[in] entry The root BasicBlock of a CFG tree
+/// @param[in] successor_func  A function which will return a pointer to the
+///                            successor nodes
+/// @param[in] preorder   A function that will be called for every block in a CFG
+///                       following preorder traversal semantics
+/// @param[in] postorder  A function that will be called for every block in a
+///                       CFG following postorder traversal semantics
+/// @param[in] backedge   A function that will be called when a backedge is
+///                       encountered during a traversal
+/// NOTE: The @p successor_func return a pointer to a collection such that
+/// iterators to that collection remain valid for the lifetime of the algorithm
+void DepthFirstTraversal(const BasicBlock& entry,
+                         get_blocks_func successor_func,
+                         function<void(cbb_ptr)> preorder,
+                         function<void(cbb_ptr)> postorder,
+                         function<void(cbb_ptr, cbb_ptr)> backedge) {
   vector<cbb_ptr> out;
-  vector<block_info> staged;
   unordered_set<uint32_t> processed;
-
-  staged.reserve(size);
-  staged.emplace_back(block_info{&entry, begin(entry.get_successors())});
-  processed.insert(entry.get_id());
-
-  while (!staged.empty()) {
-    block_info& top = staged.back();
-    if (top.iter == end(top.block->get_successors())) {
-      out.push_back(top.block);
-      staged.pop_back();
+  /// NOTE: work_list is the sequence of nodes from the entry node to the node
+  /// being processed in the traversal
+  vector<block_info> work_list;
+
+  work_list.reserve(10);
+  work_list.push_back({&entry, begin(*successor_func(&entry))});
+  preorder(&entry);
+
+  while (!work_list.empty()) {
+    block_info& top = work_list.back();
+    if (top.iter == end(*successor_func(top.block))) {
+      postorder(top.block);
+      work_list.pop_back();
     } else {
       BasicBlock* child = *top.iter;
       top.iter++;
-      if (processed.find(child->get_id()) == end(processed)) {
-        staged.emplace_back(block_info{child, begin(child->get_successors())});
+      if (FindInWorkList(work_list, child->get_id())) {
+        backedge(top.block, child);
+      }
+      if (processed.count(child->get_id()) == 0) {
+        preorder(child);
+        work_list.emplace_back(
+            block_info{child, begin(*successor_func(child))});
         processed.insert(child->get_id());
       }
     }
   }
-  return out;
 }
+
+/// Returns the successor of a basic block.
+/// NOTE: This will be passed as a function pointer to when calculating
+/// the dominator and post dominator
+const vector<BasicBlock*>* successor(const BasicBlock* b) {
+  return b->get_successors();
+}
+
 }  // namespace
 
 vector<pair<BasicBlock*, BasicBlock*>> CalculateDominators(
-    const BasicBlock& first_block) {
+    vector<cbb_ptr>& postorder) {
   struct block_detail {
     size_t dominator;  ///< The index of blocks's dominator in post order array
     size_t postorder_index;  ///< The index of the block in the post order array
   };
 
-  vector<cbb_ptr> postorder = PostOrderSort(first_block, 10);
   const size_t undefined_dom = static_cast<size_t>(postorder.size());
 
   unordered_map<cbb_ptr, block_detail> idoms;
@@ -123,19 +165,19 @@ vector<pair<BasicBlock*, BasicBlock*>> CalculateDominators(
     changed = false;
     for (auto b = postorder.rbegin() + 1; b != postorder.rend(); b++) {
       size_t& b_dom = idoms[*b].dominator;
-      const vector<BasicBlock*>& predecessors = (*b)->get_predecessors();
+      const vector<BasicBlock*>* predecessors = (*b)->get_predecessors();
 
       // first processed predecessor
-      auto res = find_if(begin(predecessors), end(predecessors),
+      auto res = find_if(begin(*predecessors), end(*predecessors),
                          [&idoms, undefined_dom](BasicBlock* pred) {
                            return idoms[pred].dominator != undefined_dom;
                          });
-      assert(res != end(predecessors));
+      assert(res != end(*predecessors));
       BasicBlock* idom = *res;
       size_t idom_idx = idoms[idom].postorder_index;
 
       // all other predecessors
-      for (auto p : predecessors) {
+      for (auto p : *predecessors) {
         if (idom == p || p->is_reachable() == false) {
           continue;
         }
@@ -212,8 +254,16 @@ spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) {
 spv_result_t PerformCfgChecks(ValidationState_t& _) {
   for (auto& function : _.get_functions()) {
     // Updates each blocks immediate dominators
+    vector<const BasicBlock*> postorder;
+    vector<pair<uint32_t, uint32_t>> back_edges;
     if (auto* first_block = function.get_first_block()) {
-      auto edges = libspirv::CalculateDominators(*first_block);
+      DepthFirstTraversal(*first_block, successor, [](cbb_ptr) {},
+                          [&](cbb_ptr b) { postorder.push_back(b); },
+                          [&](cbb_ptr from, cbb_ptr to) {
+                            back_edges.emplace_back(from->get_id(),
+                                                    to->get_id());
+                          });
+      auto edges = libspirv::CalculateDominators(postorder);
       libspirv::UpdateImmediateDominators(edges);
     }