Initial support for loop unrolling.
authorStephen McGroarty <stephen@codeplay.com>
Wed, 14 Feb 2018 17:03:12 +0000 (17:03 +0000)
committerSteven Perron <31666470+s-perron@users.noreply.github.com>
Wed, 14 Feb 2018 20:44:38 +0000 (15:44 -0500)
This patch adds initial support for loop unrolling in the form of a
series of utility classes which perform the unrolling. The pass can
be run with the command spirv-opt --loop-unroll. This will unroll
loops within the module which have the unroll hint set. The unroller
imposes a number of requirements on the loops it can unroll. These are
documented in the comments for the LoopUtils::CanPerformUnroll method in
loop_utils.h. Some of the restrictions will be lifted in future patches.

16 files changed:
Android.mk
include/spirv-tools/optimizer.hpp
source/opt/CMakeLists.txt
source/opt/dominator_tree.h
source/opt/ir_builder.h
source/opt/loop_descriptor.cpp
source/opt/loop_descriptor.h
source/opt/loop_unroller.cpp [new file with mode: 0644]
source/opt/loop_unroller.h [new file with mode: 0644]
source/opt/loop_utils.h
source/opt/optimizer.cpp
source/opt/passes.h
test/opt/loop_optimizations/CMakeLists.txt
test/opt/loop_optimizations/unroll_assumptions.cpp [new file with mode: 0644]
test/opt/loop_optimizations/unroll_simple.cpp [new file with mode: 0644]
tools/opt/opt.cpp

index 8f5e55b..4429ad1 100644 (file)
@@ -100,6 +100,7 @@ SPVTOOLS_OPT_SRC_FILES := \
                source/opt/local_single_store_elim_pass.cpp \
                source/opt/local_ssa_elim_pass.cpp \
                source/opt/loop_descriptor.cpp \
+               source/opt/loop_unroller.cpp \
                source/opt/mem_pass.cpp \
                source/opt/merge_return_pass.cpp \
                source/opt/module.cpp \
index e738b50..9f3b360 100644 (file)
@@ -512,6 +512,14 @@ Optimizer::PassToken CreateReplaceInvalidOpcodePass();
 // Creates a pass that simplifies instructions using the instruction folder.
 Optimizer::PassToken CreateSimplificationPass();
 
+// Create loop unroller pass.
+// Creates a pass to fully unroll loops which have the "Unroll" loop control
+// mask set. The loops must meet a specific criteria in order to be unrolled
+// safely this criteria is checked before doing the unroll by the
+// LoopUtils::CanPerformUnroll method. Any loop that does not meet the criteria
+// won't be unrolled. See CanPerformUnroll LoopUtils.h for more information.
+Optimizer::PassToken CreateLoopFullyUnrollPass();
+
 }  // namespace spvtools
 
 #endif  // SPIRV_TOOLS_OPTIMIZER_HPP_
