From e8ad02f3dde41d227319944f52cc850bcc077fb7 Mon Sep 17 00:00:00 2001 From: Victor Lomuller Date: Thu, 21 Dec 2017 14:47:25 +0000 Subject: [PATCH] Add loop descriptors and some required dominator tree extensions. Add post-order tree iterator. Add DominatorTreeNode extensions: - Add begin/end methods to do pre-order and post-order tree traversal from a given DominatorTreeNode Add DominatorTree extensions: - Add begin/end methods to do pre-order and post-order tree traversal - Tree traversal ignore by default the pseudo entry block - Retrieve a DominatorTreeNode from a basic block Add loop descriptor: - Add a LoopDescriptor class to register all loops in a given function. - Add a Loop class to describe a loop: - Loop parent - Nested loops - Loop depth - Loop header, merge, continue and preheader - Basic blocks that belong to the loop Correct a bug that forced dominator tree to be constantly rebuilt. --- Android.mk | 1 + source/opt/CMakeLists.txt | 2 + source/opt/dominator_tree.cpp | 27 +- source/opt/dominator_tree.h | 85 +++- source/opt/ir_context.cpp | 19 +- source/opt/ir_context.h | 9 + source/opt/loop_descriptor.cpp | 157 ++++++ source/opt/loop_descriptor.h | 262 ++++++++++ source/opt/tree_iterator.h | 125 +++++ test/opt/CMakeLists.txt | 1 + test/opt/dominator_tree/generated.cpp | 63 ++- test/opt/loop_optimizations/CMakeLists.txt | 27 + test/opt/loop_optimizations/loop_descriptions.cpp | 205 ++++++++ test/opt/loop_optimizations/nested_loops.cpp | 595 ++++++++++++++++++++++ 14 files changed, 1553 insertions(+), 25 deletions(-) create mode 100644 source/opt/loop_descriptor.cpp create mode 100644 source/opt/loop_descriptor.h create mode 100644 test/opt/loop_optimizations/CMakeLists.txt create mode 100644 test/opt/loop_optimizations/loop_descriptions.cpp create mode 100644 test/opt/loop_optimizations/nested_loops.cpp diff --git a/Android.mk b/Android.mk index 48fb087..2aeae42 100644 --- a/Android.mk +++ b/Android.mk @@ -90,6 +90,7 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/local_single_block_elim_pass.cpp \ source/opt/local_single_store_elim_pass.cpp \ source/opt/local_ssa_elim_pass.cpp \ + source/opt/loop_descriptor.cpp \ source/opt/mem_pass.cpp \ source/opt/merge_return_pass.cpp \ source/opt/module.cpp \ diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index 0b49842..680edaf 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -49,6 +49,7 @@ add_library(SPIRV-Tools-opt local_single_store_elim_pass.h local_ssa_elim_pass.h log.h + loop_descriptor.h mem_pass.h merge_return_pass.h module.h @@ -107,6 +108,7 @@ add_library(SPIRV-Tools-opt local_single_block_elim_pass.cpp local_single_store_elim_pass.cpp local_ssa_elim_pass.cpp + loop_descriptor.cpp mem_pass.cpp merge_return_pass.cpp module.cpp diff --git a/source/opt/dominator_tree.cpp b/source/opt/dominator_tree.cpp index 9dfc559..552a80d 100644 --- a/source/opt/dominator_tree.cpp +++ b/source/opt/dominator_tree.cpp @@ -227,23 +227,28 @@ bool DominatorTree::StrictlyDominates(const ir::BasicBlock* a, return DominatorTree::StrictlyDominates(a->id(), b->id()); } +bool DominatorTree::StrictlyDominates(const DominatorTreeNode* a, + const DominatorTreeNode* b) const { + if (a == b) return false; + return Dominates(a, b); +} + bool DominatorTree::Dominates(uint32_t a, uint32_t b) const { // Check that both of the inputs are actual nodes. - auto a_itr = nodes_.find(a); - auto b_itr = nodes_.find(b); - if (a_itr == nodes_.end() || b_itr == nodes_.end()) return false; + const DominatorTreeNode* a_node = GetTreeNode(a); + const DominatorTreeNode* b_node = GetTreeNode(b); + if (!a_node || !b_node) return false; + return Dominates(a_node, b_node); +} + +bool DominatorTree::Dominates(const DominatorTreeNode* a, + const DominatorTreeNode* b) const { // Node A dominates node B if they are the same. if (a == b) return true; - const DominatorTreeNode* nodeA = &a_itr->second; - const DominatorTreeNode* nodeB = &b_itr->second; - - if (nodeA->dfs_num_pre_ < nodeB->dfs_num_pre_ && - nodeA->dfs_num_post_ > nodeB->dfs_num_post_) { - return true; - } - return false; + return a->dfs_num_pre_ < b->dfs_num_pre_ && + a->dfs_num_post_ > b->dfs_num_post_; } bool DominatorTree::Dominates(const ir::BasicBlock* A, diff --git a/source/opt/dominator_tree.h b/source/opt/dominator_tree.h index d3cdcdf..0be4951 100644 --- a/source/opt/dominator_tree.h +++ b/source/opt/dominator_tree.h @@ -40,6 +40,13 @@ struct DominatorTreeNode { using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; + // depth first preorder iterator. + using df_iterator = TreeDFIterator; + using const_df_iterator = TreeDFIterator; + // depth first postorder iterator. + using post_iterator = PostOrderTreeDFIterator; + using const_post_iterator = PostOrderTreeDFIterator; + iterator begin() { return children_.begin(); } iterator end() { return children_.end(); } const_iterator begin() const { return cbegin(); } @@ -47,6 +54,26 @@ struct DominatorTreeNode { const_iterator cbegin() const { return children_.begin(); } const_iterator cend() const { return children_.end(); } + // Depth first preorder iterator using this node as root. + df_iterator df_begin() { return df_iterator(this); } + df_iterator df_end() { return df_iterator(); } + const_df_iterator df_begin() const { return df_cbegin(); } + const_df_iterator df_end() const { return df_cend(); } + const_df_iterator df_cbegin() const { return const_df_iterator(this); } + const_df_iterator df_cend() const { return const_df_iterator(); } + + // Depth first postorder iterator using this node as root. + post_iterator post_begin() { return post_iterator::begin(this); } + post_iterator post_end() { return post_iterator::end(nullptr); } + const_post_iterator post_begin() const { return post_cbegin(); } + const_post_iterator post_end() const { return post_cend(); } + const_post_iterator post_cbegin() const { + return const_post_iterator::begin(this); + } + const_post_iterator post_cend() const { + return const_post_iterator::end(nullptr); + } + inline uint32_t id() const { return bb_->id(); } ir::BasicBlock* bb_; @@ -69,6 +96,8 @@ class DominatorTree { using DominatorTreeNodeMap = std::map; using iterator = TreeDFIterator; using const_iterator = TreeDFIterator; + using post_iterator = PostOrderTreeDFIterator; + using const_post_iterator = PostOrderTreeDFIterator; // List of DominatorTreeNode to define the list of roots using DominatorTreeNodeList = std::vector; @@ -80,13 +109,27 @@ class DominatorTree { // Depth first iterators. // Traverse the dominator tree in a depth first pre-order. - iterator begin() { return iterator(GetRoot()); } + // The pseudo-block is ignored. + iterator begin() { return ++iterator(GetRoot()); } iterator end() { return iterator(); } const_iterator begin() const { return cbegin(); } const_iterator end() const { return cend(); } - const_iterator cbegin() const { return const_iterator(GetRoot()); } + const_iterator cbegin() const { return ++const_iterator(GetRoot()); } const_iterator cend() const { return const_iterator(); } + // Traverse the dominator tree in a depth first post-order. + // The pseudo-block is ignored. + post_iterator post_begin() { return post_iterator::begin(GetRoot()); } + post_iterator post_end() { return post_iterator::end(GetRoot()); } + const_post_iterator post_begin() const { return post_cbegin(); } + const_post_iterator post_end() const { return post_cend(); } + const_post_iterator post_cbegin() const { + return const_post_iterator::begin(GetRoot()); + } + const_post_iterator post_cend() const { + return const_post_iterator::end(GetRoot()); + } + roots_iterator roots_begin() { return roots_.begin(); } roots_iterator roots_end() { return roots_.end(); } roots_const_iterator roots_begin() const { return roots_cbegin(); } @@ -122,6 +165,9 @@ class DominatorTree { // Check if the basic block id |a| dominates the basic block id |b|. bool Dominates(uint32_t a, uint32_t b) const; + // Check if the dominator tree node |a| dominates the dominator tree node |b|. + bool Dominates(const DominatorTreeNode* a, const DominatorTreeNode* b) const; + // Check if the basic block |a| strictly dominates the basic block |b|. bool StrictlyDominates(const ir::BasicBlock* a, const ir::BasicBlock* b) const; @@ -129,6 +175,11 @@ class DominatorTree { // Check if the basic block id |a| strictly dominates the basic block id |b|. bool StrictlyDominates(uint32_t a, uint32_t b) const; + // Check if the dominator tree node |a| strictly dominates the dominator tree + // node |b|. + bool StrictlyDominates(const DominatorTreeNode* a, + const DominatorTreeNode* b) const; + // Returns the immediate dominator of basic block |a|. ir::BasicBlock* ImmediateDominator(const ir::BasicBlock* A) const; @@ -173,6 +224,36 @@ class DominatorTree { return true; } + // Returns the DominatorTreeNode associated with the basic block |bb|. + // If the |bb| is unknown to the dominator tree, it returns null. + inline DominatorTreeNode* GetTreeNode(ir::BasicBlock* bb) { + return GetTreeNode(bb->id()); + } + // Returns the DominatorTreeNode associated with the basic block |bb|. + // If the |bb| is unknown to the dominator tree, it returns null. + inline const DominatorTreeNode* GetTreeNode(ir::BasicBlock* bb) const { + return GetTreeNode(bb->id()); + } + + // Returns the DominatorTreeNode associated with the basic block id |id|. + // If the id |id| is unknown to the dominator tree, it returns null. + inline DominatorTreeNode* GetTreeNode(uint32_t id) { + DominatorTreeNodeMap::iterator node_iter = nodes_.find(id); + if (node_iter == nodes_.end()) { + return nullptr; + } + return &node_iter->second; + } + // Returns the DominatorTreeNode associated with the basic block id |id|. + // If the id |id| is unknown to the dominator tree, it returns null. + inline const DominatorTreeNode* GetTreeNode(uint32_t id) const { + DominatorTreeNodeMap::const_iterator node_iter = nodes_.find(id); + if (node_iter == nodes_.end()) { + return nullptr; + } + return &node_iter->second; + } + private: // Adds the basic block |bb| to the tree structure if it doesn't already // exist. diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp index 38038d7..9ab969a 100644 --- a/source/opt/ir_context.cpp +++ b/source/opt/ir_context.cpp @@ -38,10 +38,7 @@ void IRContext::BuildInvalidAnalyses(IRContext::Analysis set) { BuildCFG(); } if (set & kAnalysisDominatorAnalysis) { - // An invalid dominator tree analysis will be empty so rebuilding it just - // means marking it as valid. Each tree will be initalisalised when - // requested on a per function basis. - valid_analyses_ |= kAnalysisDominatorAnalysis; + ResetDominatorAnalysis(); } } @@ -478,8 +475,11 @@ void IRContext::InitializeCombinators() { // Gets the dominator analysis for function |f|. opt::DominatorAnalysis* IRContext::GetDominatorAnalysis(const ir::Function* f, const ir::CFG& in_cfg) { - if (dominator_trees_.find(f) == dominator_trees_.end() || - !AreAnalysesValid(kAnalysisDominatorAnalysis)) { + if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) { + ResetDominatorAnalysis(); + } + + if (dominator_trees_.find(f) == dominator_trees_.end()) { dominator_trees_[f].InitializeTree(f, in_cfg); } @@ -489,8 +489,11 @@ opt::DominatorAnalysis* IRContext::GetDominatorAnalysis(const ir::Function* f, // Gets the postdominator analysis for function |f|. opt::PostDominatorAnalysis* IRContext::GetPostDominatorAnalysis( const ir::Function* f, const ir::CFG& in_cfg) { - if (post_dominator_trees_.find(f) == post_dominator_trees_.end() || - !AreAnalysesValid(kAnalysisDominatorAnalysis)) { + if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) { + ResetDominatorAnalysis(); + } + + if (post_dominator_trees_.find(f) == post_dominator_trees_.end()) { post_dominator_trees_[f].InitializeTree(f, in_cfg); } diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h index 67f5c0e..9a56053 100644 --- a/source/opt/ir_context.h +++ b/source/opt/ir_context.h @@ -423,6 +423,15 @@ class IRContext { valid_analyses_ = valid_analyses_ | kAnalysisCFG; } + // Removes all computed dominator and post-dominator trees. This will force + // the context to rebuild the trees on demand. + void ResetDominatorAnalysis() { + // Clear the cache. + dominator_trees_.clear(); + post_dominator_trees_.clear(); + valid_analyses_ = valid_analyses_ | kAnalysisDominatorAnalysis; + } + // Analyzes the features in the owned module. Builds the manager if required. void AnalyzeFeatures() { feature_mgr_.reset(new opt::FeatureManager(grammar_)); diff --git a/source/opt/loop_descriptor.cpp b/source/opt/loop_descriptor.cpp new file mode 100644 index 0000000..e1bb0c7 --- /dev/null +++ b/source/opt/loop_descriptor.cpp @@ -0,0 +1,157 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "opt/loop_descriptor.h" +#include +#include +#include +#include + +#include "opt/iterator.h" +#include "opt/loop_descriptor.h" +#include "opt/make_unique.h" +#include "opt/tree_iterator.h" + +namespace spvtools { +namespace ir { + +Loop::Loop(IRContext* context, opt::DominatorAnalysis* dom_analysis, + BasicBlock* header, BasicBlock* continue_target, + BasicBlock* merge_target) + : loop_header_(header), + loop_continue_(continue_target), + loop_merge_(merge_target), + loop_preheader_(nullptr), + parent_(nullptr) { + assert(context); + assert(dom_analysis); + loop_preheader_ = FindLoopPreheader(context, dom_analysis); + AddBasicBlockToLoop(header); + AddBasicBlockToLoop(continue_target); +} + +BasicBlock* Loop::FindLoopPreheader(IRContext* ir_context, + opt::DominatorAnalysis* dom_analysis) { + CFG* cfg = ir_context->cfg(); + opt::DominatorTree& dom_tree = dom_analysis->GetDomTree(); + opt::DominatorTreeNode* header_node = dom_tree.GetTreeNode(loop_header_); + + // The loop predecessor. + BasicBlock* loop_pred = nullptr; + + auto header_pred = cfg->preds(loop_header_->id()); + for (uint32_t p_id : header_pred) { + opt::DominatorTreeNode* node = dom_tree.GetTreeNode(p_id); + if (node && !dom_tree.Dominates(header_node, node)) { + // The predecessor is not part of the loop, so potential loop preheader. + if (loop_pred && node->bb_ != loop_pred) { + // If we saw 2 distinct predecessors that are outside the loop, we don't + // have a loop preheader. + return nullptr; + } + loop_pred = node->bb_; + } + } + // Safe guard against invalid code, SPIR-V spec forbids loop with the entry + // node as header. + assert(loop_pred && "The header node is the entry block ?"); + + // So we have a unique basic block that can enter this loop. + // If this loop is the unique successor of this block, then it is a loop + // preheader. + bool is_preheader = true; + uint32_t loop_header_id = loop_header_->id(); + loop_pred->ForEachSuccessorLabel( + [&is_preheader, loop_header_id](const uint32_t id) { + if (id != loop_header_id) is_preheader = false; + }); + if (is_preheader) return loop_pred; + return nullptr; +} + +LoopDescriptor::LoopDescriptor(const Function* f) { PopulateList(f); } + +void LoopDescriptor::PopulateList(const Function* f) { + IRContext* context = f->GetParent()->context(); + + opt::DominatorAnalysis* dom_analysis = + context->GetDominatorAnalysis(f, *context->cfg()); + + loops_.clear(); + + // Post-order traversal of the dominator tree to find all the OpLoopMerge + // instructions. + opt::DominatorTree& dom_tree = dom_analysis->GetDomTree(); + for (opt::DominatorTreeNode& node : + ir::make_range(dom_tree.post_begin(), dom_tree.post_end())) { + Instruction* merge_inst = node.bb_->GetLoopMergeInst(); + if (merge_inst) { + // The id of the merge basic block of this loop. + uint32_t merge_bb_id = merge_inst->GetSingleWordOperand(0); + + // The id of the continue basic block of this loop. + uint32_t continue_bb_id = merge_inst->GetSingleWordOperand(1); + + // The merge target of this loop. + BasicBlock* merge_bb = context->cfg()->block(merge_bb_id); + + // The continue target of this loop. + BasicBlock* continue_bb = context->cfg()->block(continue_bb_id); + + // The basic block containing the merge instruction. + BasicBlock* header_bb = context->get_instr_block(merge_inst); + + // Add the loop to the list of all the loops in the function. + loops_.emplace_back(MakeUnique(context, dom_analysis, header_bb, + continue_bb, merge_bb)); + Loop* current_loop = loops_.back().get(); + + // We have a bottom-up construction, so if this loop has nested-loops, + // they are by construction at the tail of the loop list. + for (auto itr = loops_.rbegin() + 1; itr != loops_.rend(); ++itr) { + Loop* previous_loop = itr->get(); + + // If the loop already has a parent, then it has been processed. + if (previous_loop->HasParent()) continue; + + // If the current loop does not dominates the previous loop then it is + // not nested loop. + if (!dom_analysis->Dominates(header_bb, + previous_loop->GetHeaderBlock())) + continue; + // If the current loop merge dominates the previous loop then it is + // not nested loop. + if (dom_analysis->Dominates(merge_bb, previous_loop->GetHeaderBlock())) + continue; + + current_loop->AddNestedLoop(previous_loop); + } + opt::DominatorTreeNode* dom_merge_node = dom_tree.GetTreeNode(merge_bb); + for (opt::DominatorTreeNode& loop_node : + make_range(node.df_begin(), node.df_end())) { + // Check if we are in the loop. + if (dom_tree.Dominates(dom_merge_node, &loop_node)) continue; + current_loop->AddBasicBlockToLoop(loop_node.bb_); + basic_block_to_loop_.insert( + std::make_pair(loop_node.bb_->id(), current_loop)); + } + } + } + for (std::unique_ptr& loop : loops_) { + if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop.get()); + } +} + +} // namespace ir +} // namespace spvtools diff --git a/source/opt/loop_descriptor.h b/source/opt/loop_descriptor.h new file mode 100644 index 0000000..b00c18e --- /dev/null +++ b/source/opt/loop_descriptor.h @@ -0,0 +1,262 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIBSPIRV_OPT_LOOP_DESCRIPTORS_H_ +#define LIBSPIRV_OPT_LOOP_DESCRIPTORS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "opt/module.h" +#include "opt/pass.h" +#include "opt/tree_iterator.h" + +namespace spvtools { +namespace ir { +class CFG; +class LoopDescriptor; + +// A class to represent and manipulate a loop in structured control flow. +class Loop { + // The type used to represent nested child loops. + using ChildrenList = std::vector; + + public: + using iterator = ChildrenList::iterator; + using const_iterator = ChildrenList::const_iterator; + using BasicBlockListTy = std::unordered_set; + + Loop() + : loop_header_(nullptr), + loop_continue_(nullptr), + loop_merge_(nullptr), + loop_preheader_(nullptr), + parent_(nullptr) {} + + Loop(IRContext* context, opt::DominatorAnalysis* analysis, BasicBlock* header, + BasicBlock* continue_target, BasicBlock* merge_target); + + // Iterators over the immediate sub-loops. + inline iterator begin() { return nested_loops_.begin(); } + inline iterator end() { return nested_loops_.end(); } + inline const_iterator begin() const { return cbegin(); } + inline const_iterator end() const { return cend(); } + inline const_iterator cbegin() const { return nested_loops_.begin(); } + inline const_iterator cend() const { return nested_loops_.end(); } + + // Returns the header (first basic block of the loop). This block contains the + // OpLoopMerge instruction. + inline BasicBlock* GetHeaderBlock() { return loop_header_; } + inline const BasicBlock* GetHeaderBlock() const { return loop_header_; } + + // Returns the latch basic block (basic block that holds the back-edge). + inline BasicBlock* GetLatchBlock() { return loop_continue_; } + inline const BasicBlock* GetLatchBlock() const { return loop_continue_; } + + // Returns the BasicBlock which marks the end of the loop. + inline BasicBlock* GetMergeBlock() { return loop_merge_; } + inline const BasicBlock* GetMergeBlock() const { return loop_merge_; } + + // Returns the loop pre-header, nullptr means that the loop predecessor does + // not qualify as a preheader. + // The preheader is the unique predecessor that: + // - Dominates the loop header; + // - Has only the loop header as successor. + inline BasicBlock* GetPreHeaderBlock() { return loop_preheader_; } + + // Returns the loop pre-header. + inline const BasicBlock* GetPreHeaderBlock() const { return loop_preheader_; } + + // Returns true if this loop contains any nested loops. + inline bool HasNestedLoops() const { return nested_loops_.size() != 0; } + + // Returns the depth of this loop in the loop nest. + // The outer-most loop has a depth of 1. + inline size_t GetDepth() const { + size_t lvl = 1; + for (const Loop* loop = GetParent(); loop; loop = loop->GetParent()) lvl++; + return lvl; + } + + // Adds |nested| as a nested loop of this loop. Automatically register |this| + // as the parent of |nested|. + inline void AddNestedLoop(Loop* nested) { + assert(!nested->GetParent() && "The loop has another parent."); + nested_loops_.push_back(nested); + nested->SetParent(this); + } + + inline Loop* GetParent() { return parent_; } + inline const Loop* GetParent() const { return parent_; } + + inline bool HasParent() const { return parent_; } + + // Returns true if this loop is itself nested within another loop. + inline bool IsNested() const { return parent_ != nullptr; } + + // Returns the set of all basic blocks contained within the loop. Will be all + // BasicBlocks dominated by the header which are not also dominated by the + // loop merge block. + inline const BasicBlockListTy& GetBlocks() const { + return loop_basic_blocks_; + } + + // Returns true if the basic block |bb| is inside this loop. + inline bool IsInsideLoop(const BasicBlock* bb) const { + return IsInsideLoop(bb->id()); + } + + // Returns true if the basic block id |bb_id| is inside this loop. + inline bool IsInsideLoop(uint32_t bb_id) const { + return loop_basic_blocks_.count(bb_id); + } + + // Returns true if the instruction |inst| is inside this loop. + inline bool IsInsideLoop(Instruction* inst) const { + const BasicBlock* parent_block = inst->context()->get_instr_block(inst); + if (!parent_block) return true; + return IsInsideLoop(parent_block); + } + + // Adds the Basic Block |bb| this loop and its parents. + void AddBasicBlockToLoop(const BasicBlock* bb) { +#ifndef NDEBUG + assert(bb->GetParent() && "The basic block does not belong to a function"); + IRContext* context = bb->GetParent()->GetParent()->context(); + + opt::DominatorAnalysis* dom_analysis = + context->GetDominatorAnalysis(bb->GetParent(), *context->cfg()); + assert(dom_analysis->Dominates(GetHeaderBlock(), bb)); + + opt::PostDominatorAnalysis* postdom_analysis = + context->GetPostDominatorAnalysis(bb->GetParent(), *context->cfg()); + assert(postdom_analysis->Dominates(GetMergeBlock(), bb)); +#endif // NDEBUG + + for (Loop* loop = this; loop != nullptr; loop = loop->parent_) { + loop_basic_blocks_.insert(bb->id()); + } + } + + private: + // The block which marks the start of the loop. + BasicBlock* loop_header_; + + // The block which begins the body of the loop. + BasicBlock* loop_continue_; + + // The block which marks the end of the loop. + BasicBlock* loop_merge_; + + // The block immediately before the loop header. + BasicBlock* loop_preheader_; + + // A parent of a loop is the loop which contains it as a nested child loop. + Loop* parent_; + + // Nested child loops of this loop. + ChildrenList nested_loops_; + + // A set of all the basic blocks which comprise the loop structure. Will be + // computed only when needed on demand. + BasicBlockListTy loop_basic_blocks_; + + // Sets the parent loop of this loop, that is, a loop which contains this loop + // as a nested child loop. + inline void SetParent(Loop* parent) { parent_ = parent; } + + // Returns the loop preheader if it exists, returns nullptr otherwise. + BasicBlock* FindLoopPreheader(IRContext* context, + opt::DominatorAnalysis* dom_analysis); + + // This is only to allow LoopDescriptor::dummy_top_loop_ to add top level + // loops as child. + friend class LoopDescriptor; +}; + +// Loop descriptions class for a given function. +// For a given function, the class builds loop nests information. +// The analysis expects a structured control flow. +class LoopDescriptor { + public: + // Iterator interface (depth first postorder traversal). + using iterator = opt::PostOrderTreeDFIterator; + using const_iterator = opt::PostOrderTreeDFIterator; + + // Creates a loop object for all loops found in |f|. + explicit LoopDescriptor(const Function* f); + + // Returns the number of loops found in the function. + inline size_t NumLoops() const { return loops_.size(); } + + // Returns the loop at a particular |index|. The |index| must be in bounds, + // check with NumLoops before calling. + inline Loop& GetLoopByIndex(size_t index) const { + assert(loops_.size() > index && + "Index out of range (larger than loop count)"); + return *loops_[index].get(); + } + + // Returns the inner most loop that contains the basic block id |block_id|. + inline Loop* operator[](uint32_t block_id) const { + return FindLoopForBasicBlock(block_id); + } + + // Returns the inner most loop that contains the basic block |bb|. + inline Loop* operator[](const BasicBlock* bb) const { + return (*this)[bb->id()]; + } + + // Iterators for post order depth first traversal of the loops. + // Inner most loops will be visited first. + inline iterator begin() { return iterator::begin(&dummy_top_loop_); } + inline iterator end() { return iterator::end(&dummy_top_loop_); } + inline const_iterator begin() const { return cbegin(); } + inline const_iterator end() const { return cend(); } + inline const_iterator cbegin() const { + return const_iterator::begin(&dummy_top_loop_); + } + inline const_iterator cend() const { + return const_iterator::end(&dummy_top_loop_); + } + + private: + using LoopContainerType = std::vector>; + + // Creates loop descriptors for the function |f|. + void PopulateList(const Function* f); + + // Returns the inner most loop that contains the basic block id |block_id|. + inline Loop* FindLoopForBasicBlock(uint32_t block_id) const { + std::unordered_map::const_iterator it = + basic_block_to_loop_.find(block_id); + return it != basic_block_to_loop_.end() ? it->second : nullptr; + } + + // A list of all the loops in the function. + LoopContainerType loops_; + // Dummy root: this "loop" is only there to help iterators creation. + Loop dummy_top_loop_; + std::unordered_map basic_block_to_loop_; +}; + +} // namespace ir +} // namespace spvtools + +#endif // LIBSPIRV_OPT_LOOP_DESCRIPTORS_H_ diff --git a/source/opt/tree_iterator.h b/source/opt/tree_iterator.h index 4a24a01..ba724df 100644 --- a/source/opt/tree_iterator.h +++ b/source/opt/tree_iterator.h @@ -115,6 +115,131 @@ class TreeDFIterator { std::stack> parent_iterators_; }; +// Helper class to iterate over a tree in a depth first post-order. +// The class assumes the data structure is a tree, tree node type implements a +// forward iterator. +// At each step, the iterator holds the pointer to the current node and state of +// the walk. +// The state is recorded by stacking the iteration position of the node +// children. To move to the next node, the iterator: +// - Looks at the top of the stack; +// - If the children iterator has reach the end, then the node become the +// current one and we pop the stack; +// - Otherwise, we save the child and increment the iterator; +// - We walk the child sub-tree until we find a leaf, stacking all non-leaves +// states (pair of node pointer and child iterator) as we walk it. +template +class PostOrderTreeDFIterator { + static_assert(!std::is_pointer::value && + !std::is_reference::value, + "NodeTy should be a class"); + // Type alias to keep track of the const qualifier. + using NodeIterator = + typename std::conditional::value, + typename NodeTy::const_iterator, + typename NodeTy::iterator>::type; + + // Type alias to keep track of the const qualifier. + using NodePtr = NodeTy*; + + public: + // Standard iterator interface. + using reference = NodeTy&; + using value_type = NodeTy; + + static inline PostOrderTreeDFIterator begin(NodePtr top_node) { + return PostOrderTreeDFIterator(top_node); + } + + static inline PostOrderTreeDFIterator end(NodePtr sentinel_node) { + return PostOrderTreeDFIterator(sentinel_node, false); + } + + bool operator==(const PostOrderTreeDFIterator& x) const { + return current_ == x.current_; + } + + bool operator!=(const PostOrderTreeDFIterator& x) const { + return !(*this == x); + } + + reference operator*() const { return *current_; } + + NodePtr operator->() const { return current_; } + + PostOrderTreeDFIterator& operator++() { + MoveToNextNode(); + return *this; + } + + PostOrderTreeDFIterator operator++(int) { + PostOrderTreeDFIterator tmp = *this; + ++*this; + return tmp; + } + + private: + explicit inline PostOrderTreeDFIterator(NodePtr top_node) + : current_(top_node) { + if (current_) WalkToLeaf(); + } + + // Constructor for the "end()" iterator. + // |end_sentinel| is the value that acts as end value (can be null). The bool + // parameters is to distinguish from the start() Ctor. + inline PostOrderTreeDFIterator(NodePtr sentinel_node, bool) + : current_(sentinel_node) {} + + // Moves the iterator to the next node in the tree. + // If we are at the end, do nothing, otherwise + // if our current node has children, use the children iterator and push the + // current node into the stack. + // If we reach the end of the local iterator, pop it. + inline void MoveToNextNode() { + if (!current_) return; + if (parent_iterators_.empty()) { + current_ = nullptr; + return; + } + std::pair& next_it = parent_iterators_.top(); + // If we visited all children, the current node is the top of the stack. + if (next_it.second == next_it.first->end()) { + // Set the new node. + current_ = next_it.first; + parent_iterators_.pop(); + return; + } + // We have more children to visit, set the current node to the first child + // and dive to leaf. + current_ = *next_it.second; + // Update the iterator for the next child (avoid unneeded pop). + ++next_it.second; + WalkToLeaf(); + } + + // Moves the iterator to the next node in the tree. + // If we are at the end, do nothing, otherwise + // if our current node has children, use the children iterator and push the + // current node into the stack. + // If we reach the end of the local iterator, pop it. + inline void WalkToLeaf() { + while (current_->begin() != current_->end()) { + NodeIterator next = ++current_->begin(); + parent_iterators_.emplace(make_pair(current_, next)); + // Set the first child as the new node. + current_ = *current_->begin(); + } + } + + // The current node of the tree. + NodePtr current_; + // State of the tree walk: each pair contains the parent node and the iterator + // of the next children to visit. + // When all the children has been visited, we pop the first entry and the + // parent node become the current node. + std::stack> parent_iterators_; +}; + } // namespace opt } // namespace spvtools diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index d8a5b53..bd83f94 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -13,6 +13,7 @@ # limitations under the License. add_subdirectory(dominator_tree) +add_subdirectory(loop_optimizations) add_spvtools_unittest(TARGET instruction SRCS instruction_test.cpp diff --git a/test/opt/dominator_tree/generated.cpp b/test/opt/dominator_tree/generated.cpp index e9d9827..786dbfc 100644 --- a/test/opt/dominator_tree/generated.cpp +++ b/test/opt/dominator_tree/generated.cpp @@ -25,6 +25,7 @@ #include "../pass_fixture.h" #include "../pass_utils.h" #include "opt/dominator_analysis.h" +#include "opt/iterator.h" #include "opt/pass.h" namespace { @@ -431,8 +432,7 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), spvtest::GetBasicBlock(fn, 11)); - uint32_t entry_id = cfg.pseudo_entry_block()->id(); - std::array node_order = {{entry_id, 10, 11, 12}}; + std::array node_order = {{10, 11, 12}}; { // Test dominator tree iteration order. opt::DominatorTree::iterator node_it = dom_tree.GetDomTree().begin(); @@ -457,6 +457,34 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { } EXPECT_EQ(node_it, node_end); } + { + // Test dominator tree iteration order. + opt::DominatorTree::post_iterator node_it = + dom_tree.GetDomTree().post_begin(); + opt::DominatorTree::post_iterator node_end = + dom_tree.GetDomTree().post_end(); + for (uint32_t id : + ir::make_range(node_order.rbegin(), node_order.rend())) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } + { + // Same as above, but with const iterators. + opt::DominatorTree::const_post_iterator node_it = + dom_tree.GetDomTree().post_cbegin(); + opt::DominatorTree::const_post_iterator node_end = + dom_tree.GetDomTree().post_cend(); + for (uint32_t id : + ir::make_range(node_order.rbegin(), node_order.rend())) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } } // Check post dominator tree @@ -488,8 +516,7 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { EXPECT_EQ(dom_tree.ImmediateDominator(spvtest::GetBasicBlock(fn, 12)), cfg.pseudo_exit_block()); - uint32_t entry_id = cfg.pseudo_exit_block()->id(); - std::array node_order = {{entry_id, 12, 11, 10}}; + std::array node_order = {{12, 11, 10}}; { // Test dominator tree iteration order. opt::DominatorTree::iterator node_it = tree.begin(); @@ -512,6 +539,34 @@ TEST_F(PassClassTest, DominatorLoopToSelf) { } EXPECT_EQ(node_it, node_end); } + { + // Test dominator tree iteration order. + opt::DominatorTree::post_iterator node_it = + dom_tree.GetDomTree().post_begin(); + opt::DominatorTree::post_iterator node_end = + dom_tree.GetDomTree().post_end(); + for (uint32_t id : + ir::make_range(node_order.rbegin(), node_order.rend())) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } + { + // Same as above, but with const iterators. + opt::DominatorTree::const_post_iterator node_it = + dom_tree.GetDomTree().post_cbegin(); + opt::DominatorTree::const_post_iterator node_end = + dom_tree.GetDomTree().post_cend(); + for (uint32_t id : + ir::make_range(node_order.rbegin(), node_order.rend())) { + EXPECT_NE(node_it, node_end); + EXPECT_EQ(node_it->id(), id); + node_it++; + } + EXPECT_EQ(node_it, node_end); + } } } diff --git a/test/opt/loop_optimizations/CMakeLists.txt b/test/opt/loop_optimizations/CMakeLists.txt new file mode 100644 index 0000000..b5360c2 --- /dev/null +++ b/test/opt/loop_optimizations/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +add_spvtools_unittest(TARGET loop_descriptor_simple + SRCS ../function_utils.h + loop_descriptions.cpp + LIBS SPIRV-Tools-opt +) + +add_spvtools_unittest(TARGET loop_descriptor_nested + SRCS ../function_utils.h + nested_loops.cpp + LIBS SPIRV-Tools-opt +) + diff --git a/test/opt/loop_optimizations/loop_descriptions.cpp b/test/opt/loop_optimizations/loop_descriptions.cpp new file mode 100644 index 0000000..d54b789 --- /dev/null +++ b/test/opt/loop_optimizations/loop_descriptions.cpp @@ -0,0 +1,205 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include + +#include "../assembly_builder.h" +#include "../function_utils.h" +#include "../pass_fixture.h" +#include "../pass_utils.h" +#include "opt/loop_descriptor.h" +#include "opt/pass.h" + +namespace { + +using namespace spvtools; +using ::testing::UnorderedElementsAre; + +using PassClassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + int i = 0; + for(; i < 10; ++i) { + } +} +*/ +TEST_F(PassClassTest, BasicVisitFromEntryPoint) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %5 "i" + OpName %3 "c" + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpConstant %8 1 + %14 = OpTypeFloat 32 + %15 = OpTypeVector %14 4 + %16 = OpTypePointer Output %15 + %3 = OpVariable %16 Output + %2 = OpFunction %6 None %7 + %17 = OpLabel + %5 = OpVariable %9 Function + OpStore %5 %10 + OpBranch %18 + %18 = OpLabel + OpLoopMerge %19 %20 None + OpBranch %21 + %21 = OpLabel + %22 = OpLoad %8 %5 + %23 = OpSLessThan %12 %22 %11 + OpBranchConditional %23 %24 %19 + %24 = OpLabel + OpBranch %20 + %20 = OpLabel + %25 = OpLoad %8 %5 + %26 = OpIAdd %8 %25 %13 + OpStore %5 %26 + OpBranch %18 + %19 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const ir::Function* f = spvtest::GetFunction(module, 2); + ir::LoopDescriptor ld{f}; + + EXPECT_EQ(ld.NumLoops(), 1u); + + ir::Loop& loop = ld.GetLoopByIndex(0); + EXPECT_EQ(loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 18)); + EXPECT_EQ(loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 20)); + EXPECT_EQ(loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 19)); + + EXPECT_FALSE(loop.HasNestedLoops()); + EXPECT_FALSE(loop.IsNested()); + EXPECT_EQ(loop.GetDepth(), 1u); +} + +/* +Generated from the following GLSL: +#version 330 core +layout(location = 0) out vec4 c; +void main() { + for(int i = 0; i < 10; ++i) {} + for(int i = 0; i < 10; ++i) {} +} + +But it was "hacked" to make the first loop merge block the second loop header. +*/ +TEST_F(PassClassTest, LoopWithNoPreHeader) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + OpName %5 "i" + OpName %3 "c" + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpConstant %8 1 + %14 = OpTypeFloat 32 + %15 = OpTypeVector %14 4 + %16 = OpTypePointer Output %15 + %3 = OpVariable %16 Output + %2 = OpFunction %6 None %7 + %17 = OpLabel + %4 = OpVariable %9 Function + %5 = OpVariable %9 Function + OpStore %4 %10 + OpStore %5 %10 + OpBranch %18 + %18 = OpLabel + OpLoopMerge %27 %20 None + OpBranch %21 + %21 = OpLabel + %22 = OpLoad %8 %4 + %23 = OpSLessThan %12 %22 %11 + OpBranchConditional %23 %24 %27 + %24 = OpLabel + OpBranch %20 + %20 = OpLabel + %25 = OpLoad %8 %4 + %26 = OpIAdd %8 %25 %13 + OpStore %4 %26 + OpBranch %18 + %27 = OpLabel + OpLoopMerge %28 %29 None + OpBranch %30 + %30 = OpLabel + %31 = OpLoad %8 %5 + %32 = OpSLessThan %12 %31 %11 + OpBranchConditional %32 %33 %28 + %33 = OpLabel + OpBranch %29 + %29 = OpLabel + %34 = OpLoad %8 %5 + %35 = OpIAdd %8 %34 %13 + OpStore %5 %35 + OpBranch %27 + %28 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const ir::Function* f = spvtest::GetFunction(module, 2); + ir::LoopDescriptor ld{f}; + + EXPECT_EQ(ld.NumLoops(), 2u); + + ir::Loop* loop = ld[27]; + EXPECT_EQ(loop->GetPreHeaderBlock(), nullptr); +} + +} // namespace diff --git a/test/opt/loop_optimizations/nested_loops.cpp b/test/opt/loop_optimizations/nested_loops.cpp new file mode 100644 index 0000000..d635586 --- /dev/null +++ b/test/opt/loop_optimizations/nested_loops.cpp @@ -0,0 +1,595 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include + +#include "../assembly_builder.h" +#include "../function_utils.h" +#include "../pass_fixture.h" +#include "../pass_utils.h" + +#include "opt/iterator.h" +#include "opt/loop_descriptor.h" +#include "opt/pass.h" +#include "opt/tree_iterator.h" + +namespace { + +using namespace spvtools; +using ::testing::UnorderedElementsAre; + +using PassClassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + int i = 0; + for (; i < 10; ++i) { + int j = 0; + int k = 0; + for (; j < 11; ++j) {} + for (; k < 12; ++k) {} + } +} +*/ +TEST_F(PassClassTest, BasicVisitFromEntryPoint) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + OpName %5 "j" + OpName %6 "k" + OpName %3 "c" + OpDecorate %3 Location 0 + %7 = OpTypeVoid + %8 = OpTypeFunction %7 + %9 = OpTypeInt 32 1 + %10 = OpTypePointer Function %9 + %11 = OpConstant %9 0 + %12 = OpConstant %9 10 + %13 = OpTypeBool + %14 = OpConstant %9 11 + %15 = OpConstant %9 1 + %16 = OpConstant %9 12 + %17 = OpTypeFloat 32 + %18 = OpTypeVector %17 4 + %19 = OpTypePointer Output %18 + %3 = OpVariable %19 Output + %2 = OpFunction %7 None %8 + %20 = OpLabel + %4 = OpVariable %10 Function + %5 = OpVariable %10 Function + %6 = OpVariable %10 Function + OpStore %4 %11 + OpBranch %21 + %21 = OpLabel + OpLoopMerge %22 %23 None + OpBranch %24 + %24 = OpLabel + %25 = OpLoad %9 %4 + %26 = OpSLessThan %13 %25 %12 + OpBranchConditional %26 %27 %22 + %27 = OpLabel + OpStore %5 %11 + OpStore %6 %11 + OpBranch %28 + %28 = OpLabel + OpLoopMerge %29 %30 None + OpBranch %31 + %31 = OpLabel + %32 = OpLoad %9 %5 + %33 = OpSLessThan %13 %32 %14 + OpBranchConditional %33 %34 %29 + %34 = OpLabel + OpBranch %30 + %30 = OpLabel + %35 = OpLoad %9 %5 + %36 = OpIAdd %9 %35 %15 + OpStore %5 %36 + OpBranch %28 + %29 = OpLabel + OpBranch %37 + %37 = OpLabel + OpLoopMerge %38 %39 None + OpBranch %40 + %40 = OpLabel + %41 = OpLoad %9 %6 + %42 = OpSLessThan %13 %41 %16 + OpBranchConditional %42 %43 %38 + %43 = OpLabel + OpBranch %39 + %39 = OpLabel + %44 = OpLoad %9 %6 + %45 = OpIAdd %9 %44 %15 + OpStore %6 %45 + OpBranch %37 + %38 = OpLabel + OpBranch %23 + %23 = OpLabel + %46 = OpLoad %9 %4 + %47 = OpIAdd %9 %46 %15 + OpStore %4 %47 + OpBranch %21 + %22 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const ir::Function* f = spvtest::GetFunction(module, 2); + ir::LoopDescriptor ld{f}; + + EXPECT_EQ(ld.NumLoops(), 3u); + + // Invalid basic block id. + EXPECT_EQ(ld[0u], nullptr); + // Not a loop header. + EXPECT_EQ(ld[20], nullptr); + + ir::Loop& parent_loop = *ld[21]; + EXPECT_TRUE(parent_loop.HasNestedLoops()); + EXPECT_FALSE(parent_loop.IsNested()); + EXPECT_EQ(parent_loop.GetDepth(), 1u); + EXPECT_EQ(std::distance(parent_loop.begin(), parent_loop.end()), 2u); + EXPECT_EQ(parent_loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 21)); + EXPECT_EQ(parent_loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 23)); + EXPECT_EQ(parent_loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 22)); + + ir::Loop& child_loop_1 = *ld[28]; + EXPECT_FALSE(child_loop_1.HasNestedLoops()); + EXPECT_TRUE(child_loop_1.IsNested()); + EXPECT_EQ(child_loop_1.GetDepth(), 2u); + EXPECT_EQ(std::distance(child_loop_1.begin(), child_loop_1.end()), 0u); + EXPECT_EQ(child_loop_1.GetHeaderBlock(), spvtest::GetBasicBlock(f, 28)); + EXPECT_EQ(child_loop_1.GetLatchBlock(), spvtest::GetBasicBlock(f, 30)); + EXPECT_EQ(child_loop_1.GetMergeBlock(), spvtest::GetBasicBlock(f, 29)); + + ir::Loop& child_loop_2 = *ld[37]; + EXPECT_FALSE(child_loop_2.HasNestedLoops()); + EXPECT_TRUE(child_loop_2.IsNested()); + EXPECT_EQ(child_loop_2.GetDepth(), 2u); + EXPECT_EQ(std::distance(child_loop_2.begin(), child_loop_2.end()), 0u); + EXPECT_EQ(child_loop_2.GetHeaderBlock(), spvtest::GetBasicBlock(f, 37)); + EXPECT_EQ(child_loop_2.GetLatchBlock(), spvtest::GetBasicBlock(f, 39)); + EXPECT_EQ(child_loop_2.GetMergeBlock(), spvtest::GetBasicBlock(f, 38)); +} + +static void CheckLoopBlocks(ir::Loop* loop, + std::unordered_set* expected_ids) { + SCOPED_TRACE("Check loop " + std::to_string(loop->GetHeaderBlock()->id())); + for (uint32_t bb_id : loop->GetBlocks()) { + EXPECT_EQ(expected_ids->count(bb_id), 1u); + expected_ids->erase(bb_id); + } + EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock())); + EXPECT_EQ(expected_ids->size(), 0u); +} + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + int i = 0; + for (; i < 10; ++i) { + for (int j = 0; j < 11; ++j) { + if (j < 5) { + for (int k = 0; k < 12; ++k) {} + } + else {} + for (int k = 0; k < 12; ++k) {} + } + } +}*/ +TEST_F(PassClassTest, TripleNestedLoop) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + OpName %5 "j" + OpName %6 "k" + OpName %7 "k" + OpName %3 "c" + OpDecorate %3 Location 0 + %8 = OpTypeVoid + %9 = OpTypeFunction %8 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %12 = OpConstant %10 0 + %13 = OpConstant %10 10 + %14 = OpTypeBool + %15 = OpConstant %10 11 + %16 = OpConstant %10 5 + %17 = OpConstant %10 12 + %18 = OpConstant %10 1 + %19 = OpTypeFloat 32 + %20 = OpTypeVector %19 4 + %21 = OpTypePointer Output %20 + %3 = OpVariable %21 Output + %2 = OpFunction %8 None %9 + %22 = OpLabel + %4 = OpVariable %11 Function + %5 = OpVariable %11 Function + %6 = OpVariable %11 Function + %7 = OpVariable %11 Function + OpStore %4 %12 + OpBranch %23 + %23 = OpLabel + OpLoopMerge %24 %25 None + OpBranch %26 + %26 = OpLabel + %27 = OpLoad %10 %4 + %28 = OpSLessThan %14 %27 %13 + OpBranchConditional %28 %29 %24 + %29 = OpLabel + OpStore %5 %12 + OpBranch %30 + %30 = OpLabel + OpLoopMerge %31 %32 None + OpBranch %33 + %33 = OpLabel + %34 = OpLoad %10 %5 + %35 = OpSLessThan %14 %34 %15 + OpBranchConditional %35 %36 %31 + %36 = OpLabel + %37 = OpLoad %10 %5 + %38 = OpSLessThan %14 %37 %16 + OpSelectionMerge %39 None + OpBranchConditional %38 %40 %39 + %40 = OpLabel + OpStore %6 %12 + OpBranch %41 + %41 = OpLabel + OpLoopMerge %42 %43 None + OpBranch %44 + %44 = OpLabel + %45 = OpLoad %10 %6 + %46 = OpSLessThan %14 %45 %17 + OpBranchConditional %46 %47 %42 + %47 = OpLabel + OpBranch %43 + %43 = OpLabel + %48 = OpLoad %10 %6 + %49 = OpIAdd %10 %48 %18 + OpStore %6 %49 + OpBranch %41 + %42 = OpLabel + OpBranch %39 + %39 = OpLabel + OpStore %7 %12 + OpBranch %50 + %50 = OpLabel + OpLoopMerge %51 %52 None + OpBranch %53 + %53 = OpLabel + %54 = OpLoad %10 %7 + %55 = OpSLessThan %14 %54 %17 + OpBranchConditional %55 %56 %51 + %56 = OpLabel + OpBranch %52 + %52 = OpLabel + %57 = OpLoad %10 %7 + %58 = OpIAdd %10 %57 %18 + OpStore %7 %58 + OpBranch %50 + %51 = OpLabel + OpBranch %32 + %32 = OpLabel + %59 = OpLoad %10 %5 + %60 = OpIAdd %10 %59 %18 + OpStore %5 %60 + OpBranch %30 + %31 = OpLabel + OpBranch %25 + %25 = OpLabel + %61 = OpLoad %10 %4 + %62 = OpIAdd %10 %61 %18 + OpStore %4 %62 + OpBranch %23 + %24 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const ir::Function* f = spvtest::GetFunction(module, 2); + ir::LoopDescriptor ld{f}; + + EXPECT_EQ(ld.NumLoops(), 4u); + + // Invalid basic block id. + EXPECT_EQ(ld[0u], nullptr); + // Not in a loop. + EXPECT_EQ(ld[22], nullptr); + + // Check that we can map basic block to the correct loop. + // The following block ids do not belong to a loop. + for (uint32_t bb_id : {22, 24}) EXPECT_EQ(ld[bb_id], nullptr); + + { + std::unordered_set basic_block_in_loop = { + {23, 26, 29, 30, 33, 36, 40, 41, 44, 47, 43, + 42, 39, 50, 53, 56, 52, 51, 32, 31, 25}}; + ir::Loop* loop = ld[23]; + CheckLoopBlocks(loop, &basic_block_in_loop); + + EXPECT_TRUE(loop->HasNestedLoops()); + EXPECT_FALSE(loop->IsNested()); + EXPECT_EQ(loop->GetDepth(), 1u); + EXPECT_EQ(std::distance(loop->begin(), loop->end()), 1u); + EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 22)); + EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 23)); + EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 25)); + EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 24)); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock())); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock())); + } + + { + std::unordered_set basic_block_in_loop = { + {30, 33, 36, 40, 41, 44, 47, 43, 42, 39, 50, 53, 56, 52, 51, 32}}; + ir::Loop* loop = ld[30]; + CheckLoopBlocks(loop, &basic_block_in_loop); + + EXPECT_TRUE(loop->HasNestedLoops()); + EXPECT_TRUE(loop->IsNested()); + EXPECT_EQ(loop->GetDepth(), 2u); + EXPECT_EQ(std::distance(loop->begin(), loop->end()), 2u); + EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 29)); + EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 30)); + EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 32)); + EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 31)); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock())); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock())); + } + + { + std::unordered_set basic_block_in_loop = {{41, 44, 47, 43}}; + ir::Loop* loop = ld[41]; + CheckLoopBlocks(loop, &basic_block_in_loop); + + EXPECT_FALSE(loop->HasNestedLoops()); + EXPECT_TRUE(loop->IsNested()); + EXPECT_EQ(loop->GetDepth(), 3u); + EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u); + EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 40)); + EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 41)); + EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 43)); + EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 42)); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock())); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock())); + } + + { + std::unordered_set basic_block_in_loop = {{50, 53, 56, 52}}; + ir::Loop* loop = ld[50]; + CheckLoopBlocks(loop, &basic_block_in_loop); + + EXPECT_FALSE(loop->HasNestedLoops()); + EXPECT_TRUE(loop->IsNested()); + EXPECT_EQ(loop->GetDepth(), 3u); + EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u); + EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 39)); + EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 50)); + EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 52)); + EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 51)); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock())); + EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock())); + } + + // Make sure LoopDescriptor gives us the inner most loop when we query for + // loops. + for (const ir::BasicBlock& bb : *f) { + if (ir::Loop* loop = ld[&bb]) { + for (ir::Loop& sub_loop : + ir::make_range(++opt::TreeDFIterator(loop), + opt::TreeDFIterator())) { + EXPECT_FALSE(sub_loop.IsInsideLoop(bb.id())); + } + } + } +} + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 11; ++j) { + for (int k = 0; k < 11; ++k) {} + } + for (int k = 0; k < 12; ++k) {} + } +} +*/ +TEST_F(PassClassTest, LoopParentTest) { + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %4 "i" + OpName %5 "j" + OpName %6 "k" + OpName %7 "k" + OpName %3 "c" + OpDecorate %3 Location 0 + %8 = OpTypeVoid + %9 = OpTypeFunction %8 + %10 = OpTypeInt 32 1 + %11 = OpTypePointer Function %10 + %12 = OpConstant %10 0 + %13 = OpConstant %10 10 + %14 = OpTypeBool + %15 = OpConstant %10 11 + %16 = OpConstant %10 1 + %17 = OpConstant %10 12 + %18 = OpTypeFloat 32 + %19 = OpTypeVector %18 4 + %20 = OpTypePointer Output %19 + %3 = OpVariable %20 Output + %2 = OpFunction %8 None %9 + %21 = OpLabel + %4 = OpVariable %11 Function + %5 = OpVariable %11 Function + %6 = OpVariable %11 Function + %7 = OpVariable %11 Function + OpStore %4 %12 + OpBranch %22 + %22 = OpLabel + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %26 = OpLoad %10 %4 + %27 = OpSLessThan %14 %26 %13 + OpBranchConditional %27 %28 %23 + %28 = OpLabel + OpStore %5 %12 + OpBranch %29 + %29 = OpLabel + OpLoopMerge %30 %31 None + OpBranch %32 + %32 = OpLabel + %33 = OpLoad %10 %5 + %34 = OpSLessThan %14 %33 %15 + OpBranchConditional %34 %35 %30 + %35 = OpLabel + OpStore %6 %12 + OpBranch %36 + %36 = OpLabel + OpLoopMerge %37 %38 None + OpBranch %39 + %39 = OpLabel + %40 = OpLoad %10 %6 + %41 = OpSLessThan %14 %40 %15 + OpBranchConditional %41 %42 %37 + %42 = OpLabel + OpBranch %38 + %38 = OpLabel + %43 = OpLoad %10 %6 + %44 = OpIAdd %10 %43 %16 + OpStore %6 %44 + OpBranch %36 + %37 = OpLabel + OpBranch %31 + %31 = OpLabel + %45 = OpLoad %10 %5 + %46 = OpIAdd %10 %45 %16 + OpStore %5 %46 + OpBranch %29 + %30 = OpLabel + OpStore %7 %12 + OpBranch %47 + %47 = OpLabel + OpLoopMerge %48 %49 None + OpBranch %50 + %50 = OpLabel + %51 = OpLoad %10 %7 + %52 = OpSLessThan %14 %51 %17 + OpBranchConditional %52 %53 %48 + %53 = OpLabel + OpBranch %49 + %49 = OpLabel + %54 = OpLoad %10 %7 + %55 = OpIAdd %10 %54 %16 + OpStore %7 %55 + OpBranch %47 + %48 = OpLabel + OpBranch %24 + %24 = OpLabel + %56 = OpLoad %10 %4 + %57 = OpIAdd %10 %56 %16 + OpStore %4 %57 + OpBranch %22 + %23 = OpLabel + OpReturn + OpFunctionEnd + )"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + const ir::Function* f = spvtest::GetFunction(module, 2); + ir::LoopDescriptor ld{f}; + + EXPECT_EQ(ld.NumLoops(), 4u); + + { + ir::Loop& loop = *ld[22]; + EXPECT_TRUE(loop.HasNestedLoops()); + EXPECT_FALSE(loop.IsNested()); + EXPECT_EQ(loop.GetDepth(), 1u); + EXPECT_EQ(loop.GetParent(), nullptr); + } + + { + ir::Loop& loop = *ld[29]; + EXPECT_TRUE(loop.HasNestedLoops()); + EXPECT_TRUE(loop.IsNested()); + EXPECT_EQ(loop.GetDepth(), 2u); + EXPECT_EQ(loop.GetParent(), ld[22]); + } + + { + ir::Loop& loop = *ld[36]; + EXPECT_FALSE(loop.HasNestedLoops()); + EXPECT_TRUE(loop.IsNested()); + EXPECT_EQ(loop.GetDepth(), 3u); + EXPECT_EQ(loop.GetParent(), ld[29]); + } + + { + ir::Loop& loop = *ld[47]; + EXPECT_FALSE(loop.HasNestedLoops()); + EXPECT_TRUE(loop.IsNested()); + EXPECT_EQ(loop.GetDepth(), 2u); + EXPECT_EQ(loop.GetParent(), ld[22]); + } +} + +} // namespace -- 2.7.4