Add LoopUtils class to gather some loop transformation support.
authorVictor Lomuller <victor@codeplay.com>
Fri, 26 Jan 2018 12:07:10 +0000 (12:07 +0000)
committerSteven Perron <stevenperron@google.com>
Thu, 1 Feb 2018 20:35:09 +0000 (15:35 -0500)
This patch adds LoopUtils class to handle some loop related transformations. For now it has 2 transformations that simplifies other transformations such as loop unroll or unswitch:
 - Dedicate exit blocks: this ensure that all exit basic block
   (out-of-loop basic blocks that have a predecessor in the loop)
   have all their predecessors in the loop;
 - Loop Closed SSA (LCSSA): this ensure that all definitions in a loop are used inside the loop
   or in a phi instruction in an exit basic block.

It also adds the following capabilities:
 - Loop::IsLCSSA to test if the loop is in a LCSSA form
 - Loop::GetOrCreatePreHeaderBlock that can build a loop preheader if required;
 - New methods to allow on the fly updates of the loop descriptors.
 - New methods to allow on the fly updates of the CFG analysis.
 - Instruction::SetOperand to allow expression of the index relative to Instruction::NumOperands (to be compatible with the index returned by DefUseManager::ForEachUse)

16 files changed:
source/opt/CMakeLists.txt
source/opt/basic_block.cpp
source/opt/basic_block.h
source/opt/cfg.cpp
source/opt/cfg.h
source/opt/def_use_manager.cpp
source/opt/instruction.h
source/opt/ir_builder.h
source/opt/loop_descriptor.cpp
source/opt/loop_descriptor.h
source/opt/loop_utils.cpp [new file with mode: 0644]
source/opt/loop_utils.h [new file with mode: 0644]
test/opt/loop_optimizations/CMakeLists.txt
test/opt/loop_optimizations/lcssa.cpp [new file with mode: 0644]
test/opt/loop_optimizations/loop_descriptions.cpp
test/opt/loop_optimizations/nested_loops.cpp

index bb2311a..aec805c 100644 (file)
@@ -53,6 +53,7 @@ add_library(SPIRV-Tools-opt
   local_ssa_elim_pass.h
   log.h
   loop_descriptor.h
+  loop_utils.h
   mem_pass.h
   merge_return_pass.h
   module.h
@@ -118,6 +119,7 @@ add_library(SPIRV-Tools-opt
   local_single_store_elim_pass.cpp
   local_ssa_elim_pass.cpp
   loop_descriptor.cpp
+  loop_utils.cpp
   mem_pass.cpp
   merge_return_pass.cpp
   module.cpp
index d2f4b33..b07696b 100644 (file)
@@ -109,6 +109,28 @@ void BasicBlock::ForEachSuccessorLabel(
   }
 }
 
+void BasicBlock::ForEachSuccessorLabel(
+    const std::function<void(uint32_t*)>& f) {
+  auto br = &insts_.back();
+  switch (br->opcode()) {
+    case SpvOpBranch: {
+      uint32_t tmp_id = br->GetOperand(0).words[0];
+      f(&tmp_id);
+      if (tmp_id != br->GetOperand(0).words[0]) br->SetOperand(0, {tmp_id});
+    } break;
+    case SpvOpBranchConditional:
+    case SpvOpSwitch: {
+      bool is_first = true;
+      br->ForEachInId([&is_first, &f](uint32_t* idp) {
+        if (!is_first) f(idp);
+        is_first = false;
+      });
+    } break;
+    default:
+      break;
+  }
+}
+
 bool BasicBlock::IsSuccessor(const ir::BasicBlock* block) const {
   uint32_t succId = block->id();
   bool isSuccessor = false;
index 15d8cb9..d0186e6 100644 (file)
@@ -133,6 +133,11 @@ class BasicBlock {
   void ForEachSuccessorLabel(
       const std::function<void(const uint32_t)>& f) const;
 
+  // Runs the given function |f| on each label id of each successor block.
+  // Modifying the pointed value will change the branch taken by the basic
+  // block. It is the caller responsibility to update or invalidate the CFG.
+  void ForEachSuccessorLabel(const std::function<void(uint32_t*)>& f);
+
   // Returns true if |block| is a direct successor of |this|.
   bool IsSuccessor(const ir::BasicBlock* block) const;
 
index e3acdb2..9834256 100644 (file)
@@ -35,18 +35,25 @@ CFG::CFG(ir::Module* module)
           module->context(), SpvOpLabel, 0, kInvalidId, {}))) {
   for (auto& fn : *module) {
     for (auto& blk : fn) {
-      uint32_t blkId = blk.id();
-      id2block_[blkId] = &blk;
-      // Force the creation of an entry, not all basic block have predecessors
-      // (such as the entry block and some unreachables)
-      label2preds_[blkId];
-      blk.ForEachSuccessorLabel([&blkId, this](uint32_t sbid) {
-        label2preds_[sbid].push_back(blkId);
-      });
+      RegisterBlock(&blk);
     }
   }
 }
 