index b3c6ebe..0194851 100644 (file)
@@ -58,6 +58,7 @@ add_library(SPIRV-Tools-opt
   local_ssa_elim_pass.h
   log.h
   loop_descriptor.h
+  loop_unroller.h
   loop_utils.h
   make_unique.h
   mem_pass.h
@@ -130,6 +131,7 @@ add_library(SPIRV-Tools-opt
   local_ssa_elim_pass.cpp
   loop_descriptor.cpp
   loop_utils.cpp
+  loop_unroller.cpp
   mem_pass.cpp
   merge_return_pass.cpp
   module.cpp
index 0be4951..5221eea 100644 (file)
@@ -224,6 +224,19 @@ class DominatorTree {
     return true;
   }
 
+  // Applies the std::function |func| to all nodes in the dominator tree from
+  // |node| downwards. The boolean return from |func| is used to determine
+  // whether or not the children should also be traversed. Tree nodes are
+  // visited in a depth first pre-order.
+  void VisitChildrenIf(std::function<bool(DominatorTreeNode*)> func,
+                       iterator node) {
+    if (func(&*node)) {
+      for (auto n : *node) {
+        VisitChildrenIf(func, n->df_begin());
+      }
+    }
+  }
+
   // Returns the DominatorTreeNode associated with the basic block |bb|.
   // If the |bb| is unknown to the dominator tree, it returns null.
   inline DominatorTreeNode* GetTreeNode(ir::BasicBlock* bb) {
index 9c153de..3b75ec6 100644 (file)
@@ -16,9 +16,9 @@
 #define LIBSPIRV_OPT_IR_BUILDER_H_
 
 #include "opt/basic_block.h"
+#include "opt/constants.h"
 #include "opt/instruction.h"
 #include "opt/ir_context.h"
-
 namespace spvtools {
 namespace opt {
 
@@ -136,6 +136,12 @@ class InstructionBuilder {
     return AddInstruction(std::move(select));
   }
 
+  // Adds a signed int32 constant to the binary.
+  // The |value| parameter is the constant value to be added.
+  ir::Instruction* Add32BitSignedIntegerConstant(int32_t value) {
+    return Add32BitConstantInteger<int32_t>(value, true);
+  }
+
   // Create a composite construct.
   // |type| should be a composite type and the number of elements it has should
   // match the size od |ids|.
@@ -151,6 +157,38 @@ class InstructionBuilder {
                             GetContext()->TakeNextId(), ops));
     return AddInstruction(std::move(construct));
   }
+  // Adds an unsigned int32 constant to the binary.
+  // The |value| parameter is the constant value to be added.
+  ir::Instruction* Add32BitUnsignedIntegerConstant(uint32_t value) {
+    return Add32BitConstantInteger<uint32_t>(value, false);
+  }
+
+  // Adds either a signed or unsigned 32 bit integer constant to the binary
+  // depedning on the |sign|. If |sign| is true then the value is added as a
+  // signed constant otherwise as an unsigned constant. If |sign| is false the
+  // value must not be a negative number.
+  template <typename T>
+  ir::Instruction* Add32BitConstantInteger(T value, bool sign) {
+    // Assert that we are not trying to store a negative number in an unsigned
+    // type.
+    if (!sign)
+      assert(value > 0 &&
+             "Trying to add a signed integer with an unsigned type!");
+
+    // Get or create the integer type.
+    analysis::Integer int_type(32, sign);
+
+    // Even if the value is negative we need to pass the bit pattern as a
+    // uint32_t to GetConstant.
+    uint32_t word = value;
+
+    // Create the constant value.
+    const opt::analysis::Constant* constant =
+        GetContext()->get_constant_mgr()->GetConstant(&int_type, {word});
+
+    // Create the OpConstant instruction using the type and the value.
+    return GetContext()->get_constant_mgr()->GetDefiningInstruction(constant);
+  }
 
   ir::Instruction* AddCompositeExtract(
       uint32_t type, uint32_t id_of_composite,
index a2b9546..32765be 100644 (file)
@@ -18,6 +18,7 @@
 #include <utility>
 #include <vector>
 
+#include "constants.h"
 #include "opt/cfg.h"
 #include "opt/dominator_tree.h"
 #include "opt/ir_builder.h"
 namespace spvtools {
 namespace ir {
 
+// Takes in a phi instruction |induction| and the loop |header| and returns the
+// step operation of the loop.
+ir::Instruction* Loop::GetInductionStepOperation(
+    const ir::Loop* loop, const ir::Instruction* induction) const {
+  // Induction must be a phi instruction.
+  assert(induction->opcode() == SpvOpPhi);
+
+  ir::Instruction* step = nullptr;
+
+  opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+  // Traverse the incoming operands of the phi instruction.
+  for (uint32_t operand_id = 1; operand_id < induction->NumInOperands();
+       operand_id += 2) {
+    // Incoming edge.
+    ir::BasicBlock* incoming_block =
+        context_->cfg()->block(induction->GetSingleWordInOperand(operand_id));
+
+    // Check if the block is dominated by header, and thus coming from within
+    // the loop.
+    if (loop->IsInsideLoop(incoming_block)) {
+      step = def_use_manager->GetDef(
+          induction->GetSingleWordInOperand(operand_id - 1));
+      break;
+    }
+  }
+
+  if (!step || !IsSupportedStepOp(step->opcode())) {
+    return nullptr;
+  }
+
+  return step;
+}
+
+// Returns true if the |step| operation is an induction variable step operation
+// which is currently handled.
+bool Loop::IsSupportedStepOp(SpvOp step) const {
+  switch (step) {
+    case SpvOp::SpvOpISub:
+    case SpvOp::SpvOpIAdd:
+      return true;
+    default:
+      return false;
+  }
+}
+
+bool Loop::IsSupportedCondition(SpvOp condition) const {
+  switch (condition) {
+    // <
+    case SpvOp::SpvOpULessThan:
+    case SpvOp::SpvOpSLessThan:
+    // >
+    case SpvOp::SpvOpUGreaterThan:
+    case SpvOp::SpvOpSGreaterThan:
+      return true;
+    default:
+      return false;
+  }
+}
+
+// Extract the initial value from the |induction| OpPhi instruction and store it
+// in |value|. If the function couldn't find the initial value of |induction|
+// return false.
+bool Loop::GetInductionInitValue(const ir::Loop* loop,
+                                 const ir::Instruction* induction,
+                                 int64_t* value) const {
+  ir::Instruction* constant_instruction = nullptr;
+  opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+  for (uint32_t operand_id = 0; operand_id < induction->NumInOperands();
+       operand_id += 2) {
+    ir::BasicBlock* bb = context_->cfg()->block(
+        induction->GetSingleWordInOperand(operand_id + 1));
+
+    if (!loop->IsInsideLoop(bb)) {
+      constant_instruction = def_use_manager->GetDef(
+          induction->GetSingleWordInOperand(operand_id));
+    }
+  }
+
+  if (!constant_instruction) return false;
+
+  const opt::analysis::Constant* constant =
+      context_->get_constant_mgr()->FindDeclaredConstant(
+          constant_instruction->result_id());
+  if (!constant) return false;
+
+  if (value) {
+    const opt::analysis::Integer* type =
+        constant->AsIntConstant()->type()->AsInteger();
+
+    if (type->IsSigned()) {
+      *value = constant->AsIntConstant()->GetS32BitValue();
+    } else {
+      *value = constant->AsIntConstant()->GetU32BitValue();
+    }
+  }
+
+  return true;
+}
+
 Loop::Loop(IRContext* context, opt::DominatorAnalysis* dom_analysis,
            BasicBlock* header, BasicBlock* continue_target,
            BasicBlock* merge_target)
@@ -37,12 +139,11 @@ Loop::Loop(IRContext* context, opt::DominatorAnalysis* dom_analysis,
       loop_continue_(continue_target),
       loop_merge_(merge_target),
       loop_preheader_(nullptr),
-      parent_(nullptr) {
+      parent_(nullptr),
+      loop_is_marked_for_removal_(false) {
   assert(context);
   assert(dom_analysis);
   loop_preheader_ = FindLoopPreheader(dom_analysis);
-  AddBasicBlockToLoop(header);
-  AddBasicBlockToLoop(continue_target);
 }
 
 BasicBlock* Loop::FindLoopPreheader(opt::DominatorAnalysis* dom_analysis) {
@@ -92,7 +193,6 @@ bool Loop::IsInsideLoop(Instruction* inst) const {
 
 bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) {
   assert(bb->GetParent() && "The basic block does not belong to a function");
-
   opt::DominatorAnalysis* dom_analysis =
       context_->GetDominatorAnalysis(bb->GetParent(), *context_->cfg());
   if (!dom_analysis->Dominates(GetHeaderBlock(), bb)) return false;
@@ -219,14 +319,8 @@ void Loop::SetLatchBlock(BasicBlock* latch) {
 void Loop::SetMergeBlock(BasicBlock* merge) {
 #ifndef NDEBUG
   assert(merge->GetParent() && "The basic block does not belong to a function");
-  CFG& cfg = *merge->GetParent()->GetParent()->context()->cfg();
-
-  for (uint32_t pred : cfg.preds(merge->id())) {
-    assert(IsInsideLoop(pred) &&
-           "A predecessor of the merge block does not belong to the loop");
-  }
-  assert(!IsInsideLoop(merge) && "The merge block is in the loop");
 #endif  // NDEBUG
+  assert(!IsInsideLoop(merge) && "The merge block is in the loop");
 
   SetMergeBlockImpl(merge);
   if (GetHeaderBlock()->GetLoopMergeInst()) {
@@ -327,6 +421,7 @@ LoopDescriptor::~LoopDescriptor() { ClearLoops(); }
 
 void LoopDescriptor::PopulateList(const Function* f) {
   IRContext* context = f->GetParent()->context();
+
   opt::DominatorAnalysis* dom_analysis =
       context->GetDominatorAnalysis(f, *context->cfg());
 
@@ -384,7 +479,7 @@ void LoopDescriptor::PopulateList(const Function* f) {
            make_range(node.df_begin(), node.df_end())) {
         // Check if we are in the loop.
         if (dom_tree.Dominates(dom_merge_node, &loop_node)) continue;
-        current_loop->AddBasicBlockToLoop(loop_node.bb_);
+        current_loop->AddBasicBlock(loop_node.bb_);
         basic_block_to_loop_.insert(
             std::make_pair(loop_node.bb_->id(), current_loop));
       }
@@ -395,12 +490,262 @@ void LoopDescriptor::PopulateList(const Function* f) {
   }
 }
 
+ir::BasicBlock* Loop::FindConditionBlock() const {
+  const ir::Function& function = *loop_merge_->GetParent();
+  ir::BasicBlock* condition_block = nullptr;
+
+  const opt::DominatorAnalysis* dom_analysis =
+      context_->GetDominatorAnalysis(&function, *context_->cfg());
+  ir::BasicBlock* bb = dom_analysis->ImmediateDominator(loop_merge_);
+
+  if (!bb) return nullptr;
+
+  const ir::Instruction& branch = *bb->ctail();
+
+  // Make sure the branch is a conditional branch.
+  if (branch.opcode() != SpvOpBranchConditional) return nullptr;
+
+  // Make sure one of the two possible branches is to the merge block.
+  if (branch.GetSingleWordInOperand(1) == loop_merge_->id() ||
+      branch.GetSingleWordInOperand(2) == loop_merge_->id()) {
+    condition_block = bb;
+  }
+
+  return condition_block;
+}
+
+bool Loop::FindNumberOfIterations(const ir::Instruction* induction,
+                                  const ir::Instruction* branch_inst,
+                                  size_t* iterations_out,
+                                  int64_t* step_value_out,
+                                  int64_t* init_value_out) const {
+  // From the branch instruction find the branch condition.
+  opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+  // Condition instruction from the OpConditionalBranch.
+  ir::Instruction* condition =
+      def_use_manager->GetDef(branch_inst->GetSingleWordOperand(0));
+
+  assert(IsSupportedCondition(condition->opcode()));
+
+  // Get the constant manager from the ir context.
+  opt::analysis::ConstantManager* const_manager = context_->get_constant_mgr();
+
+  // Find the constant value used by the condition variable. Exit out if it
+  // isn't a constant int.
+  const opt::analysis::Constant* upper_bound =
+      const_manager->FindDeclaredConstant(condition->GetSingleWordOperand(3));
+  if (!upper_bound) return false;
+
+  // Must be integer because of the opcode on the condition.
+  int64_t condition_value = 0;
+
+  const opt::analysis::Integer* type =
+      upper_bound->AsIntConstant()->type()->AsInteger();
+
+  if (type->IsSigned()) {
+    condition_value = upper_bound->AsIntConstant()->GetS32BitValue();
+  } else {
+    condition_value = upper_bound->AsIntConstant()->GetU32BitValue();
+  }
+
+  // Find the instruction which is stepping through the loop.
+  ir::Instruction* step_inst = GetInductionStepOperation(this, induction);
+  if (!step_inst) return false;
+
+  // Find the constant value used by the condition variable.
+  const opt::analysis::Constant* step_constant =
+      const_manager->FindDeclaredConstant(step_inst->GetSingleWordOperand(3));
+  if (!step_constant) return false;
+
+  // Must be integer because of the opcode on the condition.
+  int64_t step_value = 0;
+
+  const opt::analysis::Integer* step_type =
+      step_constant->AsIntConstant()->type()->AsInteger();
+
+  if (step_type->IsSigned()) {
+    step_value = step_constant->AsIntConstant()->GetS32BitValue();
+  } else {
+    step_value = step_constant->AsIntConstant()->GetU32BitValue();
+  }
+
+  // If this is a subtraction step we should negate the step value.
+  if (step_inst->opcode() == SpvOp::SpvOpISub) {
+    step_value = -step_value;
+  }
+
+  // Find the inital value of the loop and make sure it is a constant integer.
+  int64_t init_value = 0;
+  if (!GetInductionInitValue(this, induction, &init_value)) return false;
+
+  // If iterations is non null then store the value in that.
+  if (iterations_out) {
+    int64_t num_itrs = GetIterations(condition->opcode(), condition_value,
+                                     init_value, step_value);
+
+    // If the loop body will not be reached return false.
+    if (num_itrs <= 0) {
+      return false;
+    }
+    assert(static_cast<size_t>(num_itrs) <= std::numeric_limits<size_t>::max());
+    *iterations_out = static_cast<size_t>(num_itrs);
+  }
+
+  if (step_value_out) {
+    *step_value_out = step_value;
+  }
+
+  if (init_value_out) {
+    *init_value_out = init_value;
+  }
+
+  return true;
+}
+
+// We retrieve the number of iterations using the following formula, diff /
+// |step_value| where diff is calculated differently according to the
+// |condition| and uses the |condition_value| and |init_value|. If diff /
+// |step_value| is NOT cleanly divisable then we add one to the sum.
+int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value,
+                            int64_t init_value, int64_t step_value) const {
+  int64_t diff = 0;
+
+  // Take the abs of - step values.
+  step_value = llabs(step_value);
+
+  switch (condition) {
+    case SpvOp::SpvOpSLessThan:
+    case SpvOp::SpvOpULessThan: {
+      diff = condition_value - init_value;
+      break;
+    }
+    case SpvOp::SpvOpSGreaterThan:
+    case SpvOp::SpvOpUGreaterThan: {
+      diff = init_value - condition_value;
+      break;
+    }
+    default:
+      assert(false &&
+             "Could not retrieve number of iterations from the loop condition. "
+             "Condition is not supported.");
+  }
+
+  int64_t result = diff / step_value;
+
+  if (diff % step_value != 0) {
+    result += 1;
+  }
+  return result;
+}
+
+ir::Instruction* Loop::FindInductionVariable(
+    const ir::BasicBlock* condition_block) const {
+  // Find the branch instruction.
+  const ir::Instruction& branch_inst = *condition_block->ctail();
+
+  ir::Instruction* induction = nullptr;
+  // Verify that the branch instruction is a conditional branch.
+  if (branch_inst.opcode() == SpvOp::SpvOpBranchConditional) {
+    // From the branch instruction find the branch condition.
+    opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+    // Find the instruction representing the condition used in the conditional
+    // branch.
+    ir::Instruction* condition =
+        def_use_manager->GetDef(branch_inst.GetSingleWordOperand(0));
+
+    // Ensure that the condition is a less than operation.
+    if (condition && IsSupportedCondition(condition->opcode())) {
+      // The left hand side operand of the operation.
+      ir::Instruction* variable_inst =
+          def_use_manager->GetDef(condition->GetSingleWordOperand(2));
+
+      // Make sure the variable instruction used is a phi.
+      if (!variable_inst || variable_inst->opcode() != SpvOpPhi) return nullptr;
+
+      // Make sure the phi instruction only has two incoming blocks. Each
+      // incoming block will be represented by two in operands in the phi
+      // instruction, the value and the block which that value came from. We
+      // assume the cannocalised phi will have two incoming values, one from the
+      // preheader and one from the continue block.
+      size_t max_supported_operands = 4;
+      if (variable_inst->NumInOperands() == max_supported_operands) {
+        // The operand index of the first incoming block label.
+        uint32_t operand_label_1 = 1;
+
+        // The operand index of the second incoming block label.
+        uint32_t operand_label_2 = 3;
+
+        // Make sure one of them is the preheader.
+        if (variable_inst->GetSingleWordInOperand(operand_label_1) !=
+                loop_preheader_->id() &&
+            variable_inst->GetSingleWordInOperand(operand_label_2) !=
+                loop_preheader_->id()) {
+          return nullptr;
+        }
+
+        // And make sure that the other is the latch block.
+        if (variable_inst->GetSingleWordInOperand(operand_label_1) !=
+                loop_continue_->id() &&
+            variable_inst->GetSingleWordInOperand(operand_label_2) !=
+                loop_continue_->id()) {
+          return nullptr;
+        }
+      } else {
+        return nullptr;
+      }
+
+      if (!FindNumberOfIterations(variable_inst, &branch_inst, nullptr))
+        return nullptr;
+      induction = variable_inst;
+    }
+  }
+
+  return induction;
+}
+
+// Add and remove loops which have been marked for addition and removal to
+// maintain the state of the loop descriptor class.
+void LoopDescriptor::PostModificationCleanup() {
+  LoopContainerType loops_to_remove_;
+  for (ir::Loop* loop : loops_) {
+    if (loop->IsMarkedForRemoval()) {
+      loops_to_remove_.push_back(loop);
+      if (loop->HasParent()) {
+        loop->GetParent()->RemoveChildLoop(loop);
+      }
+    }
+  }
+
+  for (ir::Loop* loop : loops_to_remove_) {
+    loops_.erase(std::find(loops_.begin(), loops_.end(), loop));
+  }
+
+  for (auto& pair : loops_to_add_) {
+    ir::Loop* parent = pair.first;
+    ir::Loop* loop = pair.second;
+
+    if (parent) {
+      loop->SetParent(nullptr);
+      parent->AddNestedLoop(loop);
+
+      for (uint32_t block_id : loop->GetBlocks()) {
+        parent->AddBasicBlock(block_id);
+      }
+    }
+
+    loops_.emplace_back(loop);
+  }
+
+  loops_to_add_.clear();
+}
+
 void LoopDescriptor::ClearLoops() {
   for (Loop* loop : loops_) {
     delete loop;
   }
   loops_.clear();
 }
-
 }  // namespace ir
 }  // namespace spvtools
index b4651d7..ef5f19f 100644 (file)
@@ -24,6 +24,7 @@
 #include <vector>
 
 #include "opt/basic_block.h"
+#include "opt/module.h"
 #include "opt/tree_iterator.h"
 
 namespace spvtools {
@@ -52,7 +53,8 @@ class Loop {
         loop_continue_(nullptr),
         loop_merge_(nullptr),
         loop_preheader_(nullptr),
-        parent_(nullptr) {}
+        parent_(nullptr),
+        loop_is_marked_for_removal_(false) {}
 
   Loop(IRContext* context, opt::DominatorAnalysis* analysis, BasicBlock* header,
        BasicBlock* continue_target, BasicBlock* merge_target);
@@ -144,6 +146,8 @@ class Loop {
     return lvl;
   }
 
+  inline size_t NumImmediateChildren() const { return nested_loops_.size(); }
+
   // Adds |nested| as a nested loop of this loop. Automatically register |this|
   // as the parent of |nested|.
   inline void AddNestedLoop(Loop* nested) {
@@ -180,6 +184,21 @@ class Loop {
   // Returns true if the instruction |inst| is inside this loop.
   bool IsInsideLoop(Instruction* inst) const;
 
+  // Adds the Basic Block |bb| to this loop and its parents.
+  void AddBasicBlock(const BasicBlock* bb) { AddBasicBlock(bb->id()); }
+
+  // Adds the Basic Block with |id| to this loop and its parents.
+  void AddBasicBlock(uint32_t id) {
+    for (Loop* loop = this; loop != nullptr; loop = loop->parent_) {
+      loop_basic_blocks_.insert(id);
+    }
+  }
+
+  // Removes all the basic blocks from the set of basic blocks within the loop.
+  // This does not affect any of the stored pointers to the header, preheader,
+  // merge, or continue blocks.
+  void ClearBlocks() { loop_basic_blocks_.clear(); }
+
   // Adds the Basic Block |bb| this loop and its parents.
   void AddBasicBlockToLoop(const BasicBlock* bb) {
     assert(IsBasicBlockInLoopSlow(bb) &&
@@ -188,11 +207,58 @@ class Loop {
     AddBasicBlock(bb);
   }
 
-  // Adds the Basic Block |bb| this loop and its parents.
-  void AddBasicBlock(const BasicBlock* bb) {
-    for (Loop* loop = this; loop != nullptr; loop = loop->parent_) {
-      loop_basic_blocks_.insert(bb->id());
+  // This function uses the |condition| to find the induction variable within
+  // the loop. This only works if the loop is bound by a single condition and a
+  // single induction variable.
+  ir::Instruction* FindInductionVariable(const ir::BasicBlock* condition) const;
+
+  // Returns the number of iterations within a loop when given the |induction|
+  // variable and the loop |condition| check. It stores the found number of
+  // iterations in the output parameter |iterations| and optionally, the step
+  // value in |step_value| and the initial value of the induction variable in
+  // |init_value|.
+  bool FindNumberOfIterations(const ir::Instruction* induction,
+                              const ir::Instruction* condition,
+                              size_t* iterations,
+                              int64_t* step_amount = nullptr,
+                              int64_t* init_value = nullptr) const;
+
+  // Returns the value of the OpLoopMerge control operand as a bool. Loop
+  // control can be None(0), Unroll(1), or DontUnroll(2). This function returns
+  // true if it is set to Unroll.
+  inline bool HasUnrollLoopControl() const {
+    assert(loop_header_);
+    if (!loop_header_->GetLoopMergeInst()) return false;
+
+    return loop_header_->GetLoopMergeInst()->GetSingleWordOperand(2) == 1;
+  }
+
+  // Finds the conditional block with a branch to the merge and continue blocks
+  // within the loop body.
+  ir::BasicBlock* FindConditionBlock() const;
+
+  // Remove the child loop form this loop.
+  inline void RemoveChildLoop(Loop* loop) {
+    nested_loops_.erase(
+        std::find(nested_loops_.begin(), nested_loops_.end(), loop));
+    loop->SetParent(nullptr);
+  }
+
+  // Mark this loop to be removed later by a call to
+  // LoopDescriptor::PostModificationCleanup.
+  inline void MarkLoopForRemoval() { loop_is_marked_for_removal_ = true; }
+
+  // Returns whether or not this loop has been marked for removal.
+  inline bool IsMarkedForRemoval() const { return loop_is_marked_for_removal_; }
+
+  // Returns true if all nested loops have been marked for removal.
+  inline bool AreAllChildrenMarkedForRemoval() const {
+    for (const Loop* child : nested_loops_) {
+      if (!child->IsMarkedForRemoval()) {
+        return false;
+      }
     }
+    return true;
   }
 
   // Sets the parent loop of this loop, that is, a loop which contains this loop
@@ -206,6 +272,28 @@ class Loop {
   // loop
   bool AreAllOperandsOutsideLoop(IRContext* context, Instruction* inst);
 
+  // Extract the initial value from the |induction| variable and store it in
+  // |value|. If the function couldn't find the initial value of |induction|
+  // return false.
+  bool GetInductionInitValue(const ir::Loop* loop,
+                             const ir::Instruction* induction,
+                             int64_t* value) const;
+
+  // Takes in a phi instruction |induction| and the loop |header| and returns
+  // the step operation of the loop.
+  ir::Instruction* GetInductionStepOperation(
+      const ir::Loop* loop, const ir::Instruction* induction) const;
+
+  // Returns true if we can deduce the number of loop iterations in the step
+  // operation |step|. IsSupportedCondition must also be true for the condition
+  // instruction.
+  bool IsSupportedStepOp(SpvOp step) const;
+
+  // Returns true if we can deduce the number of loop iterations in the
+  // condition operation |condition|. IsSupportedStepOp must also be true for
+  // the step instruction.
+  bool IsSupportedCondition(SpvOp condition) const;
+
  private:
   IRContext* context_;
   // The block which marks the start of the loop.
@@ -244,6 +332,17 @@ class Loop {
   // Sets |merge| as the loop merge block. No checks are performed here.
   inline void SetMergeBlockImpl(BasicBlock* merge) { loop_merge_ = merge; }
 
+  // Each differnt loop |condition| affects how we calculate the number of
+  // iterations using the |condition_value|, |init_value|, and |step_values| of
+  // the induction variable. This method will return the number of iterations in
+  // a loop with those values for a given |condition|.
+  int64_t GetIterations(SpvOp condition, int64_t condition_value,
+                        int64_t init_value, int64_t step_value) const;
+
+  // This is to allow for loops to be removed mid iteration without invalidating
+  // the iterators.
+  bool loop_is_marked_for_removal_;
+
   // This is only to allow LoopDescriptor::dummy_top_loop_ to add top level
   // loops as child.
   friend class LoopDescriptor;
@@ -317,10 +416,21 @@ class LoopDescriptor {
     basic_block_to_loop_[bb_id] = loop;
   }
 
+  // Mark the loop |loop_to_add| as needing to be added when the user calls
+  // PostModificationCleanup. |parent| may be null.
+  inline void AddLoop(ir::Loop* loop_to_add, ir::Loop* parent) {
+    loops_to_add_.emplace_back(std::make_pair(parent, loop_to_add));
+  }
+
+  // Should be called to preserve the LoopAnalysis after loops have been marked
+  // for addition with AddLoop or MarkLoopForRemoval.
+  void PostModificationCleanup();
+
  private:
   // TODO(dneto): This should be a vector of unique_ptr.  But VisualStudio 2013
   // is unable to compile it.
   using LoopContainerType = std::vector<Loop*>;
+  using LoopsToAddContainerType = std::vector<std::pair<Loop*, Loop*>>;
 
   // Creates loop descriptors for the function |f|.
   void PopulateList(const Function* f);
@@ -338,9 +448,15 @@ class LoopDescriptor {
   // A list of all the loops in the function.  This variable owns the Loop
   // objects.
   LoopContainerType loops_;
+
   // Dummy root: this "loop" is only there to help iterators creation.
   Loop dummy_top_loop_;
+
   std::unordered_map<uint32_t, Loop*> basic_block_to_loop_;
+
+  // List of the loops marked for addition when PostModificationCleanup is
+  // called.
+  LoopsToAddContainerType loops_to_add_;
 };
 
 }  // namespace ir
diff --git a/source/opt/loop_unroller.cpp b/source/opt/loop_unroller.cpp
new file mode 100644 (file)
index 0000000..16031e2
--- /dev/null
@@ -0,0 +1,939 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "opt/loop_unroller.h"
+#include <map>
+#include <memory>
+#include <utility>
+#include "opt/ir_builder.h"
+#include "opt/loop_utils.h"
+
+// Implements loop util unrolling functionality for fully and partially
+// unrolling loops. Given a factor it will duplicate the loop that many times,
+// appending each one to the end of the old loop and removing backedges, to
+// create a new unrolled loop.
+//
+// 1 - User calls LoopUtils::FullyUnroll or LoopUtils::PartiallyUnroll with a
+// loop they wish to unroll. LoopUtils::CanPerformUnroll is used to
+// validate that a given loop can be unrolled. That method (along with the
+// constructor of loop) checks that the IR is in the expected canonicalised
+// format.
+//
+// 2 - The LoopUtils methods create a LoopUnrollerUtilsImpl object to actually
+// perform the unrolling. This implements helper methods to copy the loop basic
+// blocks and remap the ids of instructions used inside them.
+//
+// 3 - The core of LoopUnrollerUtilsImpl is the Unroll method, this method
+// actually performs the loop duplication. It does this by creating a
+// LoopUnrollState object and then copying the loop as given by the factor
+// parameter. The LoopUnrollState object retains the state of the unroller
+// between the loop body copies as each iteration needs information on the last
+// to adjust the phi induction variable, adjust the OpLoopMerge instruction in
+// the main loop header, and change the previous continue block to point to the
+// new header and the new continue block to the main loop header.
+//
+// 4 - If the loop is to be fully unrolled then it is simply closed after step
+// 3, with the OpLoopMerge being deleted, the backedge removed, and the
+// condition blocks folded.
+//
+// 5 - If it is being partially unrolled: if the unrolling factor leaves the
+// loop with an even number of bodies with respect to the number of loop
+// iterations then step 3 is all that is needed. If it is uneven then we need to
+// duplicate the loop completely and unroll the duplicated loop to cover the
+// residual part and adjust the first loop to cover only the "even" part. For
+// instance if you request an unroll factor of 3 on a loop with 10 iterations
+// then copying the body three times would leave you with three bodies in the
+// loop
+// where the loop still iterates over each 4 times. So we make two loops one
+// iterating once then a second loop of three iterating 3 times.
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+// This utility class encapsulates some of the state we need to maintain between
+// loop unrolls. Specifically it maintains key blocks and the induction variable
+// in the current loop duplication step and the blocks from the previous one.
+// This is because each step of the unroll needs to use data from both the
+// preceding step and the original loop.
+struct LoopUnrollState {
+  LoopUnrollState()
+      : previous_phi_(nullptr),
+        previous_continue_block_(nullptr),
+        previous_condition_block_(nullptr),
+        new_phi(nullptr),
+        new_continue_block(nullptr),
+        new_condition_block(nullptr),
+        new_header_block(nullptr) {}
+
+  // Initialize from the loop descriptor class.
+  LoopUnrollState(ir::Instruction* induction, ir::BasicBlock* continue_block,
+                  ir::BasicBlock* condition)
+      : previous_phi_(induction),
+        previous_continue_block_(continue_block),
+        previous_condition_block_(condition),
+        new_phi(nullptr),
+        new_continue_block(nullptr),
+        new_condition_block(nullptr),
+        new_header_block(nullptr) {}
+
+  // Swap the state so that the new nodes are now the previous nodes.
+  void NextIterationState() {
+    previous_phi_ = new_phi;
+    previous_continue_block_ = new_continue_block;
+    previous_condition_block_ = new_condition_block;
+
+    // Clear new nodes.
+    new_phi = nullptr;
+    new_continue_block = nullptr;
+    new_condition_block = nullptr;
+    new_header_block = nullptr;
+
+    // Clear new block/instruction maps.
+    new_blocks.clear();
+    new_inst.clear();
+  }
+
+  // The induction variable from the immediately preceding loop body.
+  ir::Instruction* previous_phi_;
+
+  // The previous continue block. The backedge will be removed from this and
+  // added to the new continue block.
+  ir::BasicBlock* previous_continue_block_;
+
+  // The previous condition block. This may be folded to flatten the loop.
+  ir::BasicBlock* previous_condition_block_;
+
+  // The new induction variable.
+  ir::Instruction* new_phi;
+
+  // The new continue block.
+  ir::BasicBlock* new_continue_block;
+
+  // The new condition block.
+  ir::BasicBlock* new_condition_block;
+
+  // The new header block.
+  ir::BasicBlock* new_header_block;
+
+  // A mapping of new block ids to the original blocks which they were copied
+  // from.
+  std::unordered_map<uint32_t, ir::BasicBlock*> new_blocks;
+
+  // A mapping of new instruction ids to the instruction ids from which they
+  // were copied.
+  std::unordered_map<uint32_t, uint32_t> new_inst;
+};
+
+// This class implements the actual unrolling. It uses a LoopUnrollState to
+// maintain the state of the unrolling inbetween steps.
+class LoopUnrollerUtilsImpl {
+ public:
+  using BasicBlockListTy = std::vector<std::unique_ptr<ir::BasicBlock>>;
+
+  LoopUnrollerUtilsImpl(ir::IRContext* c, ir::Function* function)
+      : context_(c),
+        function_(*function),
+        loop_condition_block_(nullptr),
+        loop_induction_variable_(nullptr),
+        number_of_loop_iterations_(0),
+        loop_step_value_(0),
+        loop_init_value_(0) {}
+
+  // Unroll the |loop| by given |factor| by copying the whole body |factor|
+  // times. The resulting basicblock structure will remain a loop.
+  void PartiallyUnroll(ir::Loop*, size_t factor);
+
+  // If partially unrolling the |loop| would leave the loop with too many bodies
+  // for its number of iterations then this method should be used. This method
+  // will duplicate the |loop| completely, making the duplicated loop the
+  // successor of the original's merge block. The original loop will have its
+  // condition changed to loop over the residual part and the duplicate will be
+  // partially unrolled. The resulting structure will be two loops.
+  void PartiallyUnrollResidualFactor(ir::Loop* loop, size_t factor);
+
+  // Fully unroll the |loop| by copying the full body by the total number of
+  // loop iterations, folding all conditions, and removing the backedge from the
+  // continue block to the header.
+  void FullyUnroll(ir::Loop* loop);
+
+  // Get the ID of the variable in the |phi| paired with |label|.
+  uint32_t GetPhiDefID(const ir::Instruction* phi, uint32_t label) const;
+
+  // Close the loop by removing the OpLoopMerge from the |loop| header block and
+  // making the backedge point to the merge block.
+  void CloseUnrolledLoop(ir::Loop* loop);
+
+  // Remove the OpConditionalBranch instruction inside |conditional_block| used
+  // to branch to either exit or continue the loop and replace it with an
+  // unconditional OpBranch to block |new_target|.
+  void FoldConditionBlock(ir::BasicBlock* condtion_block, uint32_t new_target);
+
+  // Add all blocks_to_add_ to function_ at the |insert_point|.
+  void AddBlocksToFunction(const ir::BasicBlock* insert_point);
+
+  // Duplicates the |old_loop|, cloning each body and remaping the ids without
+  // removing instructions or changing relative structure. Result will be stored
+  // in |new_loop|.
+  void DuplicateLoop(ir::Loop* old_loop, ir::Loop* new_loop);
+
+  inline size_t GetLoopIterationCount() const {
+    return number_of_loop_iterations_;
+  }
+
+  // Extracts the initial state information from the |loop|.
+  void Init(ir::Loop* loop);
+
+ private:
+  // Remap all the in |basic_block| to new IDs and keep the mapping of new ids
+  // to old
+  // ids. |loop| is used to identify special loop blocks (header, continue,
+  // ect).
+  void AssignNewResultIds(ir::BasicBlock* basic_block);
+
+  // Using the map built by AssignNewResultIds, for each instruction in
+  // |basic_block| use
+  // that map to substitute the IDs used by instructions (in the operands) with
+  // the new ids.
+  void RemapOperands(ir::BasicBlock* basic_block);
+
+  // Copy the whole body of the loop, all blocks dominated by the |loop| header
+  // and not dominated by the |loop| merge. The copied body will be linked to by
+  // the old |loop| continue block and the new body will link to the |loop|
+  // header via the new continue block. |eliminate_conditions| is used to decide
+  // whether or not to fold all the condition blocks other than the last one.
+  void CopyBody(ir::Loop* loop, bool eliminate_conditions);
+
+  // Copy a given |block_to_copy| in the |loop| and record the mapping of the
+  // old/new ids. |preserve_instructions| determines whether or not the method
+  // will modify (other than result_id) instructions which are copied.
+  void CopyBasicBlock(ir::Loop* loop, const ir::BasicBlock* block_to_copy,
+                      bool preserve_instructions);
+
+  // The actual implementation of the unroll step. Unrolls |loop| by given
+  // |factor| by copying the body by |factor| times. Also propagates the
+  // induction variable value throughout the copies.
+  void Unroll(ir::Loop* loop, size_t factor);
+
+  // Fills the loop_blocks_inorder_ field with the ordered list of basic blocks
+  // as computed by the method ComputeLoopOrderedBlocks.
+  void ComputeLoopOrderedBlocks(ir::Loop* loop);
+
+  // Adds the blocks_to_add_ to both the |loop| and to the parent of |loop| if
+  // the parent exists.
+  void AddBlocksToLoop(ir::Loop* loop) const;
+
+  // A pointer to the IRContext. Used to add/remove instructions and for usedef
+  // chains.
+  ir::IRContext* context_;
+
+  // A reference the function the loop is within.
+  ir::Function& function_;
+
+  // A list of basic blocks to be added to the loop at the end of an unroll
+  // step.
+  BasicBlockListTy blocks_to_add_;
+
+  // List of instructions which are now dead and can be removed.
+  std::vector<ir::Instruction*> dead_instructions_;
+
+  // Maintains the current state of the transform between calls to unroll.
+  LoopUnrollState state_;
+
+  // An ordered list containing the loop basic blocks.
+  std::vector<ir::BasicBlock*> loop_blocks_inorder_;
+
+  // The block containing the condition check which contains a conditional
+  // branch to the merge and continue block.
+  ir::BasicBlock* loop_condition_block_;
+
+  // The induction variable of the loop.
+  ir::Instruction* loop_induction_variable_;
+
+  // The number of loop iterations that the loop would preform pre-unroll.
+  size_t number_of_loop_iterations_;
+
+  // The amount that the loop steps each iteration.
+  int64_t loop_step_value_;
+
+  // The value the loop starts stepping from.
+  int64_t loop_init_value_;
+};
+
+/*
+ * Static helper functions.
+ */
+
+// Retrieve the index of the OpPhi instruction |phi| which corresponds to the
+// incoming |block| id.
+static uint32_t GetPhiIndexFromLabel(const ir::BasicBlock* block,
+                                     const ir::Instruction* phi) {
+  for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
+    if (block->id() == phi->GetSingleWordInOperand(i)) {
+      return i;
+    }
+  }
+  assert(false && "Could not find operand in instruction.");
+  return 0;
+}
+
+void LoopUnrollerUtilsImpl::Init(ir::Loop* loop) {
+  loop_condition_block_ = loop->FindConditionBlock();
+
+  // When we reinit the second loop during PartiallyUnrollResidualFactor we need
+  // to use the cached value from the duplicate step as the dominator tree
+  // basded solution, loop->FindConditionBlock, requires all the nodes to be
+  // connected up with the correct branches. They won't be at this point.
+  if (!loop_condition_block_) {
+    loop_condition_block_ = state_.new_condition_block;
+  }
+  assert(loop_condition_block_);
+
+  loop_induction_variable_ = loop->FindInductionVariable(loop_condition_block_);
+  assert(loop_induction_variable_);
+
+  bool found = loop->FindNumberOfIterations(
+      loop_induction_variable_, &*loop_condition_block_->ctail(),
+      &number_of_loop_iterations_, &loop_step_value_, &loop_init_value_);
+  (void)found;  // To silence unused variable warning on release builds.
+  assert(found);
+  ComputeLoopOrderedBlocks(loop);
+}
+
+// This function is used to partially unroll the loop when the factor provided
+// would normally lead to an illegal optimization. Instead of just unrolling the
+// loop it creates two loops and unrolls one and adjusts the condition on the
+// other. The end result being that the new loop pair iterates over the correct
+// number of bodies.
+void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop,
+                                                          size_t factor) {
+  // Create a new merge block for the first loop.
+  std::unique_ptr<ir::Instruction> new_label{new ir::Instruction(
+      context_, SpvOp::SpvOpLabel, 0, context_->TakeNextId(), {})};
+  std::unique_ptr<ir::BasicBlock> new_exit_bb{
+      new ir::BasicBlock(std::move(new_label))};
+
+  // Save the id of the block before we move it.
+  uint32_t new_merge_id = new_exit_bb->id();
+
+  // Add the block the list of blocks to add, we want this merge block to be
+  // right at the start of the new blocks.
+  blocks_to_add_.push_back(std::move(new_exit_bb));
+  ir::BasicBlock* new_exit_bb_raw = blocks_to_add_[0].get();
+  ir::Instruction& original_conditional_branch = *loop_condition_block_->tail();
+
+  // Duplicate the loop, providing access to the blocks of both loops.
+  // This is a naked new due to the VS2013 requirement of not having unique
+  // pointers in vectors, as it will be inserted into a vector with
+  // loop_descriptor.AddLoop.
+  ir::Loop* new_loop = new ir::Loop(*loop);
+
+  // Clear the basic blocks of the new loop.
+  new_loop->ClearBlocks();
+
+  DuplicateLoop(loop, new_loop);
+
+  // Add the blocks to the function.
+  AddBlocksToFunction(loop->GetMergeBlock());
+  blocks_to_add_.clear();
+
+  InstructionBuilder builder{context_, new_exit_bb_raw};
+  // Make the first loop branch to the second.
+  builder.AddBranch(new_loop->GetHeaderBlock()->id());
+
+  loop_condition_block_ = state_.new_condition_block;
+  loop_induction_variable_ = state_.new_phi;
+
+  // Unroll the new loop by the factor with the usual -1 to account for the
+  // existing block iteration.
+  Unroll(new_loop, factor);
+
+  // We need to account for the initial body when calculating the remainder.
+  int64_t remainder = loop_init_value_ +
+                      (number_of_loop_iterations_ % factor) * loop_step_value_;
+
+  assert(remainder > std::numeric_limits<int32_t>::min() &&
+         remainder < std::numeric_limits<int32_t>::max());
+
+  ir::Instruction* new_constant = nullptr;
+
+  // If the remainder is negative then we add a signed constant, otherwise just
+  // add an unsigned constant.
+  if (remainder < 0) {
+    new_constant =
+        builder.Add32BitSignedIntegerConstant(static_cast<int32_t>(remainder));
+  } else {
+    new_constant = builder.Add32BitUnsignedIntegerConstant(
+        static_cast<int32_t>(remainder));
+  }
+
+  uint32_t constant_id = new_constant->result_id();
+
+  // Add the merge block to the back of the binary.
+  blocks_to_add_.push_back(
+      std::unique_ptr<ir::BasicBlock>(new_loop->GetMergeBlock()));
+
+  AddBlocksToLoop(new_loop);
+  // Add the blocks to the function.
+  AddBlocksToFunction(loop->GetMergeBlock());
+
+  // Reset the usedef analysis.
+  context_->InvalidateAnalysesExceptFor(
+      ir::IRContext::Analysis::kAnalysisLoopAnalysis);
+  opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+  // Update the condition check.
+  ir::Instruction* condition_check = def_use_manager->GetDef(
+      original_conditional_branch.GetSingleWordOperand(0));
+
+  // This should have been checked by the LoopUtils::CanPerformUnroll function
+  // before entering this.
+  assert(condition_check->opcode() == SpvOpSLessThan);
+  condition_check->SetInOperand(1, {constant_id});
+
+  // Update the next phi node. The phi will have a constant value coming in from
+  // the preheader block. For the duplicated loop we need to update the constant
+  // to be the amount of iterations covered by the first loop and the incoming
+  // block to be the first loops new merge block.
+  uint32_t phi_incoming_index =
+      GetPhiIndexFromLabel(loop->GetPreHeaderBlock(), loop_induction_variable_);
+  loop_induction_variable_->SetInOperand(phi_incoming_index - 1, {constant_id});
+  loop_induction_variable_->SetInOperand(phi_incoming_index, {new_merge_id});
+
+  context_->InvalidateAnalysesExceptFor(
+      ir::IRContext::Analysis::kAnalysisLoopAnalysis);
+
+  context_->ReplaceAllUsesWith(loop->GetMergeBlock()->id(), new_merge_id);
+
+  ir::LoopDescriptor& loop_descriptor =
+      *context_->GetLoopDescriptor(&function_);
+
+  loop_descriptor.AddLoop(new_loop, loop->GetParent());
+}
+
+// Duplicate the |loop| body |factor| number of times while keeping the loop
+// backedge intact.
+void LoopUnrollerUtilsImpl::PartiallyUnroll(ir::Loop* loop, size_t factor) {
+  Unroll(loop, factor);
+  AddBlocksToLoop(loop);
+  AddBlocksToFunction(loop->GetMergeBlock());
+}
+
+// Duplicate the |loop| body |factor| number of times while keeping the loop
+// backedge intact.
+void LoopUnrollerUtilsImpl::Unroll(ir::Loop* loop, size_t factor) {
+  state_ = LoopUnrollState{loop_induction_variable_, loop->GetLatchBlock(),
+                           loop_condition_block_};
+  for (size_t i = 0; i < factor - 1; ++i) {
+    CopyBody(loop, true);
+  }
+
+  uint32_t phi_index = GetPhiIndexFromLabel(state_.previous_continue_block_,
+                                            state_.previous_phi_);
+  uint32_t phi_variable =
+      state_.previous_phi_->GetSingleWordInOperand(phi_index - 1);
+  uint32_t phi_label = state_.previous_phi_->GetSingleWordInOperand(phi_index);
+
+  ir::Instruction* original_phi = loop_induction_variable_;
+
+  // SetInOperands are offset by two.
+  original_phi->SetInOperand(phi_index - 1, {phi_variable});
+  original_phi->SetInOperand(phi_index, {phi_label});
+}
+
+// Fully unroll the loop by partially unrolling it by the number of loop
+// iterations minus one for the body already accounted for.
+void LoopUnrollerUtilsImpl::FullyUnroll(ir::Loop* loop) {
+  // We unroll the loop by number of iterations in the loop.
+  Unroll(loop, number_of_loop_iterations_);
+
+  // The first condition block is preserved until now so it can be copied.
+  FoldConditionBlock(loop_condition_block_, 1);
+
+  // Delete the OpLoopMerge and remove the backedge to the header.
+  CloseUnrolledLoop(loop);
+
+  // Mark the loop for later deletion. This allows us to preserve the loop
+  // iterators but still disregard dead loops.
+  loop->MarkLoopForRemoval();
+
+  // If the loop has a parent add the new blocks to the parent.
+  if (loop->GetParent()) {
+    AddBlocksToLoop(loop->GetParent());
+  }
+
+  // Add the blocks to the function.
+  AddBlocksToFunction(loop->GetMergeBlock());
+
+  // Invalidate all analyses.
+  context_->InvalidateAnalysesExceptFor(
+      ir::IRContext::Analysis::kAnalysisLoopAnalysis);
+}
+
+// Copy a given basic block, give it a new result_id, and store the new block
+// and the id mapping in the state. |preserve_instructions| is used to determine
+// whether or not this function should edit instructions other than the
+// |result_id|.
+void LoopUnrollerUtilsImpl::CopyBasicBlock(ir::Loop* loop,
+                                           const ir::BasicBlock* itr,
+                                           bool preserve_instructions) {
+  // Clone the block exactly, including the IDs.
+  ir::BasicBlock* basic_block = itr->Clone(context_);
+
+  basic_block->SetParent(itr->GetParent());
+
+  // Assign each result a new unique ID and keep a mapping of the old ids to
+  // the new ones.
+  AssignNewResultIds(basic_block);
+
+  // If this is the continue block we are copying.
+  if (itr == loop->GetLatchBlock()) {
+    // Make the OpLoopMerge point to this block for the continue.
+    if (!preserve_instructions) {
+      ir::Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst();
+      merge_inst->SetInOperand(1, {basic_block->id()});
+    }
+
+    state_.new_continue_block = basic_block;
+  }
+
+  // If this is the header block we are copying.
+  if (itr == loop->GetHeaderBlock()) {
+    state_.new_header_block = basic_block;
+
+    if (!preserve_instructions) {
+      // Remove the loop merge instruction if it exists.
+      ir::Instruction* merge_inst = basic_block->GetLoopMergeInst();
+      if (merge_inst) context_->KillInst(merge_inst);
+    }
+  }
+
+  // If this is the condition block we are copying.
+  if (itr == loop_condition_block_) {
+    state_.new_condition_block = basic_block;
+  }
+
+  // Add this block to the list of blocks to add to the function at the end of
+  // the unrolling process.
+  blocks_to_add_.push_back(std::unique_ptr<ir::BasicBlock>(basic_block));
+
+  // Keep tracking the old block via a map.
+  state_.new_blocks[itr->id()] = basic_block;
+}
+
+void LoopUnrollerUtilsImpl::CopyBody(ir::Loop* loop,
+                                     bool eliminate_conditions) {
+  // Copy each basic block in the loop, give them new ids, and save state
+  // information.
+  for (const ir::BasicBlock* itr : loop_blocks_inorder_) {
+    CopyBasicBlock(loop, itr, false);
+  }
+
+  // Set the previous continue block to point to the new header.
+  ir::Instruction& continue_branch = *state_.previous_continue_block_->tail();
+  continue_branch.SetInOperand(0, {state_.new_header_block->id()});
+
+  // As the algorithm copies the original loop blocks exactly, the tail of the
+  // latch block on iterations after the first one will be a branch to the new
+  // header and not the actual loop header. The last continue block in the loop
+  // should always be a backedge to the global header.
+  ir::Instruction& new_continue_branch = *state_.new_continue_block->tail();
+  new_continue_branch.SetInOperand(0, {loop->GetHeaderBlock()->id()});
+
+  // Update references to the old phi node with the actual variable.
+  const ir::Instruction* induction = loop_induction_variable_;
+  state_.new_inst[induction->result_id()] =
+      GetPhiDefID(state_.previous_phi_, state_.previous_continue_block_->id());
+
+  if (eliminate_conditions &&
+      state_.new_condition_block != loop_condition_block_) {
+    FoldConditionBlock(state_.new_condition_block, 1);
+  }
+
+  // Only reference to the header block is the backedge in the latch block,
+  // don't change this.
+  state_.new_inst[loop->GetHeaderBlock()->id()] = loop->GetHeaderBlock()->id();
+
+  for (auto& pair : state_.new_blocks) {
+    RemapOperands(pair.second);
+  }
+
+  dead_instructions_.push_back(state_.new_phi);
+
+  // Swap the state so the new is now the previous.
+  state_.NextIterationState();
+}
+
+uint32_t LoopUnrollerUtilsImpl::GetPhiDefID(const ir::Instruction* phi,
+                                            uint32_t label) const {
+  for (uint32_t operand = 3; operand < phi->NumOperands(); operand += 2) {
+    if (phi->GetSingleWordOperand(operand) == label) {
+      return phi->GetSingleWordOperand(operand - 1);
+    }
+  }
+
+  return 0;
+}
+
+void LoopUnrollerUtilsImpl::FoldConditionBlock(ir::BasicBlock* condition_block,
+                                               uint32_t operand_label) {
+  // Remove the old conditional branch to the merge and continue blocks.
+  ir::Instruction& old_branch = *condition_block->tail();
+  uint32_t new_target = old_branch.GetSingleWordOperand(operand_label);
+  context_->KillInst(&old_branch);
+
+  // Add the new unconditional branch to the merge block.
+  InstructionBuilder builder{context_, condition_block};
+  builder.AddBranch(new_target);
+}
+
+void LoopUnrollerUtilsImpl::CloseUnrolledLoop(ir::Loop* loop) {
+  // Remove the OpLoopMerge instruction from the function.
+  ir::Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst();
+  context_->KillInst(merge_inst);
+
+  // Remove the final backedge to the header and make it point instead to the
+  // merge block.
+  state_.previous_continue_block_->tail()->SetInOperand(
+      0, {loop->GetMergeBlock()->id()});
+
+  // Remove the induction variable as the phi will now be invalid. Replace all
+  // uses with the constant initializer value (all uses of the phi will be in
+  // the first iteration with the subsequent phis already having been removed.
+  uint32_t initalizer_id =
+      GetPhiDefID(loop_induction_variable_, loop->GetPreHeaderBlock()->id());
+  context_->ReplaceAllUsesWith(loop_induction_variable_->result_id(),
+                               initalizer_id);
+
+  // Remove the now unused phi.
+  context_->KillInst(loop_induction_variable_);
+}
+
+// Uses the first loop to create a copy of the loop with new IDs.
+void LoopUnrollerUtilsImpl::DuplicateLoop(ir::Loop* old_loop,
+                                          ir::Loop* new_loop) {
+  std::vector<ir::BasicBlock*> new_block_order;
+
+  // Copy every block in the old loop.
+  for (const ir::BasicBlock* itr : loop_blocks_inorder_) {
+    CopyBasicBlock(old_loop, itr, true);
+    new_block_order.push_back(blocks_to_add_.back().get());
+  }
+
+  ir::BasicBlock* new_merge = old_loop->GetMergeBlock()->Clone(context_);
+  new_merge->SetParent(old_loop->GetMergeBlock()->GetParent());
+  AssignNewResultIds(new_merge);
+  state_.new_blocks[old_loop->GetMergeBlock()->id()] = new_merge;
+  for (auto& pair : state_.new_blocks) {
+    RemapOperands(pair.second);
+  }
+
+  loop_blocks_inorder_ = std::move(new_block_order);
+
+  AddBlocksToLoop(new_loop);
+
+  new_loop->SetHeaderBlock(state_.new_header_block);
+  new_loop->SetLatchBlock(state_.new_continue_block);
+  new_loop->SetMergeBlock(new_merge);
+}
+
+void LoopUnrollerUtilsImpl::AddBlocksToFunction(
+    const ir::BasicBlock* insert_point) {
+  for (ir::Instruction* inst : dead_instructions_) {
+    context_->KillInst(inst);
+  }
+
+  for (auto basic_block_iterator = function_.begin();
+       basic_block_iterator != function_.end(); ++basic_block_iterator) {
+    if (basic_block_iterator->id() == insert_point->id()) {
+      basic_block_iterator.InsertBefore(&blocks_to_add_);
+      return;
+    }
+  }
+
+  assert(
+      false &&
+      "Could not add basic blocks to function as insert point was not found.");
+}
+
+// Assign all result_ids in |basic_block| instructions to new IDs and preserve
+// the mapping of new ids to old ones.
+void LoopUnrollerUtilsImpl::AssignNewResultIds(ir::BasicBlock* basic_block) {
+  // Label instructions aren't covered by normal traversal of the
+  // instructions.
+  uint32_t new_label_id = context_->TakeNextId();
+
+  // Assign a new id to the label.
+  state_.new_inst[basic_block->GetLabelInst()->result_id()] = new_label_id;
+  basic_block->GetLabelInst()->SetResultId(new_label_id);
+
+  for (ir::Instruction& inst : *basic_block) {
+    uint32_t old_id = inst.result_id();
+
+    // Ignore stores etc.
+    if (old_id == 0) {
+      continue;
+    }
+
+    // Give the instruction a new id.
+    inst.SetResultId(context_->TakeNextId());
+
+    // Save the mapping of old_id -> new_id.
+    state_.new_inst[old_id] = inst.result_id();
+
+    // Check if this instruction is the induction variable.
+    if (loop_induction_variable_->result_id() == old_id) {
+      // Save a pointer to the new copy of it.
+      state_.new_phi = &inst;
+    }
+  }
+}
+
+// For all instructions in |basic_block| check if the operands used are from a
+// copied instruction and if so swap out the operand for the copy of it.
+void LoopUnrollerUtilsImpl::RemapOperands(ir::BasicBlock* basic_block) {
+  for (ir::Instruction& inst : *basic_block) {
+    auto remap_operands_to_new_ids = [this](uint32_t* id) {
+      auto itr = state_.new_inst.find(*id);
+      if (itr != state_.new_inst.end()) {
+        *id = itr->second;
+      }
+    };
+
+    inst.ForEachInId(remap_operands_to_new_ids);
+  }
+}
+
+// Generate the ordered list of basic blocks in the |loop| and cache it for
+// later use.
+void LoopUnrollerUtilsImpl::ComputeLoopOrderedBlocks(ir::Loop* loop) {
+  loop_blocks_inorder_.clear();
+
+  opt::DominatorAnalysis* analysis =
+      context_->GetDominatorAnalysis(&function_, *context_->cfg());
+  opt::DominatorTree& tree = analysis->GetDomTree();
+
+  // Starting at the loop header BasicBlock, traverse the dominator tree until
+  // we reach the merge block and add every node we traverse to the set of
+  // blocks
+  // which we consider to be the loop.
+  auto begin_itr = tree.GetTreeNode(loop->GetHeaderBlock())->df_begin();
+  const ir::BasicBlock* merge = loop->GetMergeBlock();
+  auto func = [merge, &tree, this](DominatorTreeNode* node) {
+    if (!tree.Dominates(merge->id(), node->id())) {
+      this->loop_blocks_inorder_.push_back(node->bb_);
+      return true;
+    }
+    return false;
+  };
+
+  tree.VisitChildrenIf(func, begin_itr);
+}
+
+// Adds the blocks_to_add_ to both the loop and to the parent.
+void LoopUnrollerUtilsImpl::AddBlocksToLoop(ir::Loop* loop) const {
+  // Add the blocks to this loop.
+  for (auto& block_itr : blocks_to_add_) {
+    loop->AddBasicBlock(block_itr.get());
+  }
+
+  // Add the blocks to the parent as well.
+  if (loop->GetParent()) AddBlocksToLoop(loop->GetParent());
+}
+
+/*
+ * End LoopUtilsImpl.
+ */
+
+}  // namespace
+
+/*
+ *
+ *  Begin Utils.
+ *
+ * */
+
+bool LoopUtils::CanPerformUnroll() {
+  // The loop is expected to be in structured order.
+  if (!loop_->GetHeaderBlock()->GetMergeInst()) {
+    return false;
+  }
+
+  // Find check the loop has a condition we can find and evaluate.
+  const ir::BasicBlock* condition = loop_->FindConditionBlock();
+  if (!condition) return false;
+
+  // Check that we can find and process the induction variable.
+  const ir::Instruction* induction = loop_->FindInductionVariable(condition);
+  if (!induction || induction->opcode() != SpvOpPhi) return false;
+
+  // Check that we can find the number of loop iterations.
+  if (!loop_->FindNumberOfIterations(induction, &*condition->ctail(), nullptr))
+    return false;
+
+  // Make sure the continue block is a unconditional branch to the header
+  // block.
+  const ir::Instruction& branch = *loop_->GetLatchBlock()->ctail();
+  bool branching_assumption =
+      branch.opcode() == SpvOpBranch &&
+      branch.GetSingleWordInOperand(0) == loop_->GetHeaderBlock()->id();
+  if (!branching_assumption) {
+    return false;
+  }
+
+  // Make sure the induction is the only phi instruction we have in the loop
+  // header. Other optimizations have been seen to leave dead phi nodes in the
+  // header so we also check that the phi is used.
+  for (const ir::Instruction& inst : *loop_->GetHeaderBlock()) {
+    if (inst.opcode() == SpvOpPhi &&
+        inst.result_id() != induction->result_id()) {
+      return false;
+    }
+  }
+
+  // Ban breaks within the loop.
+  const std::vector<uint32_t>& merge_block_preds =
+      context_->cfg()->preds(loop_->GetMergeBlock()->id());
+  if (merge_block_preds.size() != 1) {
+    return false;
+  }
+
+  // Ban continues within the loop.
+  const std::vector<uint32_t>& continue_block_preds =
+      context_->cfg()->preds(loop_->GetLatchBlock()->id());
+  if (continue_block_preds.size() != 1) {
+    return false;
+  }
+
+  // Ban returns in the loop.
+  // Iterate over all the blocks within the loop and check that none of them
+  // exit the loop.
+  for (uint32_t label_id : loop_->GetBlocks()) {
+    const ir::BasicBlock* block = context_->cfg()->block(label_id);
+    if (block->ctail()->opcode() == SpvOp::SpvOpKill ||
+        block->ctail()->opcode() == SpvOp::SpvOpReturn ||
+        block->ctail()->opcode() == SpvOp::SpvOpReturnValue) {
+      return false;
+    }
+  }
+  // Can only unroll inner loops.
+  if (!loop_->AreAllChildrenMarkedForRemoval()) {
+    return false;
+  }
+
+  for (uint32_t block_id : loop_->GetBlocks()) {
+    opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+    ir::BasicBlock& bb = *context_->cfg()->block(block_id);
+    // For every instruction in the block.
+    for (ir::Instruction& inst : bb) {
+      if (inst.result_id() == 0) continue;
+
+      auto is_used_outside_loop = [this,
+                                   def_use_manager](ir::Instruction* user) {
+
+        if (!loop_->IsInsideLoop(user)) {
+          // Some optimization passes have been seen to leave dead phis in the
+          // IR so we check that if a phi is used outside of the loop that the
+          // user is not dead.
+          if (!(user->opcode() == SpvOpPhi &&
+                def_use_manager->NumUsers(user) == 0))
+            return false;
+        }
+        return true;
+      };
+
+      if (!def_use_manager->WhileEachUser(&inst, is_used_outside_loop)) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
+bool LoopUtils::PartiallyUnroll(size_t factor) {
+  if (factor == 1 || !CanPerformUnroll()) return false;
+
+  // Create the unroller utility.
+  LoopUnrollerUtilsImpl unroller{context_,
+                                 loop_->GetHeaderBlock()->GetParent()};
+  unroller.Init(loop_);
+
+  // If the unrolling factor is larger than or the same size as the loop just
+  // fully unroll the loop.
+  if (factor >= unroller.GetLoopIterationCount()) {
+    unroller.FullyUnroll(loop_);
+    return true;
+  }
+
+  // If the loop unrolling factor is an residual number of iterations we need to
+  // let run the loop for the residual part then let it branch into the unrolled
+  // remaining part. We add one when calucating the remainder to take into
+  // account the one iteration already in the loop.
+  if (unroller.GetLoopIterationCount() % factor != 0) {
+    unroller.PartiallyUnrollResidualFactor(loop_, factor);
+  } else {
+    unroller.PartiallyUnroll(loop_, factor);
+  }
+
+  return true;
+}
+
+bool LoopUtils::FullyUnroll() {
+  if (!CanPerformUnroll()) return false;
+
+  LoopUnrollerUtilsImpl unroller{context_,
+                                 loop_->GetHeaderBlock()->GetParent()};
+
+  unroller.Init(loop_);
+  unroller.FullyUnroll(loop_);
+
+  return true;
+}
+
+void LoopUtils::Finalize() {
+  // Clean up the loop descriptor to preserve the analysis.
+
+  ir::LoopDescriptor* LD = context_->GetLoopDescriptor(&function_);
+  LD->PostModificationCleanup();
+}
+
+/*
+ *
+ * Begin Pass.
+ *
+ */
+
+Pass::Status LoopUnroller::Process(ir::IRContext* c) {
+  context_ = c;
+  bool changed = false;
+  for (ir::Function& f : *c->module()) {
+    ir::LoopDescriptor* LD = context_->GetLoopDescriptor(&f);
+    for (ir::Loop& loop : *LD) {
+      LoopUtils loop_utils{c, &loop};
+      if (!loop.HasUnrollLoopControl() || !loop_utils.CanPerformUnroll()) {
+        continue;
+      }
+
+      loop_utils.FullyUnroll();
+      changed = true;
+    }
+    LD->PostModificationCleanup();
+  }
+
+  return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/loop_unroller.h b/source/opt/loop_unroller.h
new file mode 100644 (file)
index 0000000..38dfa34
--- /dev/null
@@ -0,0 +1,37 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_LOOP_UNROLLER_H_
+#define SOURCE_OPT_LOOP_UNROLLER_H_
+#include "opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+class LoopUnroller : public Pass {
+ public:
+  LoopUnroller() : Pass() {}
+
+  const char* name() const override { return "Loop unroller"; }
+
+  Status Process(ir::IRContext* context) override;
+
+ private:
+  ir::IRContext* context_;
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // SOURCE_OPT_LOOP_UNROLLER_H_
index 65a5431..3eb20de 100644 (file)
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#ifndef LIBSPIRV_OPT_LOOP_UTILS_H_
-#define LIBSPIRV_OPT_LOOP_UTILS_H_
+#ifndef SOURCE_OPT_LOOP_UTILS_H_
+#define SOURCE_OPT_LOOP_UTILS_H_
+#include <list>
+#include <memory>
+#include <vector>
+#include "opt/loop_descriptor.h"
 
 namespace spvtools {
 
@@ -24,11 +28,15 @@ class IRContext;
 
 namespace opt {
 
-// Set of basic loop transformation.
+// LoopUtils is used to encapsulte loop optimizations and from the passes which
+// use them. Any pass which needs a loop optimization should do it through this
+// or through a pass which is using this.
 class LoopUtils {
  public:
   LoopUtils(ir::IRContext* context, ir::Loop* loop)
-      : context_(context), loop_(loop) {}
+      : context_(context),
+        loop_(loop),
+        function_(*loop_->GetHeaderBlock()->GetParent()) {}
 
   // The converts the current loop to loop closed SSA form.
   // In the loop closed SSA, all loop exiting values go through a dedicated Phi
@@ -64,12 +72,42 @@ class LoopUtils {
   // Preserves: CFG, def/use and instruction to block mapping.
   void CreateLoopDedicatedExits();
 
+  // Perfom a partial unroll of |loop| by given |factor|. This will copy the
+  // body of the loop |factor| times. So a |factor| of one would give a new loop
+  // with the original body plus one unrolled copy body.
+  bool PartiallyUnroll(size_t factor);
+
+  // Fully unroll |loop|.
+  bool FullyUnroll();
+
+  // This function validates that |loop| meets the assumptions made by the
+  // implementation of the loop unroller. As the implementation accommodates
+  // more types of loops this function can reduce its checks.
+  //
+  // The conditions checked to ensure the loop can be unrolled are as follows:
+  // 1. That the loop is in structured order.
+  // 2. That the condinue block is a branch to the header.
+  // 3. That the only phi used in the loop is the induction variable.
+  //  TODO(stephen@codeplay.com): This is a temporary mesure, after the loop is
+  //  converted into LCSAA form and has a single entry and exit we can rewrite
+  //  the other phis.
+  // 4. That this is an inner most loop, or that loops contained within this
+  // loop have already been fully unrolled.
+  // 5. That each instruction in the loop is only used within the loop.
+  // (Related to the above phi condition).
+  bool CanPerformUnroll();
+
+  // Maintains the loop descriptor object after the unroll functions have been
+  // called, otherwise the analysis should be invalidated.
+  void Finalize();
+
  private:
   ir::IRContext* context_;
   ir::Loop* loop_;
+  ir::Function& function_;
 };
 
 }  // namespace opt
 }  // namespace spvtools
 
-#endif  // LIBSPIRV_OPT_LOOP_UTILS_H_
+#endif  // SOURCE_OPT_LOOP_UTILS_H_
index 9f36f19..8aed7bc 100644 (file)
@@ -389,4 +389,9 @@ Optimizer::PassToken CreateSimplificationPass() {
   return MakeUnique<Optimizer::PassToken::Impl>(
       MakeUnique<opt::SimplificationPass>());
 }
+
+Optimizer::PassToken CreateLoopFullyUnrollPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::LoopUnroller>());
+}
 }  // namespace spvtools
index 890e16e..9fb98aa 100644 (file)
@@ -41,6 +41,7 @@
 #include "local_single_block_elim_pass.h"
 #include "local_single_store_elim_pass.h"
 #include "local_ssa_elim_pass.h"
+#include "loop_unroller.h"
 #include "merge_return_pass.h"
 #include "null_pass.h"
 #include "private_to_local_pass.h"
@@ -53,5 +54,4 @@
 #include "strip_debug_info_pass.h"
 #include "unify_const_pass.h"
 #include "workaround1209.h"
-
 #endif  // LIBSPIRV_OPT_PASSES_H_
index 947f5c5..a9cb499 100644 (file)
@@ -66,3 +66,16 @@ add_spvtools_unittest(TARGET licm_hoist_no_preheader
         hoist_without_preheader.cpp
     LIBS SPIRV-Tools-opt
 )
+
+add_spvtools_unittest(TARGET loop_unroll_simple
+    SRCS ../function_utils.h
+        unroll_simple.cpp
+    LIBS SPIRV-Tools-opt
+)
+
+add_spvtools_unittest(TARGET loop_unroll_assumtion_checks
+    SRCS ../function_utils.h
+        unroll_assumptions.cpp
+    LIBS SPIRV-Tools-opt
+)
+
diff --git a/test/opt/loop_optimizations/unroll_assumptions.cpp b/test/opt/loop_optimizations/unroll_assumptions.cpp
new file mode 100644 (file)
index 0000000..e3ff1ee
--- /dev/null
@@ -0,0 +1,627 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+
+#include "../assembly_builder.h"
+#include "../function_utils.h"
+#include "../pass_fixture.h"
+#include "../pass_utils.h"
+#include "opt/loop_unroller.h"
+#include "opt/loop_utils.h"
+#include "opt/pass.h"
+
+namespace {
+
+using namespace spvtools;
+using ::testing::UnorderedElementsAre;
+
+template <int factor>
+class PartialUnrollerTestPass : public opt::Pass {
+ public:
+  PartialUnrollerTestPass() : Pass() {}
+
+  const char* name() const override { return "Loop unroller"; }
+
+  Status Process(ir::IRContext* context) override {
+    bool changed = false;
+    for (ir::Function& f : *context->module()) {
+      ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(&f);
+      for (auto& loop : loop_descriptor) {
+        opt::LoopUtils loop_utils{context, &loop};
+        if (loop_utils.PartiallyUnroll(factor)) {
+          changed = true;
+        }
+      }
+    }
+
+    if (changed) return Pass::Status::SuccessWithChange;
+    return Pass::Status::SuccessWithoutChange;
+  }
+};
+
+using PassClassTest = PassTest<::testing::Test>;
+
+/*
+Generated from the following GLSL
+#version 410 core
+layout(location = 0) flat in int in_upper_bound;
+void main() {
+  for (int i = ; i < in_upper_bound; ++i) {
+    x[i] = 1.0f;
+  }
+}
+*/
+TEST_F(PassClassTest, CheckUpperBound) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "in_upper_bound"
+OpName %4 "x"
+OpDecorate %3 Flat
+OpDecorate %3 Location 0
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%7 = OpTypeInt 32 1
+%8 = OpTypePointer Function %7
+%9 = OpConstant %7 0
+%10 = OpTypePointer Input %7
+%3 = OpVariable %10 Input
+%11 = OpTypeBool
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpConstant %12 1
+%18 = OpTypePointer Function %12
+%19 = OpConstant %7 1
+%2 = OpFunction %5 None %6
+%20 = OpLabel
+%4 = OpVariable %16 Function
+OpBranch %21
+%21 = OpLabel
+%22 = OpPhi %7 %9 %20 %23 %24
+OpLoopMerge %25 %24 None
+OpBranch %26
+%26 = OpLabel
+%27 = OpLoad %7 %3
+%28 = OpSLessThan %11 %22 %27
+OpBranchConditional %28 %29 %25
+%29 = OpLabel
+%30 = OpAccessChain %18 %4 %22
+OpStore %30 %17
+OpBranch %24
+%24 = OpLabel
+%23 = OpIAdd %7 %22 %19
+OpBranch %21
+%25 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+    float out_array[10];
+    int i = 0;
+    for (int i = 0; i < 10; ++i) {
+        out_array[i] = i;
+    }
+    out_array[9] = i*10;
+}
+*/
+TEST_F(PassClassTest, InductionUsedOutsideOfLoop) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 10
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 10
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 1
+%18 = OpConstant %6 9
+%2 = OpFunction %4 None %5
+%19 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %20
+%20 = OpLabel
+%21 = OpPhi %6 %8 %19 %22 %23
+OpLoopMerge %24 %23 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %10 %21 %9
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%28 = OpConvertSToF %11 %21
+%29 = OpAccessChain %16 %3 %21
+OpStore %29 %28
+OpBranch %23
+%23 = OpLabel
+%22 = OpIAdd %6 %21 %17
+OpBranch %20
+%24 = OpLabel
+%30 = OpIMul %6 %21 %9
+%31 = OpConvertSToF %11 %30
+%32 = OpAccessChain %16 %3 %18
+OpStore %32 %31
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+    float out_array[10];
+    for (uint i = 0; i < 2; i++) {
+      for (float x = 0; x < 5; ++x) {
+        out_array[x + i*5] = i;
+      }
+    }
+}
+*/
+TEST_F(PassClassTest, UnrollNestedLoopsInvalid) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 0
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 2
+%10 = OpTypeBool
+%11 = OpTypeInt 32 1
+%12 = OpTypePointer Function %11
+%13 = OpConstant %11 0
+%14 = OpConstant %11 5
+%15 = OpTypeFloat 32
+%16 = OpConstant %6 10
+%17 = OpTypeArray %15 %16
+%18 = OpTypePointer Function %17
+%19 = OpConstant %6 5
+%20 = OpTypePointer Function %15
+%21 = OpConstant %11 1
+%22 = OpUndef %11
+%2 = OpFunction %4 None %5
+%23 = OpLabel
+%3 = OpVariable %18 Function
+OpBranch %24
+%24 = OpLabel
+%25 = OpPhi %6 %8 %23 %26 %27
+%28 = OpPhi %11 %22 %23 %29 %27
+OpLoopMerge %30 %27 None
+OpBranch %31
+%31 = OpLabel
+%32 = OpULessThan %10 %25 %9
+OpBranchConditional %32 %33 %30
+%33 = OpLabel
+OpBranch %34
+%34 = OpLabel
+%29 = OpPhi %11 %13 %33 %35 %36
+OpLoopMerge %37 %36 None
+OpBranch %38
+%38 = OpLabel
+%39 = OpSLessThan %10 %29 %14
+OpBranchConditional %39 %40 %37
+%40 = OpLabel
+%41 = OpBitcast %6 %29
+%42 = OpIMul %6 %25 %19
+%43 = OpIAdd %6 %41 %42
+%44 = OpConvertUToF %15 %25
+%45 = OpAccessChain %20 %3 %43
+OpStore %45 %44
+OpBranch %36
+%36 = OpLabel
+%35 = OpIAdd %11 %29 %21
+OpBranch %34
+%37 = OpLabel
+OpBranch %27
+%27 = OpLabel
+%26 = OpIAdd %6 %25 %21
+OpBranch %24
+%30 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+}
+
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main(){
+  float x[10];
+  int ind = 0;
+  for (int i = 0; i < 10; i++) {
+    ind = i;
+    x[i] = i;
+  }
+}
+*/
+TEST_F(PassClassTest, MultiplePhiInHeader) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+OpName %3 "x"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 10
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 10
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%18 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %19
+%19 = OpLabel
+%20 = OpPhi %6 %8 %18 %21 %22
+%21 = OpPhi %6 %8 %18 %23 %22
+OpLoopMerge %24 %22 None
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %10 %21 %9
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%28 = OpConvertSToF %11 %21
+%29 = OpAccessChain %16 %3 %21
+OpStore %29 %28
+OpBranch %22
+%22 = OpLabel
+%23 = OpIAdd %6 %21 %17
+OpBranch %19
+%24 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main(){
+  float x[10];
+  for (int i = 0; i < 10; i++) {
+    if (i == 5) {
+      break;
+    }
+    x[i] = i;
+  }
+}
+*/
+TEST_F(PassClassTest, BreakInBody) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+OpName %3 "x"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 10
+%10 = OpTypeBool
+%11 = OpConstant %6 5
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpTypePointer Function %12
+%18 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%19 = OpLabel
+%3 = OpVariable %16 Function
+OpBranch %20
+%20 = OpLabel
+%21 = OpPhi %6 %8 %19 %22 %23
+OpLoopMerge %24 %23 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %10 %21 %9
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%28 = OpIEqual %10 %21 %11
+OpSelectionMerge %29 None
+OpBranchConditional %28 %30 %29
+%30 = OpLabel
+OpBranch %24
+%29 = OpLabel
+%31 = OpConvertSToF %12 %21
+%32 = OpAccessChain %17 %3 %21
+OpStore %32 %31
+OpBranch %23
+%23 = OpLabel
+%22 = OpIAdd %6 %21 %18
+OpBranch %20
+%24 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main(){
+  float x[10];
+  for (int i = 0; i < 10; i++) {
+    if (i == 5) {
+      continue;
+    }
+    x[i] = i;
+  }
+}
+*/
+TEST_F(PassClassTest, ContinueInBody) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+OpName %3 "x"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 10
+%10 = OpTypeBool
+%11 = OpConstant %6 5
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpTypePointer Function %12
+%18 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%19 = OpLabel
+%3 = OpVariable %16 Function
+OpBranch %20
+%20 = OpLabel
+%21 = OpPhi %6 %8 %19 %22 %23
+OpLoopMerge %24 %23 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %10 %21 %9
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%28 = OpIEqual %10 %21 %11
+OpSelectionMerge %29 None
+OpBranchConditional %28 %30 %29
+%30 = OpLabel
+OpBranch %23
+%29 = OpLabel
+%31 = OpConvertSToF %12 %21
+%32 = OpAccessChain %17 %3 %21
+OpStore %32 %31
+OpBranch %23
+%23 = OpLabel
+%22 = OpIAdd %6 %21 %18
+OpBranch %20
+%24 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main(){
+  float x[10];
+  for (int i = 0; i < 10; i++) {
+    if (i == 5) {
+      return;
+    }
+    x[i] = i;
+  }
+}
+*/
+TEST_F(PassClassTest, ReturnInBody) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+OpName %3 "x"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 10
+%10 = OpTypeBool
+%11 = OpConstant %6 5
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpTypePointer Function %12
+%18 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%19 = OpLabel
+%3 = OpVariable %16 Function
+OpBranch %20
+%20 = OpLabel
+%21 = OpPhi %6 %8 %19 %22 %23
+OpLoopMerge %24 %23 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %10 %21 %9
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%28 = OpIEqual %10 %21 %11
+OpSelectionMerge %29 None
+OpBranchConditional %28 %30 %29
+%30 = OpLabel
+OpReturn
+%29 = OpLabel
+%31 = OpConvertSToF %12 %21
+%32 = OpAccessChain %17 %3 %21
+OpStore %32 %31
+OpBranch %23
+%23 = OpLabel
+%22 = OpIAdd %6 %21 %18
+OpBranch %20
+%24 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+}
+
+}  // namespace
diff --git a/test/opt/loop_optimizations/unroll_simple.cpp b/test/opt/loop_optimizations/unroll_simple.cpp
new file mode 100644 (file)
index 0000000..59d2f95
--- /dev/null
@@ -0,0 +1,2179 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+
+#include "../assembly_builder.h"
+#include "../function_utils.h"
+#include "../pass_fixture.h"
+#include "../pass_utils.h"
+#include "opt/loop_unroller.h"
+#include "opt/loop_utils.h"
+#include "opt/pass.h"
+
+namespace {
+
+using namespace spvtools;
+using ::testing::UnorderedElementsAre;
+
+using PassClassTest = PassTest<::testing::Test>;
+
+/*
+Generated from the following GLSL
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  float x[4];
+  for (int i = 0; i < 4; ++i) {
+    x[i] = 1.0f;
+  }
+}
+*/
+TEST_F(PassClassTest, SimpleFullyUnrollTest) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(
+            OpCapability Shader
+            %1 = OpExtInstImport "GLSL.std.450"
+            OpMemoryModel Logical GLSL450
+            OpEntryPoint Fragment %2 "main" %3
+            OpExecutionMode %2 OriginUpperLeft
+            OpSource GLSL 330
+            OpName %2 "main"
+            OpName %5 "x"
+            OpName %3 "c"
+            OpDecorate %3 Location 0
+            %6 = OpTypeVoid
+            %7 = OpTypeFunction %6
+            %8 = OpTypeInt 32 1
+            %9 = OpTypePointer Function %8
+            %10 = OpConstant %8 0
+            %11 = OpConstant %8 4
+            %12 = OpTypeBool
+            %13 = OpTypeFloat 32
+            %14 = OpTypeInt 32 0
+            %15 = OpConstant %14 4
+            %16 = OpTypeArray %13 %15
+            %17 = OpTypePointer Function %16
+            %18 = OpConstant %13 1
+            %19 = OpTypePointer Function %13
+            %20 = OpConstant %8 1
+            %21 = OpTypeVector %13 4
+            %22 = OpTypePointer Output %21
+            %3 = OpVariable %22 Output
+            %2 = OpFunction %6 None %7
+            %23 = OpLabel
+            %5 = OpVariable %17 Function
+            OpBranch %24
+            %24 = OpLabel
+            %35 = OpPhi %8 %10 %23 %34 %26
+            OpLoopMerge %25 %26 Unroll
+            OpBranch %27
+            %27 = OpLabel
+            %29 = OpSLessThan %12 %35 %11
+            OpBranchConditional %29 %30 %25
+            %30 = OpLabel
+            %32 = OpAccessChain %19 %5 %35
+            OpStore %32 %18
+            OpBranch %26
+            %26 = OpLabel
+            %34 = OpIAdd %8 %35 %20
+            OpBranch %24
+            %25 = OpLabel
+            OpReturn
+            OpFunctionEnd
+  )";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 330
+OpName %2 "main"
+OpName %4 "x"
+OpName %3 "c"
+OpDecorate %3 Location 0
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%7 = OpTypeInt 32 1
+%8 = OpTypePointer Function %7
+%9 = OpConstant %7 0
+%10 = OpConstant %7 4
+%11 = OpTypeBool
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 4
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpConstant %12 1
+%18 = OpTypePointer Function %12
+%19 = OpConstant %7 1
+%20 = OpTypeVector %12 4
+%21 = OpTypePointer Output %20
+%3 = OpVariable %21 Output
+%2 = OpFunction %5 None %6
+%22 = OpLabel
+%4 = OpVariable %16 Function
+OpBranch %23
+%23 = OpLabel
+OpBranch %28
+%28 = OpLabel
+%29 = OpSLessThan %11 %9 %10
+OpBranch %30
+%30 = OpLabel
+%31 = OpAccessChain %18 %4 %9
+OpStore %31 %17
+OpBranch %26
+%26 = OpLabel
+%25 = OpIAdd %7 %9 %19
+OpBranch %32
+%32 = OpLabel
+OpBranch %34
+%34 = OpLabel
+%35 = OpSLessThan %11 %25 %10
+OpBranch %36
+%36 = OpLabel
+%37 = OpAccessChain %18 %4 %25
+OpStore %37 %17
+OpBranch %38
+%38 = OpLabel
+%39 = OpIAdd %7 %25 %19
+OpBranch %40
+%40 = OpLabel
+OpBranch %42
+%42 = OpLabel
+%43 = OpSLessThan %11 %39 %10
+OpBranch %44
+%44 = OpLabel
+%45 = OpAccessChain %18 %4 %39
+OpStore %45 %17
+OpBranch %46
+%46 = OpLabel
+%47 = OpIAdd %7 %39 %19
+OpBranch %48
+%48 = OpLabel
+OpBranch %50
+%50 = OpLabel
+%51 = OpSLessThan %11 %47 %10
+OpBranch %52
+%52 = OpLabel
+%53 = OpAccessChain %18 %4 %47
+OpStore %53 %17
+OpBranch %54
+%54 = OpLabel
+%55 = OpIAdd %7 %47 %19
+OpBranch %27
+%27 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+template <int factor>
+class PartialUnrollerTestPass : public opt::Pass {
+ public:
+  PartialUnrollerTestPass() : Pass() {}
+
+  const char* name() const override { return "Loop unroller"; }
+
+  Status Process(ir::IRContext* context) override {
+    for (ir::Function& f : *context->module()) {
+      ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(&f);
+      for (auto& loop : loop_descriptor) {
+        opt::LoopUtils loop_utils{context, &loop};
+        loop_utils.PartiallyUnroll(factor);
+      }
+    }
+
+    return Pass::Status::SuccessWithChange;
+  }
+};
+
+/*
+Generated from the following GLSL
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  float x[10];
+  for (int i = 0; i < 10; ++i) {
+    x[i] = 1.0f;
+  }
+}
+*/
+TEST_F(PassClassTest, SimplePartialUnroll) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(
+            OpCapability Shader
+            %1 = OpExtInstImport "GLSL.std.450"
+            OpMemoryModel Logical GLSL450
+            OpEntryPoint Fragment %2 "main" %3
+            OpExecutionMode %2 OriginUpperLeft
+            OpSource GLSL 330
+            OpName %2 "main"
+            OpName %5 "x"
+            OpName %3 "c"
+            OpDecorate %3 Location 0
+            %6 = OpTypeVoid
+            %7 = OpTypeFunction %6
+            %8 = OpTypeInt 32 1
+            %9 = OpTypePointer Function %8
+            %10 = OpConstant %8 0
+            %11 = OpConstant %8 10
+            %12 = OpTypeBool
+            %13 = OpTypeFloat 32
+            %14 = OpTypeInt 32 0
+            %15 = OpConstant %14 10
+            %16 = OpTypeArray %13 %15
+            %17 = OpTypePointer Function %16
+            %18 = OpConstant %13 1
+            %19 = OpTypePointer Function %13
+            %20 = OpConstant %8 1
+            %21 = OpTypeVector %13 4
+            %22 = OpTypePointer Output %21
+            %3 = OpVariable %22 Output
+            %2 = OpFunction %6 None %7
+            %23 = OpLabel
+            %5 = OpVariable %17 Function
+            OpBranch %24
+            %24 = OpLabel
+            %35 = OpPhi %8 %10 %23 %34 %26
+            OpLoopMerge %25 %26 Unroll
+            OpBranch %27
+            %27 = OpLabel
+            %29 = OpSLessThan %12 %35 %11
+            OpBranchConditional %29 %30 %25
+            %30 = OpLabel
+            %32 = OpAccessChain %19 %5 %35
+            OpStore %32 %18
+            OpBranch %26
+            %26 = OpLabel
+            %34 = OpIAdd %8 %35 %20
+            OpBranch %24
+            %25 = OpLabel
+            OpReturn
+            OpFunctionEnd
+  )";
+
+  const std::string output = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 330
+OpName %2 "main"
+OpName %4 "x"
+OpName %3 "c"
+OpDecorate %3 Location 0
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%7 = OpTypeInt 32 1
+%8 = OpTypePointer Function %7
+%9 = OpConstant %7 0
+%10 = OpConstant %7 10
+%11 = OpTypeBool
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpConstant %12 1
+%18 = OpTypePointer Function %12
+%19 = OpConstant %7 1
+%20 = OpTypeVector %12 4
+%21 = OpTypePointer Output %20
+%3 = OpVariable %21 Output
+%2 = OpFunction %5 None %6
+%22 = OpLabel
+%4 = OpVariable %16 Function
+OpBranch %23
+%23 = OpLabel
+%24 = OpPhi %7 %9 %22 %39 %38
+OpLoopMerge %27 %38 Unroll
+OpBranch %28
+%28 = OpLabel
+%29 = OpSLessThan %11 %24 %10
+OpBranchConditional %29 %30 %27
+%30 = OpLabel
+%31 = OpAccessChain %18 %4 %24
+OpStore %31 %17
+OpBranch %26
+%26 = OpLabel
+%25 = OpIAdd %7 %24 %19
+OpBranch %32
+%32 = OpLabel
+OpBranch %34
+%34 = OpLabel
+%35 = OpSLessThan %11 %25 %10
+OpBranch %36
+%36 = OpLabel
+%37 = OpAccessChain %18 %4 %25
+OpStore %37 %17
+OpBranch %38
+%38 = OpLabel
+%39 = OpIAdd %7 %25 %19
+OpBranch %23
+%27 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, output, false);
+}
+
+/*
+Generated from the following GLSL
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+  float x[10];
+  for (int i = 0; i < 10; ++i) {
+    x[i] = 1.0f;
+  }
+}
+*/
+TEST_F(PassClassTest, SimpleUnevenPartialUnroll) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(
+            OpCapability Shader
+            %1 = OpExtInstImport "GLSL.std.450"
+            OpMemoryModel Logical GLSL450
+            OpEntryPoint Fragment %2 "main" %3
+            OpExecutionMode %2 OriginUpperLeft
+            OpSource GLSL 330
+            OpName %2 "main"
+            OpName %5 "x"
+            OpName %3 "c"
+            OpDecorate %3 Location 0
+            %6 = OpTypeVoid
+            %7 = OpTypeFunction %6
+            %8 = OpTypeInt 32 1
+            %9 = OpTypePointer Function %8
+            %10 = OpConstant %8 0
+            %11 = OpConstant %8 10
+            %12 = OpTypeBool
+            %13 = OpTypeFloat 32
+            %14 = OpTypeInt 32 0
+            %15 = OpConstant %14 10
+            %16 = OpTypeArray %13 %15
+            %17 = OpTypePointer Function %16
+            %18 = OpConstant %13 1
+            %19 = OpTypePointer Function %13
+            %20 = OpConstant %8 1
+            %21 = OpTypeVector %13 4
+            %22 = OpTypePointer Output %21
+            %3 = OpVariable %22 Output
+            %2 = OpFunction %6 None %7
+            %23 = OpLabel
+            %5 = OpVariable %17 Function
+            OpBranch %24
+            %24 = OpLabel
+            %35 = OpPhi %8 %10 %23 %34 %26
+            OpLoopMerge %25 %26 Unroll
+            OpBranch %27
+            %27 = OpLabel
+            %29 = OpSLessThan %12 %35 %11
+            OpBranchConditional %29 %30 %25
+            %30 = OpLabel
+            %32 = OpAccessChain %19 %5 %35
+            OpStore %32 %18
+            OpBranch %26
+            %26 = OpLabel
+            %34 = OpIAdd %8 %35 %20
+            OpBranch %24
+            %25 = OpLabel
+            OpReturn
+            OpFunctionEnd
+  )";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 330
+OpName %2 "main"
+OpName %4 "x"
+OpName %3 "c"
+OpDecorate %3 Location 0
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%7 = OpTypeInt 32 1
+%8 = OpTypePointer Function %7
+%9 = OpConstant %7 0
+%10 = OpConstant %7 10
+%11 = OpTypeBool
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpConstant %12 1
+%18 = OpTypePointer Function %12
+%19 = OpConstant %7 1
+%20 = OpTypeVector %12 4
+%21 = OpTypePointer Output %20
+%3 = OpVariable %21 Output
+%58 = OpConstant %13 1
+%2 = OpFunction %5 None %6
+%22 = OpLabel
+%4 = OpVariable %16 Function
+OpBranch %23
+%23 = OpLabel
+%24 = OpPhi %7 %9 %22 %25 %26
+OpLoopMerge %32 %26 Unroll
+OpBranch %28
+%28 = OpLabel
+%29 = OpSLessThan %11 %24 %58
+OpBranchConditional %29 %30 %32
+%30 = OpLabel
+%31 = OpAccessChain %18 %4 %24
+OpStore %31 %17
+OpBranch %26
+%26 = OpLabel
+%25 = OpIAdd %7 %24 %19
+OpBranch %23
+%32 = OpLabel
+OpBranch %33
+%33 = OpLabel
+%34 = OpPhi %7 %58 %32 %57 %56
+OpLoopMerge %41 %56 Unroll
+OpBranch %35
+%35 = OpLabel
+%36 = OpSLessThan %11 %34 %10
+OpBranchConditional %36 %37 %41
+%37 = OpLabel
+%38 = OpAccessChain %18 %4 %34
+OpStore %38 %17
+OpBranch %39
+%39 = OpLabel
+%40 = OpIAdd %7 %34 %19
+OpBranch %42
+%42 = OpLabel
+OpBranch %44
+%44 = OpLabel
+%45 = OpSLessThan %11 %40 %10
+OpBranch %46
+%46 = OpLabel
+%47 = OpAccessChain %18 %4 %40
+OpStore %47 %17
+OpBranch %48
+%48 = OpLabel
+%49 = OpIAdd %7 %40 %19
+OpBranch %50
+%50 = OpLabel
+OpBranch %52
+%52 = OpLabel
+%53 = OpSLessThan %11 %49 %10
+OpBranch %54
+%54 = OpLabel
+%55 = OpAccessChain %18 %4 %49
+OpStore %55 %17
+OpBranch %56
+%56 = OpLabel
+%57 = OpIAdd %7 %49 %19
+OpBranch %33
+%41 = OpLabel
+OpReturn
+%27 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  // By unrolling by a factor that doesn't divide evenly into the number of loop
+  // iterations we perfom an additional transform when partially unrolling to
+  // account for the remainder.
+  SinglePassRunAndCheck<PartialUnrollerTestPass<3>>(text, output, false);
+}
+
+/* Generated from
+#version 410 core
+layout(location=0) flat in int upper_bound;
+void main() {
+    float x[10];
+    for (int i = 2; i < 8; i+=2) {
+        x[i] = i;
+    }
+}
+*/
+TEST_F(PassClassTest, SimpleLoopIterationsCheck) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(
+OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %5 "x"
+OpName %3 "upper_bound"
+OpDecorate %3 Flat
+OpDecorate %3 Location 0
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%8 = OpTypeInt 32 1
+%9 = OpTypePointer Function %8
+%10 = OpConstant %8 2
+%11 = OpConstant %8 8
+%12 = OpTypeBool
+%13 = OpTypeFloat 32
+%14 = OpTypeInt 32 0
+%15 = OpConstant %14 10
+%16 = OpTypeArray %13 %15
+%17 = OpTypePointer Function %16
+%18 = OpTypePointer Function %13
+%19 = OpTypePointer Input %8
+%3 = OpVariable %19 Input
+%2 = OpFunction %6 None %7
+%20 = OpLabel
+%5 = OpVariable %17 Function
+OpBranch %21
+%21 = OpLabel
+%34 = OpPhi %8 %10 %20 %33 %23
+OpLoopMerge %22 %23 Unroll
+OpBranch %24
+%24 = OpLabel
+%26 = OpSLessThan %12 %34 %11
+OpBranchConditional %26 %27 %22
+%27 = OpLabel
+%30 = OpConvertSToF %13 %34
+%31 = OpAccessChain %18 %5 %34
+OpStore %31 %30
+OpBranch %23
+%23 = OpLabel
+%33 = OpIAdd %8 %34 %10
+OpBranch %21
+%22 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  ir::Function* f = spvtest::GetFunction(module, 2);
+
+  ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+  EXPECT_EQ(loop_descriptor.NumLoops(), 1u);
+
+  ir::Loop& loop = loop_descriptor.GetLoopByIndex(0);
+
+  EXPECT_TRUE(loop.HasUnrollLoopControl());
+
+  ir::BasicBlock* condition = loop.FindConditionBlock();
+  EXPECT_EQ(condition->id(), 24u);
+
+  ir::Instruction* induction = loop.FindInductionVariable(condition);
+  EXPECT_EQ(induction->result_id(), 34u);
+
+  opt::LoopUtils loop_utils{context.get(), &loop};
+  EXPECT_TRUE(loop_utils.CanPerformUnroll());
+
+  size_t iterations = 0;
+  EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(),
+                                          &iterations));
+  EXPECT_EQ(iterations, 3u);
+}
+
+/* Generated from
+#version 410 core
+void main() {
+    float x[10];
+    for (int i = -1; i < 6; i+=3) {
+        x[i] = i;
+    }
+}
+*/
+TEST_F(PassClassTest, SimpleLoopIterationsCheckSignedInit) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(
+OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %5 "x"
+OpName %3 "upper_bound"
+OpDecorate %3 Flat
+OpDecorate %3 Location 0
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%8 = OpTypeInt 32 1
+%9 = OpTypePointer Function %8
+%10 = OpConstant %8 -1
+%11 = OpConstant %8 6
+%12 = OpTypeBool
+%13 = OpTypeFloat 32
+%14 = OpTypeInt 32 0
+%15 = OpConstant %14 10
+%16 = OpTypeArray %13 %15
+%17 = OpTypePointer Function %16
+%18 = OpTypePointer Function %13
+%19 = OpConstant %8 3
+%20 = OpTypePointer Input %8
+%3 = OpVariable %20 Input
+%2 = OpFunction %6 None %7
+%21 = OpLabel
+%5 = OpVariable %17 Function
+OpBranch %22
+%22 = OpLabel
+%35 = OpPhi %8 %10 %21 %34 %24
+OpLoopMerge %23 %24 None
+OpBranch %25
+%25 = OpLabel
+%27 = OpSLessThan %12 %35 %11
+OpBranchConditional %27 %28 %23
+%28 = OpLabel
+%31 = OpConvertSToF %13 %35
+%32 = OpAccessChain %18 %5 %35
+OpStore %32 %31
+OpBranch %24
+%24 = OpLabel
+%34 = OpIAdd %8 %35 %19
+OpBranch %22
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  ir::Function* f = spvtest::GetFunction(module, 2);
+
+  ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+
+  EXPECT_EQ(loop_descriptor.NumLoops(), 1u);
+
+  ir::Loop& loop = loop_descriptor.GetLoopByIndex(0);
+
+  EXPECT_FALSE(loop.HasUnrollLoopControl());
+
+  ir::BasicBlock* condition = loop.FindConditionBlock();
+  EXPECT_EQ(condition->id(), 25u);
+
+  ir::Instruction* induction = loop.FindInductionVariable(condition);
+  EXPECT_EQ(induction->result_id(), 35u);
+
+  opt::LoopUtils loop_utils{context.get(), &loop};
+  EXPECT_TRUE(loop_utils.CanPerformUnroll());
+
+  size_t iterations = 0;
+  EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(),
+                                          &iterations));
+  EXPECT_EQ(iterations, 3u);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+    float out_array[6];
+    for (uint i = 0; i < 2; i++) {
+      for (int x = 0; x < 3; ++x) {
+        out_array[x + i*3] = i;
+      }
+    }
+}
+*/
+TEST_F(PassClassTest, UnrollNestedLoops) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource GLSL 410
+               OpName %4 "main"
+               OpName %35 "out_array"
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 0
+          %7 = OpTypePointer Function %6
+          %9 = OpConstant %6 0
+         %16 = OpConstant %6 2
+         %17 = OpTypeBool
+         %19 = OpTypeInt 32 1
+         %20 = OpTypePointer Function %19
+         %22 = OpConstant %19 0
+         %29 = OpConstant %19 3
+         %31 = OpTypeFloat 32
+         %32 = OpConstant %6 6
+         %33 = OpTypeArray %31 %32
+         %34 = OpTypePointer Function %33
+         %39 = OpConstant %6 3
+         %44 = OpTypePointer Function %31
+         %47 = OpConstant %19 1
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+         %35 = OpVariable %34 Function
+               OpBranch %10
+         %10 = OpLabel
+         %51 = OpPhi %6 %9 %5 %50 %13
+               OpLoopMerge %12 %13 Unroll
+               OpBranch %14
+         %14 = OpLabel
+         %18 = OpULessThan %17 %51 %16
+               OpBranchConditional %18 %11 %12
+         %11 = OpLabel
+               OpBranch %23
+         %23 = OpLabel
+         %54 = OpPhi %19 %22 %11 %48 %26
+               OpLoopMerge %25 %26 Unroll
+               OpBranch %27
+         %27 = OpLabel
+         %30 = OpSLessThan %17 %54 %29
+               OpBranchConditional %30 %24 %25
+         %24 = OpLabel
+         %37 = OpBitcast %6 %54
+         %40 = OpIMul %6 %51 %39
+         %41 = OpIAdd %6 %37 %40
+         %43 = OpConvertUToF %31 %51
+         %45 = OpAccessChain %44 %35 %41
+               OpStore %45 %43
+               OpBranch %26
+         %26 = OpLabel
+         %48 = OpIAdd %19 %54 %47
+               OpBranch %23
+         %25 = OpLabel
+               OpBranch %13
+         %13 = OpLabel
+         %50 = OpIAdd %6 %51 %47
+               OpBranch %10
+         %12 = OpLabel
+               OpReturn
+               OpFunctionEnd
+    )";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 0
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 2
+%10 = OpTypeBool
+%11 = OpTypeInt 32 1
+%12 = OpTypePointer Function %11
+%13 = OpConstant %11 0
+%14 = OpConstant %11 3
+%15 = OpTypeFloat 32
+%16 = OpConstant %6 6
+%17 = OpTypeArray %15 %16
+%18 = OpTypePointer Function %17
+%19 = OpConstant %6 3
+%20 = OpTypePointer Function %15
+%21 = OpConstant %11 1
+%2 = OpFunction %4 None %5
+%22 = OpLabel
+%3 = OpVariable %18 Function
+OpBranch %23
+%23 = OpLabel
+OpBranch %28
+%28 = OpLabel
+%29 = OpULessThan %10 %8 %9
+OpBranch %30
+%30 = OpLabel
+OpBranch %31
+%31 = OpLabel
+OpBranch %36
+%36 = OpLabel
+%37 = OpSLessThan %10 %13 %14
+OpBranch %38
+%38 = OpLabel
+%39 = OpBitcast %6 %13
+%40 = OpIMul %6 %8 %19
+%41 = OpIAdd %6 %39 %40
+%42 = OpConvertUToF %15 %8
+%43 = OpAccessChain %20 %3 %41
+OpStore %43 %42
+OpBranch %34
+%34 = OpLabel
+%33 = OpIAdd %11 %13 %21
+OpBranch %44
+%44 = OpLabel
+OpBranch %46
+%46 = OpLabel
+%47 = OpSLessThan %10 %33 %14
+OpBranch %48
+%48 = OpLabel
+%49 = OpBitcast %6 %33
+%50 = OpIMul %6 %8 %19
+%51 = OpIAdd %6 %49 %50
+%52 = OpConvertUToF %15 %8
+%53 = OpAccessChain %20 %3 %51
+OpStore %53 %52
+OpBranch %54
+%54 = OpLabel
+%55 = OpIAdd %11 %33 %21
+OpBranch %56
+%56 = OpLabel
+OpBranch %58
+%58 = OpLabel
+%59 = OpSLessThan %10 %55 %14
+OpBranch %60
+%60 = OpLabel
+%61 = OpBitcast %6 %55
+%62 = OpIMul %6 %8 %19
+%63 = OpIAdd %6 %61 %62
+%64 = OpConvertUToF %15 %8
+%65 = OpAccessChain %20 %3 %63
+OpStore %65 %64
+OpBranch %66
+%66 = OpLabel
+%67 = OpIAdd %11 %55 %21
+OpBranch %35
+%35 = OpLabel
+OpBranch %26
+%26 = OpLabel
+%25 = OpIAdd %6 %8 %21
+OpBranch %68
+%68 = OpLabel
+OpBranch %70
+%70 = OpLabel
+%71 = OpULessThan %10 %25 %9
+OpBranch %72
+%72 = OpLabel
+OpBranch %73
+%73 = OpLabel
+OpBranch %74
+%74 = OpLabel
+%75 = OpSLessThan %10 %13 %14
+OpBranch %76
+%76 = OpLabel
+%77 = OpBitcast %6 %13
+%78 = OpIMul %6 %25 %19
+%79 = OpIAdd %6 %77 %78
+%80 = OpConvertUToF %15 %25
+%81 = OpAccessChain %20 %3 %79
+OpStore %81 %80
+OpBranch %82
+%82 = OpLabel
+%83 = OpIAdd %11 %13 %21
+OpBranch %84
+%84 = OpLabel
+OpBranch %85
+%85 = OpLabel
+%86 = OpSLessThan %10 %83 %14
+OpBranch %87
+%87 = OpLabel
+%88 = OpBitcast %6 %83
+%89 = OpIMul %6 %25 %19
+%90 = OpIAdd %6 %88 %89
+%91 = OpConvertUToF %15 %25
+%92 = OpAccessChain %20 %3 %90
+OpStore %92 %91
+OpBranch %93
+%93 = OpLabel
+%94 = OpIAdd %11 %83 %21
+OpBranch %95
+%95 = OpLabel
+OpBranch %96
+%96 = OpLabel
+%97 = OpSLessThan %10 %94 %14
+OpBranch %98
+%98 = OpLabel
+%99 = OpBitcast %6 %94
+%100 = OpIMul %6 %25 %19
+%101 = OpIAdd %6 %99 %100
+%102 = OpConvertUToF %15 %25
+%103 = OpAccessChain %20 %3 %101
+OpStore %103 %102
+OpBranch %104
+%104 = OpLabel
+%105 = OpIAdd %11 %94 %21
+OpBranch %106
+%106 = OpLabel
+OpBranch %107
+%107 = OpLabel
+%108 = OpIAdd %6 %25 %21
+OpBranch %27
+%27 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+    float out_array[2];
+    for (int i = -3; i < -1; i++) {
+      out_array[3 + i] = i;
+    }
+}
+*/
+TEST_F(PassClassTest, NegativeConditionAndInit) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource GLSL 410
+               OpName %4 "main"
+               OpName %23 "out_array"
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Function %6
+          %9 = OpConstant %6 -3
+         %16 = OpConstant %6 -1
+         %17 = OpTypeBool
+         %19 = OpTypeInt 32 0
+         %20 = OpConstant %19 2
+         %21 = OpTypeArray %6 %20
+         %22 = OpTypePointer Function %21
+         %25 = OpConstant %6 3
+         %30 = OpConstant %6 1
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+         %23 = OpVariable %22 Function
+               OpBranch %10
+         %10 = OpLabel
+         %32 = OpPhi %6 %9 %5 %31 %13
+               OpLoopMerge %12 %13 Unroll
+               OpBranch %14
+         %14 = OpLabel
+         %18 = OpSLessThan %17 %32 %16
+               OpBranchConditional %18 %11 %12
+         %11 = OpLabel
+         %26 = OpIAdd %6 %32 %25
+         %28 = OpAccessChain %7 %23 %26
+               OpStore %28 %32
+               OpBranch %13
+         %13 = OpLabel
+         %31 = OpIAdd %6 %32 %30
+               OpBranch %10
+         %12 = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+const std::string expected = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 -3
+%9 = OpConstant %6 -1
+%10 = OpTypeBool
+%11 = OpTypeInt 32 0
+%12 = OpConstant %11 2
+%13 = OpTypeArray %6 %12
+%14 = OpTypePointer Function %13
+%15 = OpConstant %6 3
+%16 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%17 = OpLabel
+%3 = OpVariable %14 Function
+OpBranch %18
+%18 = OpLabel
+OpBranch %23
+%23 = OpLabel
+%24 = OpSLessThan %10 %8 %9
+OpBranch %25
+%25 = OpLabel
+%26 = OpIAdd %6 %8 %15
+%27 = OpAccessChain %7 %3 %26
+OpStore %27 %8
+OpBranch %21
+%21 = OpLabel
+%20 = OpIAdd %6 %8 %16
+OpBranch %28
+%28 = OpLabel
+OpBranch %30
+%30 = OpLabel
+%31 = OpSLessThan %10 %20 %9
+OpBranch %32
+%32 = OpLabel
+%33 = OpIAdd %6 %20 %15
+%34 = OpAccessChain %7 %3 %33
+OpStore %34 %20
+OpBranch %35
+%35 = OpLabel
+%36 = OpIAdd %6 %20 %16
+OpBranch %22
+%22 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  // SinglePassRunAndCheck<opt::LoopUnroller>(text, expected, false);
+
+  ir::Function* f = spvtest::GetFunction(module, 4);
+
+  ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+  EXPECT_EQ(loop_descriptor.NumLoops(), 1u);
+
+  ir::Loop& loop = loop_descriptor.GetLoopByIndex(0);
+
+  EXPECT_TRUE(loop.HasUnrollLoopControl());
+
+  ir::BasicBlock* condition = loop.FindConditionBlock();
+  EXPECT_EQ(condition->id(), 14u);
+
+  ir::Instruction* induction = loop.FindInductionVariable(condition);
+  EXPECT_EQ(induction->result_id(), 32u);
+
+  opt::LoopUtils loop_utils{context.get(), &loop};
+  EXPECT_TRUE(loop_utils.CanPerformUnroll());
+
+  size_t iterations = 0;
+  EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(),
+                                          &iterations));
+  EXPECT_EQ(iterations, 2u);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, expected, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+    float out_array[9];
+    for (int i = -10; i < -1; i++) {
+      out_array[i] = i;
+    }
+}
+*/
+TEST_F(PassClassTest, NegativeConditionAndInitResidualUnroll) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource GLSL 410
+               OpName %4 "main"
+               OpName %23 "out_array"
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Function %6
+          %9 = OpConstant %6 -10
+         %16 = OpConstant %6 -1
+         %17 = OpTypeBool
+         %19 = OpTypeInt 32 0
+         %20 = OpConstant %19 9
+         %21 = OpTypeArray %6 %20
+         %22 = OpTypePointer Function %21
+         %25 = OpConstant %6 10
+         %30 = OpConstant %6 1
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+         %23 = OpVariable %22 Function
+               OpBranch %10
+         %10 = OpLabel
+         %32 = OpPhi %6 %9 %5 %31 %13
+               OpLoopMerge %12 %13 Unroll
+               OpBranch %14
+         %14 = OpLabel
+         %18 = OpSLessThan %17 %32 %16
+               OpBranchConditional %18 %11 %12
+         %11 = OpLabel
+         %26 = OpIAdd %6 %32 %25
+         %28 = OpAccessChain %7 %23 %26
+               OpStore %28 %32
+               OpBranch %13
+         %13 = OpLabel
+         %31 = OpIAdd %6 %32 %30
+               OpBranch %10
+         %12 = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+const std::string expected = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 -10
+%9 = OpConstant %6 -1
+%10 = OpTypeBool
+%11 = OpTypeInt 32 0
+%12 = OpConstant %11 9
+%13 = OpTypeArray %6 %12
+%14 = OpTypePointer Function %13
+%15 = OpConstant %6 10
+%16 = OpConstant %6 1
+%48 = OpConstant %6 -9
+%2 = OpFunction %4 None %5
+%17 = OpLabel
+%3 = OpVariable %14 Function
+OpBranch %18
+%18 = OpLabel
+%19 = OpPhi %6 %8 %17 %20 %21
+OpLoopMerge %28 %21 Unroll
+OpBranch %23
+%23 = OpLabel
+%24 = OpSLessThan %10 %19 %48
+OpBranchConditional %24 %25 %28
+%25 = OpLabel
+%26 = OpIAdd %6 %19 %15
+%27 = OpAccessChain %7 %3 %26
+OpStore %27 %19
+OpBranch %21
+%21 = OpLabel
+%20 = OpIAdd %6 %19 %16
+OpBranch %18
+%28 = OpLabel
+OpBranch %29
+%29 = OpLabel
+%30 = OpPhi %6 %48 %28 %47 %46
+OpLoopMerge %38 %46 Unroll
+OpBranch %31
+%31 = OpLabel
+%32 = OpSLessThan %10 %30 %9
+OpBranchConditional %32 %33 %38
+%33 = OpLabel
+%34 = OpIAdd %6 %30 %15
+%35 = OpAccessChain %7 %3 %34
+OpStore %35 %30
+OpBranch %36
+%36 = OpLabel
+%37 = OpIAdd %6 %30 %16
+OpBranch %39
+%39 = OpLabel
+OpBranch %41
+%41 = OpLabel
+%42 = OpSLessThan %10 %37 %9
+OpBranch %43
+%43 = OpLabel
+%44 = OpIAdd %6 %37 %15
+%45 = OpAccessChain %7 %3 %44
+OpStore %45 %37
+OpBranch %46
+%46 = OpLabel
+%47 = OpIAdd %6 %37 %16
+OpBranch %29
+%38 = OpLabel
+OpReturn
+%22 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  ir::Function* f = spvtest::GetFunction(module, 4);
+
+  ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+  EXPECT_EQ(loop_descriptor.NumLoops(), 1u);
+
+  ir::Loop& loop = loop_descriptor.GetLoopByIndex(0);
+
+  EXPECT_TRUE(loop.HasUnrollLoopControl());
+
+  ir::BasicBlock* condition = loop.FindConditionBlock();
+  EXPECT_EQ(condition->id(), 14u);
+
+  ir::Instruction* induction = loop.FindInductionVariable(condition);
+  EXPECT_EQ(induction->result_id(), 32u);
+
+  opt::LoopUtils loop_utils{context.get(), &loop};
+  EXPECT_TRUE(loop_utils.CanPerformUnroll());
+
+  size_t iterations = 0;
+  EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(),
+                                          &iterations));
+  EXPECT_EQ(iterations, 9u);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, expected, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+    float out_array[10];
+    for (uint i = 0; i < 2; i++) {
+      for (int x = 0; x < 5; ++x) {
+        out_array[x + i*5] = i;
+      }
+    }
+}
+*/
+TEST_F(PassClassTest, UnrollNestedLoopsValidateDescriptor) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource GLSL 410
+               OpName %4 "main"
+               OpName %35 "out_array"
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 0
+          %7 = OpTypePointer Function %6
+          %9 = OpConstant %6 0
+         %16 = OpConstant %6 2
+         %17 = OpTypeBool
+         %19 = OpTypeInt 32 1
+         %20 = OpTypePointer Function %19
+         %22 = OpConstant %19 0
+         %29 = OpConstant %19 5
+         %31 = OpTypeFloat 32
+         %32 = OpConstant %6 10
+         %33 = OpTypeArray %31 %32
+         %34 = OpTypePointer Function %33
+         %39 = OpConstant %6 5
+         %44 = OpTypePointer Function %31
+         %47 = OpConstant %19 1
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+         %35 = OpVariable %34 Function
+               OpBranch %10
+         %10 = OpLabel
+         %51 = OpPhi %6 %9 %5 %50 %13
+               OpLoopMerge %12 %13 Unroll
+               OpBranch %14
+         %14 = OpLabel
+         %18 = OpULessThan %17 %51 %16
+               OpBranchConditional %18 %11 %12
+         %11 = OpLabel
+               OpBranch %23
+         %23 = OpLabel
+         %54 = OpPhi %19 %22 %11 %48 %26
+               OpLoopMerge %25 %26 Unroll
+               OpBranch %27
+         %27 = OpLabel
+         %30 = OpSLessThan %17 %54 %29
+               OpBranchConditional %30 %24 %25
+         %24 = OpLabel
+         %37 = OpBitcast %6 %54
+         %40 = OpIMul %6 %51 %39
+         %41 = OpIAdd %6 %37 %40
+         %43 = OpConvertUToF %31 %51
+         %45 = OpAccessChain %44 %35 %41
+               OpStore %45 %43
+               OpBranch %26
+         %26 = OpLabel
+         %48 = OpIAdd %19 %54 %47
+               OpBranch %23
+         %25 = OpLabel
+               OpBranch %13
+         %13 = OpLabel
+         %50 = OpIAdd %6 %51 %47
+               OpBranch %10
+         %12 = OpLabel
+               OpReturn
+               OpFunctionEnd
+    )";
+
+  // clang-format on
+
+  {  // Test fully unroll
+    std::unique_ptr<ir::IRContext> context =
+        BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                    SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+    ir::Module* module = context->module();
+    EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                               << text << std::endl;
+    SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+    ir::Function* f = spvtest::GetFunction(module, 4);
+    ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+    EXPECT_EQ(loop_descriptor.NumLoops(), 2u);
+
+    ir::Loop& outer_loop = loop_descriptor.GetLoopByIndex(1);
+
+    EXPECT_TRUE(outer_loop.HasUnrollLoopControl());
+
+    ir::Loop& inner_loop = loop_descriptor.GetLoopByIndex(0);
+
+    EXPECT_TRUE(inner_loop.HasUnrollLoopControl());
+
+    EXPECT_EQ(outer_loop.GetBlocks().size(), 9u);
+
+    EXPECT_EQ(inner_loop.GetBlocks().size(), 4u);
+    EXPECT_EQ(outer_loop.NumImmediateChildren(), 1u);
+    EXPECT_EQ(inner_loop.NumImmediateChildren(), 0u);
+
+    {
+      opt::LoopUtils loop_utils{context.get(), &inner_loop};
+      loop_utils.FullyUnroll();
+      loop_utils.Finalize();
+    }
+
+    EXPECT_EQ(loop_descriptor.NumLoops(), 1u);
+    EXPECT_EQ(outer_loop.GetBlocks().size(), 25u);
+    EXPECT_EQ(outer_loop.NumImmediateChildren(), 0u);
+    {
+      opt::LoopUtils loop_utils{context.get(), &outer_loop};
+      loop_utils.FullyUnroll();
+      loop_utils.Finalize();
+    }
+    EXPECT_EQ(loop_descriptor.NumLoops(), 0u);
+  }
+
+  {  // Test partially unroll
+    std::unique_ptr<ir::IRContext> context =
+        BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                    SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+    ir::Module* module = context->module();
+    EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                               << text << std::endl;
+    SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+    ir::Function* f = spvtest::GetFunction(module, 4);
+    ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+    EXPECT_EQ(loop_descriptor.NumLoops(), 2u);
+
+    ir::Loop& outer_loop = loop_descriptor.GetLoopByIndex(1);
+
+    EXPECT_TRUE(outer_loop.HasUnrollLoopControl());
+
+    ir::Loop& inner_loop = loop_descriptor.GetLoopByIndex(0);
+
+    EXPECT_TRUE(inner_loop.HasUnrollLoopControl());
+
+    EXPECT_EQ(outer_loop.GetBlocks().size(), 9u);
+
+    EXPECT_EQ(inner_loop.GetBlocks().size(), 4u);
+
+    EXPECT_EQ(outer_loop.NumImmediateChildren(), 1u);
+    EXPECT_EQ(inner_loop.NumImmediateChildren(), 0u);
+
+    opt::LoopUtils loop_utils{context.get(), &inner_loop};
+    loop_utils.PartiallyUnroll(2);
+    loop_utils.Finalize();
+
+    // The number of loops should actually grow.
+    EXPECT_EQ(loop_descriptor.NumLoops(), 3u);
+    EXPECT_EQ(outer_loop.GetBlocks().size(), 19u);
+    EXPECT_EQ(outer_loop.NumImmediateChildren(), 2u);
+  }
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main(){
+  float x[10];
+  int i = 1;
+  i = 0;
+  for (; i < 10; i++) {
+    x[i] = i;
+  }
+}
+*/
+TEST_F(PassClassTest, UnrollWithInductionOutsideHeader) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 440
+OpName %main "main"
+OpName %x "x"
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%int_1 = OpConstant %int 1
+%int_0 = OpConstant %int 0
+%int_10 = OpConstant %int 10
+%bool = OpTypeBool
+%float = OpTypeFloat 32
+%uint = OpTypeInt 32 0
+%uint_10 = OpConstant %uint 10
+%_arr_float_uint_10 = OpTypeArray %float %uint_10
+%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10
+%_ptr_Function_float = OpTypePointer Function %float
+%main = OpFunction %void None %3
+%5 = OpLabel
+%x = OpVariable %_ptr_Function__arr_float_uint_10 Function
+OpBranch %11
+%11 = OpLabel
+%33 = OpPhi %int %int_0 %5 %32 %14
+OpLoopMerge %13 %14 None
+OpBranch %15
+%15 = OpLabel
+%19 = OpSLessThan %bool %33 %int_10
+OpBranchConditional %19 %12 %13
+%12 = OpLabel
+%28 = OpConvertSToF %float %33
+%30 = OpAccessChain %_ptr_Function_float %x %33
+OpStore %30 %28
+OpBranch %14
+%14 = OpLabel
+%32 = OpIAdd %int %33 %int_1
+OpBranch %11
+%13 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+const std::string expected = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 440
+OpName %main "main"
+OpName %x "x"
+%void = OpTypeVoid
+%5 = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%int_1 = OpConstant %int 1
+%int_0 = OpConstant %int 0
+%int_10 = OpConstant %int 10
+%bool = OpTypeBool
+%float = OpTypeFloat 32
+%uint = OpTypeInt 32 0
+%uint_10 = OpConstant %uint 10
+%_arr_float_uint_10 = OpTypeArray %float %uint_10
+%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10
+%_ptr_Function_float = OpTypePointer Function %float
+%main = OpFunction %void None %5
+%18 = OpLabel
+%x = OpVariable %_ptr_Function__arr_float_uint_10 Function
+OpBranch %19
+%19 = OpLabel
+%20 = OpPhi %int %int_0 %18 %37 %36
+OpLoopMerge %23 %36 None
+OpBranch %24
+%24 = OpLabel
+%25 = OpSLessThan %bool %20 %int_10
+OpBranchConditional %25 %26 %23
+%26 = OpLabel
+%27 = OpConvertSToF %float %20
+%28 = OpAccessChain %_ptr_Function_float %x %20
+OpStore %28 %27
+OpBranch %22
+%22 = OpLabel
+%21 = OpIAdd %int %20 %int_1
+OpBranch %29
+%29 = OpLabel
+OpBranch %31
+%31 = OpLabel
+%32 = OpSLessThan %bool %21 %int_10
+OpBranch %33
+%33 = OpLabel
+%34 = OpConvertSToF %float %21
+%35 = OpAccessChain %_ptr_Function_float %x %21
+OpStore %35 %34
+OpBranch %36
+%36 = OpLabel
+%37 = OpIAdd %int %21 %int_1
+OpBranch %19
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, expected, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+  float out_array[3];
+  for (int i = 3; i > 0; --i) {
+    out_array[i] = i;
+  }
+}
+*/
+TEST_F(PassClassTest, FullyUnrollNegativeStepLoopTest) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource GLSL 410
+               OpName %4 "main"
+               OpName %24 "out_array"
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Function %6
+          %9 = OpConstant %6 3
+         %16 = OpConstant %6 0
+         %17 = OpTypeBool
+         %19 = OpTypeFloat 32
+         %20 = OpTypeInt 32 0
+         %21 = OpConstant %20 3
+         %22 = OpTypeArray %19 %21
+         %23 = OpTypePointer Function %22
+         %28 = OpTypePointer Function %19
+         %31 = OpConstant %6 1
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+         %24 = OpVariable %23 Function
+               OpBranch %10
+         %10 = OpLabel
+         %33 = OpPhi %6 %9 %5 %32 %13
+               OpLoopMerge %12 %13 Unroll
+               OpBranch %14
+         %14 = OpLabel
+         %18 = OpSGreaterThan %17 %33 %16
+               OpBranchConditional %18 %11 %12
+         %11 = OpLabel
+         %27 = OpConvertSToF %19 %33
+         %29 = OpAccessChain %28 %24 %33
+               OpStore %29 %27
+               OpBranch %13
+         %13 = OpLabel
+         %32 = OpISub %6 %33 %31
+               OpBranch %10
+         %12 = OpLabel
+               OpReturn
+               OpFunctionEnd
+    )";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 3
+%9 = OpConstant %6 0
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 3
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%18 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %19
+%19 = OpLabel
+OpBranch %24
+%24 = OpLabel
+%25 = OpSGreaterThan %10 %8 %9
+OpBranch %26
+%26 = OpLabel
+%27 = OpConvertSToF %11 %8
+%28 = OpAccessChain %16 %3 %8
+OpStore %28 %27
+OpBranch %22
+%22 = OpLabel
+%21 = OpISub %6 %8 %17
+OpBranch %29
+%29 = OpLabel
+OpBranch %31
+%31 = OpLabel
+%32 = OpSGreaterThan %10 %21 %9
+OpBranch %33
+%33 = OpLabel
+%34 = OpConvertSToF %11 %21
+%35 = OpAccessChain %16 %3 %21
+OpStore %35 %34
+OpBranch %36
+%36 = OpLabel
+%37 = OpISub %6 %21 %17
+OpBranch %38
+%38 = OpLabel
+OpBranch %40
+%40 = OpLabel
+%41 = OpSGreaterThan %10 %37 %9
+OpBranch %42
+%42 = OpLabel
+%43 = OpConvertSToF %11 %37
+%44 = OpAccessChain %16 %3 %37
+OpStore %44 %43
+OpBranch %45
+%45 = OpLabel
+%46 = OpISub %6 %37 %17
+OpBranch %23
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+  float out_array[3];
+  for (int i = 9; i > 0; i-=3) {
+    out_array[i] = i;
+  }
+}
+*/
+TEST_F(PassClassTest, FullyUnrollNegativeNonOneStepLoop) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource GLSL 410
+               OpName %4 "main"
+               OpName %24 "out_array"
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Function %6
+          %9 = OpConstant %6 9
+         %16 = OpConstant %6 0
+         %17 = OpTypeBool
+         %19 = OpTypeFloat 32
+         %20 = OpTypeInt 32 0
+         %21 = OpConstant %20 3
+         %22 = OpTypeArray %19 %21
+         %23 = OpTypePointer Function %22
+         %28 = OpTypePointer Function %19
+         %30 = OpConstant %6 3
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+         %24 = OpVariable %23 Function
+               OpBranch %10
+         %10 = OpLabel
+         %33 = OpPhi %6 %9 %5 %32 %13
+               OpLoopMerge %12 %13 Unroll
+               OpBranch %14
+         %14 = OpLabel
+         %18 = OpSGreaterThan %17 %33 %16
+               OpBranchConditional %18 %11 %12
+         %11 = OpLabel
+         %27 = OpConvertSToF %19 %33
+         %29 = OpAccessChain %28 %24 %33
+               OpStore %29 %27
+               OpBranch %13
+         %13 = OpLabel
+         %32 = OpISub %6 %33 %30
+               OpBranch %10
+         %12 = OpLabel
+               OpReturn
+               OpFunctionEnd
+    )";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 9
+%9 = OpConstant %6 0
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 3
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 3
+%2 = OpFunction %4 None %5
+%18 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %19
+%19 = OpLabel
+OpBranch %24
+%24 = OpLabel
+%25 = OpSGreaterThan %10 %8 %9
+OpBranch %26
+%26 = OpLabel
+%27 = OpConvertSToF %11 %8
+%28 = OpAccessChain %16 %3 %8
+OpStore %28 %27
+OpBranch %22
+%22 = OpLabel
+%21 = OpISub %6 %8 %17
+OpBranch %29
+%29 = OpLabel
+OpBranch %31
+%31 = OpLabel
+%32 = OpSGreaterThan %10 %21 %9
+OpBranch %33
+%33 = OpLabel
+%34 = OpConvertSToF %11 %21
+%35 = OpAccessChain %16 %3 %21
+OpStore %35 %34
+OpBranch %36
+%36 = OpLabel
+%37 = OpISub %6 %21 %17
+OpBranch %38
+%38 = OpLabel
+OpBranch %40
+%40 = OpLabel
+%41 = OpSGreaterThan %10 %37 %9
+OpBranch %42
+%42 = OpLabel
+%43 = OpConvertSToF %11 %37
+%44 = OpAccessChain %16 %3 %37
+OpStore %44 %43
+OpBranch %45
+%45 = OpLabel
+%46 = OpISub %6 %37 %17
+OpBranch %23
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+  float out_array[3];
+  for (int i = 0; i < 7; i+=3) {
+    out_array[i] = i;
+  }
+}
+*/
+TEST_F(PassClassTest, FullyUnrollNonDivisibleStepLoop) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+OpSource GLSL 410
+OpName %4 "main"
+OpName %24 "out_array"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%9 = OpConstant %6 0
+%16 = OpConstant %6 7
+%17 = OpTypeBool
+%19 = OpTypeFloat 32
+%20 = OpTypeInt 32 0
+%21 = OpConstant %20 3
+%22 = OpTypeArray %19 %21
+%23 = OpTypePointer Function %22
+%28 = OpTypePointer Function %19
+%30 = OpConstant %6 3
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%24 = OpVariable %23 Function
+OpBranch %10
+%10 = OpLabel
+%33 = OpPhi %6 %9 %5 %32 %13
+OpLoopMerge %12 %13 Unroll
+OpBranch %14
+%14 = OpLabel
+%18 = OpSLessThan %17 %33 %16
+OpBranchConditional %18 %11 %12
+%11 = OpLabel
+%27 = OpConvertSToF %19 %33
+%29 = OpAccessChain %28 %24 %33
+OpStore %29 %27
+OpBranch %13
+%13 = OpLabel
+%32 = OpIAdd %6 %33 %30
+OpBranch %10
+%12 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 7
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 3
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 3
+%2 = OpFunction %4 None %5
+%18 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %19
+%19 = OpLabel
+OpBranch %24
+%24 = OpLabel
+%25 = OpSLessThan %10 %8 %9
+OpBranch %26
+%26 = OpLabel
+%27 = OpConvertSToF %11 %8
+%28 = OpAccessChain %16 %3 %8
+OpStore %28 %27
+OpBranch %22
+%22 = OpLabel
+%21 = OpIAdd %6 %8 %17
+OpBranch %29
+%29 = OpLabel
+OpBranch %31
+%31 = OpLabel
+%32 = OpSLessThan %10 %21 %9
+OpBranch %33
+%33 = OpLabel
+%34 = OpConvertSToF %11 %21
+%35 = OpAccessChain %16 %3 %21
+OpStore %35 %34
+OpBranch %36
+%36 = OpLabel
+%37 = OpIAdd %6 %21 %17
+OpBranch %38
+%38 = OpLabel
+OpBranch %40
+%40 = OpLabel
+%41 = OpSLessThan %10 %37 %9
+OpBranch %42
+%42 = OpLabel
+%43 = OpConvertSToF %11 %37
+%44 = OpAccessChain %16 %3 %37
+OpStore %44 %43
+OpBranch %45
+%45 = OpLabel
+%46 = OpIAdd %6 %37 %17
+OpBranch %23
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+  float out_array[4];
+  for (int i = 11; i > 0; i-=3) {
+    out_array[i] = i;
+  }
+}
+*/
+TEST_F(PassClassTest, FullyUnrollNegativeNonDivisibleStepLoop) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+OpSource GLSL 410
+OpName %4 "main"
+OpName %24 "out_array"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%9 = OpConstant %6 11
+%16 = OpConstant %6 0
+%17 = OpTypeBool
+%19 = OpTypeFloat 32
+%20 = OpTypeInt 32 0
+%21 = OpConstant %20 4
+%22 = OpTypeArray %19 %21
+%23 = OpTypePointer Function %22
+%28 = OpTypePointer Function %19
+%30 = OpConstant %6 3
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%24 = OpVariable %23 Function
+OpBranch %10
+%10 = OpLabel
+%33 = OpPhi %6 %9 %5 %32 %13
+OpLoopMerge %12 %13 Unroll
+OpBranch %14
+%14 = OpLabel
+%18 = OpSGreaterThan %17 %33 %16
+OpBranchConditional %18 %11 %12
+%11 = OpLabel
+%27 = OpConvertSToF %19 %33
+%29 = OpAccessChain %28 %24 %33
+OpStore %29 %27
+OpBranch %13
+%13 = OpLabel
+%32 = OpISub %6 %33 %30
+OpBranch %10
+%12 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 11
+%9 = OpConstant %6 0
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 4
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 3
+%2 = OpFunction %4 None %5
+%18 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %19
+%19 = OpLabel
+OpBranch %24
+%24 = OpLabel
+%25 = OpSGreaterThan %10 %8 %9
+OpBranch %26
+%26 = OpLabel
+%27 = OpConvertSToF %11 %8
+%28 = OpAccessChain %16 %3 %8
+OpStore %28 %27
+OpBranch %22
+%22 = OpLabel
+%21 = OpISub %6 %8 %17
+OpBranch %29
+%29 = OpLabel
+OpBranch %31
+%31 = OpLabel
+%32 = OpSGreaterThan %10 %21 %9
+OpBranch %33
+%33 = OpLabel
+%34 = OpConvertSToF %11 %21
+%35 = OpAccessChain %16 %3 %21
+OpStore %35 %34
+OpBranch %36
+%36 = OpLabel
+%37 = OpISub %6 %21 %17
+OpBranch %38
+%38 = OpLabel
+OpBranch %40
+%40 = OpLabel
+%41 = OpSGreaterThan %10 %37 %9
+OpBranch %42
+%42 = OpLabel
+%43 = OpConvertSToF %11 %37
+%44 = OpAccessChain %16 %3 %37
+OpStore %44 %43
+OpBranch %45
+%45 = OpLabel
+%46 = OpISub %6 %37 %17
+OpBranch %47
+%47 = OpLabel
+OpBranch %49
+%49 = OpLabel
+%50 = OpSGreaterThan %10 %46 %9
+OpBranch %51
+%51 = OpLabel
+%52 = OpConvertSToF %11 %46
+%53 = OpAccessChain %16 %3 %46
+OpStore %53 %52
+OpBranch %54
+%54 = OpLabel
+%55 = OpISub %6 %46 %17
+OpBranch %23
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+}  // namespace
index 4525e30..fb8b3a1 100644 (file)
@@ -472,6 +472,8 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer,
         optimizer->RegisterPass(CreateReplaceInvalidOpcodePass());
       } else if (0 == strcmp(cur_arg, "--simplify-instructions")) {
         optimizer->RegisterPass(CreateSimplificationPass());
+      } else if (0 == strcmp(cur_arg, "--loop-unroll")) {
+        optimizer->RegisterPass(CreateLoopFullyUnrollPass());
       } else if (0 == strcmp(cur_arg, "--skip-validation")) {
         *skip_validator = true;
       } else if (0 == strcmp(cur_arg, "-O")) {