Track a construct by its entry block.
authorDavid Neto <dneto@google.com>
Fri, 5 Aug 2016 20:05:44 +0000 (16:05 -0400)
committerDavid Neto <dneto@google.com>
Fri, 5 Aug 2016 20:05:44 +0000 (16:05 -0400)
source/val/Function.cpp
source/val/Function.h

index 68bf532..0240d68 100644 (file)
@@ -143,12 +143,10 @@ spv_result_t Function::RegisterLoopMerge(uint32_t merge_id,
   current_block_->set_type(kBlockTypeLoop);
   merge_block.set_type(kBlockTypeMerge);
   continue_target_block.set_type(kBlockTypeContinue);
-  cfg_constructs_.emplace_back(ConstructType::kLoop, current_block_,
-                               &merge_block);
-  Construct& loop_construct = cfg_constructs_.back();
-  cfg_constructs_.emplace_back(ConstructType::kContinue,
-                               &continue_target_block);
-  Construct& continue_construct = cfg_constructs_.back();
+  Construct& loop_construct =
+      AddConstruct({ConstructType::kLoop, current_block_, &merge_block});
+  Construct& continue_construct =
+      AddConstruct({ConstructType::kContinue, &continue_target_block});
   continue_construct.set_corresponding_constructs({&loop_construct});
   loop_construct.set_corresponding_constructs({&continue_construct});
 
@@ -161,8 +159,8 @@ spv_result_t Function::RegisterSelectionMerge(uint32_t merge_id) {
   current_block_->set_type(kBlockTypeHeader);
   merge_block.set_type(kBlockTypeMerge);
 
-  cfg_constructs_.emplace_back(ConstructType::kSelection, current_block(),
-                               &merge_block);
+  AddConstruct({ConstructType::kSelection, current_block(), &merge_block});
+
   return SPV_SUCCESS;
 }
 
@@ -223,9 +221,10 @@ void Function::RegisterBlockEnd(vector<uint32_t> next_list,
     std::vector<BasicBlock*>& next_blocks_plus_continue_target =
         loop_header_successors_plus_continue_target_map_[current_block_];
     next_blocks_plus_continue_target = next_blocks;
-    // If this block is marked as Loop-type,  then the continue construct is
-    // the most recently created CFG construct.
-    auto continue_target = cfg_constructs_.back().entry_block();
+    auto continue_target = FindConstructForEntryBlock(current_block_)
+                               .corresponding_constructs()
+                               .back()
+                               ->entry_block();
     if (continue_target != current_block_) {
       next_blocks_plus_continue_target.push_back(continue_target);
     }
@@ -368,4 +367,20 @@ void Function::ComputeAugmentedCFG() {
     augmented_succ.insert(augmented_succ.end(), succ.begin(), succ.end());
   }
 };
+
+Construct& Function::AddConstruct(const Construct& new_construct) {
+  cfg_constructs_.push_back(new_construct);
+  auto& result = cfg_constructs_.back();
+  entry_block_to_construct_[new_construct.entry_block()] = &result;
+  return result;
+}
+
+Construct& Function::FindConstructForEntryBlock(const BasicBlock* entry_block) {
+  auto where = entry_block_to_construct_.find(entry_block);
+  assert(where != entry_block_to_construct_.end());
+  auto construct_ptr = (*where).second;
+  assert(construct_ptr);
+  return *construct_ptr;
+}
+
 }  /// namespace libspirv
index d9a0ae0..1fbe113 100644 (file)
@@ -197,6 +197,14 @@ class Function {
   // Populates augmented_successors_map_ and augmented_predecessors_map_.
   void ComputeAugmentedCFG();
 
+  // Adds a copy of the given Construct, and tracks it by its entry block.
+  // Returns a reference to the stored construct.
+  Construct& AddConstruct(const Construct& new_construct);
+
+  // Returns a reference to the construct corresponding to the given entry
+  // block.
+  Construct& FindConstructForEntryBlock(const BasicBlock* entry_block);
+
   /// The result id of the OpLabel that defined this block
   uint32_t id_;
 
@@ -279,6 +287,9 @@ class Function {
 
   /// The function parameter ids of the functions
   std::vector<uint32_t> parameter_ids_;
+
+  /// Maps a construct's entry block to the construct.
+  std::unordered_map<const BasicBlock*, Construct*> entry_block_to_construct_;
 };
 
 }  /// namespace libspirv