Add loop descriptors and some required dominator tree extensions.
authorVictor Lomuller <victor@codeplay.com>
Thu, 21 Dec 2017 14:47:25 +0000 (14:47 +0000)
committerSteven Perron <stevenperron@google.com>
Mon, 8 Jan 2018 14:31:13 +0000 (09:31 -0500)
Add post-order tree iterator.

Add DominatorTreeNode extensions:
 - Add begin/end methods to do pre-order and post-order tree traversal from a given DominatorTreeNode

Add DominatorTree extensions:
  - Add begin/end methods to do pre-order and post-order tree traversal
  - Tree traversal ignore by default the pseudo entry block
  - Retrieve a DominatorTreeNode from a basic block

Add loop descriptor:
  - Add a LoopDescriptor class to register all loops in a given function.
  - Add a Loop class to describe a loop:
    - Loop parent
    - Nested loops
    - Loop depth
    - Loop header, merge, continue and preheader
    - Basic blocks that belong to the loop

Correct a bug that forced dominator tree to be constantly rebuilt.

14 files changed:
Android.mk
source/opt/CMakeLists.txt
source/opt/dominator_tree.cpp
source/opt/dominator_tree.h
source/opt/ir_context.cpp
source/opt/ir_context.h
source/opt/loop_descriptor.cpp [new file with mode: 0644]
source/opt/loop_descriptor.h [new file with mode: 0644]
source/opt/tree_iterator.h
test/opt/CMakeLists.txt
test/opt/dominator_tree/generated.cpp
test/opt/loop_optimizations/CMakeLists.txt [new file with mode: 0644]
test/opt/loop_optimizations/loop_descriptions.cpp [new file with mode: 0644]
test/opt/loop_optimizations/nested_loops.cpp [new file with mode: 0644]

index 48fb087..2aeae42 100644 (file)
@@ -90,6 +90,7 @@ SPVTOOLS_OPT_SRC_FILES := \
                source/opt/local_single_block_elim_pass.cpp \
                source/opt/local_single_store_elim_pass.cpp \
                source/opt/local_ssa_elim_pass.cpp \
+               source/opt/loop_descriptor.cpp \
                source/opt/mem_pass.cpp \
                source/opt/merge_return_pass.cpp \
                source/opt/module.cpp \