+void CFG::RemoveNonExistingEdges(uint32_t blk_id) {
+  std::vector<uint32_t> updated_pred_list;
+  for (uint32_t id : preds(blk_id)) {
+    ir::BasicBlock* pred_blk = block(id);
+    bool has_branch = false;
+    pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) {
+      if (succ == blk_id) has_branch = true;
+    });
+    if (has_branch) updated_pred_list.push_back(id);
+  }
+
+  label2preds_.at(blk_id) = std::move(updated_pred_list);
+}
+
 void CFG::ComputeStructuredOrder(ir::Function* func, ir::BasicBlock* root,
                                  std::list<ir::BasicBlock*>* order) {
   assert(module_->context()->get_feature_mgr()->HasCapability(
index bfeaca7..0680e39 100644 (file)
@@ -68,6 +68,34 @@ class CFG {
   void ComputeStructuredOrder(ir::Function* func, ir::BasicBlock* root,
                               std::list<ir::BasicBlock*>* order);
 
+  // Registers |blk| as a basic block in the cfg, this also updates the
+  // predecessor lists of each successor of |blk|.
+  void RegisterBlock(ir::BasicBlock* blk) {
+    uint32_t blk_id = blk->id();
+    id2block_[blk_id] = blk;
+    AddEdges(blk);
+  }
+
+  // Registers |blk| to all of its successors.
+  void AddEdges(ir::BasicBlock* blk) {
+    uint32_t blk_id = blk->id();
+    // Force the creation of an entry, not all basic block have predecessors
+    // (such as the entry blocks and some unreachables).
+    label2preds_[blk_id];
+    blk->ForEachSuccessorLabel(
+        [blk_id, this](uint32_t succ_id) { AddEdge(blk_id, succ_id); });
+  }
+
+  // Registers the basic block id |pred_blk_id| as being a predecessor of the
+  // basic block id |succ_blk_id|.
+  void AddEdge(uint32_t pred_blk_id, uint32_t succ_blk_id) {
+    label2preds_[succ_blk_id].push_back(pred_blk_id);
+  }
+
+  // Removes any edges that no longer exist from the predecessor mapping for
+  // the basic block id |blk_id|.
+  void RemoveNonExistingEdges(uint32_t blk_id);
+
  private:
   using cbb_ptr = const ir::BasicBlock*;
 
index 2b1b00c..33776ce 100644 (file)
@@ -40,8 +40,12 @@ void DefUseManager::AnalyzeInstUse(ir::Instruction* inst) {
   // Create entry for the given instruction. Note that the instruction may
   // not have any in-operands. In such cases, we still need a entry for those
   // instructions so this manager knows it has seen the instruction later.
-  auto& used_ids = inst_to_used_ids_[inst];
-  used_ids.clear();  // It might have existed before.
+  auto* used_ids = &inst_to_used_ids_[inst];
+  if (used_ids->size()) {
+    EraseUseRecordsOfOperandIds(inst);
+    used_ids = &inst_to_used_ids_[inst];
+  }
+  used_ids->clear();  // It might have existed before.
 
   for (uint32_t i = 0; i < inst->NumOperands(); ++i) {
     switch (inst->GetOperand(i).type) {
@@ -54,7 +58,7 @@ void DefUseManager::AnalyzeInstUse(ir::Instruction* inst) {
         ir::Instruction* def = GetDef(use_id);
         assert(def && "Definition is not registered.");
         id_to_users_.insert(UserEntry(def, inst));
-        used_ids.push_back(use_id);
+        used_ids->push_back(use_id);
       } break;
       default:
         break;
@@ -100,8 +104,10 @@ bool DefUseManager::WhileEachUser(
     const ir::Instruction* def,
     const std::function<bool(ir::Instruction*)>& f) const {
   // Ensure that |def| has been registered.
-  assert(def && def == GetDef(def->result_id()) &&
+  assert(def && (!def->HasResultId() || def == GetDef(def->result_id())) &&
          "Definition is not registered.");
+  if (!def->HasResultId()) return true;
+
   auto end = id_to_users_.end();
   for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) {
     if (!f(iter->second)) return false;
@@ -132,8 +138,10 @@ bool DefUseManager::WhileEachUse(
     const ir::Instruction* def,
     const std::function<bool(ir::Instruction*, uint32_t)>& f) const {
   // Ensure that |def| has been registered.
-  assert(def && def == GetDef(def->result_id()) &&
+  assert(def && (!def->HasResultId() || def == GetDef(def->result_id())) &&
          "Definition is not registered.");
+  if (!def->HasResultId()) return true;
+
   auto end = id_to_users_.end();
   for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) {
     ir::Instruction* user = iter->second;
index b13a78d..a383cd5 100644 (file)
@@ -197,10 +197,15 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   uint32_t GetSingleWordOperand(uint32_t index) const;
   // Sets the |index|-th in-operand's data to the given |data|.
   inline void SetInOperand(uint32_t index, std::vector<uint32_t>&& data);
+  // Sets the |index|-th operand's data to the given |data|.
+  // This is for in-operands modification only, but with |index| expressed in
+  // terms of operand index rather than in-operand index.
+  inline void SetOperand(uint32_t index, std::vector<uint32_t>&& data);
   // Sets the result type id.
   inline void SetResultType(uint32_t ty_id);
   // Sets the result id
   inline void SetResultId(uint32_t res_id);
+  inline bool HasResultId() const { return result_id_ != 0; }
   // Remove the |index|-th operand
   void RemoveOperand(uint32_t index) {
     operands_.erase(operands_.begin() + index);
@@ -450,9 +455,14 @@ inline void Instruction::AddOperand(Operand&& operand) {
 
 inline void Instruction::SetInOperand(uint32_t index,
                                       std::vector<uint32_t>&& data) {
-  assert(index + TypeResultIdCount() < operands_.size() &&
-         "operand index out of bound");
-  operands_[index + TypeResultIdCount()].words = std::move(data);
+  SetOperand(index + TypeResultIdCount(), std::move(data));
+}
+
+inline void Instruction::SetOperand(uint32_t index,
+                                    std::vector<uint32_t>&& data) {
+  assert(index < operands_.size() && "operand index out of bound");
+  assert(index >= TypeResultIdCount() && "operand is not a in-operand");
+  operands_[index].words = std::move(data);
 }
 
 inline void Instruction::SetResultId(uint32_t res_id) {
index 1762308..b376f6d 100644 (file)
@@ -163,6 +163,20 @@ class InstructionBuilder {
   // Returns the insertion point iterator.
   InsertionPointTy GetInsertPoint() { return insert_before_; }
 
+  // Change the insertion point to insert before the instruction
+  // |insert_before|.
+  void SetInsertPoint(ir::Instruction* insert_before) {
+    parent_ = context_->get_instr_block(insert_before);
+    insert_before_ = InsertionPointTy(insert_before);
+  }
+
+  // Change the insertion point to insert at the end of the basic block
+  // |parent_block|.
+  void SetInsertPoint(ir::BasicBlock* parent_block) {
+    parent_ = parent_block;
+    insert_before_ = parent_block->end();
+  }
+
   // Returns the context which instructions are constructed for.
   ir::IRContext* GetContext() const { return context_; }
 
index 2545be3..8c9a63c 100644 (file)
@@ -20,9 +20,9 @@
 
 #include "opt/cfg.h"
 #include "opt/dominator_tree.h"
+#include "opt/ir_builder.h"
 #include "opt/ir_context.h"
 #include "opt/iterator.h"
-#include "opt/loop_descriptor.h"
 #include "opt/make_unique.h"
 #include "opt/tree_iterator.h"
 
@@ -103,6 +103,197 @@ bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) {
   return true;
 }
 
+BasicBlock* Loop::GetOrCreatePreHeaderBlock(ir::IRContext* context) {
+  if (loop_preheader_) return loop_preheader_;
+
+  Function* fn = loop_header_->GetParent();
+  // Find the insertion point for the preheader.
+  Function::iterator header_it =
+      std::find_if(fn->begin(), fn->end(),
+                   [this](BasicBlock& bb) { return &bb == loop_header_; });
+  assert(header_it != fn->end());
+
+  // Create the preheader basic block.
+  loop_preheader_ = &*header_it.InsertBefore(std::unique_ptr<ir::BasicBlock>(
+      new ir::BasicBlock(std::unique_ptr<ir::Instruction>(new ir::Instruction(
+          context, SpvOpLabel, 0, context->TakeNextId(), {})))));
+  loop_preheader_->SetParent(fn);
+  uint32_t loop_preheader_id = loop_preheader_->id();
+
+  // Redirect the branches and patch the phi:
+  //  - For each phi instruction in the header:
+  //    - If the header has only 1 out-of-loop incoming branch:
+  //      - Change the incomning branch to be the preheader.
+  //    - If the header has more than 1 out-of-loop incoming branch:
+  //      - Create a new phi in the preheader, gathering all out-of-loops
+  //      incoming values;
+  //      - Patch the header phi instruction to use the preheader phi
+  //      instruction;
+  //  - Redirect all edges coming from outside the loop to the preheader.
+  opt::InstructionBuilder builder(
+      context, loop_preheader_,
+      ir::IRContext::kAnalysisDefUse |
+          ir::IRContext::kAnalysisInstrToBlockMapping);
+  // Patch all the phi instructions.
+  loop_header_->ForEachPhiInst([&builder, context, this](Instruction* phi) {
+    std::vector<uint32_t> preheader_phi_ops;
+    std::vector<uint32_t> header_phi_ops;
+    for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
+      uint32_t def_id = phi->GetSingleWordInOperand(i);
+      uint32_t branch_id = phi->GetSingleWordInOperand(i + 1);
+      if (IsInsideLoop(branch_id)) {
+        header_phi_ops.push_back(def_id);
+        header_phi_ops.push_back(branch_id);
+      } else {
+        preheader_phi_ops.push_back(def_id);
+        preheader_phi_ops.push_back(branch_id);
+      }
+    }
+
+    Instruction* preheader_insn_def = nullptr;
+    // Create a phi instruction if and only if the preheader_phi_ops has more
+    // than one pair.
+    if (preheader_phi_ops.size() > 2)
+      preheader_insn_def = builder.AddPhi(phi->type_id(), preheader_phi_ops);
+    else
+      preheader_insn_def =
+          context->get_def_use_mgr()->GetDef(preheader_phi_ops[0]);
+    // Build the new incoming edge.
+    header_phi_ops.push_back(preheader_insn_def->result_id());
+    header_phi_ops.push_back(loop_preheader_->id());
+    // Rewrite operands of the header's phi instruction.
+    uint32_t idx = 0;
+    for (; idx < header_phi_ops.size(); idx++)
+      phi->SetInOperand(idx, {header_phi_ops[idx]});
+    // Remove extra operands, from last to first (more efficient).
+    for (uint32_t j = phi->NumInOperands() - 1; j >= idx; j--)
+      phi->RemoveInOperand(j);
+  });
+  // Branch from the preheader to the header.
+  builder.AddBranch(loop_header_->id());
+
+  // Redirect all out of loop branches to the header to the preheader.
+  CFG* cfg = context->cfg();
+  cfg->RegisterBlock(loop_preheader_);
+  for (uint32_t pred_id : cfg->preds(loop_header_->id())) {
+    if (pred_id == loop_preheader_->id()) continue;
+    if (IsInsideLoop(pred_id)) continue;
+    BasicBlock* pred = cfg->block(pred_id);
+    pred->ForEachSuccessorLabel([this, loop_preheader_id](uint32_t* id) {
+      if (*id == loop_header_->id()) *id = loop_preheader_id;
+    });
+    cfg->AddEdge(pred_id, loop_preheader_id);
+  }
+  // Delete predecessors that are no longer predecessors of the loop header.
+  cfg->RemoveNonExistingEdges(loop_header_->id());
+  // Update the loop descriptors.
+  if (HasParent()) {
+    GetParent()->AddBasicBlock(loop_preheader_);
+    context->GetLoopDescriptor(fn)->SetBasicBlockToLoop(loop_preheader_->id(),
+                                                        GetParent());
+  }
+
+  context->InvalidateAnalysesExceptFor(
+      builder.GetPreservedAnalysis() |
+      ir::IRContext::Analysis::kAnalysisLoopAnalysis |
+      ir::IRContext::kAnalysisCFG);
+
+  return loop_preheader_;
+}
+
+void Loop::SetLatchBlock(BasicBlock* latch) {
+#ifndef NDEBUG
+  assert(latch->GetParent() && "The basic block does not belong to a function");
+
+  latch->ForEachSuccessorLabel([this](uint32_t id) {
+    assert((!IsInsideLoop(id) || id == GetHeaderBlock()->id()) &&
+           "A predecessor of the continue block does not belong to the loop");
+  });
+#endif  // NDEBUG
+  assert(IsInsideLoop(latch) && "The continue block is not in the loop");
+
+  SetLatchBlockImpl(latch);
+}
+
+void Loop::SetMergeBlock(BasicBlock* merge) {
+#ifndef NDEBUG
+  assert(merge->GetParent() && "The basic block does not belong to a function");
+  CFG& cfg = *merge->GetParent()->GetParent()->context()->cfg();
+
+  for (uint32_t pred : cfg.preds(merge->id())) {
+    assert(IsInsideLoop(pred) &&
+           "A predecessor of the merge block does not belong to the loop");
+  }
+#endif  // NDEBUG
+  assert(!IsInsideLoop(merge) && "The merge block is in the loop");
+
+  SetMergeBlockImpl(merge);
+  if (GetHeaderBlock()->GetLoopMergeInst()) {
+    UpdateLoopMergeInst();
+  }
+}
+
+void Loop::GetExitBlocks(IRContext* context,
+                         std::unordered_set<uint32_t>* exit_blocks) const {
+  ir::CFG* cfg = context->cfg();
+
+  for (uint32_t bb_id : GetBlocks()) {
+    const spvtools::ir::BasicBlock* bb = cfg->block(bb_id);
+    bb->ForEachSuccessorLabel([exit_blocks, this](uint32_t succ) {
+      if (!IsInsideLoop(succ)) {
+        exit_blocks->insert(succ);
+      }
+    });
+  }
+}
+
+void Loop::GetMergingBlocks(
+    IRContext* context, std::unordered_set<uint32_t>* merging_blocks) const {
+  assert(GetMergeBlock() && "This loop is not structured");
+  ir::CFG* cfg = context->cfg();
+
+  std::stack<const ir::BasicBlock*> to_visit;
+  to_visit.push(GetMergeBlock());
+  while (!to_visit.empty()) {
+    const ir::BasicBlock* bb = to_visit.top();
+    to_visit.pop();
+    merging_blocks->insert(bb->id());
+    for (uint32_t pred_id : cfg->preds(bb->id())) {
+      if (!IsInsideLoop(pred_id) && !merging_blocks->count(pred_id)) {
+        to_visit.push(cfg->block(pred_id));
+      }
+    }
+  }
+}
+
+bool Loop::IsLCSSA() const {
+  IRContext* context = GetHeaderBlock()->GetParent()->GetParent()->context();
+  ir::CFG* cfg = context->cfg();
+  opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+
+  std::unordered_set<uint32_t> exit_blocks;
+  GetExitBlocks(context, &exit_blocks);
+
+  for (uint32_t bb_id : GetBlocks()) {
+    for (Instruction& insn : *cfg->block(bb_id)) {
+      // All uses must be either:
+      //  - In the loop;
+      //  - In an exit block and in a phi instruction.
+      if (!def_use_mgr->WhileEachUser(
+              &insn,
+              [&exit_blocks, context, this](ir::Instruction* use) -> bool {
+                BasicBlock* parent = context->get_instr_block(use);
+                assert(parent && "Invalid analysis");
+                if (IsInsideLoop(parent)) return true;
+                if (use->opcode() != SpvOpPhi) return false;
+                return exit_blocks.count(parent->id());
+              }))
+        return false;
+    }
+  }
+  return true;
+}
+
 LoopDescriptor::LoopDescriptor(const Function* f) { PopulateList(f); }
 
 void LoopDescriptor::PopulateList(const Function* f) {
index 87d457d..3dfb0c1 100644 (file)
@@ -68,14 +68,40 @@ class Loop {
   // OpLoopMerge instruction.
   inline BasicBlock* GetHeaderBlock() { return loop_header_; }
   inline const BasicBlock* GetHeaderBlock() const { return loop_header_; }
+  inline void SetHeaderBlock(BasicBlock* header) { loop_header_ = header; }
+
+  // Updates the OpLoopMerge instruction to reflect the current state of the
+  // loop.
+  inline void UpdateLoopMergeInst() {
+    assert(GetHeaderBlock()->GetLoopMergeInst() &&
+           "The loop is not structured");
+    ir::Instruction* merge_inst = GetHeaderBlock()->GetLoopMergeInst();
+    merge_inst->SetInOperand(0, {GetMergeBlock()->id()});
+  }
 
   // Returns the latch basic block (basic block that holds the back-edge).
+  // These functions return nullptr if the loop is not structured (i.e. if it
+  // has more than one backedge).
   inline BasicBlock* GetLatchBlock() { return loop_continue_; }
   inline const BasicBlock* GetLatchBlock() const { return loop_continue_; }
-
-  // Returns the BasicBlock which marks the end of the loop.
+  // Sets |latch| as the loop unique block branching back to the header.
+  // A latch block must have the following properties:
+  //  - |latch| must be in the loop;
+  //  - must be the only block branching back to the header block.
+  void SetLatchBlock(BasicBlock* latch);
+
+  // Returns the basic block which marks the end of the loop.
+  // These functions return nullptr if the loop is not structured.
   inline BasicBlock* GetMergeBlock() { return loop_merge_; }
   inline const BasicBlock* GetMergeBlock() const { return loop_merge_; }
+  // Sets |merge| as the loop merge block. A merge block must have the following
+  // properties:
+  //  - |merge| must not be in the loop;
+  //  - all its predecessors must be in the loop.
+  //  - it must not be already used as merge block.
+  // If the loop has an OpLoopMerge in its header, this instruction is also
+  // updated.
+  void SetMergeBlock(BasicBlock* merge);
 
   // Returns the loop pre-header, nullptr means that the loop predecessor does
   // not qualify as a preheader.
@@ -87,9 +113,30 @@ class Loop {
   // Returns the loop pre-header.
   inline const BasicBlock* GetPreHeaderBlock() const { return loop_preheader_; }
 
+  // Returns the loop pre-header, if there is no suitable preheader it will be
+  // created.
+  BasicBlock* GetOrCreatePreHeaderBlock(ir::IRContext* context);
+
   // Returns true if this loop contains any nested loops.
   inline bool HasNestedLoops() const { return nested_loops_.size() != 0; }
 
+  // Fills |exit_blocks| with all basic blocks that are not in the loop and has
+  // at least one predecessor in the loop.
+  void GetExitBlocks(IRContext* context,
+                     std::unordered_set<uint32_t>* exit_blocks) const;
+
+  // Fills |merging_blocks| with all basic blocks that are post-dominated by the
+  // merge block. The merge block must exist.
+  // The set |merging_blocks| will only contain the merge block if it is
+  // unreachable.
+  void GetMergingBlocks(IRContext* context,
+                        std::unordered_set<uint32_t>* merging_blocks) const;
+
+  // Returns true if the loop is in a Loop Closed SSA form.
+  // In LCSSA form, all in-loop definitions are used in the loop or in phi
+  // instructions in the loop exit blocks.
+  bool IsLCSSA() const;
+
   // Returns the depth of this loop in the loop nest.
   // The outer-most loop has a depth of 1.
   inline size_t GetDepth() const {
@@ -139,11 +186,20 @@ class Loop {
     assert(IsBasicBlockInLoopSlow(bb) &&
            "Basic block does not belong to the loop");
 
+    AddBasicBlock(bb);
+  }
+
+  // Adds the Basic Block |bb| this loop and its parents.
+  void AddBasicBlock(const BasicBlock* bb) {
     for (Loop* loop = this; loop != nullptr; loop = loop->parent_) {
       loop_basic_blocks_.insert(bb->id());
     }
   }
 
+  // Sets the parent loop of this loop, that is, a loop which contains this loop
+  // as a nested child loop.
+  inline void SetParent(Loop* parent) { parent_ = parent; }
+
  private:
   // The block which marks the start of the loop.
   BasicBlock* loop_header_;
@@ -167,22 +223,25 @@ class Loop {
   // computed only when needed on demand.
   BasicBlockListTy loop_basic_blocks_;
 
-  // Check that |bb| is inside the loop using domination properties.
+  // Check that |bb| is inside the loop using domination property.
   // Note: this is for assertion purposes only, IsInsideLoop should be used
   // instead.
   bool IsBasicBlockInLoopSlow(const BasicBlock* bb);
 
-  // Sets the parent loop of this loop, that is, a loop which contains this loop
-  // as a nested child loop.
-  inline void SetParent(Loop* parent) { parent_ = parent; }
-
   // Returns the loop preheader if it exists, returns nullptr otherwise.
   BasicBlock* FindLoopPreheader(IRContext* context,
                                 opt::DominatorAnalysis* dom_analysis);
 
+  // Sets |latch| as the loop unique continue block. No checks are performed
+  // here.
+  inline void SetLatchBlockImpl(BasicBlock* latch) { loop_continue_ = latch; }
+  // Sets |merge| as the loop merge block. No checks are performed here.
+  inline void SetMergeBlockImpl(BasicBlock* merge) { loop_merge_ = merge; }
+
   // This is only to allow LoopDescriptor::dummy_top_loop_ to add top level
   // loops as child.
   friend class LoopDescriptor;
+  friend class LoopUtils;
 };
 
 // Loop descriptions class for a given function.
@@ -231,6 +290,11 @@ class LoopDescriptor {
     return const_iterator::end(&dummy_top_loop_);
   }
 
+  // Returns the inner most loop that contains the basic block |bb|.
+  inline void SetBasicBlockToLoop(uint32_t bb_id, Loop* loop) {
+    basic_block_to_loop_[bb_id] = loop;
+  }
+
  private:
   using LoopContainerType = std::vector<std::unique_ptr<Loop>>;
 
diff --git a/source/opt/loop_utils.cpp b/source/opt/loop_utils.cpp
new file mode 100644 (file)
index 0000000..86bc2eb
--- /dev/null
@@ -0,0 +1,485 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <algorithm>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "opt/cfg.h"
+#include "opt/ir_builder.h"
+#include "opt/ir_context.h"
+#include "opt/loop_descriptor.h"
+#include "opt/loop_utils.h"
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+// Return true if |bb| is dominated by at least one block in |exits|
+static inline bool DominatesAnExit(
+    ir::BasicBlock* bb, const std::unordered_set<ir::BasicBlock*>& exits,
+    const opt::DominatorTree& dom_tree) {
+  for (ir::BasicBlock* e_bb : exits)
+    if (dom_tree.Dominates(bb, e_bb)) return true;
+  return false;
+}
+
+// Utility class to rewrite out-of-loop uses of an in-loop definition in terms
+// of phi instructions to achieve a LCSSA form.
+// For a given definition, the class user registers phi instructions using that
+// definition in all loop exit blocks by which the definition escapes.
+// Then, when rewriting a use of the definition, the rewriter walks the
+// paths from the use the loop exits. At each step, it will insert a phi
+// instruction to merge the incoming value according to exit blocks definition.
+class LCSSARewriter {
+ public:
+  LCSSARewriter(ir::IRContext* context, const opt::DominatorTree& dom_tree,
+                const std::unordered_set<ir::BasicBlock*>& exit_bb,
+                ir::BasicBlock* merge_block)
+      : context_(context),
+        cfg_(context_->cfg()),
+        dom_tree_(dom_tree),
+        exit_bb_(exit_bb),
+        merge_block_id_(merge_block ? merge_block->id() : 0) {}
+
+  struct UseRewriter {
+    explicit UseRewriter(LCSSARewriter* base, const ir::Instruction& def_insn)
+        : base_(base), def_insn_(def_insn) {}
+    // Rewrites the use of |def_insn_| by the instruction |user| at the index
+    // |operand_index| in terms of phi instruction. This recursively builds new
+    // phi instructions from |user| to the loop exit blocks' phis. The use of
+    // |def_insn_| in |user| is replaced by the relevant phi instruction at the
+    // end of the operation.
+    // It is assumed that |user| does not dominates any of the loop exit basic
+    // block. This operation does not update the def/use manager, instead it
+    // records what needs to be updated. The actual update is performed by
+    // UpdateManagers.
+    void RewriteUse(ir::BasicBlock* bb, ir::Instruction* user,
+                    uint32_t operand_index) {
+      assert(
+          (user->opcode() != SpvOpPhi || bb != GetParent(user)) &&
+          "The root basic block must be the incoming edge if |user| is a phi "
+          "instruction");
+      assert((user->opcode() == SpvOpPhi || bb == GetParent(user)) &&
+             "The root basic block must be the instruction parent if |user| is "
+             "not "
+             "phi instruction");
+
+      ir::Instruction* new_def = GetOrBuildIncoming(bb->id());
+
+      user->SetOperand(operand_index, {new_def->result_id()});
+      rewritten_.insert(user);
+    }
+
+    // In-place update of some managers (avoid full invalidation).
+    inline void UpdateManagers() {
+      opt::analysis::DefUseManager* def_use_mgr =
+          base_->context_->get_def_use_mgr();
+      // Register all new definitions.
+      for (ir::Instruction* insn : rewritten_) {
+        def_use_mgr->AnalyzeInstDef(insn);
+      }
+      // Register all new uses.
+      for (ir::Instruction* insn : rewritten_) {
+        def_use_mgr->AnalyzeInstUse(insn);
+      }
+    }
+
+   private:
+    // Return the basic block that |instr| belongs to.
+    ir::BasicBlock* GetParent(ir::Instruction* instr) {
+      return base_->context_->get_instr_block(instr);
+    }
+
+    // Builds a phi instruction for the basic block |bb|. The function assumes
+    // that |defining_blocks| contains the list of basic block that define the
+    // usable value for each predecessor of |bb|.
+    inline ir::Instruction* CreatePhiInstruction(
+        ir::BasicBlock* bb, const std::vector<uint32_t>& defining_blocks) {
+      std::vector<uint32_t> incomings;
+      const std::vector<uint32_t>& bb_preds = base_->cfg_->preds(bb->id());
+      assert(bb_preds.size() == defining_blocks.size());
+      for (size_t i = 0; i < bb_preds.size(); i++) {
+        incomings.push_back(
+            GetOrBuildIncoming(defining_blocks[i])->result_id());
+        incomings.push_back(bb_preds[i]);
+      }
+      opt::InstructionBuilder builder(
+          base_->context_, &*bb->begin(),
+          ir::IRContext::kAnalysisInstrToBlockMapping);
+      ir::Instruction* incoming_phi =
+          builder.AddPhi(def_insn_.type_id(), incomings);
+
+      rewritten_.insert(incoming_phi);
+      return incoming_phi;
+    }
+
+    // Builds a phi instruction for the basic block |bb|, all incoming values
+    // will be |value|.
+    inline ir::Instruction* CreatePhiInstruction(ir::BasicBlock* bb,
+                                                 const ir::Instruction& value) {
+      std::vector<uint32_t> incomings;
+      const std::vector<uint32_t>& bb_preds = base_->cfg_->preds(bb->id());
+      for (size_t i = 0; i < bb_preds.size(); i++) {
+        incomings.push_back(value.result_id());
+        incomings.push_back(bb_preds[i]);
+      }
+      opt::InstructionBuilder builder(
+          base_->context_, &*bb->begin(),
+          ir::IRContext::kAnalysisInstrToBlockMapping);
+      ir::Instruction* incoming_phi =
+          builder.AddPhi(def_insn_.type_id(), incomings);
+
+      rewritten_.insert(incoming_phi);
+      return incoming_phi;
+    }
+
+    // Return the new def to use for the basic block |bb_id|.
+    // If |bb_id| does not have a suitable def to use then we:
+    //   - return the common def used by all predecessors;
+    //   - if there is no common def, then we build a new phi instr at the
+    //     beginning of |bb_id| and return this new instruction.
+    ir::Instruction* GetOrBuildIncoming(uint32_t bb_id) {
+      assert(base_->cfg_->block(bb_id) != nullptr && "Unknown basic block");
+
+      ir::Instruction*& incoming_phi = bb_to_phi_[bb_id];
+      if (incoming_phi) {
+        return incoming_phi;
+      }
+
+      ir::BasicBlock* bb = &*base_->cfg_->block(bb_id);
+      // If this is an exit basic block, look if there already is an eligible
+      // phi instruction. An eligible phi has |def_insn_| as all incoming
+      // values.
+      if (base_->exit_bb_.count(bb)) {
+        // Look if there is an eligible phi in this block.
+        if (!bb->WhileEachPhiInst([&incoming_phi, this](ir::Instruction* phi) {
+              for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
+                if (phi->GetSingleWordInOperand(i) != def_insn_.result_id())
+                  return true;
+              }
+              incoming_phi = phi;
+              rewritten_.insert(incoming_phi);
+              return false;
+            })) {
+          return incoming_phi;
+        }
+        incoming_phi = CreatePhiInstruction(bb, def_insn_);
+        return incoming_phi;
+      }
+
+      // Get the block that defines the value to use for each predecessor.
+      // If the vector has 1 value, then it means that this block does not need
+      // to build a phi instruction unless |bb_id| is the loop merge block.
+      const std::vector<uint32_t>& defining_blocks =
+          base_->GetDefiningBlocks(bb_id);
+
+      // Special case for structured loops: merge block might be different from
+      // the exit block set. To maintain structured properties it will ease
+      // transformations if the merge block also holds a phi instruction like
+      // the exit ones.
+      if (defining_blocks.size() > 1 || bb_id == base_->merge_block_id_) {
+        if (defining_blocks.size() > 1) {
+          incoming_phi = CreatePhiInstruction(bb, defining_blocks);
+        } else {
+          assert(bb_id == base_->merge_block_id_);
+          incoming_phi =
+              CreatePhiInstruction(bb, *GetOrBuildIncoming(defining_blocks[0]));
+        }
+      } else {
+        incoming_phi = GetOrBuildIncoming(defining_blocks[0]);
+      }
+
+      return incoming_phi;
+    }
+
+    LCSSARewriter* base_;
+    const ir::Instruction& def_insn_;
+    std::unordered_map<uint32_t, ir::Instruction*> bb_to_phi_;
+    std::unordered_set<ir::Instruction*> rewritten_;
+  };
+
+ private:
+  // Return the new def to use for the basic block |bb_id|.
+  // If |bb_id| does not have a suitable def to use then we:
+  //   - return the common def used by all predecessors;
+  //   - if there is no common def, then we build a new phi instr at the
+  //     beginning of |bb_id| and return this new instruction.
+  const std::vector<uint32_t>& GetDefiningBlocks(uint32_t bb_id) {
+    assert(cfg_->block(bb_id) != nullptr && "Unknown basic block");
+    std::vector<uint32_t>& defining_blocks = bb_to_defining_blocks_[bb_id];
+
+    if (defining_blocks.size()) return defining_blocks;
+
+    // Check if one of the loop exit basic block dominates |bb_id|.
+    for (const ir::BasicBlock* e_bb : exit_bb_) {
+      if (dom_tree_.Dominates(e_bb->id(), bb_id)) {
+        defining_blocks.push_back(e_bb->id());
+        return defining_blocks;
+      }
+    }
+
+    // Process parents, they will returns their suitable blocks.
+    // If they are all the same, this means this basic block is dominated by a
+    // common block, so we won't need to build a phi instruction.
+    for (uint32_t pred_id : cfg_->preds(bb_id)) {
+      const std::vector<uint32_t>& pred_blocks = GetDefiningBlocks(pred_id);
+      if (pred_blocks.size() == 1)
+        defining_blocks.push_back(pred_blocks[0]);
+      else
+        defining_blocks.push_back(pred_id);
+    }
+    assert(defining_blocks.size());
+    if (std::all_of(defining_blocks.begin(), defining_blocks.end(),
+                    [&defining_blocks](uint32_t id) {
+                      return id == defining_blocks[0];
+                    })) {
+      // No need for a phi.
+      defining_blocks.resize(1);
+    }
+
+    return defining_blocks;
+  }
+
+  ir::IRContext* context_;
+  ir::CFG* cfg_;
+  const opt::DominatorTree& dom_tree_;
+  const std::unordered_set<ir::BasicBlock*>& exit_bb_;
+  uint32_t merge_block_id_;
+  // This map represent the set of known paths. For each key, the vector
+  // represent the set of blocks holding the definition to be used to build the
+  // phi instruction.
+  // If the vector has 0 value, then the path is unknown yet, and must be built.
+  // If the vector has 1 value, then the value defined by that basic block
+  //   should be used.
+  // If the vector has more than 1 value, then a phi node must be created, the
+  //   basic block ordering is the same as the predecessor ordering.
+  std::unordered_map<uint32_t, std::vector<uint32_t>> bb_to_defining_blocks_;
+};
+
+// Make the set |blocks| closed SSA. The set is closed SSA if all the uses
+// outside the set are phi instructions in exiting basic block set (hold by
+// |lcssa_rewriter|).
+inline void MakeSetClosedSSA(ir::IRContext* context, ir::Function* function,
+                             const std::unordered_set<uint32_t>& blocks,
+                             const std::unordered_set<ir::BasicBlock*>& exit_bb,
+                             LCSSARewriter* lcssa_rewriter) {
+  ir::CFG& cfg = *context->cfg();
+  opt::DominatorTree& dom_tree =
+      context->GetDominatorAnalysis(function, cfg)->GetDomTree();
+  opt::analysis::DefUseManager* def_use_manager = context->get_def_use_mgr();
+
+  for (uint32_t bb_id : blocks) {
+    ir::BasicBlock* bb = cfg.block(bb_id);
+    // If bb does not dominate an exit block, then it cannot have escaping defs.
+    if (!DominatesAnExit(bb, exit_bb, dom_tree)) continue;
+    for (ir::Instruction& inst : *bb) {
+      LCSSARewriter::UseRewriter rewriter(lcssa_rewriter, inst);
+      def_use_manager->ForEachUse(
+          &inst, [&blocks, &rewriter, &exit_bb, context](
+                     ir::Instruction* use, uint32_t operand_index) {
+            ir::BasicBlock* use_parent = context->get_instr_block(use);
+            assert(use_parent);
+            if (blocks.count(use_parent->id())) return;
+
+            if (use->opcode() == SpvOpPhi) {
+              // If the use is a Phi instruction and the incoming block is
+              // coming from the loop, then that's consistent with LCSSA form.
+              if (exit_bb.count(use_parent)) {
+                return;
+              } else {
+                // That's not an exit block, but the user is a phi instruction.
+                // Consider the incoming branch only.
+                use_parent = context->get_instr_block(
+                    use->GetSingleWordOperand(operand_index + 1));
+              }
+            }
+            // Rewrite the use. Note that this call does not invalidate the
+            // def/use manager. So this operation is safe.
+            rewriter.RewriteUse(use_parent, use, operand_index);
+          });
+      rewriter.UpdateManagers();
+    }
+  }
+}
+
+}  // namespace
+
+void LoopUtils::CreateLoopDedicatedExits() {
+  ir::Function* function = loop_->GetHeaderBlock()->GetParent();
+  ir::LoopDescriptor& loop_desc = *context_->GetLoopDescriptor(function);
+  ir::CFG& cfg = *context_->cfg();
+  opt::analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
+
+  const ir::IRContext::Analysis PreservedAnalyses =
+      ir::IRContext::kAnalysisDefUse |
+      ir::IRContext::kAnalysisInstrToBlockMapping;
+
+  // Gathers the set of basic block that are not in this loop and have at least
+  // one predecessor in the loop and one not in the loop.
+  std::unordered_set<uint32_t> exit_bb_set;
+  loop_->GetExitBlocks(context_, &exit_bb_set);
+
+  std::unordered_set<ir::BasicBlock*> new_loop_exits;
+  bool made_change = false;
+  // For each block, we create a new one that gathers all branches from
+  // the loop and fall into the block.
+  for (uint32_t non_dedicate_id : exit_bb_set) {
+    ir::BasicBlock* non_dedicate = cfg.block(non_dedicate_id);
+    const std::vector<uint32_t>& bb_pred = cfg.preds(non_dedicate_id);
+    // Ignore the block if all the predecessors are in the loop.
+    if (std::all_of(bb_pred.begin(), bb_pred.end(),
+                    [this](uint32_t id) { return loop_->IsInsideLoop(id); })) {
+      new_loop_exits.insert(non_dedicate);
+      continue;
+    }
+
+    made_change = true;
+    ir::Function::iterator insert_pt = function->begin();
+    for (; insert_pt != function->end() && &*insert_pt != non_dedicate;
+         ++insert_pt) {
+    }
+    assert(insert_pt != function->end() && "Basic Block not found");
+
+    // Create the dedicate exit basic block.
+    ir::BasicBlock& exit = *insert_pt.InsertBefore(
+        std::unique_ptr<ir::BasicBlock>(new ir::BasicBlock(
+            std::unique_ptr<ir::Instruction>(new ir::Instruction(
+                context_, SpvOpLabel, 0, context_->TakeNextId(), {})))));
+    exit.SetParent(function);
+
+    // Redirect in loop predecessors to |exit| block.
+    for (uint32_t exit_pred_id : bb_pred) {
+      if (loop_->IsInsideLoop(exit_pred_id)) {
+        ir::BasicBlock* pred_block = cfg.block(exit_pred_id);
+        pred_block->ForEachSuccessorLabel([non_dedicate, &exit](uint32_t* id) {
+          if (*id == non_dedicate->id()) *id = exit.id();
+        });
+        // Update the CFG.
+        // |non_dedicate|'s predecessor list will be updated at the end of the
+        // loop.
+        cfg.RegisterBlock(pred_block);
+      }
+    }
+
+    // Register the label to the def/use manager, requires for the phi patching.
+    def_use_mgr->AnalyzeInstDefUse(exit.GetLabelInst());
+    context_->set_instr_block(exit.GetLabelInst(), &exit);
+
+    opt::InstructionBuilder builder(context_, &exit, PreservedAnalyses);
+    // Now jump from our dedicate basic block to the old exit.
+    // We also reset the insert point so all instructions are inserted before
+    // the branch.
+    builder.SetInsertPoint(builder.AddBranch(non_dedicate->id()));
+    non_dedicate->ForEachPhiInst([&builder, &exit, def_use_mgr,
+                                  this](ir::Instruction* phi) {
+      // New phi operands for this instruction.
+      std::vector<uint32_t> new_phi_op;
+      // Phi operands for the dedicated exit block.
+      std::vector<uint32_t> exit_phi_op;
+      for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
+        uint32_t def_id = phi->GetSingleWordInOperand(i);
+        uint32_t incoming_id = phi->GetSingleWordInOperand(i + 1);
+        if (loop_->IsInsideLoop(incoming_id)) {
+          exit_phi_op.push_back(def_id);
+          exit_phi_op.push_back(incoming_id);
+        } else {
+          new_phi_op.push_back(def_id);
+          new_phi_op.push_back(incoming_id);
+        }
+      }
+
+      // Build the new phi instruction dedicated exit block.
+      ir::Instruction* exit_phi = builder.AddPhi(phi->type_id(), exit_phi_op);
+      // Build the new incoming branch.
+      new_phi_op.push_back(exit_phi->result_id());
+      new_phi_op.push_back(exit.id());
+      // Rewrite operands.
+      uint32_t idx = 0;
+      for (; idx < new_phi_op.size(); idx++)
+        phi->SetInOperand(idx, {new_phi_op[idx]});
+      // Remove extra operands, from last to first (more efficient).
+      for (uint32_t j = phi->NumInOperands() - 1; j >= idx; j--)
+        phi->RemoveInOperand(j);
+      // Update the def/use manager for this |phi|.
+      def_use_mgr->AnalyzeInstUse(phi);
+    });
+    // Update the CFG.
+    cfg.RegisterBlock(&exit);
+    cfg.RemoveNonExistingEdges(non_dedicate->id());
+    new_loop_exits.insert(&exit);
+    // If non_dedicate is in a loop, add the new dedicated exit in that loop.
+    if (ir::Loop* parent_loop = loop_desc[non_dedicate])
+      parent_loop->AddBasicBlock(&exit);
+  }
+
+  if (new_loop_exits.size() == 1) {
+    loop_->SetMergeBlock(*new_loop_exits.begin());
+  }
+
+  if (made_change) {
+    context_->InvalidateAnalysesExceptFor(
+        PreservedAnalyses | ir::IRContext::kAnalysisCFG |
+        ir::IRContext::Analysis::kAnalysisLoopAnalysis);
+  }
+}
+
+void LoopUtils::MakeLoopClosedSSA() {
+  CreateLoopDedicatedExits();
+
+  ir::Function* function = loop_->GetHeaderBlock()->GetParent();
+  ir::CFG& cfg = *context_->cfg();
+  opt::DominatorTree& dom_tree =
+      context_->GetDominatorAnalysis(function, cfg)->GetDomTree();
+
+  std::unordered_set<ir::BasicBlock*> exit_bb;
+  {
+    std::unordered_set<uint32_t> exit_bb_id;
+    loop_->GetExitBlocks(context_, &exit_bb_id);
+    for (uint32_t bb_id : exit_bb_id) {
+      exit_bb.insert(cfg.block(bb_id));
+    }
+  }
+
+  LCSSARewriter lcssa_rewriter(context_, dom_tree, exit_bb,
+                               loop_->GetMergeBlock());
+  MakeSetClosedSSA(context_, function, loop_->GetBlocks(), exit_bb,
+                   &lcssa_rewriter);
+
+  // Make sure all defs post-dominated by the merge block have their last use no
+  // further than the merge block.
+  if (loop_->GetMergeBlock()) {
+    std::unordered_set<uint32_t> merging_bb_id;
+    loop_->GetMergingBlocks(context_, &merging_bb_id);
+    merging_bb_id.erase(loop_->GetMergeBlock()->id());
+    // Reset the exit set, now only the merge block is the exit.
+    exit_bb.clear();
+    exit_bb.insert(loop_->GetMergeBlock());
+    // LCSSARewriter is reusable here only because it forces the creation of a
+    // phi instruction in the merge block.
+    MakeSetClosedSSA(context_, function, merging_bb_id, exit_bb,
+                     &lcssa_rewriter);
+  }
+
+  context_->InvalidateAnalysesExceptFor(
+      ir::IRContext::Analysis::kAnalysisDefUse |
+      ir::IRContext::Analysis::kAnalysisCFG |
+      ir::IRContext::Analysis::kAnalysisDominatorAnalysis |
+      ir::IRContext::Analysis::kAnalysisLoopAnalysis);
+}
+
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/loop_utils.h b/source/opt/loop_utils.h
new file mode 100644 (file)
index 0000000..fcee9e3
--- /dev/null
@@ -0,0 +1,75 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_LOOP_UTILS_H_
+#define LIBSPIRV_OPT_LOOP_UTILS_H_
+
+namespace spvtools {
+
+namespace ir {
+class Loop;
+class IRContext;
+}  // namespace ir
+
+namespace opt {
+
+// Set of basic loop transformation.
+class LoopUtils {
+ public:
+  LoopUtils(ir::IRContext* context, ir::Loop* loop)
+      : context_(context), loop_(loop) {}
+
+  // The make the current loop in the loop closed SSA form.
+  // In the loop closed SSA, all loop exiting values goes through a dedicate SSA
+  // instruction. For instance:
+  //
+  // for (...) {
+  //   A1 = ...
+  //   if (...)
+  //     A2 = ...
+  //   A = phi A1, A2
+  // }
+  // ... = op A ...
+  //
+  // Becomes
+  //
+  // for (...) {
+  //   A1 = ...
+  //   if (...)
+  //     A2 = ...
+  //   A = phi A1, A2
+  // }
+  // C = phi A
+  // ... = op C ...
+  //
+  // This makes some loop transformations (such as loop unswitch) simpler
+  // (removes the needs to take care of exiting variables).
+  void MakeLoopClosedSSA();
+
+  // Create dedicate exit basic block. This ensure all exit basic blocks has the
+  // loop as sole predecessors.
+  // By construction, structured control flow already has a dedicated exit
+  // block.
+  // Preserves: CFG, def/use and instruction to block mapping.
+  void CreateLoopDedicatedExits();
+
+ private:
+  ir::IRContext* context_;
+  ir::Loop* loop_;
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // LIBSPIRV_OPT_LOOP_UTILS_H_
index b5360c2..613ccf5 100644 (file)
@@ -25,3 +25,8 @@ add_spvtools_unittest(TARGET loop_descriptor_nested
     LIBS SPIRV-Tools-opt
 )
 
+add_spvtools_unittest(TARGET lcssa_test
+    SRCS ../function_utils.h
+        lcssa.cpp
+    LIBS SPIRV-Tools-opt
+)
diff --git a/test/opt/loop_optimizations/lcssa.cpp b/test/opt/loop_optimizations/lcssa.cpp
new file mode 100644 (file)
index 0000000..95d5a2a
--- /dev/null
@@ -0,0 +1,614 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gmock/gmock.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#ifdef SPIRV_EFFCEE
+#include "effcee/effcee.h"
+#endif
+
+#include "../assembly_builder.h"
+#include "../function_utils.h"
+
+#include "opt/build_module.h"
+#include "opt/loop_descriptor.h"
+#include "opt/loop_utils.h"
+#include "opt/pass.h"
+
+namespace {
+
+using namespace spvtools;
+
+#ifdef SPIRV_EFFCEE
+
+bool Validate(const std::vector<uint32_t>& bin) {
+  spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
+  spv_context spvContext = spvContextCreate(target_env);
+  spv_diagnostic diagnostic = nullptr;
+  spv_const_binary_t binary = {bin.data(), bin.size()};
+  spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
+  if (error != 0) spvDiagnosticPrint(diagnostic);
+  spvDiagnosticDestroy(diagnostic);
+  spvContextDestroy(spvContext);
+  return error == 0;
+}
+
+void Match(const std::string& original, ir::IRContext* context,
+           bool do_validation = true) {
+  std::vector<uint32_t> bin;
+  context->module()->ToBinary(&bin, true);
+  if (do_validation) {
+    EXPECT_TRUE(Validate(bin));
+  }
+  std::string assembly;
+  SpirvTools tools(SPV_ENV_UNIVERSAL_1_2);
+  EXPECT_TRUE(
+      tools.Disassemble(bin, &assembly, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER))
+      << "Disassembling failed for shader:\n"
+      << assembly << std::endl;
+  auto match_result = effcee::Match(assembly, original);
+  EXPECT_EQ(effcee::Result::Status::Ok, match_result.status())
+      << match_result.message() << "\nChecking result:\n"
+      << assembly;
+}
+
+using LCSSATest = ::testing::Test;
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  int i = 0;
+  for (; i < 10; i++) {
+  }
+  if (i != 0) {
+    i = 1;
+  }
+}
+*/
+TEST_F(LCSSATest, SimpleLCSSA) {
+  const std::string text = R"(
+; CHECK: OpLoopMerge [[merge:%\w+]] %19 None
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} %30 %20
+; CHECK-NEXT: %27 = OpINotEqual {{%\w+}} [[phi]] %9
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main" %3
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+               OpName %3 "c"
+               OpDecorate %3 Location 0
+          %5 = OpTypeVoid
+          %6 = OpTypeFunction %5
+          %7 = OpTypeInt 32 1
+          %8 = OpTypePointer Function %7
+          %9 = OpConstant %7 0
+         %10 = OpConstant %7 10
+         %11 = OpTypeBool
+         %12 = OpConstant %7 1
+         %13 = OpTypeFloat 32
+         %14 = OpTypeVector %13 4
+         %15 = OpTypePointer Output %14
+          %3 = OpVariable %15 Output
+          %2 = OpFunction %5 None %6
+         %16 = OpLabel
+               OpBranch %17
+         %17 = OpLabel
+         %30 = OpPhi %7 %9 %16 %25 %19
+               OpLoopMerge %18 %19 None
+               OpBranch %20
+         %20 = OpLabel
+         %22 = OpSLessThan %11 %30 %10
+               OpBranchConditional %22 %23 %18
+         %23 = OpLabel
+               OpBranch %19
+         %19 = OpLabel
+         %25 = OpIAdd %7 %30 %12
+               OpBranch %17
+         %18 = OpLabel
+         %27 = OpINotEqual %11 %30 %9
+               OpSelectionMerge %28 None
+               OpBranchConditional %27 %29 %28
+         %29 = OpLabel
+               OpBranch %28
+         %28 = OpLabel
+         %31 = OpPhi %7 %30 %18 %12 %29
+               OpReturn
+               OpFunctionEnd
+  )";
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor ld{f};
+
+  ir::Loop* loop = ld[17];
+  EXPECT_FALSE(loop->IsLCSSA());
+  opt::LoopUtils Util(context.get(), loop);
+  Util.MakeLoopClosedSSA();
+  EXPECT_TRUE(loop->IsLCSSA());
+  Match(text, context.get());
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  int i = 0;
+  for (; i < 10; i++) {
+  }
+  if (i != 0) {
+    i = 1;
+  }
+}
+*/
+// Same test as above, but should reuse an existing phi.
+TEST_F(LCSSATest, PhiReuseLCSSA) {
+  const std::string text = R"(
+; CHECK: OpLoopMerge [[merge:%\w+]] %19 None
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} %30 %20
+; CHECK-NEXT: %27 = OpINotEqual {{%\w+}} [[phi]] %9
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main" %3
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+               OpName %3 "c"
+               OpDecorate %3 Location 0
+          %5 = OpTypeVoid
+          %6 = OpTypeFunction %5
+          %7 = OpTypeInt 32 1
+          %8 = OpTypePointer Function %7
+          %9 = OpConstant %7 0
+         %10 = OpConstant %7 10
+         %11 = OpTypeBool
+         %12 = OpConstant %7 1
+         %13 = OpTypeFloat 32
+         %14 = OpTypeVector %13 4
+         %15 = OpTypePointer Output %14
+          %3 = OpVariable %15 Output
+          %2 = OpFunction %5 None %6
+         %16 = OpLabel
+               OpBranch %17
+         %17 = OpLabel
+         %30 = OpPhi %7 %9 %16 %25 %19
+               OpLoopMerge %18 %19 None
+               OpBranch %20
+         %20 = OpLabel
+         %22 = OpSLessThan %11 %30 %10
+               OpBranchConditional %22 %23 %18
+         %23 = OpLabel
+               OpBranch %19
+         %19 = OpLabel
+         %25 = OpIAdd %7 %30 %12
+               OpBranch %17
+         %18 = OpLabel
+         %32 = OpPhi %7 %30 %20
+         %27 = OpINotEqual %11 %30 %9
+               OpSelectionMerge %28 None
+               OpBranchConditional %27 %29 %28
+         %29 = OpLabel
+               OpBranch %28
+         %28 = OpLabel
+         %31 = OpPhi %7 %30 %18 %12 %29
+               OpReturn
+               OpFunctionEnd
+  )";
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor ld{f};
+
+  ir::Loop* loop = ld[17];
+  EXPECT_FALSE(loop->IsLCSSA());
+  opt::LoopUtils Util(context.get(), loop);
+  Util.MakeLoopClosedSSA();
+  EXPECT_TRUE(loop->IsLCSSA());
+  Match(text, context.get());
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  int i = 0;
+  int j = 0;
+  for (; i < 10; i++) {}
+  for (; j < 10; j++) {}
+  if (j != 0) {
+    i = 1;
+  }
+}
+*/
+TEST_F(LCSSATest, DualLoopLCSSA) {
+  const std::string text = R"(
+; CHECK: %20 = OpLabel
+; CHECK-NEXT: [[phi:%\w+]] = OpPhi %6 %17 %21
+; CHECK: %33 = OpLabel
+; CHECK-NEXT: {{%\w+}} = OpPhi {{%\w+}} [[phi]] %28 %11 %34
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main" %3
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+               OpName %3 "c"
+               OpDecorate %3 Location 0
+          %4 = OpTypeVoid
+          %5 = OpTypeFunction %4
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Function %6
+          %8 = OpConstant %6 0
+          %9 = OpConstant %6 10
+         %10 = OpTypeBool
+         %11 = OpConstant %6 1
+         %12 = OpTypeFloat 32
+         %13 = OpTypeVector %12 4
+         %14 = OpTypePointer Output %13
+          %3 = OpVariable %14 Output
+          %2 = OpFunction %4 None %5
+         %15 = OpLabel
+               OpBranch %16
+         %16 = OpLabel
+         %17 = OpPhi %6 %8 %15 %18 %19
+               OpLoopMerge %20 %19 None
+               OpBranch %21
+         %21 = OpLabel
+         %22 = OpSLessThan %10 %17 %9
+               OpBranchConditional %22 %23 %20
+         %23 = OpLabel
+               OpBranch %19
+         %19 = OpLabel
+         %18 = OpIAdd %6 %17 %11
+               OpBranch %16
+         %20 = OpLabel
+               OpBranch %24
+         %24 = OpLabel
+         %25 = OpPhi %6 %8 %20 %26 %27
+               OpLoopMerge %28 %27 None
+               OpBranch %29
+         %29 = OpLabel
+         %30 = OpSLessThan %10 %25 %9
+               OpBranchConditional %30 %31 %28
+         %31 = OpLabel
+               OpBranch %27
+         %27 = OpLabel
+         %26 = OpIAdd %6 %25 %11
+               OpBranch %24
+         %28 = OpLabel
+         %32 = OpINotEqual %10 %25 %8
+               OpSelectionMerge %33 None
+               OpBranchConditional %32 %34 %33
+         %34 = OpLabel
+               OpBranch %33
+         %33 = OpLabel
+         %35 = OpPhi %6 %17 %28 %11 %34
+               OpReturn
+               OpFunctionEnd
+  )";
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor ld{f};
+
+  ir::Loop* loop = ld[16];
+  EXPECT_FALSE(loop->IsLCSSA());
+  opt::LoopUtils Util(context.get(), loop);
+  Util.MakeLoopClosedSSA();
+  EXPECT_TRUE(loop->IsLCSSA());
+  Match(text, context.get());
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  int i = 0;
+  if (i != 0) {
+    for (; i < 10; i++) {}
+  }
+  if (i != 0) {
+    i = 1;
+  }
+}
+*/
+TEST_F(LCSSATest, PhiUserLCSSA) {
+  const std::string text = R"(
+; CHECK: OpLoopMerge [[merge:%\w+]] %22 None
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} %20 %24
+; CHECK: %17 = OpLabel
+; CHECK-NEXT: {{%\w+}} = OpPhi {{%\w+}} %8 %15 [[phi]] %23
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main" %3
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+               OpName %3 "c"
+               OpDecorate %3 Location 0
+          %4 = OpTypeVoid
+          %5 = OpTypeFunction %4
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Function %6
+          %8 = OpConstant %6 0
+          %9 = OpTypeBool
+         %10 = OpConstant %6 10
+         %11 = OpConstant %6 1
+         %12 = OpTypeFloat 32
+         %13 = OpTypeVector %12 4
+         %14 = OpTypePointer Output %13
+          %3 = OpVariable %14 Output
+          %2 = OpFunction %4 None %5
+         %15 = OpLabel
+         %16 = OpINotEqual %9 %8 %8
+               OpSelectionMerge %17 None
+               OpBranchConditional %16 %18 %17
+         %18 = OpLabel
+               OpBranch %19
+         %19 = OpLabel
+         %20 = OpPhi %6 %8 %18 %21 %22
+               OpLoopMerge %23 %22 None
+               OpBranch %24
+         %24 = OpLabel
+         %25 = OpSLessThan %9 %20 %10
+               OpBranchConditional %25 %26 %23
+         %26 = OpLabel
+               OpBranch %22
+         %22 = OpLabel
+         %21 = OpIAdd %6 %20 %11
+               OpBranch %19
+         %23 = OpLabel
+               OpBranch %17
+         %17 = OpLabel
+         %27 = OpPhi %6 %8 %15 %20 %23
+         %28 = OpINotEqual %9 %27 %8
+               OpSelectionMerge %29 None
+               OpBranchConditional %28 %30 %29
+         %30 = OpLabel
+               OpBranch %29
+         %29 = OpLabel
+         %31 = OpPhi %6 %27 %17 %11 %30
+               OpReturn
+               OpFunctionEnd
+  )";
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor ld{f};
+
+  ir::Loop* loop = ld[19];
+  EXPECT_FALSE(loop->IsLCSSA());
+  opt::LoopUtils Util(context.get(), loop);
+  Util.MakeLoopClosedSSA();
+  EXPECT_TRUE(loop->IsLCSSA());
+  Match(text, context.get());
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+void main() {
+  int i = 0;
+  if (i != 0) {
+    for (; i < 10; i++) {
+      if (i > 5) break;
+    }
+  }
+  if (i != 0) {
+    i = 1;
+  }
+}
+*/
+TEST_F(LCSSATest, LCSSAWithBreak) {
+  const std::string text = R"(
+; CHECK: OpLoopMerge [[merge:%\w+]] %19 None
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} %17 %21 %17 %26
+; CHECK: %14 = OpLabel
+; CHECK-NEXT: {{%\w+}} = OpPhi {{%\w+}} %7 %12 [[phi]] [[merge]]
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+          %3 = OpTypeVoid
+          %4 = OpTypeFunction %3
+          %5 = OpTypeInt 32 1
+          %6 = OpTypePointer Function %5
+          %7 = OpConstant %5 0
+          %8 = OpTypeBool
+          %9 = OpConstant %5 10
+         %10 = OpConstant %5 5
+         %11 = OpConstant %5 1
+          %2 = OpFunction %3 None %4
+         %12 = OpLabel
+         %13 = OpINotEqual %8 %7 %7
+               OpSelectionMerge %14 None
+               OpBranchConditional %13 %15 %14
+         %15 = OpLabel
+               OpBranch %16
+         %16 = OpLabel
+         %17 = OpPhi %5 %7 %15 %18 %19
+               OpLoopMerge %20 %19 None
+               OpBranch %21
+         %21 = OpLabel
+         %22 = OpSLessThan %8 %17 %9
+               OpBranchConditional %22 %23 %20
+         %23 = OpLabel
+         %24 = OpSGreaterThan %8 %17 %10
+               OpSelectionMerge %25 None
+               OpBranchConditional %24 %26 %25
+         %26 = OpLabel
+               OpBranch %20
+         %25 = OpLabel
+               OpBranch %19
+         %19 = OpLabel
+         %18 = OpIAdd %5 %17 %11
+               OpBranch %16
+         %20 = OpLabel
+               OpBranch %14
+         %14 = OpLabel
+         %27 = OpPhi %5 %7 %12 %17 %20
+         %28 = OpINotEqual %8 %27 %7
+               OpSelectionMerge %29 None
+               OpBranchConditional %28 %30 %29
+         %30 = OpLabel
+               OpBranch %29
+         %29 = OpLabel
+         %31 = OpPhi %5 %27 %14 %11 %30
+               OpReturn
+               OpFunctionEnd
+  )";
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor ld{f};
+
+  ir::Loop* loop = ld[19];
+  EXPECT_FALSE(loop->IsLCSSA());
+  opt::LoopUtils Util(context.get(), loop);
+  Util.MakeLoopClosedSSA();
+  EXPECT_TRUE(loop->IsLCSSA());
+  Match(text, context.get());
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 330 core
+void main() {
+  int i = 0;
+  for (; i < 10; i++) {}
+  for (int j = i; j < 10;) { j = i + j; }
+}
+*/
+TEST_F(LCSSATest, LCSSAUseInNonEligiblePhi) {
+  const std::string text = R"(
+; CHECK: %12 = OpLabel
+; CHECK-NEXT: [[def_to_close:%\w+]] = OpPhi {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} [[continue:%\w+]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]] None
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: [[closing_phi:%\w+]] = OpPhi {{%\w+}} [[def_to_close]] %17
+; CHECK: %16 = OpLabel
+; CHECK-NEXT: [[use_in_phi:%\w+]] = OpPhi {{%\w+}} %21 %22 [[closing_phi]] [[merge]]
+; CHECK: OpIAdd {{%\w+}} [[closing_phi]] [[use_in_phi]]
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+          %3 = OpTypeVoid
+          %4 = OpTypeFunction %3
+          %5 = OpTypeInt 32 1
+          %6 = OpTypePointer Function %5
+          %7 = OpConstant %5 0
+          %8 = OpConstant %5 10
+          %9 = OpTypeBool
+         %10 = OpConstant %5 1
+          %2 = OpFunction %3 None %4
+         %11 = OpLabel
+               OpBranch %12
+         %12 = OpLabel
+         %13 = OpPhi %5 %7 %11 %14 %15
+               OpLoopMerge %16 %15 None
+               OpBranch %17
+         %17 = OpLabel
+         %18 = OpSLessThan %9 %13 %8
+               OpBranchConditional %18 %19 %16
+         %19 = OpLabel
+               OpBranch %15
+         %15 = OpLabel
+         %14 = OpIAdd %5 %13 %10
+               OpBranch %12
+         %16 = OpLabel
+         %20 = OpPhi %5 %13 %17 %21 %22
+               OpLoopMerge %23 %22 None
+               OpBranch %24
+         %24 = OpLabel
+         %25 = OpSLessThan %9 %20 %8
+               OpBranchConditional %25 %26 %23
+         %26 = OpLabel
+         %21 = OpIAdd %5 %13 %20
+               OpBranch %22
+         %22 = OpLabel
+               OpBranch %16
+         %23 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor ld{f};
+
+  ir::Loop* loop = ld[12];
+  EXPECT_FALSE(loop->IsLCSSA());
+  opt::LoopUtils Util(context.get(), loop);
+  Util.MakeLoopClosedSSA();
+  EXPECT_TRUE(loop->IsLCSSA());
+  Match(text, context.get());
+}
+
+#endif  // SPIRV_EFFCEE
+
+}  // namespace
index d2a9110..a6032db 100644 (file)
@@ -200,6 +200,7 @@ TEST_F(PassClassTest, LoopWithNoPreHeader) {
 
   ir::Loop* loop = ld[27];
   EXPECT_EQ(loop->GetPreHeaderBlock(), nullptr);
+  EXPECT_NE(loop->GetOrCreatePreHeaderBlock(context.get()), nullptr);
 }
 
 }  // namespace
index be42dff..a2ebadf 100644 (file)
@@ -34,6 +34,18 @@ namespace {
 using namespace spvtools;
 using ::testing::UnorderedElementsAre;
 
+bool Validate(const std::vector<uint32_t>& bin) {
+  spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
+  spv_context spvContext = spvContextCreate(target_env);
+  spv_diagnostic diagnostic = nullptr;
+  spv_const_binary_t binary = {bin.data(), bin.size()};
+  spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
+  if (error != 0) spvDiagnosticPrint(diagnostic);
+  spvDiagnosticDestroy(diagnostic);
+  spvContextDestroy(spvContext);
+  return error == 0;
+}
+
 using PassClassTest = PassTest<::testing::Test>;
 
 /*
@@ -592,4 +604,193 @@ TEST_F(PassClassTest, LoopParentTest) {
   }
 }
 
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+The preheader of loop %33 and %41 were removed as well.
+
+#version 330 core
+void main() {
+  int a = 0;
+  for (int i = 0; i < 10; ++i) {
+    if (i == 0) {
+      a = 1;
+    } else {
+      a = 2;
+    }
+    for (int j = 0; j < 11; ++j) {
+      a++;
+    }
+  }
+  for (int k = 0; k < 12; ++k) {}
+}
+*/
+TEST_F(PassClassTest, CreatePreheaderTest) {
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource GLSL 330
+               OpName %2 "main"
+          %3 = OpTypeVoid
+          %4 = OpTypeFunction %3
+          %5 = OpTypeInt 32 1
+          %6 = OpTypePointer Function %5
+          %7 = OpConstant %5 0
+          %8 = OpConstant %5 10
+          %9 = OpTypeBool
+         %10 = OpConstant %5 1
+         %11 = OpConstant %5 2
+         %12 = OpConstant %5 11
+         %13 = OpConstant %5 12
+         %14 = OpUndef %5
+          %2 = OpFunction %3 None %4
+         %15 = OpLabel
+               OpBranch %16
+         %16 = OpLabel
+         %17 = OpPhi %5 %7 %15 %18 %19
+         %20 = OpPhi %5 %7 %15 %21 %19
+         %22 = OpPhi %5 %14 %15 %23 %19
+               OpLoopMerge %41 %19 None
+               OpBranch %25
+         %25 = OpLabel
+         %26 = OpSLessThan %9 %20 %8
+               OpBranchConditional %26 %27 %41
+         %27 = OpLabel
+         %28 = OpIEqual %9 %20 %7
+               OpSelectionMerge %33 None
+               OpBranchConditional %28 %30 %31
+         %30 = OpLabel
+               OpBranch %33
+         %31 = OpLabel
+               OpBranch %33
+         %33 = OpLabel
+         %18 = OpPhi %5 %10 %30 %11 %31 %34 %35
+         %23 = OpPhi %5 %7 %30 %7 %31 %36 %35
+               OpLoopMerge %37 %35 None
+               OpBranch %38
+         %38 = OpLabel
+         %39 = OpSLessThan %9 %23 %12
+               OpBranchConditional %39 %40 %37
+         %40 = OpLabel
+         %34 = OpIAdd %5 %18 %10
+               OpBranch %35
+         %35 = OpLabel
+         %36 = OpIAdd %5 %23 %10
+               OpBranch %33
+         %37 = OpLabel
+               OpBranch %19
+         %19 = OpLabel
+         %21 = OpIAdd %5 %20 %10
+               OpBranch %16
+         %41 = OpLabel
+         %42 = OpPhi %5 %7 %25 %43 %44
+               OpLoopMerge %45 %44 None
+               OpBranch %46
+         %46 = OpLabel
+         %47 = OpSLessThan %9 %42 %13
+               OpBranchConditional %47 %48 %45
+         %48 = OpLabel
+               OpBranch %44
+         %44 = OpLabel
+         %43 = OpIAdd %5 %42 %10
+               OpBranch %41
+         %45 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+  const ir::Function* f = spvtest::GetFunction(module, 2);
+  ir::LoopDescriptor& ld = *context->GetLoopDescriptor(f);
+  // No invalidation of the cfg should occur during this test.
+  ir::CFG* cfg = context->cfg();
+
+  EXPECT_EQ(ld.NumLoops(), 3u);
+
+  {
+    ir::Loop& loop = *ld[16];
+    EXPECT_TRUE(loop.HasNestedLoops());
+    EXPECT_FALSE(loop.IsNested());
+    EXPECT_EQ(loop.GetDepth(), 1u);
+    EXPECT_EQ(loop.GetParent(), nullptr);
+  }
+
+  {
+    ir::Loop& loop = *ld[33];
+    EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
+    EXPECT_NE(loop.GetOrCreatePreHeaderBlock(context.get()), nullptr);
+    // Make sure the loop descriptor was properly updated.
+    EXPECT_EQ(ld[loop.GetPreHeaderBlock()], ld[16]);
+    {
+      const std::vector<uint32_t>& preds =
+          cfg->preds(loop.GetPreHeaderBlock()->id());
+      std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
+      EXPECT_EQ(pred_set.size(), 2u);
+      EXPECT_TRUE(pred_set.count(30));
+      EXPECT_TRUE(pred_set.count(31));
+      // Check the phi instructions.
+      loop.GetPreHeaderBlock()->ForEachPhiInst(
+          [&pred_set](ir::Instruction* phi) {
+            for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
+              EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
+            }
+          });
+    }
+    {
+      const std::vector<uint32_t>& preds =
+          cfg->preds(loop.GetHeaderBlock()->id());
+      std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
+      EXPECT_EQ(pred_set.size(), 2u);
+      EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
+      EXPECT_TRUE(pred_set.count(35));
+      // Check the phi instructions.
+      loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](ir::Instruction* phi) {
+        for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
+          EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
+        }
+      });
+    }
+  }
+
+  {
+    ir::Loop& loop = *ld[41];
+    EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
+    EXPECT_NE(loop.GetOrCreatePreHeaderBlock(context.get()), nullptr);
+    EXPECT_EQ(ld[loop.GetPreHeaderBlock()], nullptr);
+    EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id()).size(), 1u);
+    EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id())[0], 25u);
+    // Check the phi instructions.
+    loop.GetPreHeaderBlock()->ForEachPhiInst([](ir::Instruction* phi) {
+      EXPECT_EQ(phi->NumInOperands(), 2u);
+      EXPECT_EQ(phi->GetSingleWordInOperand(1), 25u);
+    });
+    {
+      const std::vector<uint32_t>& preds =
+          cfg->preds(loop.GetHeaderBlock()->id());
+      std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
+      EXPECT_EQ(pred_set.size(), 2u);
+      EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
+      EXPECT_TRUE(pred_set.count(44));
+      // Check the phi instructions.
+      loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](ir::Instruction* phi) {
+        for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
+          EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
+        }
+      });
+    }
+  }
+
+  // Make sure pre-header insertion leaves the module valid.
+  std::vector<uint32_t> bin;
+  context->module()->ToBinary(&bin, true);
+  EXPECT_TRUE(Validate(bin));
+}
+
 }  // namespace