index 0b49842..680edaf 100644 (file)
@@ -49,6 +49,7 @@ add_library(SPIRV-Tools-opt
   local_single_store_elim_pass.h
   local_ssa_elim_pass.h
   log.h
+  loop_descriptor.h
   mem_pass.h
   merge_return_pass.h
   module.h
@@ -107,6 +108,7 @@ add_library(SPIRV-Tools-opt
   local_single_block_elim_pass.cpp
   local_single_store_elim_pass.cpp
   local_ssa_elim_pass.cpp
+  loop_descriptor.cpp
   mem_pass.cpp
   merge_return_pass.cpp
   module.cpp
index 9dfc559..552a80d 100644 (file)
@@ -227,23 +227,28 @@ bool DominatorTree::StrictlyDominates(const ir::BasicBlock* a,
   return DominatorTree::StrictlyDominates(a->id(), b->id());
 }
 
+bool DominatorTree::StrictlyDominates(const DominatorTreeNode* a,
+                                      const DominatorTreeNode* b) const {
+  if (a == b) return false;
+  return Dominates(a, b);
+}
+
 bool DominatorTree::Dominates(uint32_t a, uint32_t b) const {
   // Check that both of the inputs are actual nodes.
-  auto a_itr = nodes_.find(a);
-  auto b_itr = nodes_.find(b);
-  if (a_itr == nodes_.end() || b_itr == nodes_.end()) return false;
+  const DominatorTreeNode* a_node = GetTreeNode(a);
+  const DominatorTreeNode* b_node = GetTreeNode(b);
+  if (!a_node || !b_node) return false;
 
+  return Dominates(a_node, b_node);
+}
+
+bool DominatorTree::Dominates(const DominatorTreeNode* a,
+                              const DominatorTreeNode* b) const {
   // Node A dominates node B if they are the same.
   if (a == b) return true;
-  const DominatorTreeNode* nodeA = &a_itr->second;
-  const DominatorTreeNode* nodeB = &b_itr->second;
-
-  if (nodeA->dfs_num_pre_ < nodeB->dfs_num_pre_ &&
-      nodeA->dfs_num_post_ > nodeB->dfs_num_post_) {
-    return true;
-  }
 
-  return false;
+  return a->dfs_num_pre_ < b->dfs_num_pre_ &&
+         a->dfs_num_post_ > b->dfs_num_post_;
 }
 
 bool DominatorTree::Dominates(const ir::BasicBlock* A,
index d3cdcdf..0be4951 100644 (file)
@@ -40,6 +40,13 @@ struct DominatorTreeNode {
   using iterator = std::vector<DominatorTreeNode*>::iterator;
   using const_iterator = std::vector<DominatorTreeNode*>::const_iterator;
 
+  // depth first preorder iterator.
+  using df_iterator = TreeDFIterator<DominatorTreeNode>;
+  using const_df_iterator = TreeDFIterator<const DominatorTreeNode>;
+  // depth first postorder iterator.
+  using post_iterator = PostOrderTreeDFIterator<DominatorTreeNode>;
+  using const_post_iterator = PostOrderTreeDFIterator<const DominatorTreeNode>;
+
   iterator begin() { return children_.begin(); }
   iterator end() { return children_.end(); }
   const_iterator begin() const { return cbegin(); }
@@ -47,6 +54,26 @@ struct DominatorTreeNode {
   const_iterator cbegin() const { return children_.begin(); }
   const_iterator cend() const { return children_.end(); }
 
+  // Depth first preorder iterator using this node as root.
+  df_iterator df_begin() { return df_iterator(this); }
+  df_iterator df_end() { return df_iterator(); }
+  const_df_iterator df_begin() const { return df_cbegin(); }
+  const_df_iterator df_end() const { return df_cend(); }
+  const_df_iterator df_cbegin() const { return const_df_iterator(this); }
+  const_df_iterator df_cend() const { return const_df_iterator(); }
+
+  // Depth first postorder iterator using this node as root.
+  post_iterator post_begin() { return post_iterator::begin(this); }
+  post_iterator post_end() { return post_iterator::end(nullptr); }
+  const_post_iterator post_begin() const { return post_cbegin(); }
+  const_post_iterator post_end() const { return post_cend(); }
+  const_post_iterator post_cbegin() const {
+    return const_post_iterator::begin(this);
+  }
+  const_post_iterator post_cend() const {
+    return const_post_iterator::end(nullptr);
+  }
+
   inline uint32_t id() const { return bb_->id(); }
 
   ir::BasicBlock* bb_;
@@ -69,6 +96,8 @@ class DominatorTree {
   using DominatorTreeNodeMap = std::map<uint32_t, DominatorTreeNode>;
   using iterator = TreeDFIterator<DominatorTreeNode>;
   using const_iterator = TreeDFIterator<const DominatorTreeNode>;
+  using post_iterator = PostOrderTreeDFIterator<DominatorTreeNode>;
+  using const_post_iterator = PostOrderTreeDFIterator<const DominatorTreeNode>;
 
   // List of DominatorTreeNode to define the list of roots
   using DominatorTreeNodeList = std::vector<DominatorTreeNode*>;
@@ -80,13 +109,27 @@ class DominatorTree {
 
   // Depth first iterators.
   // Traverse the dominator tree in a depth first pre-order.
-  iterator begin() { return iterator(GetRoot()); }
+  // The pseudo-block is ignored.
+  iterator begin() { return ++iterator(GetRoot()); }
   iterator end() { return iterator(); }
   const_iterator begin() const { return cbegin(); }
   const_iterator end() const { return cend(); }
-  const_iterator cbegin() const { return const_iterator(GetRoot()); }
+  const_iterator cbegin() const { return ++const_iterator(GetRoot()); }
   const_iterator cend() const { return const_iterator(); }
 
+  // Traverse the dominator tree in a depth first post-order.
+  // The pseudo-block is ignored.
+  post_iterator post_begin() { return post_iterator::begin(GetRoot()); }
+  post_iterator post_end() { return post_iterator::end(GetRoot()); }
+  const_post_iterator post_begin() const { return post_cbegin(); }
+  const_post_iterator post_end() const { return post_cend(); }
+  const_post_iterator post_cbegin() const {
+    return const_post_iterator::begin(GetRoot());
+  }
+  const_post_iterator post_cend() const {
+    return const_post_iterator::end(GetRoot());
+  }
+
   roots_iterator roots_begin() { return roots_.begin(); }
   roots_iterator roots_end() { return roots_.end(); }
   roots_const_iterator roots_begin() const { return roots_cbegin(); }
@@ -122,6 +165,9 @@ class DominatorTree {
   // Check if the basic block id |a| dominates the basic block id |b|.
   bool Dominates(uint32_t a, uint32_t b) const;
 
+  // Check if the dominator tree node |a| dominates the dominator tree node |b|.
+  bool Dominates(const DominatorTreeNode* a, const DominatorTreeNode* b) const;
+
   // Check if the basic block |a| strictly dominates the basic block |b|.
   bool StrictlyDominates(const ir::BasicBlock* a,
                          const ir::BasicBlock* b) const;
@@ -129,6 +175,11 @@ class DominatorTree {
   // Check if the basic block id |a| strictly dominates the basic block id |b|.
   bool StrictlyDominates(uint32_t a, uint32_t b) const;
 
+  // Check if the dominator tree node |a| strictly dominates the dominator tree
+  // node |b|.
+  bool StrictlyDominates(const DominatorTreeNode* a,
+                         const DominatorTreeNode* b) const;
+
   // Returns the immediate dominator of basic block |a|.
   ir::BasicBlock* ImmediateDominator(const ir::BasicBlock* A) const;
 
@@ -173,6 +224,36 @@ class DominatorTree {
     return true;
   }
 
+  // Returns the DominatorTreeNode associated with the basic block |bb|.
+  // If the |bb| is unknown to the dominator tree, it returns null.
+  inline DominatorTreeNode* GetTreeNode(ir::BasicBlock* bb) {
+    return GetTreeNode(bb->id());
+  }
+  // Returns the DominatorTreeNode associated with the basic block |bb|.
+  // If the |bb| is unknown to the dominator tree, it returns null.
+  inline const DominatorTreeNode* GetTreeNode(ir::BasicBlock* bb) const {
+    return GetTreeNode(bb->id());
+  }
+
+  // Returns the DominatorTreeNode associated with the basic block id |id|.
+  // If the id |id| is unknown to the dominator tree, it returns null.
+  inline DominatorTreeNode* GetTreeNode(uint32_t id) {
+    DominatorTreeNodeMap::iterator node_iter = nodes_.find(id);
+    if (node_iter == nodes_.end()) {
+      return nullptr;
+    }
+    return &node_iter->second;
+  }
+  // Returns the DominatorTreeNode associated with the basic block id |id|.
+  // If the id |id| is unknown to the dominator tree, it returns null.
+  inline const DominatorTreeNode* GetTreeNode(uint32_t id) const {
+    DominatorTreeNodeMap::const_iterator node_iter = nodes_.find(id);
+    if (node_iter == nodes_.end()) {
+      return nullptr;
+    }
+    return &node_iter->second;
+  }
+
  private:
   // Adds the basic block |bb| to the tree structure if it doesn't already
   // exist.
index 38038d7..9ab969a 100644 (file)
@@ -38,10 +38,7 @@ void IRContext::BuildInvalidAnalyses(IRContext::Analysis set) {
     BuildCFG();
   }
   if (set & kAnalysisDominatorAnalysis) {
-    // An invalid dominator tree analysis will be empty so rebuilding it just
-    // means marking it as valid. Each tree will be initalisalised when
-    // requested on a per function basis.
-    valid_analyses_ |= kAnalysisDominatorAnalysis;
+    ResetDominatorAnalysis();
   }
 }
 
@@ -478,8 +475,11 @@ void IRContext::InitializeCombinators() {
 // Gets the dominator analysis for function |f|.
 opt::DominatorAnalysis* IRContext::GetDominatorAnalysis(const ir::Function* f,
                                                         const ir::CFG& in_cfg) {
-  if (dominator_trees_.find(f) == dominator_trees_.end() ||
-      !AreAnalysesValid(kAnalysisDominatorAnalysis)) {
+  if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) {
+    ResetDominatorAnalysis();
+  }
+
+  if (dominator_trees_.find(f) == dominator_trees_.end()) {
     dominator_trees_[f].InitializeTree(f, in_cfg);
   }
 
@@ -489,8 +489,11 @@ opt::DominatorAnalysis* IRContext::GetDominatorAnalysis(const ir::Function* f,
 // Gets the postdominator analysis for function |f|.
 opt::PostDominatorAnalysis* IRContext::GetPostDominatorAnalysis(
     const ir::Function* f, const ir::CFG& in_cfg) {
-  if (post_dominator_trees_.find(f) == post_dominator_trees_.end() ||
-      !AreAnalysesValid(kAnalysisDominatorAnalysis)) {
+  if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) {
+    ResetDominatorAnalysis();
+  }
+
+  if (post_dominator_trees_.find(f) == post_dominator_trees_.end()) {
     post_dominator_trees_[f].InitializeTree(f, in_cfg);
   }
 
index 67f5c0e..9a56053 100644 (file)
@@ -423,6 +423,15 @@ class IRContext {
     valid_analyses_ = valid_analyses_ | kAnalysisCFG;
   }
 
+  // Removes all computed dominator and post-dominator trees. This will force
+  // the context to rebuild the trees on demand.
+  void ResetDominatorAnalysis() {
+    // Clear the cache.
+    dominator_trees_.clear();
+    post_dominator_trees_.clear();
+    valid_analyses_ = valid_analyses_ | kAnalysisDominatorAnalysis;
+  }
+
   // Analyzes the features in the owned module. Builds the manager if required.
   void AnalyzeFeatures() {
     feature_mgr_.reset(new opt::FeatureManager(grammar_));
diff --git a/source/opt/loop_descriptor.cpp b/source/opt/loop_descriptor.cpp
new file mode 100644 (file)
index 0000000..e1bb0c7
--- /dev/null
@@ -0,0 +1,157 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "opt/loop_descriptor.h"
+#include <iostream>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "opt/iterator.h"
+#include "opt/loop_descriptor.h"
+#include "opt/make_unique.h"
+#include "opt/tree_iterator.h"
+
+namespace spvtools {
+namespace ir {
+
+Loop::Loop(IRContext* context, opt::DominatorAnalysis* dom_analysis,
+           BasicBlock* header, BasicBlock* continue_target,
+           BasicBlock* merge_target)
+    : loop_header_(header),
+      loop_continue_(continue_target),
+      loop_merge_(merge_target),
+      loop_preheader_(nullptr),
+      parent_(nullptr) {
+  assert(context);
+  assert(dom_analysis);
+  loop_preheader_ = FindLoopPreheader(context, dom_analysis);
+  AddBasicBlockToLoop(header);
+  AddBasicBlockToLoop(continue_target);
+}
+
+BasicBlock* Loop::FindLoopPreheader(IRContext* ir_context,
+                                    opt::DominatorAnalysis* dom_analysis) {
+  CFG* cfg = ir_context->cfg();
+  opt::DominatorTree& dom_tree = dom_analysis->GetDomTree();
+  opt::DominatorTreeNode* header_node = dom_tree.GetTreeNode(loop_header_);
+
+  // The loop predecessor.
+  BasicBlock* loop_pred = nullptr;
+
+  auto header_pred = cfg->preds(loop_header_->id());
+  for (uint32_t p_id : header_pred) {
+    opt::DominatorTreeNode* node = dom_tree.GetTreeNode(p_id);
+    if (node && !dom_tree.Dominates(header_node, node)) {
+      // The predecessor is not part of the loop, so potential loop preheader.
+      if (loop_pred && node->bb_ != loop_pred) {
+        // If we saw 2 distinct predecessors that are outside the loop, we don't
+        // have a loop preheader.
+        return nullptr;
+      }
+      loop_pred = node->bb_;
+    }
+  }
+  // Safe guard against invalid code, SPIR-V spec forbids loop with the entry
+  // node as header.
+  assert(loop_pred && "The header node is the entry block ?");
+
+  // So we have a unique basic block that can enter this loop.
+  // If this loop is the unique successor of this block, then it is a loop
+  // preheader.
+  bool is_preheader = true;
+  uint32_t loop_header_id = loop_header_->id();
+  loop_pred->ForEachSuccessorLabel(
+      [&is_preheader, loop_header_id](const uint32_t id) {
+        if (id != loop_header_id) is_preheader = false;
+      });
+  if (is_preheader) return loop_pred;
+  return nullptr;
+}
+
+LoopDescriptor::LoopDescriptor(const Function* f) { PopulateList(f); }
+
+void LoopDescriptor::PopulateList(const Function* f) {
+  IRContext* context = f->GetParent()->context();
+
+  opt::DominatorAnalysis* dom_analysis =
+      context->GetDominatorAnalysis(f, *context->cfg());
+
+  loops_.clear();
+
+  // Post-order traversal of the dominator tree to find all the OpLoopMerge
+  // instructions.
+  opt::DominatorTree& dom_tree = dom_analysis->GetDomTree();
+  for (opt::DominatorTreeNode& node :
+       ir::make_range(dom_tree.post_begin(), dom_tree.post_end())) {
+    Instruction* merge_inst = node.bb_->GetLoopMergeInst();
+    if (merge_inst) {
+      // The id of the merge basic block of this loop.
+      uint32_t merge_bb_id = merge_inst->GetSingleWordOperand(0);
+
+      // The id of the continue basic block of this loop.
+      uint32_t continue_bb_id = merge_inst->GetSingleWordOperand(1);
+
+      // The merge target of this loop.
+      BasicBlock* merge_bb = context->cfg()->block(merge_bb_id);
+
+      // The continue target of this loop.
+      BasicBlock* continue_bb = context->cfg()->block(continue_bb_id);
+
+      // The basic block containing the merge instruction.
+      BasicBlock* header_bb = context->get_instr_block(merge_inst);
+
+      // Add the loop to the list of all the loops in the function.
+      loops_.emplace_back(MakeUnique<Loop>(context, dom_analysis, header_bb,
+                                           continue_bb, merge_bb));
+      Loop* current_loop = loops_.back().get();
+
+      // We have a bottom-up construction, so if this loop has nested-loops,
+      // they are by construction at the tail of the loop list.
+      for (auto itr = loops_.rbegin() + 1; itr != loops_.rend(); ++itr) {
+        Loop* previous_loop = itr->get();
+
+        // If the loop already has a parent, then it has been processed.
+        if (previous_loop->HasParent()) continue;
+
+        // If the current loop does not dominates the previous loop then it is
+        // not nested loop.
+        if (!dom_analysis->Dominates(header_bb,
+                                     previous_loop->GetHeaderBlock()))
+          continue;
+        // If the current loop merge dominates the previous loop then it is
+        // not nested loop.
+        if (dom_analysis->Dominates(merge_bb, previous_loop->GetHeaderBlock()))
+          continue;
+
+        current_loop->AddNestedLoop(previous_loop);
+      }
+      opt::DominatorTreeNode* dom_merge_node = dom_tree.GetTreeNode(merge_bb);
+      for (opt::DominatorTreeNode& loop_node :
+           make_range(node.df_begin(), node.df_end())) {
+        // Check if we are in the loop.
+        if (dom_tree.Dominates(dom_merge_node, &loop_node)) continue;
+        current_loop->AddBasicBlockToLoop(loop_node.bb_);
+        basic_block_to_loop_.insert(
+            std::make_pair(loop_node.bb_->id(), current_loop));
+      }
+    }
+  }
+  for (std::unique_ptr<Loop>& loop : loops_) {
+    if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop.get());
+  }
+}
+
+}  // namespace ir
+}  // namespace spvtools
diff --git a/source/opt/loop_descriptor.h b/source/opt/loop_descriptor.h
new file mode 100644 (file)
index 0000000..b00c18e
--- /dev/null
@@ -0,0 +1,262 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_LOOP_DESCRIPTORS_H_
+#define LIBSPIRV_OPT_LOOP_DESCRIPTORS_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "opt/module.h"
+#include "opt/pass.h"
+#include "opt/tree_iterator.h"
+
+namespace spvtools {
+namespace ir {
+class CFG;
+class LoopDescriptor;
+
+// A class to represent and manipulate a loop in structured control flow.
+class Loop {
+  // The type used to represent nested child loops.
+  using ChildrenList = std::vector<Loop*>;
+
+ public:
+  using iterator = ChildrenList::iterator;
+  using const_iterator = ChildrenList::const_iterator;
+  using BasicBlockListTy = std::unordered_set<uint32_t>;
+
+  Loop()
+      : loop_header_(nullptr),
+        loop_continue_(nullptr),
+        loop_merge_(nullptr),
+        loop_preheader_(nullptr),
+        parent_(nullptr) {}
+
+  Loop(IRContext* context, opt::DominatorAnalysis* analysis, BasicBlock* header,
+       BasicBlock* continue_target, BasicBlock* merge_target);
+
+  // Iterators over the immediate sub-loops.
+  inline iterator begin() { return nested_loops_.begin(); }
+  inline iterator end() { return nested_loops_.end(); }
+  inline const_iterator begin() const { return cbegin(); }
+  inline const_iterator end() const { return cend(); }
+  inline const_iterator cbegin() const { return nested_loops_.begin(); }
+  inline const_iterator cend() const { return nested_loops_.end(); }
+
+  // Returns the header (first basic block of the loop). This block contains the
+  // OpLoopMerge instruction.
+  inline BasicBlock* GetHeaderBlock() { return loop_header_; }
+  inline const BasicBlock* GetHeaderBlock() const { return loop_header_; }
+
+  // Returns the latch basic block (basic block that holds the back-edge).
+  inline BasicBlock* GetLatchBlock() { return loop_continue_; }
+  inline const BasicBlock* GetLatchBlock() const { return loop_continue_; }
+
+  // Returns the BasicBlock which marks the end of the loop.
+  inline BasicBlock* GetMergeBlock() { return loop_merge_; }
+  inline const BasicBlock* GetMergeBlock() const { return loop_merge_; }
+
+  // Returns the loop pre-header, nullptr means that the loop predecessor does
+  // not qualify as a preheader.
+  // The preheader is the unique predecessor that:
+  //   - Dominates the loop header;
+  //   - Has only the loop header as successor.
+  inline BasicBlock* GetPreHeaderBlock() { return loop_preheader_; }
+
+  // Returns the loop pre-header.
+  inline const BasicBlock* GetPreHeaderBlock() const { return loop_preheader_; }
+
+  // Returns true if this loop contains any nested loops.
+  inline bool HasNestedLoops() const { return nested_loops_.size() != 0; }
+
+  // Returns the depth of this loop in the loop nest.
+  // The outer-most loop has a depth of 1.
+  inline size_t GetDepth() const {
+    size_t lvl = 1;
+    for (const Loop* loop = GetParent(); loop; loop = loop->GetParent()) lvl++;
+    return lvl;
+  }
+
+  // Adds |nested| as a nested loop of this loop. Automatically register |this|
+  // as the parent of |nested|.
+  inline void AddNestedLoop(Loop* nested) {
+    assert(!nested->GetParent() && "The loop has another parent.");
+    nested_loops_.push_back(nested);
+    nested->SetParent(this);
+  }
+
+  inline Loop* GetParent() { return parent_; }
+  inline const Loop* GetParent() const { return parent_; }
+
+  inline bool HasParent() const { return parent_; }
+
+  // Returns true if this loop is itself nested within another loop.
+  inline bool IsNested() const { return parent_ != nullptr; }
+
+  // Returns the set of all basic blocks contained within the loop. Will be all
+  // BasicBlocks dominated by the header which are not also dominated by the
+  // loop merge block.
+  inline const BasicBlockListTy& GetBlocks() const {
+    return loop_basic_blocks_;
+  }
+
+  // Returns true if the basic block |bb| is inside this loop.
+  inline bool IsInsideLoop(const BasicBlock* bb) const {
+    return IsInsideLoop(bb->id());
+  }
+
+  // Returns true if the basic block id |bb_id| is inside this loop.
+  inline bool IsInsideLoop(uint32_t bb_id) const {
+    return loop_basic_blocks_.count(bb_id);
+  }
+
+  // Returns true if the instruction |inst| is inside this loop.
+  inline bool IsInsideLoop(Instruction* inst) const {
+    const BasicBlock* parent_block = inst->context()->get_instr_block(inst);
+    if (!parent_block) return true;
+    return IsInsideLoop(parent_block);
+  }
+
+  // Adds the Basic Block |bb| this loop and its parents.
+  void AddBasicBlockToLoop(const BasicBlock* bb) {
+#ifndef NDEBUG
+    assert(bb->GetParent() && "The basic block does not belong to a function");
+    IRContext* context = bb->GetParent()->GetParent()->context();
+
+    opt::DominatorAnalysis* dom_analysis =
+        context->GetDominatorAnalysis(bb->GetParent(), *context->cfg());
+    assert(dom_analysis->Dominates(GetHeaderBlock(), bb));
+
+    opt::PostDominatorAnalysis* postdom_analysis =
+        context->GetPostDominatorAnalysis(bb->GetParent(), *context->cfg());
+    assert(postdom_analysis->Dominates(GetMergeBlock(), bb));
+#endif  // NDEBUG
+
+    for (Loop* loop = this; loop != nullptr; loop = loop->parent_) {
+      loop_basic_blocks_.insert(bb->id());
+    }
+  }
+
+ private:
+  // The block which marks the start of the loop.
+  BasicBlock* loop_header_;
+
+  // The block which begins the body of the loop.
+  BasicBlock* loop_continue_;
+
+  // The block which marks the end of the loop.
+  BasicBlock* loop_merge_;
+
+  // The block immediately before the loop header.
+  BasicBlock* loop_preheader_;
+
+  // A parent of a loop is the loop which contains it as a nested child loop.
+  Loop* parent_;
+
+  // Nested child loops of this loop.
+  ChildrenList nested_loops_;
+
+  // A set of all the basic blocks which comprise the loop structure. Will be
+  // computed only when needed on demand.
+  BasicBlockListTy loop_basic_blocks_;
+
+  // Sets the parent loop of this loop, that is, a loop which contains this loop
+  // as a nested child loop.
+  inline void SetParent(Loop* parent) { parent_ = parent; }
+
+  // Returns the loop preheader if it exists, returns nullptr otherwise.
+  BasicBlock* FindLoopPreheader(IRContext* context,
+                                opt::DominatorAnalysis* dom_analysis);
+
+  // This is only to allow LoopDescriptor::dummy_top_loop_ to add top level
+  // loops as child.
+  friend class LoopDescriptor;
+};
+
+// Loop descriptions class for a given function.
+// For a given function, the class builds loop nests information.
+// The analysis expects a structured control flow.
+class LoopDescriptor {
+ public:
+  // Iterator interface (depth first postorder traversal).
+  using iterator = opt::PostOrderTreeDFIterator<Loop>;
+  using const_iterator = opt::PostOrderTreeDFIterator<const Loop>;
+
+  // Creates a loop object for all loops found in |f|.
+  explicit LoopDescriptor(const Function* f);
+
+  // Returns the number of loops found in the function.
+  inline size_t NumLoops() const { return loops_.size(); }
+
+  // Returns the loop at a particular |index|. The |index| must be in bounds,
+  // check with NumLoops before calling.
+  inline Loop& GetLoopByIndex(size_t index) const {
+    assert(loops_.size() > index &&
+           "Index out of range (larger than loop count)");
+    return *loops_[index].get();
+  }
+
+  // Returns the inner most loop that contains the basic block id |block_id|.
+  inline Loop* operator[](uint32_t block_id) const {
+    return FindLoopForBasicBlock(block_id);
+  }
+
+  // Returns the inner most loop that contains the basic block |bb|.
+  inline Loop* operator[](const BasicBlock* bb) const {
+    return (*this)[bb->id()];
+  }
+
+  // Iterators for post order depth first traversal of the loops.
+  // Inner most loops will be visited first.
+  inline iterator begin() { return iterator::begin(&dummy_top_loop_); }
+  inline iterator end() { return iterator::end(&dummy_top_loop_); }
+  inline const_iterator begin() const { return cbegin(); }
+  inline const_iterator end() const { return cend(); }
+  inline const_iterator cbegin() const {
+    return const_iterator::begin(&dummy_top_loop_);
+  }
+  inline const_iterator cend() const {
+    return const_iterator::end(&dummy_top_loop_);
+  }
+
+ private:
+  using LoopContainerType = std::vector<std::unique_ptr<Loop>>;
+
+  // Creates loop descriptors for the function |f|.
+  void PopulateList(const Function* f);
+
+  // Returns the inner most loop that contains the basic block id |block_id|.
+  inline Loop* FindLoopForBasicBlock(uint32_t block_id) const {
+    std::unordered_map<uint32_t, Loop*>::const_iterator it =
+        basic_block_to_loop_.find(block_id);
+    return it != basic_block_to_loop_.end() ? it->second : nullptr;
+  }
+
+  // A list of all the loops in the function.
+  LoopContainerType loops_;
+  // Dummy root: this "loop" is only there to help iterators creation.
+  Loop dummy_top_loop_;
+  std::unordered_map<uint32_t, Loop*> basic_block_to_loop_;
+};
+
+}  // namespace ir
+}  // namespace spvtools
+
+#endif  // LIBSPIRV_OPT_LOOP_DESCRIPTORS_H_
index 4a24a01..ba724df 100644 (file)
@@ -115,6 +115,131 @@ class TreeDFIterator {
   std::stack<std::pair<NodePtr, NodeIterator>> parent_iterators_;
 };
 
+// Helper class to iterate over a tree in a depth first post-order.
+// The class assumes the data structure is a tree, tree node type implements a
+// forward iterator.
+// At each step, the iterator holds the pointer to the current node and state of
+// the walk.
+// The state is recorded by stacking the iteration position of the node
+// children. To move to the next node, the iterator:
+//  - Looks at the top of the stack;
+//  - If the children iterator has reach the end, then the node become the
+//    current one and we pop the stack;
+//  - Otherwise, we save the child and increment the iterator;
+//  - We walk the child sub-tree until we find a leaf, stacking all non-leaves
+//    states (pair of node pointer and child iterator) as we walk it.
+template <typename NodeTy>
+class PostOrderTreeDFIterator {
+  static_assert(!std::is_pointer<NodeTy>::value &&
+                    !std::is_reference<NodeTy>::value,
+                "NodeTy should be a class");
+  // Type alias to keep track of the const qualifier.
+  using NodeIterator =
+      typename std::conditional<std::is_const<NodeTy>::value,
+                                typename NodeTy::const_iterator,
+                                typename NodeTy::iterator>::type;
+
+  // Type alias to keep track of the const qualifier.
+  using NodePtr = NodeTy*;
+
+ public:
+  // Standard iterator interface.
+  using reference = NodeTy&;
+  using value_type = NodeTy;
+
+  static inline PostOrderTreeDFIterator begin(NodePtr top_node) {
+    return PostOrderTreeDFIterator(top_node);
+  }
+
+  static inline PostOrderTreeDFIterator end(NodePtr sentinel_node) {
+    return PostOrderTreeDFIterator(sentinel_node, false);
+  }
+
+  bool operator==(const PostOrderTreeDFIterator& x) const {
+    return current_ == x.current_;
+  }
+
+  bool operator!=(const PostOrderTreeDFIterator& x) const {
+    return !(*this == x);
+  }
+
+  reference operator*() const { return *current_; }
+
+  NodePtr operator->() const { return current_; }
+
+  PostOrderTreeDFIterator& operator++() {
+    MoveToNextNode();
+    return *this;
+  }
+
+  PostOrderTreeDFIterator operator++(int) {
+    PostOrderTreeDFIterator tmp = *this;
+    ++*this;
+    return tmp;
+  }
+
+ private:
+  explicit inline PostOrderTreeDFIterator(NodePtr top_node)
+      : current_(top_node) {
+    if (current_) WalkToLeaf();
+  }
+
+  // Constructor for the "end()" iterator.
+  // |end_sentinel| is the value that acts as end value (can be null). The bool
+  // parameters is to distinguish from the start() Ctor.
+  inline PostOrderTreeDFIterator(NodePtr sentinel_node, bool)
+      : current_(sentinel_node) {}
+
+  // Moves the iterator to the next node in the tree.
+  // If we are at the end, do nothing, otherwise
+  // if our current node has children, use the children iterator and push the
+  // current node into the stack.
+  // If we reach the end of the local iterator, pop it.
+  inline void MoveToNextNode() {
+    if (!current_) return;
+    if (parent_iterators_.empty()) {
+      current_ = nullptr;
+      return;
+    }
+    std::pair<NodePtr, NodeIterator>& next_it = parent_iterators_.top();
+    // If we visited all children, the current node is the top of the stack.
+    if (next_it.second == next_it.first->end()) {
+      // Set the new node.
+      current_ = next_it.first;
+      parent_iterators_.pop();
+      return;
+    }
+    // We have more children to visit, set the current node to the first child
+    // and dive to leaf.
+    current_ = *next_it.second;
+    // Update the iterator for the next child (avoid unneeded pop).
+    ++next_it.second;
+    WalkToLeaf();
+  }
+
+  // Moves the iterator to the next node in the tree.
+  // If we are at the end, do nothing, otherwise
+  // if our current node has children, use the children iterator and push the
+  // current node into the stack.
+  // If we reach the end of the local iterator, pop it.
+  inline void WalkToLeaf() {
+    while (current_->begin() != current_->end()) {
+      NodeIterator next = ++current_->begin();
+      parent_iterators_.emplace(make_pair(current_, next));
+      // Set the first child as the new node.
+      current_ = *current_->begin();
+    }
+  }
+
+  // The current node of the tree.
+  NodePtr current_;
+  // State of the tree walk: each pair contains the parent node and the iterator
+  // of the next children to visit.
+  // When all the children has been visited, we pop the first entry and the
+  // parent node become the current node.
+  std::stack<std::pair<NodePtr, NodeIterator>> parent_iterators_;
+};
+
 }  // namespace opt
 }  // namespace spvtools
 
index d8a5b53..bd83f94 100644 (file)
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 add_subdirectory(dominator_tree)
+add_subdirectory(loop_optimizations)
 
 add_spvtools_unittest(TARGET instruction
   SRCS instruction_test.cpp
index e9d9827..786dbfc 100644 (file)
@@ -25,6 +25,7 @@
 #include "../pass_fixture.h"
 #include "../pass_utils.h"
 #include "opt/dominator_analysis.h"
+#include "opt/iterator.h"
 #include "opt/pass.h"
 
 namespace {
@@ -431,8 +432,7 @@ TEST_F(PassClassTest, DominatorLoopToSelf) {
     EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)),
               spvtest::GetBasicBlock(fn, 11));
 
-    uint32_t entry_id = cfg.pseudo_entry_block()->id();
-    std::array<uint32_t, 4> node_order = {{entry_id, 10, 11, 12}};
+    std::array<uint32_t, 3> node_order = {{10, 11, 12}};
     {
       // Test dominator tree iteration order.
       opt::DominatorTree::iterator node_it = dom_tree.GetDomTree().begin();
@@ -457,6 +457,34 @@ TEST_F(PassClassTest, DominatorLoopToSelf) {
       }
       EXPECT_EQ(node_it, node_end);
     }
+    {
+      // Test dominator tree iteration order.
+      opt::DominatorTree::post_iterator node_it =
+          dom_tree.GetDomTree().post_begin();
+      opt::DominatorTree::post_iterator node_end =
+          dom_tree.GetDomTree().post_end();
+      for (uint32_t id :
+           ir::make_range(node_order.rbegin(), node_order.rend())) {
+        EXPECT_NE(node_it, node_end);
+        EXPECT_EQ(node_it->id(), id);
+        node_it++;
+      }
+      EXPECT_EQ(node_it, node_end);
+    }
+    {
+      // Same as above, but with const iterators.
+      opt::DominatorTree::const_post_iterator node_it =
+          dom_tree.GetDomTree().post_cbegin();
+      opt::DominatorTree::const_post_iterator node_end =
+          dom_tree.GetDomTree().post_cend();
+      for (uint32_t id :
+           ir::make_range(node_order.rbegin(), node_order.rend())) {
+        EXPECT_NE(node_it, node_end);
+        EXPECT_EQ(node_it->id(), id);
+        node_it++;
+      }
+      EXPECT_EQ(node_it, node_end);
+    }
   }
 
   // Check post dominator tree
@@ -488,8 +516,7 @@ TEST_F(PassClassTest, DominatorLoopToSelf) {
     EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)),
               cfg.pseudo_exit_block());
 
-    uint32_t entry_id = cfg.pseudo_exit_block()->id();
-    std::array<uint32_t, 4> node_order = {{entry_id, 12, 11, 10}};
+    std::array<uint32_t, 3> node_order = {{12, 11, 10}};
     {
       // Test dominator tree iteration order.
       opt::DominatorTree::iterator node_it = tree.begin();
@@ -512,6 +539,34 @@ TEST_F(PassClassTest, DominatorLoopToSelf) {
       }
       EXPECT_EQ(node_it, node_end);
     }
+    {
+      // Test dominator tree iteration order.
+      opt::DominatorTree::post_iterator node_it =
+          dom_tree.GetDomTree().post_begin();
+      opt::DominatorTree::post_iterator node_end =
+          dom_tree.GetDomTree().post_end();
+      for (uint32_t id :
+           ir::make_range(node_order.rbegin(), node_order.rend())) {
+        EXPECT_NE(node_it, node_end);
+        EXPECT_EQ(node_it->id(), id);
+        node_it++;
+      }
+      EXPECT_EQ(node_it, node_end);
+    }
+    {
+      // Same as above, but with const iterators.
+      opt::DominatorTree::const_post_iterator node_it =
+          dom_tree.GetDomTree().post_cbegin();
+      opt::DominatorTree::const_post_iterator node_end =
+          dom_tree.GetDomTree().post_cend();
+      for (uint32_t id :
+           ir::make_range(node_order.rbegin(), node_order.rend())) {
+        EXPECT_NE(node_it, node_end);
+        EXPECT_EQ(node_it->id(), id);
+        node_it++;
+      }
+      EXPECT_EQ(node_it, node_end);
+    }
   }
 }
 
diff --git a/test/opt/loop_optimizations/CMakeLists.txt b/test/opt/loop_optimizations/CMakeLists.txt
new file mode 100644 (file)
index 0000000..b5360c2
--- /dev/null
@@ -0,0 +1,27 @@
+# Copyright (c) 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+add_spvtools_unittest(TARGET loop_descriptor_simple
+    SRCS ../function_utils.h
+        loop_descriptions.cpp
+    LIBS SPIRV-Tools-opt
+)
+
+add_spvtools_unittest(TARGET loop_descriptor_nested
+    SRCS ../function_utils.h
+        nested_loops.cpp
+    LIBS SPIRV-Tools-opt
+)
+
diff --git a/test/opt/loop_optimizations/loop_descriptions.cpp b/test/opt/loop_optimizations/loop_descriptions.cpp
new file mode 100644 (file)
index 0000000..d54b789
--- /dev/null
@@ -0,0 +1,205 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gmock/gmock.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "../assembly_builder.h"
+#include "../function_utils.h"
+#include "../pass_fixture.h"
+#include "../pass_utils.h"
+#include "opt/loop_descriptor.h"
+#include "opt/pass.h"
+
+namespace {
+
+using namespace spvtools;
+using ::testing::UnorderedElementsAre;
+
+using PassClassTest = PassTest<::testing::Test>;
+
+/*
+Generated from the following GLSL
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  int i = 0;
+  for(; i < 10; ++i) {
+  }
+}
+*/
+TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
+  const std::string text = R"(
+                OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main" %3
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+               OpName %5 "i"
+               OpName %3 "c"
+               OpDecorate %3 Location 0
+          %6 = OpTypeVoid
+          %7 = OpTypeFunction %6
+          %8 = OpTypeInt 32 1
+          %9 = OpTypePointer Function %8
+         %10 = OpConstant %8 0
+         %11 = OpConstant %8 10
+         %12 = OpTypeBool
+         %13 = OpConstant %8 1
+         %14 = OpTypeFloat 32
+         %15 = OpTypeVector %14 4
+         %16 = OpTypePointer Output %15
+          %3 = OpVariable %16 Output
+          %2 = OpFunction %6 None %7
+         %17 = OpLabel
+          %5 = OpVariable %9 Function
+               OpStore %5 %10
+               OpBranch %18
+         %18 = OpLabel
+               OpLoopMerge %19 %20 None
+               OpBranch %21
+         %21 = OpLabel
+         %22 = OpLoad %8 %5
+         %23 = OpSLessThan %12 %22 %11
+               OpBranchConditional %23 %24 %19
+         %24 = OpLabel
+               OpBranch %20
+         %20 = OpLabel
+         %25 = OpLoad %8 %5
+         %26 = OpIAdd %8 %25 %13
+               OpStore %5 %26
+               OpBranch %18
+         %19 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor ld{f};
+
+  EXPECT_EQ(ld.NumLoops(), 1u);
+
+  ir::Loop& loop = ld.GetLoopByIndex(0);
+  EXPECT_EQ(loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 18));
+  EXPECT_EQ(loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 20));
+  EXPECT_EQ(loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 19));
+
+  EXPECT_FALSE(loop.HasNestedLoops());
+  EXPECT_FALSE(loop.IsNested());
+  EXPECT_EQ(loop.GetDepth(), 1u);
+}
+
+/*
+Generated from the following GLSL:
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  for(int i = 0; i < 10; ++i) {}
+  for(int i = 0; i < 10; ++i) {}
+}
+
+But it was "hacked" to make the first loop merge block the second loop header.
+*/
+TEST_F(PassClassTest, LoopWithNoPreHeader) {
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main" %3
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+               OpName %4 "i"
+               OpName %5 "i"
+               OpName %3 "c"
+               OpDecorate %3 Location 0
+          %6 = OpTypeVoid
+          %7 = OpTypeFunction %6
+          %8 = OpTypeInt 32 1
+          %9 = OpTypePointer Function %8
+         %10 = OpConstant %8 0
+         %11 = OpConstant %8 10
+         %12 = OpTypeBool
+         %13 = OpConstant %8 1
+         %14 = OpTypeFloat 32
+         %15 = OpTypeVector %14 4
+         %16 = OpTypePointer Output %15
+          %3 = OpVariable %16 Output
+          %2 = OpFunction %6 None %7
+         %17 = OpLabel
+          %4 = OpVariable %9 Function
+          %5 = OpVariable %9 Function
+               OpStore %4 %10
+               OpStore %5 %10
+               OpBranch %18
+         %18 = OpLabel
+               OpLoopMerge %27 %20 None
+               OpBranch %21
+         %21 = OpLabel
+         %22 = OpLoad %8 %4
+         %23 = OpSLessThan %12 %22 %11
+               OpBranchConditional %23 %24 %27
+         %24 = OpLabel
+               OpBranch %20
+         %20 = OpLabel
+         %25 = OpLoad %8 %4
+         %26 = OpIAdd %8 %25 %13
+               OpStore %4 %26
+               OpBranch %18
+         %27 = OpLabel
+               OpLoopMerge %28 %29 None
+               OpBranch %30
+         %30 = OpLabel
+         %31 = OpLoad %8 %5
+         %32 = OpSLessThan %12 %31 %11
+               OpBranchConditional %32 %33 %28
+         %33 = OpLabel
+               OpBranch %29
+         %29 = OpLabel
+         %34 = OpLoad %8 %5
+         %35 = OpIAdd %8 %34 %13
+               OpStore %5 %35
+               OpBranch %27
+         %28 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor ld{f};
+
+  EXPECT_EQ(ld.NumLoops(), 2u);
+
+  ir::Loop* loop = ld[27];
+  EXPECT_EQ(loop->GetPreHeaderBlock(), nullptr);
+}
+
+}  // namespace
diff --git a/test/opt/loop_optimizations/nested_loops.cpp b/test/opt/loop_optimizations/nested_loops.cpp
new file mode 100644 (file)
index 0000000..d635586
--- /dev/null
@@ -0,0 +1,595 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gmock/gmock.h>
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "../assembly_builder.h"
+#include "../function_utils.h"
+#include "../pass_fixture.h"
+#include "../pass_utils.h"
+
+#include "opt/iterator.h"
+#include "opt/loop_descriptor.h"
+#include "opt/pass.h"
+#include "opt/tree_iterator.h"
+
+namespace {
+
+using namespace spvtools;
+using ::testing::UnorderedElementsAre;
+
+using PassClassTest = PassTest<::testing::Test>;
+
+/*
+Generated from the following GLSL
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  int i = 0;
+  for (; i < 10; ++i) {
+    int j = 0;
+    int k = 0;
+    for (; j < 11; ++j) {}
+    for (; k < 12; ++k) {}
+  }
+}
+*/
+TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main" %3
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+               OpName %4 "i"
+               OpName %5 "j"
+               OpName %6 "k"
+               OpName %3 "c"
+               OpDecorate %3 Location 0
+          %7 = OpTypeVoid
+          %8 = OpTypeFunction %7
+          %9 = OpTypeInt 32 1
+         %10 = OpTypePointer Function %9
+         %11 = OpConstant %9 0
+         %12 = OpConstant %9 10
+         %13 = OpTypeBool
+         %14 = OpConstant %9 11
+         %15 = OpConstant %9 1
+         %16 = OpConstant %9 12
+         %17 = OpTypeFloat 32
+         %18 = OpTypeVector %17 4
+         %19 = OpTypePointer Output %18
+          %3 = OpVariable %19 Output
+          %2 = OpFunction %7 None %8
+         %20 = OpLabel
+          %4 = OpVariable %10 Function
+          %5 = OpVariable %10 Function
+          %6 = OpVariable %10 Function
+               OpStore %4 %11
+               OpBranch %21
+         %21 = OpLabel
+               OpLoopMerge %22 %23 None
+               OpBranch %24
+         %24 = OpLabel
+         %25 = OpLoad %9 %4
+         %26 = OpSLessThan %13 %25 %12
+               OpBranchConditional %26 %27 %22
+         %27 = OpLabel
+               OpStore %5 %11
+               OpStore %6 %11
+               OpBranch %28
+         %28 = OpLabel
+               OpLoopMerge %29 %30 None
+               OpBranch %31
+         %31 = OpLabel
+         %32 = OpLoad %9 %5
+         %33 = OpSLessThan %13 %32 %14
+               OpBranchConditional %33 %34 %29
+         %34 = OpLabel
+               OpBranch %30
+         %30 = OpLabel
+         %35 = OpLoad %9 %5
+         %36 = OpIAdd %9 %35 %15
+               OpStore %5 %36
+               OpBranch %28
+         %29 = OpLabel
+               OpBranch %37
+         %37 = OpLabel
+               OpLoopMerge %38 %39 None
+               OpBranch %40
+         %40 = OpLabel
+         %41 = OpLoad %9 %6
+         %42 = OpSLessThan %13 %41 %16
+               OpBranchConditional %42 %43 %38
+         %43 = OpLabel
+               OpBranch %39
+         %39 = OpLabel
+         %44 = OpLoad %9 %6
+         %45 = OpIAdd %9 %44 %15
+               OpStore %6 %45
+               OpBranch %37
+         %38 = OpLabel
+               OpBranch %23
+         %23 = OpLabel
+         %46 = OpLoad %9 %4
+         %47 = OpIAdd %9 %46 %15
+               OpStore %4 %47
+               OpBranch %21
+         %22 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor ld{f};
+
+  EXPECT_EQ(ld.NumLoops(), 3u);
+
+  // Invalid basic block id.
+  EXPECT_EQ(ld[0u], nullptr);
+  // Not a loop header.
+  EXPECT_EQ(ld[20], nullptr);
+
+  ir::Loop& parent_loop = *ld[21];
+  EXPECT_TRUE(parent_loop.HasNestedLoops());
+  EXPECT_FALSE(parent_loop.IsNested());
+  EXPECT_EQ(parent_loop.GetDepth(), 1u);
+  EXPECT_EQ(std::distance(parent_loop.begin(), parent_loop.end()), 2u);
+  EXPECT_EQ(parent_loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 21));
+  EXPECT_EQ(parent_loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 23));
+  EXPECT_EQ(parent_loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 22));
+
+  ir::Loop& child_loop_1 = *ld[28];
+  EXPECT_FALSE(child_loop_1.HasNestedLoops());
+  EXPECT_TRUE(child_loop_1.IsNested());
+  EXPECT_EQ(child_loop_1.GetDepth(), 2u);
+  EXPECT_EQ(std::distance(child_loop_1.begin(), child_loop_1.end()), 0u);
+  EXPECT_EQ(child_loop_1.GetHeaderBlock(), spvtest::GetBasicBlock(f, 28));
+  EXPECT_EQ(child_loop_1.GetLatchBlock(), spvtest::GetBasicBlock(f, 30));
+  EXPECT_EQ(child_loop_1.GetMergeBlock(), spvtest::GetBasicBlock(f, 29));
+
+  ir::Loop& child_loop_2 = *ld[37];
+  EXPECT_FALSE(child_loop_2.HasNestedLoops());
+  EXPECT_TRUE(child_loop_2.IsNested());
+  EXPECT_EQ(child_loop_2.GetDepth(), 2u);
+  EXPECT_EQ(std::distance(child_loop_2.begin(), child_loop_2.end()), 0u);
+  EXPECT_EQ(child_loop_2.GetHeaderBlock(), spvtest::GetBasicBlock(f, 37));
+  EXPECT_EQ(child_loop_2.GetLatchBlock(), spvtest::GetBasicBlock(f, 39));
+  EXPECT_EQ(child_loop_2.GetMergeBlock(), spvtest::GetBasicBlock(f, 38));
+}
+
+static void CheckLoopBlocks(ir::Loop* loop,
+                            std::unordered_set<uint32_t>* expected_ids) {
+  SCOPED_TRACE("Check loop " + std::to_string(loop->GetHeaderBlock()->id()));
+  for (uint32_t bb_id : loop->GetBlocks()) {
+    EXPECT_EQ(expected_ids->count(bb_id), 1u);
+    expected_ids->erase(bb_id);
+  }
+  EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
+  EXPECT_EQ(expected_ids->size(), 0u);
+}
+
+/*
+Generated from the following GLSL
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  int i = 0;
+  for (; i < 10; ++i) {
+    for (int j = 0; j < 11; ++j) {
+      if (j < 5) {
+        for (int k = 0; k < 12; ++k) {}
+      }
+      else {}
+      for (int k = 0; k < 12; ++k) {}
+    }
+  }
+}*/
+TEST_F(PassClassTest, TripleNestedLoop) {
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main" %3
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+               OpName %4 "i"
+               OpName %5 "j"
+               OpName %6 "k"
+               OpName %7 "k"
+               OpName %3 "c"
+               OpDecorate %3 Location 0
+          %8 = OpTypeVoid
+          %9 = OpTypeFunction %8
+         %10 = OpTypeInt 32 1
+         %11 = OpTypePointer Function %10
+         %12 = OpConstant %10 0
+         %13 = OpConstant %10 10
+         %14 = OpTypeBool
+         %15 = OpConstant %10 11
+         %16 = OpConstant %10 5
+         %17 = OpConstant %10 12
+         %18 = OpConstant %10 1
+         %19 = OpTypeFloat 32
+         %20 = OpTypeVector %19 4
+         %21 = OpTypePointer Output %20
+          %3 = OpVariable %21 Output
+          %2 = OpFunction %8 None %9
+         %22 = OpLabel
+          %4 = OpVariable %11 Function
+          %5 = OpVariable %11 Function
+          %6 = OpVariable %11 Function
+          %7 = OpVariable %11 Function
+               OpStore %4 %12
+               OpBranch %23
+         %23 = OpLabel
+               OpLoopMerge %24 %25 None
+               OpBranch %26
+         %26 = OpLabel
+         %27 = OpLoad %10 %4
+         %28 = OpSLessThan %14 %27 %13
+               OpBranchConditional %28 %29 %24
+         %29 = OpLabel
+               OpStore %5 %12
+               OpBranch %30
+         %30 = OpLabel
+               OpLoopMerge %31 %32 None
+               OpBranch %33
+         %33 = OpLabel
+         %34 = OpLoad %10 %5
+         %35 = OpSLessThan %14 %34 %15
+               OpBranchConditional %35 %36 %31
+         %36 = OpLabel
+         %37 = OpLoad %10 %5
+         %38 = OpSLessThan %14 %37 %16
+               OpSelectionMerge %39 None
+               OpBranchConditional %38 %40 %39
+         %40 = OpLabel
+               OpStore %6 %12
+               OpBranch %41
+         %41 = OpLabel
+               OpLoopMerge %42 %43 None
+               OpBranch %44
+         %44 = OpLabel
+         %45 = OpLoad %10 %6
+         %46 = OpSLessThan %14 %45 %17
+               OpBranchConditional %46 %47 %42
+         %47 = OpLabel
+               OpBranch %43
+         %43 = OpLabel
+         %48 = OpLoad %10 %6
+         %49 = OpIAdd %10 %48 %18
+               OpStore %6 %49
+               OpBranch %41
+         %42 = OpLabel
+               OpBranch %39
+         %39 = OpLabel
+               OpStore %7 %12
+               OpBranch %50
+         %50 = OpLabel
+               OpLoopMerge %51 %52 None
+               OpBranch %53
+         %53 = OpLabel
+         %54 = OpLoad %10 %7
+         %55 = OpSLessThan %14 %54 %17
+               OpBranchConditional %55 %56 %51
+         %56 = OpLabel
+               OpBranch %52
+         %52 = OpLabel
+         %57 = OpLoad %10 %7
+         %58 = OpIAdd %10 %57 %18
+               OpStore %7 %58
+               OpBranch %50
+         %51 = OpLabel
+               OpBranch %32
+         %32 = OpLabel
+         %59 = OpLoad %10 %5
+         %60 = OpIAdd %10 %59 %18
+               OpStore %5 %60
+               OpBranch %30
+         %31 = OpLabel
+               OpBranch %25
+         %25 = OpLabel
+         %61 = OpLoad %10 %4
+         %62 = OpIAdd %10 %61 %18
+               OpStore %4 %62
+               OpBranch %23
+         %24 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor ld{f};
+
+  EXPECT_EQ(ld.NumLoops(), 4u);
+
+  // Invalid basic block id.
+  EXPECT_EQ(ld[0u], nullptr);
+  // Not in a loop.
+  EXPECT_EQ(ld[22], nullptr);
+
+  // Check that we can map basic block to the correct loop.
+  // The following block ids do not belong to a loop.
+  for (uint32_t bb_id : {22, 24}) EXPECT_EQ(ld[bb_id], nullptr);
+
+  {
+    std::unordered_set<uint32_t> basic_block_in_loop = {
+        {23, 26, 29, 30, 33, 36, 40, 41, 44, 47, 43,
+         42, 39, 50, 53, 56, 52, 51, 32, 31, 25}};
+    ir::Loop* loop = ld[23];
+    CheckLoopBlocks(loop, &basic_block_in_loop);
+
+    EXPECT_TRUE(loop->HasNestedLoops());
+    EXPECT_FALSE(loop->IsNested());
+    EXPECT_EQ(loop->GetDepth(), 1u);
+    EXPECT_EQ(std::distance(loop->begin(), loop->end()), 1u);
+    EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 22));
+    EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 23));
+    EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 25));
+    EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 24));
+    EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
+    EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
+  }
+
+  {
+    std::unordered_set<uint32_t> basic_block_in_loop = {
+        {30, 33, 36, 40, 41, 44, 47, 43, 42, 39, 50, 53, 56, 52, 51, 32}};
+    ir::Loop* loop = ld[30];
+    CheckLoopBlocks(loop, &basic_block_in_loop);
+
+    EXPECT_TRUE(loop->HasNestedLoops());
+    EXPECT_TRUE(loop->IsNested());
+    EXPECT_EQ(loop->GetDepth(), 2u);
+    EXPECT_EQ(std::distance(loop->begin(), loop->end()), 2u);
+    EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 29));
+    EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 30));
+    EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 32));
+    EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 31));
+    EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
+    EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
+  }
+
+  {
+    std::unordered_set<uint32_t> basic_block_in_loop = {{41, 44, 47, 43}};
+    ir::Loop* loop = ld[41];
+    CheckLoopBlocks(loop, &basic_block_in_loop);
+
+    EXPECT_FALSE(loop->HasNestedLoops());
+    EXPECT_TRUE(loop->IsNested());
+    EXPECT_EQ(loop->GetDepth(), 3u);
+    EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
+    EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 40));
+    EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 41));
+    EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 43));
+    EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 42));
+    EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
+    EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
+  }
+
+  {
+    std::unordered_set<uint32_t> basic_block_in_loop = {{50, 53, 56, 52}};
+    ir::Loop* loop = ld[50];
+    CheckLoopBlocks(loop, &basic_block_in_loop);
+
+    EXPECT_FALSE(loop->HasNestedLoops());
+    EXPECT_TRUE(loop->IsNested());
+    EXPECT_EQ(loop->GetDepth(), 3u);
+    EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
+    EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 39));
+    EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 50));
+    EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 52));
+    EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 51));
+    EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
+    EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
+  }
+
+  // Make sure LoopDescriptor gives us the inner most loop when we query for
+  // loops.
+  for (const ir::BasicBlock& bb : *f) {
+    if (ir::Loop* loop = ld[&bb]) {
+      for (ir::Loop& sub_loop :
+           ir::make_range(++opt::TreeDFIterator<ir::Loop>(loop),
+                          opt::TreeDFIterator<ir::Loop>())) {
+        EXPECT_FALSE(sub_loop.IsInsideLoop(bb.id()));
+      }
+    }
+  }
+}
+
+/*
+Generated from the following GLSL
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  for (int i = 0; i < 10; ++i) {
+    for (int j = 0; j < 11; ++j) {
+      for (int k = 0; k < 11; ++k) {}
+    }
+    for (int k = 0; k < 12; ++k) {}
+  }
+}
+*/
+TEST_F(PassClassTest, LoopParentTest) {
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main" %3
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+               OpName %4 "i"
+               OpName %5 "j"
+               OpName %6 "k"
+               OpName %7 "k"
+               OpName %3 "c"
+               OpDecorate %3 Location 0
+          %8 = OpTypeVoid
+          %9 = OpTypeFunction %8
+         %10 = OpTypeInt 32 1
+         %11 = OpTypePointer Function %10
+         %12 = OpConstant %10 0
+         %13 = OpConstant %10 10
+         %14 = OpTypeBool
+         %15 = OpConstant %10 11
+         %16 = OpConstant %10 1
+         %17 = OpConstant %10 12
+         %18 = OpTypeFloat 32
+         %19 = OpTypeVector %18 4
+         %20 = OpTypePointer Output %19
+          %3 = OpVariable %20 Output
+          %2 = OpFunction %8 None %9
+         %21 = OpLabel
+          %4 = OpVariable %11 Function
+          %5 = OpVariable %11 Function
+          %6 = OpVariable %11 Function
+          %7 = OpVariable %11 Function
+               OpStore %4 %12
+               OpBranch %22
+         %22 = OpLabel
+               OpLoopMerge %23 %24 None
+               OpBranch %25
+         %25 = OpLabel
+         %26 = OpLoad %10 %4
+         %27 = OpSLessThan %14 %26 %13
+               OpBranchConditional %27 %28 %23
+         %28 = OpLabel
+               OpStore %5 %12
+               OpBranch %29
+         %29 = OpLabel
+               OpLoopMerge %30 %31 None
+               OpBranch %32
+         %32 = OpLabel
+         %33 = OpLoad %10 %5
+         %34 = OpSLessThan %14 %33 %15
+               OpBranchConditional %34 %35 %30
+         %35 = OpLabel
+               OpStore %6 %12
+               OpBranch %36
+         %36 = OpLabel
+               OpLoopMerge %37 %38 None
+               OpBranch %39
+         %39 = OpLabel
+         %40 = OpLoad %10 %6
+         %41 = OpSLessThan %14 %40 %15
+               OpBranchConditional %41 %42 %37
+         %42 = OpLabel
+               OpBranch %38
+         %38 = OpLabel
+         %43 = OpLoad %10 %6
+         %44 = OpIAdd %10 %43 %16
+               OpStore %6 %44
+               OpBranch %36
+         %37 = OpLabel
+               OpBranch %31
+         %31 = OpLabel
+         %45 = OpLoad %10 %5
+         %46 = OpIAdd %10 %45 %16
+               OpStore %5 %46
+               OpBranch %29
+         %30 = OpLabel
+               OpStore %7 %12
+               OpBranch %47
+         %47 = OpLabel
+               OpLoopMerge %48 %49 None
+               OpBranch %50
+         %50 = OpLabel
+         %51 = OpLoad %10 %7
+         %52 = OpSLessThan %14 %51 %17
+               OpBranchConditional %52 %53 %48
+         %53 = OpLabel
+               OpBranch %49
+         %49 = OpLabel
+         %54 = OpLoad %10 %7
+         %55 = OpIAdd %10 %54 %16
+               OpStore %7 %55
+               OpBranch %47
+         %48 = OpLabel
+               OpBranch %24
+         %24 = OpLabel
+         %56 = OpLoad %10 %4
+         %57 = OpIAdd %10 %56 %16
+               OpStore %4 %57
+               OpBranch %22
+         %23 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor ld{f};
+
+  EXPECT_EQ(ld.NumLoops(), 4u);
+
+  {
+    ir::Loop& loop = *ld[22];
+    EXPECT_TRUE(loop.HasNestedLoops());
+    EXPECT_FALSE(loop.IsNested());
+    EXPECT_EQ(loop.GetDepth(), 1u);
+    EXPECT_EQ(loop.GetParent(), nullptr);
+  }
+
+  {
+    ir::Loop& loop = *ld[29];
+    EXPECT_TRUE(loop.HasNestedLoops());
+    EXPECT_TRUE(loop.IsNested());
+    EXPECT_EQ(loop.GetDepth(), 2u);
+    EXPECT_EQ(loop.GetParent(), ld[22]);
+  }
+
+  {
+    ir::Loop& loop = *ld[36];
+    EXPECT_FALSE(loop.HasNestedLoops());
+    EXPECT_TRUE(loop.IsNested());
+    EXPECT_EQ(loop.GetDepth(), 3u);
+    EXPECT_EQ(loop.GetParent(), ld[29]);
+  }
+
+  {
+    ir::Loop& loop = *ld[47];
+    EXPECT_FALSE(loop.HasNestedLoops());
+    EXPECT_TRUE(loop.IsNested());
+    EXPECT_EQ(loop.GetDepth(), 2u);
+    EXPECT_EQ(loop.GetParent(), ld[22]);
+  }
+}
+
+}  // namespace