Unroller support for multiple induction variables
authorStephen McGroarty <stephen@codeplay.com>
Tue, 27 Feb 2018 11:50:08 +0000 (11:50 +0000)
committerStephen McGroarty <stephen@codeplay.com>
Tue, 27 Feb 2018 11:50:08 +0000 (11:50 +0000)
Support for multiple induction variables within a loop and support for
loop condition operands <= and >=.

include/spirv-tools/optimizer.hpp
source/opt/ir_builder.h
source/opt/loop_descriptor.cpp
source/opt/loop_descriptor.h
source/opt/loop_unroller.cpp
source/opt/loop_unroller.h
source/opt/loop_utils.h
source/opt/optimizer.cpp
test/opt/loop_optimizations/unroll_assumptions.cpp
test/opt/loop_optimizations/unroll_simple.cpp
tools/opt/opt.cpp

index 9f3b360..adb014a 100644 (file)
@@ -513,12 +513,12 @@ Optimizer::PassToken CreateReplaceInvalidOpcodePass();
 Optimizer::PassToken CreateSimplificationPass();
 
 // Create loop unroller pass.
-// Creates a pass to fully unroll loops which have the "Unroll" loop control
+// Creates a pass to unroll loops which have the "Unroll" loop control
 // mask set. The loops must meet a specific criteria in order to be unrolled
 // safely this criteria is checked before doing the unroll by the
 // LoopUtils::CanPerformUnroll method. Any loop that does not meet the criteria
 // won't be unrolled. See CanPerformUnroll LoopUtils.h for more information.
-Optimizer::PassToken CreateLoopFullyUnrollPass();
+Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor = 0);
 
 }  // namespace spvtools
 
index 3b75ec6..a1a1d1e 100644 (file)
@@ -172,11 +172,20 @@ class InstructionBuilder {
     // Assert that we are not trying to store a negative number in an unsigned
     // type.
     if (!sign)
-      assert(value > 0 &&
+      assert(value >= 0 &&
              "Trying to add a signed integer with an unsigned type!");
 
-    // Get or create the integer type.
-    analysis::Integer int_type(32, sign);
+    analysis::Integer int_type{32, sign};
+
+    // Get or create the integer type. This rebuilds the type and manages the
+    // memory for the rebuilt type.
+    uint32_t type_id =
+        GetContext()->get_type_mgr()->GetTypeInstruction(&int_type);
+
+    // Get the memory managed type so that it is safe to be stored by
+    // GetConstant.
+    analysis::Type* rebuilt_type =
+        GetContext()->get_type_mgr()->GetType(type_id);
 
     // Even if the value is negative we need to pass the bit pattern as a
     // uint32_t to GetConstant.
@@ -184,7 +193,7 @@ class InstructionBuilder {
 
     // Create the constant value.
     const opt::analysis::Constant* constant =
-        GetContext()->get_constant_mgr()->GetConstant(&int_type, {word});
+        GetContext()->get_constant_mgr()->GetConstant(rebuilt_type, {word});
 
     // Create the OpConstant instruction using the type and the value.
     return GetContext()->get_constant_mgr()->GetDefiningInstruction(constant);
index 32765be..131363c 100644 (file)
@@ -33,7 +33,7 @@ namespace ir {
 // Takes in a phi instruction |induction| and the loop |header| and returns the
 // step operation of the loop.
 ir::Instruction* Loop::GetInductionStepOperation(
-    const ir::Loop* loop, const ir::Instruction* induction) const {
+    const ir::Instruction* induction) const {
   // Induction must be a phi instruction.
   assert(induction->opcode() == SpvOpPhi);
 
@@ -50,7 +50,7 @@ ir::Instruction* Loop::GetInductionStepOperation(
 
     // Check if the block is dominated by header, and thus coming from within
     // the loop.
-    if (loop->IsInsideLoop(incoming_block)) {
+    if (IsInsideLoop(incoming_block)) {
       step = def_use_manager->GetDef(
           induction->GetSingleWordInOperand(operand_id - 1));
       break;
@@ -61,6 +61,21 @@ ir::Instruction* Loop::GetInductionStepOperation(
     return nullptr;
   }
 
+  // The induction variable which binds the loop must only be modified once.
+  uint32_t lhs = step->GetSingleWordInOperand(0);
+  uint32_t rhs = step->GetSingleWordInOperand(1);
+
+  // One of the left hand side or right hand side of the step instruction must
+  // be the induction phi and the other must be an OpConstant.
+  if (lhs != induction->result_id() && rhs != induction->result_id()) {
+    return nullptr;
+  }
+
+  if (def_use_manager->GetDef(lhs)->opcode() != SpvOp::SpvOpConstant &&
+      def_use_manager->GetDef(rhs)->opcode() != SpvOp::SpvOpConstant) {
+    return nullptr;
+  }
+
   return step;
 }
 
@@ -84,17 +99,52 @@ bool Loop::IsSupportedCondition(SpvOp condition) const {
     // >
     case SpvOp::SpvOpUGreaterThan:
     case SpvOp::SpvOpSGreaterThan:
+
+    // >=
+    case SpvOp::SpvOpSGreaterThanEqual:
+    case SpvOp::SpvOpUGreaterThanEqual:
+    // <=
+    case SpvOp::SpvOpSLessThanEqual:
+    case SpvOp::SpvOpULessThanEqual:
+
       return true;
     default:
       return false;
   }
 }
 
+int64_t Loop::GetResidualConditionValue(SpvOp condition, int64_t initial_value,
+                                        int64_t step_value,
+                                        size_t number_of_iterations,
+                                        size_t factor) {
+  int64_t remainder =
+      initial_value + (number_of_iterations % factor) * step_value;
+
+  // We subtract or add one as the above formula calculates the remainder if the
+  // loop where just less than or greater than. Adding or subtracting one should
+  // give a functionally equivalent value.
+  switch (condition) {
+    case SpvOp::SpvOpSGreaterThanEqual:
+    case SpvOp::SpvOpUGreaterThanEqual: {
+      remainder -= 1;
+      break;
+    }
+    case SpvOp::SpvOpSLessThanEqual:
+    case SpvOp::SpvOpULessThanEqual: {
+      remainder += 1;
+      break;
+    }
+
+    default:
+      break;
+  }
+  return remainder;
+}
+
 // Extract the initial value from the |induction| OpPhi instruction and store it
 // in |value|. If the function couldn't find the initial value of |induction|
 // return false.
-bool Loop::GetInductionInitValue(const ir::Loop* loop,
-                                 const ir::Instruction* induction,
+bool Loop::GetInductionInitValue(const ir::Instruction* induction,
                                  int64_t* value) const {
   ir::Instruction* constant_instruction = nullptr;
   opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
@@ -104,7 +154,7 @@ bool Loop::GetInductionInitValue(const ir::Loop* loop,
     ir::BasicBlock* bb = context_->cfg()->block(
         induction->GetSingleWordInOperand(operand_id + 1));
 
-    if (!loop->IsInsideLoop(bb)) {
+    if (!IsInsideLoop(bb)) {
       constant_instruction = def_use_manager->GetDef(
           induction->GetSingleWordInOperand(operand_id));
     }
@@ -413,6 +463,25 @@ bool Loop::AreAllOperandsOutsideLoop(IRContext* context, Instruction* inst) {
   return all_outside_loop;
 }
 
+void Loop::ComputeLoopStructuredOrder(
+    std::vector<ir::BasicBlock*>* ordered_loop_blocks, bool include_pre_header,
+    bool include_merge) const {
+  ir::CFG& cfg = *context_->cfg();
+
+  // Reserve the memory: all blocks in the loop + extra if needed.
+  ordered_loop_blocks->reserve(GetBlocks().size() + include_pre_header +
+                               include_merge);
+
+  if (include_pre_header && GetPreHeaderBlock())
+    ordered_loop_blocks->push_back(loop_preheader_);
+  cfg.ForEachBlockInReversePostOrder(
+      loop_header_, [ordered_loop_blocks, this](BasicBlock* bb) {
+        if (IsInsideLoop(bb)) ordered_loop_blocks->push_back(bb);
+      });
+  if (include_merge && GetMergeBlock())
+    ordered_loop_blocks->push_back(loop_merge_);
+}
+
 LoopDescriptor::LoopDescriptor(const Function* f) : loops_() {
   PopulateList(f);
 }
@@ -550,7 +619,7 @@ bool Loop::FindNumberOfIterations(const ir::Instruction* induction,
   }
 
   // Find the instruction which is stepping through the loop.
-  ir::Instruction* step_inst = GetInductionStepOperation(this, induction);
+  ir::Instruction* step_inst = GetInductionStepOperation(induction);
   if (!step_inst) return false;
 
   // Find the constant value used by the condition variable.
@@ -577,17 +646,18 @@ bool Loop::FindNumberOfIterations(const ir::Instruction* induction,
 
   // Find the inital value of the loop and make sure it is a constant integer.
   int64_t init_value = 0;
-  if (!GetInductionInitValue(this, induction, &init_value)) return false;
+  if (!GetInductionInitValue(induction, &init_value)) return false;
 
   // If iterations is non null then store the value in that.
-  if (iterations_out) {
-    int64_t num_itrs = GetIterations(condition->opcode(), condition_value,
-                                     init_value, step_value);
+  int64_t num_itrs = GetIterations(condition->opcode(), condition_value,
+                                   init_value, step_value);
 
-    // If the loop body will not be reached return false.
-    if (num_itrs <= 0) {
-      return false;
-    }
+  // If the loop body will not be reached return false.
+  if (num_itrs <= 0) {
+    return false;
+  }
+
+  if (iterations_out) {
     assert(static_cast<size_t>(num_itrs) <= std::numeric_limits<size_t>::max());
     *iterations_out = static_cast<size_t>(num_itrs);
   }
@@ -611,26 +681,87 @@ int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value,
                             int64_t init_value, int64_t step_value) const {
   int64_t diff = 0;
 
-  // Take the abs of - step values.
-  step_value = llabs(step_value);
-
   switch (condition) {
     case SpvOp::SpvOpSLessThan:
     case SpvOp::SpvOpULessThan: {
+      // If the condition is not met to begin with the loop will never iterate.
+      if (!(init_value < condition_value)) return 0;
+
       diff = condition_value - init_value;
+
+      // If the operation is a less then operation then the diff and step must
+      // have the same sign otherwise the induction will never cross the
+      // condition (either never true or always true).
+      if ((diff < 0 && step_value > 0) || (diff > 0 && step_value < 0)) {
+        return 0;
+      }
+
       break;
     }
     case SpvOp::SpvOpSGreaterThan:
     case SpvOp::SpvOpUGreaterThan: {
+      // If the condition is not met to begin with the loop will never iterate.
+      if (!(init_value > condition_value)) return 0;
+
       diff = init_value - condition_value;
+
+      // If the operation is a greater than operation then the diff and step
+      // must have opposite signs. Otherwise the condition will always be true
+      // or will never be true.
+      if ((diff < 0 && step_value < 0) || (diff > 0 && step_value > 0)) {
+        return 0;
+      }
+
+      break;
+    }
+
+    case SpvOp::SpvOpSGreaterThanEqual:
+    case SpvOp::SpvOpUGreaterThanEqual: {
+      // If the condition is not met to begin with the loop will never iterate.
+      if (!(init_value >= condition_value)) return 0;
+
+      // We subract one to make it the same as SpvOpGreaterThan as it is
+      // functionally equivalent.
+      diff = init_value - (condition_value - 1);
+
+      // If the operation is a greater than operation then the diff and step
+      // must have opposite signs. Otherwise the condition will always be true
+      // or will never be true.
+      if ((diff > 0 && step_value > 0) || (diff < 0 && step_value < 0)) {
+        return 0;
+      }
+
+      break;
+    }
+
+    case SpvOp::SpvOpSLessThanEqual:
+    case SpvOp::SpvOpULessThanEqual: {
+      // If the condition is not met to begin with the loop will never iterate.
+      if (!(init_value <= condition_value)) return 0;
+
+      // We add one to make it the same as SpvOpLessThan as it is functionally
+      // equivalent.
+      diff = (condition_value + 1) - init_value;
+
+      // If the operation is a less than operation then the diff and step must
+      // have the same sign otherwise the induction will never cross the
+      // condition (either never true or always true).
+      if ((diff < 0 && step_value > 0) || (diff > 0 && step_value < 0)) {
+        return 0;
+      }
+
       break;
     }
+
     default:
       assert(false &&
              "Could not retrieve number of iterations from the loop condition. "
              "Condition is not supported.");
   }
 
+  // Take the abs of - step values.
+  step_value = llabs(step_value);
+  diff = llabs(diff);
   int64_t result = diff / step_value;
 
   if (diff % step_value != 0) {
@@ -639,7 +770,17 @@ int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value,
   return result;
 }
 
-ir::Instruction* Loop::FindInductionVariable(
+// Returns the list of induction variables within the loop.
+void Loop::GetInductionVariables(
+    std::vector<ir::Instruction*>& induction_variables) const {
+  for (ir::Instruction& inst : *loop_header_) {
+    if (inst.opcode() == SpvOp::SpvOpPhi) {
+      induction_variables.push_back(&inst);
+    }
+  }
+}
+
+ir::Instruction* Loop::FindConditionVariable(
     const ir::BasicBlock* condition_block) const {
   // Find the branch instruction.
   const ir::Instruction& branch_inst = *condition_block->ctail();
index ef5f19f..d0421d6 100644 (file)
@@ -207,10 +207,13 @@ class Loop {
     AddBasicBlock(bb);
   }
 
-  // This function uses the |condition| to find the induction variable within
-  // the loop. This only works if the loop is bound by a single condition and a
-  // single induction variable.
-  ir::Instruction* FindInductionVariable(const ir::BasicBlock* condition) const;
+  // Returns the list of induction variables within the loop.
+  void GetInductionVariables(std::vector<ir::Instruction*>& inductions) const;
+
+  // This function uses the |condition| to find the induction variable which is
+  // used by the loop condition within the loop. This only works if the loop is
+  // bound by a single condition and single induction variable.
+  ir::Instruction* FindConditionVariable(const ir::BasicBlock* condition) const;
 
   // Returns the number of iterations within a loop when given the |induction|
   // variable and the loop |condition| check. It stores the found number of
@@ -275,14 +278,13 @@ class Loop {
   // Extract the initial value from the |induction| variable and store it in
   // |value|. If the function couldn't find the initial value of |induction|
   // return false.
-  bool GetInductionInitValue(const ir::Loop* loop,
-                             const ir::Instruction* induction,
+  bool GetInductionInitValue(const ir::Instruction* induction,
                              int64_t* value) const;
 
   // Takes in a phi instruction |induction| and the loop |header| and returns
   // the step operation of the loop.
   ir::Instruction* GetInductionStepOperation(
-      const ir::Loop* loop, const ir::Instruction* induction) const;
+      const ir::Instruction* induction) const;
 
   // Returns true if we can deduce the number of loop iterations in the step
   // operation |step|. IsSupportedCondition must also be true for the condition
@@ -294,6 +296,24 @@ class Loop {
   // the step instruction.
   bool IsSupportedCondition(SpvOp condition) const;
 
+  // Creates the list of the loop's basic block in structured order and store
+  // the result in |ordered_loop_blocks|. If |include_pre_header| is true, the
+  // pre-header block will also be included at the beginning of the list if it
+  // exist. If |include_merge| is true, the merge block will also be included at
+  // the end of the list if it exist.
+  void ComputeLoopStructuredOrder(
+      std::vector<ir::BasicBlock*>* ordered_loop_blocks,
+      bool include_pre_header = false, bool include_merge = false) const;
+
+  // Given the loop |condition|, |initial_value|, |step_value|, the trip count
+  // |number_of_iterations|, and the |unroll_factor| requested, get the new
+  // condition value for the residual loop.
+  static int64_t GetResidualConditionValue(SpvOp condition,
+                                           int64_t initial_value,
+                                           int64_t step_value,
+                                           size_t number_of_iterations,
+                                           size_t unroll_factor);
+
  private:
   IRContext* context_;
   // The block which marks the start of the loop.
index 16031e2..39aa108 100644 (file)
@@ -62,6 +62,12 @@ namespace spvtools {
 namespace opt {
 namespace {
 
+// Loop control constant value for DontUnroll flag.
+static const uint32_t kLoopControlDontUnrollIndex = 2;
+
+// Operand index of the loop control parameter of the OpLoopMerge.
+static const uint32_t kLoopControlIndex = 2;
+
 // This utility class encapsulates some of the state we need to maintain between
 // loop unrolls. Specifically it maintains key blocks and the induction variable
 // in the current loop duplication step and the blocks from the previous one.
@@ -79,20 +85,24 @@ struct LoopUnrollState {
 
   // Initialize from the loop descriptor class.
   LoopUnrollState(ir::Instruction* induction, ir::BasicBlock* continue_block,
-                  ir::BasicBlock* condition)
+                  ir::BasicBlock* condition,
+                  std::vector<ir::Instruction*>&& phis)
       : previous_phi_(induction),
         previous_continue_block_(continue_block),
         previous_condition_block_(condition),
         new_phi(nullptr),
         new_continue_block(nullptr),
         new_condition_block(nullptr),
-        new_header_block(nullptr) {}
+        new_header_block(nullptr) {
+    previous_phis_ = std::move(phis);
+  }
 
   // Swap the state so that the new nodes are now the previous nodes.
   void NextIterationState() {
     previous_phi_ = new_phi;
     previous_continue_block_ = new_continue_block;
     previous_condition_block_ = new_condition_block;
+    previous_phis_ = std::move(new_phis_);
 
     // Clear new nodes.
     new_phi = nullptr;
@@ -103,11 +113,16 @@ struct LoopUnrollState {
     // Clear new block/instruction maps.
     new_blocks.clear();
     new_inst.clear();
+    ids_to_new_inst.clear();
   }
 
   // The induction variable from the immediately preceding loop body.
   ir::Instruction* previous_phi_;
 
+  // All the phi nodes from the previous loop iteration.
+  std::vector<ir::Instruction*> previous_phis_;
+
+  std::vector<ir::Instruction*> new_phis_;
   // The previous continue block. The backedge will be removed from this and
   // added to the new continue block.
   ir::BasicBlock* previous_continue_block_;
@@ -131,9 +146,11 @@ struct LoopUnrollState {
   // from.
   std::unordered_map<uint32_t, ir::BasicBlock*> new_blocks;
 
-  // A mapping of new instruction ids to the instruction ids from which they
-  // were copied.
+  // A mapping of the original instruction ids to the instruction ids to their
+  // copies.
   std::unordered_map<uint32_t, uint32_t> new_inst;
+
+  std::unordered_map<uint32_t, ir::Instruction*> ids_to_new_inst;
 };
 
 // This class implements the actual unrolling. It uses a LoopUnrollState to
@@ -195,6 +212,22 @@ class LoopUnrollerUtilsImpl {
   // Extracts the initial state information from the |loop|.
   void Init(ir::Loop* loop);
 
+  // Replace the uses of each induction variable outside the loop with the final
+  // value of the induction variable before the loop exit. To reflect the proper
+  // state of a fully unrolled loop.
+  void ReplaceInductionUseWithFinalValue(ir::Loop* loop);
+
+  // Remove all the instructions in the invalidated_instructions_ vector.
+  void RemoveDeadInstructions();
+
+  // Replace any use of induction variables outwith the loop with the final
+  // value of the induction variable in the unrolled loop.
+  void ReplaceOutsideLoopUseWithFinalValue(ir::Loop* loop);
+
+  // Set the LoopControl operand of the OpLoopMerge instruction to be
+  // DontUnroll.
+  void MarkLoopControlAsDontUnroll(ir::Loop* loop) const;
+
  private:
   // Remap all the in |basic_block| to new IDs and keep the mapping of new ids
   // to old
@@ -234,6 +267,12 @@ class LoopUnrollerUtilsImpl {
   // the parent exists.
   void AddBlocksToLoop(ir::Loop* loop) const;
 
+  // After the partially unroll step the phi instructions in the header block
+  // will be in an illegal format. This function makes the phis legal by making
+  // the edge from the latch block come from the new latch block and the value
+  // to be the actual value of the phi at that point.
+  void LinkLastPhisToStart(ir::Loop* loop) const;
+
   // A pointer to the IRContext. Used to add/remove instructions and for usedef
   // chains.
   ir::IRContext* context_;
@@ -246,7 +285,7 @@ class LoopUnrollerUtilsImpl {
   BasicBlockListTy blocks_to_add_;
 
   // List of instructions which are now dead and can be removed.
-  std::vector<ir::Instruction*> dead_instructions_;
+  std::vector<ir::Instruction*> invalidated_instructions_;
 
   // Maintains the current state of the transform between calls to unroll.
   LoopUnrollState state_;
@@ -261,6 +300,10 @@ class LoopUnrollerUtilsImpl {
   // The induction variable of the loop.
   ir::Instruction* loop_induction_variable_;
 
+  // Phis used in the loop need to be remapped to use the actual result values
+  // and then be remapped at the end.
+  std::vector<ir::Instruction*> loop_phi_instructions_;
+
   // The number of loop iterations that the loop would preform pre-unroll.
   size_t number_of_loop_iterations_;
 
@@ -300,7 +343,7 @@ void LoopUnrollerUtilsImpl::Init(ir::Loop* loop) {
   }
   assert(loop_condition_block_);
 
-  loop_induction_variable_ = loop->FindInductionVariable(loop_condition_block_);
+  loop_induction_variable_ = loop->FindConditionVariable(loop_condition_block_);
   assert(loop_induction_variable_);
 
   bool found = loop->FindNumberOfIterations(
@@ -308,6 +351,9 @@ void LoopUnrollerUtilsImpl::Init(ir::Loop* loop) {
       &number_of_loop_iterations_, &loop_step_value_, &loop_init_value_);
   (void)found;  // To silence unused variable warning on release builds.
   assert(found);
+
+  // Blocks are stored in an unordered set of ids in the loop class, we need to
+  // create the dominator ordered list.
   ComputeLoopOrderedBlocks(loop);
 }
 
@@ -318,7 +364,6 @@ void LoopUnrollerUtilsImpl::Init(ir::Loop* loop) {
 // number of bodies.
 void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop,
                                                           size_t factor) {
-  // Create a new merge block for the first loop.
   std::unique_ptr<ir::Instruction> new_label{new ir::Instruction(
       context_, SpvOp::SpvOpLabel, 0, context_->TakeNextId(), {})};
   std::unique_ptr<ir::BasicBlock> new_exit_bb{
@@ -332,7 +377,6 @@ void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop,
   blocks_to_add_.push_back(std::move(new_exit_bb));
   ir::BasicBlock* new_exit_bb_raw = blocks_to_add_[0].get();
   ir::Instruction& original_conditional_branch = *loop_condition_block_->tail();
-
   // Duplicate the loop, providing access to the blocks of both loops.
   // This is a naked new due to the VS2013 requirement of not having unique
   // pointers in vectors, as it will be inserted into a vector with
@@ -348,20 +392,45 @@ void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop,
   AddBlocksToFunction(loop->GetMergeBlock());
   blocks_to_add_.clear();
 
+  // Create a new merge block for the first loop.
   InstructionBuilder builder{context_, new_exit_bb_raw};
   // Make the first loop branch to the second.
   builder.AddBranch(new_loop->GetHeaderBlock()->id());
 
   loop_condition_block_ = state_.new_condition_block;
   loop_induction_variable_ = state_.new_phi;
-
   // Unroll the new loop by the factor with the usual -1 to account for the
   // existing block iteration.
   Unroll(new_loop, factor);
 
+  LinkLastPhisToStart(new_loop);
+  AddBlocksToLoop(new_loop);
+
+  // Add the new merge block to the back of the list of blocks to be added. It
+  // needs to be the last block added to maintain dominator order in the binary.
+  blocks_to_add_.push_back(
+      std::unique_ptr<ir::BasicBlock>(new_loop->GetMergeBlock()));
+
+  // Add the blocks to the function.
+  AddBlocksToFunction(loop->GetMergeBlock());
+
+  // Reset the usedef analysis.
+  context_->InvalidateAnalysesExceptFor(
+      ir::IRContext::Analysis::kAnalysisLoopAnalysis);
+  opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+  // The loop condition.
+  ir::Instruction* condition_check = def_use_manager->GetDef(
+      original_conditional_branch.GetSingleWordOperand(0));
+
+  // This should have been checked by the LoopUtils::CanPerformUnroll function
+  // before entering this.
+  assert(loop->IsSupportedCondition(condition_check->opcode()));
+
   // We need to account for the initial body when calculating the remainder.
-  int64_t remainder = loop_init_value_ +
-                      (number_of_loop_iterations_ % factor) * loop_step_value_;
+  int64_t remainder = ir::Loop::GetResidualConditionValue(
+      condition_check->opcode(), loop_init_value_, loop_step_value_,
+      number_of_loop_iterations_, factor);
 
   assert(remainder > std::numeric_limits<int32_t>::min() &&
          remainder < std::numeric_limits<int32_t>::max());
@@ -380,36 +449,45 @@ void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop,
 
   uint32_t constant_id = new_constant->result_id();
 
-  // Add the merge block to the back of the binary.
-  blocks_to_add_.push_back(
-      std::unique_ptr<ir::BasicBlock>(new_loop->GetMergeBlock()));
-
-  AddBlocksToLoop(new_loop);
-  // Add the blocks to the function.
-  AddBlocksToFunction(loop->GetMergeBlock());
-
-  // Reset the usedef analysis.
-  context_->InvalidateAnalysesExceptFor(
-      ir::IRContext::Analysis::kAnalysisLoopAnalysis);
-  opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
-
   // Update the condition check.
-  ir::Instruction* condition_check = def_use_manager->GetDef(
-      original_conditional_branch.GetSingleWordOperand(0));
-
-  // This should have been checked by the LoopUtils::CanPerformUnroll function
-  // before entering this.
-  assert(condition_check->opcode() == SpvOpSLessThan);
   condition_check->SetInOperand(1, {constant_id});
 
   // Update the next phi node. The phi will have a constant value coming in from
   // the preheader block. For the duplicated loop we need to update the constant
   // to be the amount of iterations covered by the first loop and the incoming
   // block to be the first loops new merge block.
-  uint32_t phi_incoming_index =
-      GetPhiIndexFromLabel(loop->GetPreHeaderBlock(), loop_induction_variable_);
-  loop_induction_variable_->SetInOperand(phi_incoming_index - 1, {constant_id});
-  loop_induction_variable_->SetInOperand(phi_incoming_index, {new_merge_id});
+  std::vector<ir::Instruction*> new_inductions;
+  new_loop->GetInductionVariables(new_inductions);
+
+  std::vector<ir::Instruction*> old_inductions;
+  loop->GetInductionVariables(old_inductions);
+  for (size_t index = 0; index < new_inductions.size(); ++index) {
+    ir::Instruction* new_induction = new_inductions[index];
+    ir::Instruction* old_induction = old_inductions[index];
+    // Get the index of the loop initalizer, the value coming in from the
+    // preheader.
+    uint32_t initalizer_index =
+        GetPhiIndexFromLabel(new_loop->GetPreHeaderBlock(), old_induction);
+
+    // Replace the second loop initalizer with the phi from the first
+    new_induction->SetInOperand(initalizer_index - 1,
+                                {old_induction->result_id()});
+    new_induction->SetInOperand(initalizer_index, {new_merge_id});
+
+    // If the use of the first loop induction variable is outside of the loop
+    // then replace that use with the second loop induction variable.
+    uint32_t second_loop_induction = new_induction->result_id();
+    auto replace_use_outside_of_loop = [loop, second_loop_induction](
+                                           ir::Instruction* user,
+                                           uint32_t operand_index) {
+      if (!loop->IsInsideLoop(user)) {
+        user->SetOperand(operand_index, {second_loop_induction});
+      }
+    };
+
+    context_->get_def_use_mgr()->ForEachUse(old_induction,
+                                            replace_use_outside_of_loop);
+  }
 
   context_->InvalidateAnalysesExceptFor(
       ir::IRContext::Analysis::kAnalysisLoopAnalysis);
@@ -420,36 +498,58 @@ void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop,
       *context_->GetLoopDescriptor(&function_);
 
   loop_descriptor.AddLoop(new_loop, loop->GetParent());
+
+  RemoveDeadInstructions();
 }
 
-// Duplicate the |loop| body |factor| number of times while keeping the loop
-// backedge intact.
-void LoopUnrollerUtilsImpl::PartiallyUnroll(ir::Loop* loop, size_t factor) {
-  Unroll(loop, factor);
-  AddBlocksToLoop(loop);
-  AddBlocksToFunction(loop->GetMergeBlock());
+// Mark this loop as DontUnroll as it will already be unrolled and it may not
+// be safe to unroll a previously partially unrolled loop.
+void LoopUnrollerUtilsImpl::MarkLoopControlAsDontUnroll(ir::Loop* loop) const {
+  ir::Instruction* loop_merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst();
+  assert(loop_merge_inst &&
+         "Loop merge instruction could not be found after entering unroller "
+         "(should have exited before this)");
+  loop_merge_inst->SetInOperand(kLoopControlIndex,
+                                {kLoopControlDontUnrollIndex});
 }
 
-// Duplicate the |loop| body |factor| number of times while keeping the loop
-// backedge intact.
+// Duplicate the |loop| body |factor| - 1 number of times while keeping the loop
+// backedge intact. This will leave the loop with |factor| number of bodies
+// after accounting for the initial body.
 void LoopUnrollerUtilsImpl::Unroll(ir::Loop* loop, size_t factor) {
+  // If we unroll a loop partially it will not be safe to unroll it further.
+  // This is due to the current method of calculating the number of loop
+  // iterations.
+  MarkLoopControlAsDontUnroll(loop);
+
+  std::vector<ir::Instruction*> inductions;
+  loop->GetInductionVariables(inductions);
   state_ = LoopUnrollState{loop_induction_variable_, loop->GetLatchBlock(),
-                           loop_condition_block_};
+                           loop_condition_block_, std::move(inductions)};
   for (size_t i = 0; i < factor - 1; ++i) {
     CopyBody(loop, true);
   }
+}
 
-  uint32_t phi_index = GetPhiIndexFromLabel(state_.previous_continue_block_,
-                                            state_.previous_phi_);
-  uint32_t phi_variable =
-      state_.previous_phi_->GetSingleWordInOperand(phi_index - 1);
-  uint32_t phi_label = state_.previous_phi_->GetSingleWordInOperand(phi_index);
-
-  ir::Instruction* original_phi = loop_induction_variable_;
+void LoopUnrollerUtilsImpl::RemoveDeadInstructions() {
+  // Remove the dead instructions.
+  for (ir::Instruction* inst : invalidated_instructions_) {
+    context_->KillInst(inst);
+  }
+}
 
-  // SetInOperands are offset by two.
-  original_phi->SetInOperand(phi_index - 1, {phi_variable});
-  original_phi->SetInOperand(phi_index, {phi_label});
+void LoopUnrollerUtilsImpl::ReplaceInductionUseWithFinalValue(ir::Loop* loop) {
+  context_->InvalidateAnalysesExceptFor(
+      ir::IRContext::Analysis::kAnalysisLoopAnalysis);
+  std::vector<ir::Instruction*> inductions;
+  loop->GetInductionVariables(inductions);
+
+  for (size_t index = 0; index < inductions.size(); ++index) {
+    uint32_t trip_step_id = GetPhiDefID(state_.previous_phis_[index],
+                                        state_.previous_continue_block_->id());
+    context_->ReplaceAllUsesWith(inductions[index]->result_id(), trip_step_id);
+    invalidated_instructions_.push_back(inductions[index]);
+  }
 }
 
 // Fully unroll the loop by partially unrolling it by the number of loop
@@ -476,6 +576,9 @@ void LoopUnrollerUtilsImpl::FullyUnroll(ir::Loop* loop) {
   // Add the blocks to the function.
   AddBlocksToFunction(loop->GetMergeBlock());
 
+  ReplaceInductionUseWithFinalValue(loop);
+
+  RemoveDeadInstructions();
   // Invalidate all analyses.
   context_->InvalidateAnalysesExceptFor(
       ir::IRContext::Analysis::kAnalysisLoopAnalysis);
@@ -490,7 +593,6 @@ void LoopUnrollerUtilsImpl::CopyBasicBlock(ir::Loop* loop,
                                            bool preserve_instructions) {
   // Clone the block exactly, including the IDs.
   ir::BasicBlock* basic_block = itr->Clone(context_);
-
   basic_block->SetParent(itr->GetParent());
 
   // Assign each result a new unique ID and keep a mapping of the old ids to
@@ -515,7 +617,7 @@ void LoopUnrollerUtilsImpl::CopyBasicBlock(ir::Loop* loop,
     if (!preserve_instructions) {
       // Remove the loop merge instruction if it exists.
       ir::Instruction* merge_inst = basic_block->GetLoopMergeInst();
-      if (merge_inst) context_->KillInst(merge_inst);
+      if (merge_inst) invalidated_instructions_.push_back(merge_inst);
     }
   }
 
@@ -551,10 +653,26 @@ void LoopUnrollerUtilsImpl::CopyBody(ir::Loop* loop,
   ir::Instruction& new_continue_branch = *state_.new_continue_block->tail();
   new_continue_branch.SetInOperand(0, {loop->GetHeaderBlock()->id()});
 
-  // Update references to the old phi node with the actual variable.
-  const ir::Instruction* induction = loop_induction_variable_;
-  state_.new_inst[induction->result_id()] =
-      GetPhiDefID(state_.previous_phi_, state_.previous_continue_block_->id());
+  std::vector<ir::Instruction*> inductions;
+  loop->GetInductionVariables(inductions);
+  for (size_t index = 0; index < inductions.size(); ++index) {
+    ir::Instruction* master_copy = inductions[index];
+
+    assert(master_copy->result_id() != 0);
+    ir::Instruction* induction_clone =
+        state_.ids_to_new_inst[state_.new_inst[master_copy->result_id()]];
+
+    state_.new_phis_.push_back(induction_clone);
+    assert(induction_clone->result_id() != 0);
+
+    if (!state_.previous_phis_.empty()) {
+      state_.new_inst[master_copy->result_id()] = GetPhiDefID(
+          state_.previous_phis_[index], state_.previous_continue_block_->id());
+    } else {
+      // Do not replace the first phi block ids.
+      state_.new_inst[master_copy->result_id()] = master_copy->result_id();
+    }
+  }
 
   if (eliminate_conditions &&
       state_.new_condition_block != loop_condition_block_) {
@@ -569,7 +687,8 @@ void LoopUnrollerUtilsImpl::CopyBody(ir::Loop* loop,
     RemapOperands(pair.second);
   }
 
-  dead_instructions_.push_back(state_.new_phi);
+  for (ir::Instruction* dead_phi : state_.new_phis_)
+    invalidated_instructions_.push_back(dead_phi);
 
   // Swap the state so the new is now the previous.
   state_.NextIterationState();
@@ -582,7 +701,7 @@ uint32_t LoopUnrollerUtilsImpl::GetPhiDefID(const ir::Instruction* phi,
       return phi->GetSingleWordOperand(operand - 1);
     }
   }
-
+  assert(false && "Could not find a phi index matching the provided label");
   return 0;
 }
 
@@ -591,8 +710,8 @@ void LoopUnrollerUtilsImpl::FoldConditionBlock(ir::BasicBlock* condition_block,
   // Remove the old conditional branch to the merge and continue blocks.
   ir::Instruction& old_branch = *condition_block->tail();
   uint32_t new_target = old_branch.GetSingleWordOperand(operand_label);
-  context_->KillInst(&old_branch);
 
+  context_->KillInst(&old_branch);
   // Add the new unconditional branch to the merge block.
   InstructionBuilder builder{context_, condition_block};
   builder.AddBranch(new_target);
@@ -601,23 +720,35 @@ void LoopUnrollerUtilsImpl::FoldConditionBlock(ir::BasicBlock* condition_block,
 void LoopUnrollerUtilsImpl::CloseUnrolledLoop(ir::Loop* loop) {
   // Remove the OpLoopMerge instruction from the function.
   ir::Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst();
-  context_->KillInst(merge_inst);
+  invalidated_instructions_.push_back(merge_inst);
 
   // Remove the final backedge to the header and make it point instead to the
   // merge block.
   state_.previous_continue_block_->tail()->SetInOperand(
       0, {loop->GetMergeBlock()->id()});
 
-  // Remove the induction variable as the phi will now be invalid. Replace all
-  // uses with the constant initializer value (all uses of the phi will be in
-  // the first iteration with the subsequent phis already having been removed.
-  uint32_t initalizer_id =
-      GetPhiDefID(loop_induction_variable_, loop->GetPreHeaderBlock()->id());
-  context_->ReplaceAllUsesWith(loop_induction_variable_->result_id(),
-                               initalizer_id);
+  // Remove all induction variables as the phis will now be invalid. Replace all
+  // uses with the constant initializer value (all uses of phis will be in
+  // the first iteration with the subsequent phis already having been removed).
+  std::vector<ir::Instruction*> inductions;
+  loop->GetInductionVariables(inductions);
+
+  // We can use the state instruction mechanism to replace all internal loop
+  // values within the first loop trip (as the subsequent ones will be updated
+  // by the copy function) with the value coming in from the preheader and then
+  // use context ReplaceAllUsesWith for the uses outside the loop with the final
+  // trip phi value.
+  state_.new_inst.clear();
+  for (ir::Instruction* induction : inductions) {
+    uint32_t initalizer_id =
+        GetPhiDefID(induction, loop->GetPreHeaderBlock()->id());
+
+    state_.new_inst[induction->result_id()] = initalizer_id;
+  }
 
-  // Remove the now unused phi.
-  context_->KillInst(loop_induction_variable_);
+  for (ir::BasicBlock* block : loop_blocks_inorder_) {
+    RemapOperands(block);
+  }
 }
 
 // Uses the first loop to create a copy of the loop with new IDs.
@@ -631,10 +762,14 @@ void LoopUnrollerUtilsImpl::DuplicateLoop(ir::Loop* old_loop,
     new_block_order.push_back(blocks_to_add_.back().get());
   }
 
+  // Clone the merge block, give it a new id and record it in the state.
   ir::BasicBlock* new_merge = old_loop->GetMergeBlock()->Clone(context_);
   new_merge->SetParent(old_loop->GetMergeBlock()->GetParent());
   AssignNewResultIds(new_merge);
   state_.new_blocks[old_loop->GetMergeBlock()->id()] = new_merge;
+
+  // Remap the operands of every instruction in the loop to point to the new
+  // copies.
   for (auto& pair : state_.new_blocks) {
     RemapOperands(pair.second);
   }
@@ -648,12 +783,11 @@ void LoopUnrollerUtilsImpl::DuplicateLoop(ir::Loop* old_loop,
   new_loop->SetMergeBlock(new_merge);
 }
 
+// Whenever the utility copies a block it stores it in a tempory buffer, this
+// function adds the buffer into the ir::Function. The blocks will be inserted
+// after the block |insert_point|.
 void LoopUnrollerUtilsImpl::AddBlocksToFunction(
     const ir::BasicBlock* insert_point) {
-  for (ir::Instruction* inst : dead_instructions_) {
-    context_->KillInst(inst);
-  }
-
   for (auto basic_block_iterator = function_.begin();
        basic_block_iterator != function_.end(); ++basic_block_iterator) {
     if (basic_block_iterator->id() == insert_point->id()) {
@@ -691,12 +825,12 @@ void LoopUnrollerUtilsImpl::AssignNewResultIds(ir::BasicBlock* basic_block) {
 
     // Save the mapping of old_id -> new_id.
     state_.new_inst[old_id] = inst.result_id();
-
     // Check if this instruction is the induction variable.
     if (loop_induction_variable_->result_id() == old_id) {
       // Save a pointer to the new copy of it.
       state_.new_phi = &inst;
     }
+    state_.ids_to_new_inst[inst.result_id()] = &inst;
   }
 }
 
@@ -706,6 +840,7 @@ void LoopUnrollerUtilsImpl::RemapOperands(ir::BasicBlock* basic_block) {
   for (ir::Instruction& inst : *basic_block) {
     auto remap_operands_to_new_ids = [this](uint32_t* id) {
       auto itr = state_.new_inst.find(*id);
+
       if (itr != state_.new_inst.end()) {
         *id = itr->second;
       }
@@ -719,26 +854,7 @@ void LoopUnrollerUtilsImpl::RemapOperands(ir::BasicBlock* basic_block) {
 // later use.
 void LoopUnrollerUtilsImpl::ComputeLoopOrderedBlocks(ir::Loop* loop) {
   loop_blocks_inorder_.clear();
-
-  opt::DominatorAnalysis* analysis =
-      context_->GetDominatorAnalysis(&function_, *context_->cfg());
-  opt::DominatorTree& tree = analysis->GetDomTree();
-
-  // Starting at the loop header BasicBlock, traverse the dominator tree until
-  // we reach the merge block and add every node we traverse to the set of
-  // blocks
-  // which we consider to be the loop.
-  auto begin_itr = tree.GetTreeNode(loop->GetHeaderBlock())->df_begin();
-  const ir::BasicBlock* merge = loop->GetMergeBlock();
-  auto func = [merge, &tree, this](DominatorTreeNode* node) {
-    if (!tree.Dominates(merge->id(), node->id())) {
-      this->loop_blocks_inorder_.push_back(node->bb_);
-      return true;
-    }
-    return false;
-  };
-
-  tree.VisitChildrenIf(func, begin_itr);
+  loop->ComputeLoopStructuredOrder(&loop_blocks_inorder_);
 }
 
 // Adds the blocks_to_add_ to both the loop and to the parent.
@@ -752,6 +868,35 @@ void LoopUnrollerUtilsImpl::AddBlocksToLoop(ir::Loop* loop) const {
   if (loop->GetParent()) AddBlocksToLoop(loop->GetParent());
 }
 
+void LoopUnrollerUtilsImpl::LinkLastPhisToStart(ir::Loop* loop) const {
+  std::vector<ir::Instruction*> inductions;
+  loop->GetInductionVariables(inductions);
+
+  for (size_t i = 0; i < inductions.size(); ++i) {
+    ir::Instruction* last_phi_in_block = state_.previous_phis_[i];
+
+    uint32_t phi_index = GetPhiIndexFromLabel(state_.previous_continue_block_,
+                                              last_phi_in_block);
+    uint32_t phi_variable =
+        last_phi_in_block->GetSingleWordInOperand(phi_index - 1);
+    uint32_t phi_label = last_phi_in_block->GetSingleWordInOperand(phi_index);
+
+    ir::Instruction* phi = inductions[i];
+    phi->SetInOperand(phi_index - 1, {phi_variable});
+    phi->SetInOperand(phi_index, {phi_label});
+  }
+}
+
+// Duplicate the |loop| body |factor| number of times while keeping the loop
+// backedge intact.
+void LoopUnrollerUtilsImpl::PartiallyUnroll(ir::Loop* loop, size_t factor) {
+  Unroll(loop, factor);
+  LinkLastPhisToStart(loop);
+  AddBlocksToLoop(loop);
+  AddBlocksToFunction(loop->GetMergeBlock());
+  RemoveDeadInstructions();
+}
+
 /*
  * End LoopUtilsImpl.
  */
@@ -775,7 +920,7 @@ bool LoopUtils::CanPerformUnroll() {
   if (!condition) return false;
 
   // Check that we can find and process the induction variable.
-  const ir::Instruction* induction = loop_->FindInductionVariable(condition);
+  const ir::Instruction* induction = loop_->FindConditionVariable(condition);
   if (!induction || induction->opcode() != SpvOpPhi) return false;
 
   // Check that we can find the number of loop iterations.
@@ -792,15 +937,8 @@ bool LoopUtils::CanPerformUnroll() {
     return false;
   }
 
-  // Make sure the induction is the only phi instruction we have in the loop
-  // header. Other optimizations have been seen to leave dead phi nodes in the
-  // header so we also check that the phi is used.
-  for (const ir::Instruction& inst : *loop_->GetHeaderBlock()) {
-    if (inst.opcode() == SpvOpPhi &&
-        inst.result_id() != induction->result_id()) {
-      return false;
-    }
-  }
+  std::vector<ir::Instruction*> inductions;
+  loop_->GetInductionVariables(inductions);
 
   // Ban breaks within the loop.
   const std::vector<uint32_t>& merge_block_preds =
@@ -832,33 +970,6 @@ bool LoopUtils::CanPerformUnroll() {
     return false;
   }
 
-  for (uint32_t block_id : loop_->GetBlocks()) {
-    opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
-
-    ir::BasicBlock& bb = *context_->cfg()->block(block_id);
-    // For every instruction in the block.
-    for (ir::Instruction& inst : bb) {
-      if (inst.result_id() == 0) continue;
-
-      auto is_used_outside_loop = [this,
-                                   def_use_manager](ir::Instruction* user) {
-
-        if (!loop_->IsInsideLoop(user)) {
-          // Some optimization passes have been seen to leave dead phis in the
-          // IR so we check that if a phi is used outside of the loop that the
-          // user is not dead.
-          if (!(user->opcode() == SpvOpPhi &&
-                def_use_manager->NumUsers(user) == 0))
-            return false;
-        }
-        return true;
-      };
-
-      if (!def_use_manager->WhileEachUser(&inst, is_used_outside_loop)) {
-        return false;
-      }
-    }
-  }
   return true;
 }
 
@@ -893,6 +1004,9 @@ bool LoopUtils::PartiallyUnroll(size_t factor) {
 bool LoopUtils::FullyUnroll() {
   if (!CanPerformUnroll()) return false;
 
+  std::vector<ir::Instruction*> inductions;
+  loop_->GetInductionVariables(inductions);
+
   LoopUnrollerUtilsImpl unroller{context_,
                                  loop_->GetHeaderBlock()->GetParent()};
 
@@ -926,7 +1040,11 @@ Pass::Status LoopUnroller::Process(ir::IRContext* c) {
         continue;
       }
 
-      loop_utils.FullyUnroll();
+      if (fully_unroll_) {
+        loop_utils.FullyUnroll();
+      } else {
+        loop_utils.PartiallyUnroll(unroll_factor_);
+      }
       changed = true;
     }
     LD->PostModificationCleanup();
index 38dfa34..caf0a8e 100644 (file)
@@ -21,7 +21,9 @@ namespace opt {
 
 class LoopUnroller : public Pass {
  public:
-  LoopUnroller() : Pass() {}
+  LoopUnroller() : Pass(), fully_unroll_(true), unroll_factor_(0) {}
+  LoopUnroller(bool fully_unroll, int unroll_factor)
+      : Pass(), fully_unroll_(fully_unroll), unroll_factor_(unroll_factor) {}
 
   const char* name() const override { return "Loop unroller"; }
 
@@ -29,6 +31,8 @@ class LoopUnroller : public Pass {
 
  private:
   ir::IRContext* context_;
+  bool fully_unroll_;
+  int unroll_factor_;
 };
 
 }  // namespace opt
index 3eb20de..89e6936 100644 (file)
@@ -86,7 +86,7 @@ class LoopUtils {
   //
   // The conditions checked to ensure the loop can be unrolled are as follows:
   // 1. That the loop is in structured order.
-  // 2. That the condinue block is a branch to the header.
+  // 2. That the continue block is a branch to the header.
   // 3. That the only phi used in the loop is the induction variable.
   //  TODO(stephen@codeplay.com): This is a temporary mesure, after the loop is
   //  converted into LCSAA form and has a single entry and exit we can rewrite
index 0478d3a..dced5db 100644 (file)
@@ -400,8 +400,8 @@ Optimizer::PassToken CreateSimplificationPass() {
       MakeUnique<opt::SimplificationPass>());
 }
 
-Optimizer::PassToken CreateLoopFullyUnrollPass() {
+Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor) {
   return MakeUnique<Optimizer::PassToken::Impl>(
-      MakeUnique<opt::LoopUnroller>());
+      MakeUnique<opt::LoopUnroller>(fully_unroll, factor));
 }
 }  // namespace spvtools
index e3ff1ee..3a991c7 100644 (file)
@@ -61,7 +61,7 @@ Generated from the following GLSL
 #version 410 core
 layout(location = 0) flat in int in_upper_bound;
 void main() {
-  for (int i = ; i < in_upper_bound; ++i) {
+  for (int i = 0; i < in_upper_bound; ++i) {
     x[i] = 1.0f;
   }
 }
@@ -102,7 +102,7 @@ OpDecorate %3 Location 0
 OpBranch %21
 %21 = OpLabel
 %22 = OpPhi %7 %9 %20 %23 %24
-OpLoopMerge %25 %24 None
+OpLoopMerge %25 %24 Unroll
 OpBranch %26
 %26 = OpLabel
 %27 = OpLoad %7 %3
@@ -141,88 +141,6 @@ Generated from the following GLSL
 #version 410 core
 void main() {
     float out_array[10];
-    int i = 0;
-    for (int i = 0; i < 10; ++i) {
-        out_array[i] = i;
-    }
-    out_array[9] = i*10;
-}
-*/
-TEST_F(PassClassTest, InductionUsedOutsideOfLoop) {
-  // clang-format off
-  // With opt::LocalMultiStoreElimPass
-  const std::string text = R"(OpCapability Shader
-%1 = OpExtInstImport "GLSL.std.450"
-OpMemoryModel Logical GLSL450
-OpEntryPoint Fragment %2 "main"
-OpExecutionMode %2 OriginUpperLeft
-OpSource GLSL 410
-OpName %2 "main"
-OpName %3 "out_array"
-%4 = OpTypeVoid
-%5 = OpTypeFunction %4
-%6 = OpTypeInt 32 1
-%7 = OpTypePointer Function %6
-%8 = OpConstant %6 0
-%9 = OpConstant %6 10
-%10 = OpTypeBool
-%11 = OpTypeFloat 32
-%12 = OpTypeInt 32 0
-%13 = OpConstant %12 10
-%14 = OpTypeArray %11 %13
-%15 = OpTypePointer Function %14
-%16 = OpTypePointer Function %11
-%17 = OpConstant %6 1
-%18 = OpConstant %6 9
-%2 = OpFunction %4 None %5
-%19 = OpLabel
-%3 = OpVariable %15 Function
-OpBranch %20
-%20 = OpLabel
-%21 = OpPhi %6 %8 %19 %22 %23
-OpLoopMerge %24 %23 Unroll
-OpBranch %25
-%25 = OpLabel
-%26 = OpSLessThan %10 %21 %9
-OpBranchConditional %26 %27 %24
-%27 = OpLabel
-%28 = OpConvertSToF %11 %21
-%29 = OpAccessChain %16 %3 %21
-OpStore %29 %28
-OpBranch %23
-%23 = OpLabel
-%22 = OpIAdd %6 %21 %17
-OpBranch %20
-%24 = OpLabel
-%30 = OpIMul %6 %21 %9
-%31 = OpConvertSToF %11 %30
-%32 = OpAccessChain %16 %3 %18
-OpStore %32 %31
-OpReturn
-OpFunctionEnd
-)";
-  // clang-format on
-  std::unique_ptr<ir::IRContext> context =
-      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
-                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  ir::Module* module = context->module();
-  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
-                             << text << std::endl;
-
-  opt::LoopUnroller loop_unroller;
-  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
-
-  // Make sure the pass doesn't run
-  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
-  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
-  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
-}
-
-/*
-Generated from the following GLSL
-#version 410 core
-void main() {
-    float out_array[10];
     for (uint i = 0; i < 2; i++) {
       for (float x = 0; x < 5; ++x) {
         out_array[x + i*5] = i;
@@ -267,7 +185,7 @@ OpBranch %24
 %24 = OpLabel
 %25 = OpPhi %6 %8 %23 %26 %27
 %28 = OpPhi %11 %22 %23 %29 %27
-OpLoopMerge %30 %27 None
+OpLoopMerge %30 %27 Unroll
 OpBranch %31
 %31 = OpLabel
 %32 = OpULessThan %10 %25 %9
@@ -314,81 +232,6 @@ OpFunctionEnd
   SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
 }
 
-
-/*
-Generated from the following GLSL
-#version 440 core
-void main(){
-  float x[10];
-  int ind = 0;
-  for (int i = 0; i < 10; i++) {
-    ind = i;
-    x[i] = i;
-  }
-}
-*/
-TEST_F(PassClassTest, MultiplePhiInHeader) {
-  // clang-format off
-  // With opt::LocalMultiStoreElimPass
-const std::string text = R"(OpCapability Shader
-%1 = OpExtInstImport "GLSL.std.450"
-OpMemoryModel Logical GLSL450
-OpEntryPoint Fragment %2 "main"
-OpExecutionMode %2 OriginUpperLeft
-OpSource GLSL 440
-OpName %2 "main"
-OpName %3 "x"
-%4 = OpTypeVoid
-%5 = OpTypeFunction %4
-%6 = OpTypeInt 32 1
-%7 = OpTypePointer Function %6
-%8 = OpConstant %6 0
-%9 = OpConstant %6 10
-%10 = OpTypeBool
-%11 = OpTypeFloat 32
-%12 = OpTypeInt 32 0
-%13 = OpConstant %12 10
-%14 = OpTypeArray %11 %13
-%15 = OpTypePointer Function %14
-%16 = OpTypePointer Function %11
-%17 = OpConstant %6 1
-%2 = OpFunction %4 None %5
-%18 = OpLabel
-%3 = OpVariable %15 Function
-OpBranch %19
-%19 = OpLabel
-%20 = OpPhi %6 %8 %18 %21 %22
-%21 = OpPhi %6 %8 %18 %23 %22
-OpLoopMerge %24 %22 None
-OpBranch %25
-%25 = OpLabel
-%26 = OpSLessThan %10 %21 %9
-OpBranchConditional %26 %27 %24
-%27 = OpLabel
-%28 = OpConvertSToF %11 %21
-%29 = OpAccessChain %16 %3 %21
-OpStore %29 %28
-OpBranch %22
-%22 = OpLabel
-%23 = OpIAdd %6 %21 %17
-OpBranch %19
-%24 = OpLabel
-OpReturn
-OpFunctionEnd
-)";
-  // clang-format on
-  std::unique_ptr<ir::IRContext> context =
-      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
-                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  ir::Module* module = context->module();
-  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
-                             << text << std::endl;
-
-  opt::LoopUnroller loop_unroller;
-  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
-  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
-}
-
 /*
 Generated from the following GLSL
 #version 440 core
@@ -624,4 +467,980 @@ OpFunctionEnd
   SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
 }
 
+/*
+Generated from the following GLSL
+#version 440 core
+void main() {
+  int j = 0;
+  for (int i = 0; i < 10 && i > 0; i++) {
+    j++;
+  }
+}
+*/
+TEST_F(PassClassTest, MultipleConditionsSingleVariable) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+%3 = OpTypeVoid
+%4 = OpTypeFunction %3
+%5 = OpTypeInt 32 1
+%6 = OpTypePointer Function %5
+%7 = OpConstant %5 0
+%8 = OpConstant %5 10
+%9 = OpTypeBool
+%10 = OpConstant %5 1
+%2 = OpFunction %3 None %4
+%11 = OpLabel
+OpBranch %12
+%12 = OpLabel
+%13 = OpPhi %5 %7 %11 %14 %15
+%16 = OpPhi %5 %7 %11 %17 %15
+OpLoopMerge %18 %15 Unroll
+OpBranch %19
+%19 = OpLabel
+%20 = OpSLessThan %9 %16 %8
+%21 = OpSGreaterThan %9 %16 %7
+%22 = OpLogicalAnd %9 %20 %21
+OpBranchConditional %22 %23 %18
+%23 = OpLabel
+%14 = OpIAdd %5 %13 %10
+OpBranch %15
+%15 = OpLabel
+%17 = OpIAdd %5 %16 %10
+OpBranch %12
+%18 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main() {
+  int i = 0;
+  int j = 0;
+  int k = 0;
+  for (; i < 10 && j > 0; i++, j++) {
+    k++;
+  }
+}
+*/
+TEST_F(PassClassTest, MultipleConditionsMultipleVariables) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+%3 = OpTypeVoid
+%4 = OpTypeFunction %3
+%5 = OpTypeInt 32 1
+%6 = OpTypePointer Function %5
+%7 = OpConstant %5 0
+%8 = OpConstant %5 10
+%9 = OpTypeBool
+%10 = OpConstant %5 1
+%2 = OpFunction %3 None %4
+%11 = OpLabel
+OpBranch %12
+%12 = OpLabel
+%13 = OpPhi %5 %7 %11 %14 %15
+%16 = OpPhi %5 %7 %11 %17 %15
+%18 = OpPhi %5 %7 %11 %19 %15
+OpLoopMerge %20 %15 Unroll
+OpBranch %21
+%21 = OpLabel
+%22 = OpSLessThan %9 %13 %8
+%23 = OpSGreaterThan %9 %16 %7
+%24 = OpLogicalAnd %9 %22 %23
+OpBranchConditional %24 %25 %20
+%25 = OpLabel
+%19 = OpIAdd %5 %18 %10
+OpBranch %15
+%15 = OpLabel
+%14 = OpIAdd %5 %13 %10
+%17 = OpIAdd %5 %16 %10
+OpBranch %12
+%20 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main() {
+  float i = 0.0;
+  int j = 0;
+  for (; i < 10; i++) {
+    j++;
+  }
+}
+*/
+TEST_F(PassClassTest, FloatingPointLoop) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+%3 = OpTypeVoid
+%4 = OpTypeFunction %3
+%5 = OpTypeFloat 32
+%6 = OpTypePointer Function %5
+%7 = OpConstant %5 0
+%8 = OpTypeInt 32 1
+%9 = OpTypePointer Function %8
+%10 = OpConstant %8 0
+%11 = OpConstant %5 10
+%12 = OpTypeBool
+%13 = OpConstant %8 1
+%14 = OpConstant %5 1
+%2 = OpFunction %3 None %4
+%15 = OpLabel
+OpBranch %16
+%16 = OpLabel
+%17 = OpPhi %5 %7 %15 %18 %19
+%20 = OpPhi %8 %10 %15 %21 %19
+OpLoopMerge %22 %19 Unroll
+OpBranch %23
+%23 = OpLabel
+%24 = OpFOrdLessThan %12 %17 %11
+OpBranchConditional %24 %25 %22
+%25 = OpLabel
+%21 = OpIAdd %8 %20 %13
+OpBranch %19
+%19 = OpLabel
+%18 = OpFAdd %5 %17 %14
+OpBranch %16
+%22 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main() {
+  int i = 2;
+  int j = 0;
+  if (j == 0) { i = 5; }
+  for (; i < 3; ++i) {
+    j++;
+  }
+}
+*/
+TEST_F(PassClassTest, InductionPhiOutsideLoop) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+%3 = OpTypeVoid
+%4 = OpTypeFunction %3
+%5 = OpTypeInt 32 1
+%6 = OpTypePointer Function %5
+%7 = OpConstant %5 2
+%8 = OpConstant %5 0
+%9 = OpTypeBool
+%10 = OpConstant %5 5
+%11 = OpConstant %5 3
+%12 = OpConstant %5 1
+%2 = OpFunction %3 None %4
+%13 = OpLabel
+%14 = OpIEqual %9 %8 %8
+OpSelectionMerge %15 None
+OpBranchConditional %14 %16 %15
+%16 = OpLabel
+OpBranch %15
+%15 = OpLabel
+%17 = OpPhi %5 %7 %13 %10 %16
+OpBranch %18
+%18 = OpLabel
+%19 = OpPhi %5 %17 %15 %20 %21
+%22 = OpPhi %5 %8 %15 %23 %21
+OpLoopMerge %24 %21 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %9 %19 %11
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%23 = OpIAdd %5 %22 %12
+OpBranch %21
+%21 = OpLabel
+%20 = OpIAdd %5 %19 %12
+OpBranch %18
+%24 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main() {
+  int j = 0;
+  for (int i = 0; i == 0; ++i) {
+    ++j;
+  }
+  for (int i = 0; i != 3; ++i) {
+    ++j;
+  }
+  for (int i = 0; i < 3; i *= 2) {
+    ++j;
+  }
+  for (int i = 10; i > 3; i /= 2) {
+    ++j;
+  }
+  for (int i = 10; i > 3; i |= 2) {
+    ++j;
+  }
+  for (int i = 10; i > 3; i &= 2) {
+    ++j;
+  }
+  for (int i = 10; i > 3; i ^= 2) {
+    ++j;
+  }
+  for (int i = 0; i < 3; i << 2) {
+    ++j;
+  }
+  for (int i = 10; i > 3; i >> 2) {
+    ++j;
+  }
+}
+*/
+TEST_F(PassClassTest, UnsupportedLoopTypes) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+%3 = OpTypeVoid
+%4 = OpTypeFunction %3
+%5 = OpTypeInt 32 1
+%6 = OpTypePointer Function %5
+%7 = OpConstant %5 0
+%8 = OpTypeBool
+%9 = OpConstant %5 1
+%10 = OpConstant %5 3
+%11 = OpConstant %5 2
+%12 = OpConstant %5 10
+%2 = OpFunction %3 None %4
+%13 = OpLabel
+OpBranch %14
+%14 = OpLabel
+%15 = OpPhi %5 %7 %13 %16 %17
+%18 = OpPhi %5 %7 %13 %19 %17
+OpLoopMerge %20 %17 Unroll
+OpBranch %21
+%21 = OpLabel
+%22 = OpIEqual %8 %18 %7
+OpBranchConditional %22 %23 %20
+%23 = OpLabel
+%16 = OpIAdd %5 %15 %9
+OpBranch %17
+%17 = OpLabel
+%19 = OpIAdd %5 %18 %9
+OpBranch %14
+%20 = OpLabel
+OpBranch %24
+%24 = OpLabel
+%25 = OpPhi %5 %15 %20 %26 %27
+%28 = OpPhi %5 %7 %20 %29 %27
+OpLoopMerge %30 %27 Unroll
+OpBranch %31
+%31 = OpLabel
+%32 = OpINotEqual %8 %28 %10
+OpBranchConditional %32 %33 %30
+%33 = OpLabel
+%26 = OpIAdd %5 %25 %9
+OpBranch %27
+%27 = OpLabel
+%29 = OpIAdd %5 %28 %9
+OpBranch %24
+%30 = OpLabel
+OpBranch %34
+%34 = OpLabel
+%35 = OpPhi %5 %25 %30 %36 %37
+%38 = OpPhi %5 %7 %30 %39 %37
+OpLoopMerge %40 %37 Unroll
+OpBranch %41
+%41 = OpLabel
+%42 = OpSLessThan %8 %38 %10
+OpBranchConditional %42 %43 %40
+%43 = OpLabel
+%36 = OpIAdd %5 %35 %9
+OpBranch %37
+%37 = OpLabel
+%39 = OpIMul %5 %38 %11
+OpBranch %34
+%40 = OpLabel
+OpBranch %44
+%44 = OpLabel
+%45 = OpPhi %5 %35 %40 %46 %47
+%48 = OpPhi %5 %12 %40 %49 %47
+OpLoopMerge %50 %47 Unroll
+OpBranch %51
+%51 = OpLabel
+%52 = OpSGreaterThan %8 %48 %10
+OpBranchConditional %52 %53 %50
+%53 = OpLabel
+%46 = OpIAdd %5 %45 %9
+OpBranch %47
+%47 = OpLabel
+%49 = OpSDiv %5 %48 %11
+OpBranch %44
+%50 = OpLabel
+OpBranch %54
+%54 = OpLabel
+%55 = OpPhi %5 %45 %50 %56 %57
+%58 = OpPhi %5 %12 %50 %59 %57
+OpLoopMerge %60 %57 Unroll
+OpBranch %61
+%61 = OpLabel
+%62 = OpSGreaterThan %8 %58 %10
+OpBranchConditional %62 %63 %60
+%63 = OpLabel
+%56 = OpIAdd %5 %55 %9
+OpBranch %57
+%57 = OpLabel
+%59 = OpBitwiseOr %5 %58 %11
+OpBranch %54
+%60 = OpLabel
+OpBranch %64
+%64 = OpLabel
+%65 = OpPhi %5 %55 %60 %66 %67
+%68 = OpPhi %5 %12 %60 %69 %67
+OpLoopMerge %70 %67 Unroll
+OpBranch %71
+%71 = OpLabel
+%72 = OpSGreaterThan %8 %68 %10
+OpBranchConditional %72 %73 %70
+%73 = OpLabel
+%66 = OpIAdd %5 %65 %9
+OpBranch %67
+%67 = OpLabel
+%69 = OpBitwiseAnd %5 %68 %11
+OpBranch %64
+%70 = OpLabel
+OpBranch %74
+%74 = OpLabel
+%75 = OpPhi %5 %65 %70 %76 %77
+%78 = OpPhi %5 %12 %70 %79 %77
+OpLoopMerge %80 %77 Unroll
+OpBranch %81
+%81 = OpLabel
+%82 = OpSGreaterThan %8 %78 %10
+OpBranchConditional %82 %83 %80
+%83 = OpLabel
+%76 = OpIAdd %5 %75 %9
+OpBranch %77
+%77 = OpLabel
+%79 = OpBitwiseXor %5 %78 %11
+OpBranch %74
+%80 = OpLabel
+OpBranch %84
+%84 = OpLabel
+%85 = OpPhi %5 %75 %80 %86 %87
+OpLoopMerge %88 %87 Unroll
+OpBranch %89
+%89 = OpLabel
+%90 = OpSLessThan %8 %7 %10
+OpBranchConditional %90 %91 %88
+%91 = OpLabel
+%86 = OpIAdd %5 %85 %9
+OpBranch %87
+%87 = OpLabel
+%92 = OpShiftLeftLogical %5 %7 %11
+OpBranch %84
+%88 = OpLabel
+OpBranch %93
+%93 = OpLabel
+%94 = OpPhi %5 %85 %88 %95 %96
+OpLoopMerge %97 %96 Unroll
+OpBranch %98
+%98 = OpLabel
+%99 = OpSGreaterThan %8 %12 %10
+OpBranchConditional %99 %100 %97
+%100 = OpLabel
+%95 = OpIAdd %5 %94 %9
+OpBranch %96
+%96 = OpLabel
+%101 = OpShiftRightArithmetic %5 %12 %11
+OpBranch %93
+%97 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+#version 430
+
+layout(location = 0) out float o;
+
+void main(void) {
+    for (int j = 2; j < 0; j += 1) {
+      o += 1.0;
+    }
+}
+*/
+TEST_F(PassClassTest, NegativeNumberOfIterations) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 430
+OpName %2 "main"
+OpName %3 "o"
+OpDecorate %3 Location 0
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 2
+%9 = OpConstant %6 0
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypePointer Output %11
+%3 = OpVariable %12 Output
+%13 = OpConstant %11 1
+%14 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%15 = OpLabel
+OpBranch %16
+%16 = OpLabel
+%17 = OpPhi %6 %8 %15 %18 %19
+OpLoopMerge %20 %19 None
+OpBranch %21
+%21 = OpLabel
+%22 = OpSLessThan %10 %17 %9
+OpBranchConditional %22 %23 %20
+%23 = OpLabel
+%24 = OpLoad %11 %3
+%25 = OpFAdd %11 %24 %13
+OpStore %3 %25
+OpBranch %19
+%19 = OpLabel
+%18 = OpIAdd %6 %17 %14
+OpBranch %16
+%20 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+#version 430
+
+layout(location = 0) out float o;
+
+void main(void) {
+  float s = 0.0;
+  for (int j = 0; j < 3; j += 1) {
+    s += 1.0;
+    j += 1;
+  }
+  o = s;
+}
+*/
+TEST_F(PassClassTest, MultipleStepOperations) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 430
+OpName %2 "main"
+OpName %3 "o"
+OpDecorate %3 Location 0
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeFloat 32
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpTypeInt 32 1
+%10 = OpTypePointer Function %9
+%11 = OpConstant %9 0
+%12 = OpConstant %9 3
+%13 = OpTypeBool
+%14 = OpConstant %6 1
+%15 = OpConstant %9 1
+%16 = OpTypePointer Output %6
+%3 = OpVariable %16 Output
+%2 = OpFunction %4 None %5
+%17 = OpLabel
+OpBranch %18
+%18 = OpLabel
+%19 = OpPhi %6 %8 %17 %20 %21
+%22 = OpPhi %9 %11 %17 %23 %21
+OpLoopMerge %24 %21 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %13 %22 %12
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%20 = OpFAdd %6 %19 %14
+%28 = OpIAdd %9 %22 %15
+OpBranch %21
+%21 = OpLabel
+%23 = OpIAdd %9 %28 %15
+OpBranch %18
+%24 = OpLabel
+OpStore %3 %19
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+#version 430
+
+layout(location = 0) out float o;
+
+void main(void) {
+  float s = 0.0;
+  for (int j = 10; j > 20; j -= 1) {
+    s += 1.0;
+  }
+  o = s;
+}
+*/
+
+TEST_F(PassClassTest, ConditionFalseFromStartGreaterThan) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 430
+OpName %2 "main"
+OpName %3 "o"
+OpDecorate %3 Location 0
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeFloat 32
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpTypeInt 32 1
+%10 = OpTypePointer Function %9
+%11 = OpConstant %9 10
+%12 = OpConstant %9 20
+%13 = OpTypeBool
+%14 = OpConstant %6 1
+%15 = OpConstant %9 1
+%16 = OpTypePointer Output %6
+%3 = OpVariable %16 Output
+%2 = OpFunction %4 None %5
+%17 = OpLabel
+OpBranch %18
+%18 = OpLabel
+%19 = OpPhi %6 %8 %17 %20 %21
+%22 = OpPhi %9 %11 %17 %23 %21
+OpLoopMerge %24 %21 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSGreaterThan %13 %22 %12
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%20 = OpFAdd %6 %19 %14
+OpBranch %21
+%21 = OpLabel
+%23 = OpISub %9 %22 %15
+OpBranch %18
+%24 = OpLabel
+OpStore %3 %19
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+#version 430
+
+layout(location = 0) out float o;
+
+void main(void) {
+  float s = 0.0;
+  for (int j = 10; j >= 20; j -= 1) {
+    s += 1.0;
+  }
+  o = s;
+}
+*/
+TEST_F(PassClassTest, ConditionFalseFromStartGreaterThanOrEqual) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 430
+OpName %2 "main"
+OpName %3 "o"
+OpDecorate %3 Location 0
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeFloat 32
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpTypeInt 32 1
+%10 = OpTypePointer Function %9
+%11 = OpConstant %9 10
+%12 = OpConstant %9 20
+%13 = OpTypeBool
+%14 = OpConstant %6 1
+%15 = OpConstant %9 1
+%16 = OpTypePointer Output %6
+%3 = OpVariable %16 Output
+%2 = OpFunction %4 None %5
+%17 = OpLabel
+OpBranch %18
+%18 = OpLabel
+%19 = OpPhi %6 %8 %17 %20 %21
+%22 = OpPhi %9 %11 %17 %23 %21
+OpLoopMerge %24 %21 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSGreaterThanEqual %13 %22 %12
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%20 = OpFAdd %6 %19 %14
+OpBranch %21
+%21 = OpLabel
+%23 = OpISub %9 %22 %15
+OpBranch %18
+%24 = OpLabel
+OpStore %3 %19
+OpReturn
+OpFunctionEnd
+)";
+
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+#version 430
+
+layout(location = 0) out float o;
+
+void main(void) {
+  float s = 0.0;
+  for (int j = 20; j < 10; j -= 1) {
+    s += 1.0;
+  }
+  o = s;
+}
+*/
+TEST_F(PassClassTest, ConditionFalseFromStartLessThan) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 430
+OpName %2 "main"
+OpName %3 "o"
+OpDecorate %3 Location 0
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeFloat 32
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpTypeInt 32 1
+%10 = OpTypePointer Function %9
+%11 = OpConstant %9 20
+%12 = OpConstant %9 10
+%13 = OpTypeBool
+%14 = OpConstant %6 1
+%15 = OpConstant %9 1
+%16 = OpTypePointer Output %6
+%3 = OpVariable %16 Output
+%2 = OpFunction %4 None %5
+%17 = OpLabel
+OpBranch %18
+%18 = OpLabel
+%19 = OpPhi %6 %8 %17 %20 %21
+%22 = OpPhi %9 %11 %17 %23 %21
+OpLoopMerge %24 %21 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %13 %22 %12
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%20 = OpFAdd %6 %19 %14
+OpBranch %21
+%21 = OpLabel
+%23 = OpISub %9 %22 %15
+OpBranch %18
+%24 = OpLabel
+OpStore %3 %19
+OpReturn
+OpFunctionEnd
+)";
+
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+#version 430
+
+layout(location = 0) out float o;
+
+void main(void) {
+  float s = 0.0;
+  for (int j = 20; j <= 10; j -= 1) {
+    s += 1.0;
+  }
+  o = s;
+}
+*/
+TEST_F(PassClassTest, ConditionFalseFromStartLessThanEqual) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 430
+OpName %2 "main"
+OpName %3 "o"
+OpDecorate %3 Location 0
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeFloat 32
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpTypeInt 32 1
+%10 = OpTypePointer Function %9
+%11 = OpConstant %9 20
+%12 = OpConstant %9 10
+%13 = OpTypeBool
+%14 = OpConstant %6 1
+%15 = OpConstant %9 1
+%16 = OpTypePointer Output %6
+%3 = OpVariable %16 Output
+%2 = OpFunction %4 None %5
+%17 = OpLabel
+OpBranch %18
+%18 = OpLabel
+%19 = OpPhi %6 %8 %17 %20 %21
+%22 = OpPhi %9 %11 %17 %23 %21
+OpLoopMerge %24 %21 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThanEqual %13 %22 %12
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%20 = OpFAdd %6 %19 %14
+OpBranch %21
+%21 = OpLabel
+%23 = OpISub %9 %22 %15
+OpBranch %18
+%24 = OpLabel
+OpStore %3 %19
+OpReturn
+OpFunctionEnd
+)";
+
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+  // Make sure the pass doesn't run
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
 }  // namespace
index 59d2f95..3d8fb67 100644 (file)
@@ -316,7 +316,7 @@ OpDecorate %3 Location 0
 OpBranch %23
 %23 = OpLabel
 %24 = OpPhi %7 %9 %22 %39 %38
-OpLoopMerge %27 %38 Unroll
+OpLoopMerge %27 %38 DontUnroll
 OpBranch %28
 %28 = OpLabel
 %29 = OpSLessThan %11 %24 %10
@@ -474,8 +474,8 @@ OpBranch %23
 %32 = OpLabel
 OpBranch %33
 %33 = OpLabel
-%34 = OpPhi %7 %58 %32 %57 %56
-OpLoopMerge %41 %56 Unroll
+%34 = OpPhi %7 %24 %32 %57 %56
+OpLoopMerge %41 %56 DontUnroll
 OpBranch %35
 %35 = OpLabel
 %36 = OpSLessThan %11 %34 %10
@@ -617,7 +617,7 @@ OpFunctionEnd
   ir::BasicBlock* condition = loop.FindConditionBlock();
   EXPECT_EQ(condition->id(), 24u);
 
-  ir::Instruction* induction = loop.FindInductionVariable(condition);
+  ir::Instruction* induction = loop.FindConditionVariable(condition);
   EXPECT_EQ(induction->result_id(), 34u);
 
   opt::LoopUtils loop_utils{context.get(), &loop};
@@ -714,7 +714,7 @@ OpFunctionEnd
   ir::BasicBlock* condition = loop.FindConditionBlock();
   EXPECT_EQ(condition->id(), 25u);
 
-  ir::Instruction* induction = loop.FindInductionVariable(condition);
+  ir::Instruction* induction = loop.FindConditionVariable(condition);
   EXPECT_EQ(induction->result_id(), 35u);
 
   opt::LoopUtils loop_utils{context.get(), &loop};
@@ -1115,7 +1115,7 @@ OpFunctionEnd
   ir::BasicBlock* condition = loop.FindConditionBlock();
   EXPECT_EQ(condition->id(), 14u);
 
-  ir::Instruction* induction = loop.FindInductionVariable(condition);
+  ir::Instruction* induction = loop.FindConditionVariable(condition);
   EXPECT_EQ(induction->result_id(), 32u);
 
   opt::LoopUtils loop_utils{context.get(), &loop};
@@ -1231,8 +1231,8 @@ OpBranch %18
 %28 = OpLabel
 OpBranch %29
 %29 = OpLabel
-%30 = OpPhi %6 %48 %28 %47 %46
-OpLoopMerge %38 %46 Unroll
+%30 = OpPhi %6 %19 %28 %47 %46
+OpLoopMerge %38 %46 DontUnroll
 OpBranch %31
 %31 = OpLabel
 %32 = OpSLessThan %10 %30 %9
@@ -1264,7 +1264,6 @@ OpReturn
 OpReturn
 OpFunctionEnd
 )";
-  // clang-format on
 
   std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
@@ -1288,7 +1287,7 @@ OpFunctionEnd
   ir::BasicBlock* condition = loop.FindConditionBlock();
   EXPECT_EQ(condition->id(), 14u);
 
-  ir::Instruction* induction = loop.FindInductionVariable(condition);
+  ir::Instruction* induction = loop.FindConditionVariable(condition);
   EXPECT_EQ(induction->result_id(), 32u);
 
   opt::LoopUtils loop_utils{context.get(), &loop};
@@ -1464,137 +1463,13 @@ TEST_F(PassClassTest, UnrollNestedLoopsValidateDescriptor) {
 
     // The number of loops should actually grow.
     EXPECT_EQ(loop_descriptor.NumLoops(), 3u);
-    EXPECT_EQ(outer_loop.GetBlocks().size(), 19u);
+    EXPECT_EQ(outer_loop.GetBlocks().size(), 18u);
     EXPECT_EQ(outer_loop.NumImmediateChildren(), 2u);
   }
 }
 
 /*
 Generated from the following GLSL
-#version 440 core
-void main(){
-  float x[10];
-  int i = 1;
-  i = 0;
-  for (; i < 10; i++) {
-    x[i] = i;
-  }
-}
-*/
-TEST_F(PassClassTest, UnrollWithInductionOutsideHeader) {
-  // clang-format off
-  // With opt::LocalMultiStoreElimPass
-const std::string text = R"(OpCapability Shader
-%1 = OpExtInstImport "GLSL.std.450"
-OpMemoryModel Logical GLSL450
-OpEntryPoint Fragment %main "main"
-OpExecutionMode %main OriginUpperLeft
-OpSource GLSL 440
-OpName %main "main"
-OpName %x "x"
-%void = OpTypeVoid
-%3 = OpTypeFunction %void
-%int = OpTypeInt 32 1
-%_ptr_Function_int = OpTypePointer Function %int
-%int_1 = OpConstant %int 1
-%int_0 = OpConstant %int 0
-%int_10 = OpConstant %int 10
-%bool = OpTypeBool
-%float = OpTypeFloat 32
-%uint = OpTypeInt 32 0
-%uint_10 = OpConstant %uint 10
-%_arr_float_uint_10 = OpTypeArray %float %uint_10
-%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10
-%_ptr_Function_float = OpTypePointer Function %float
-%main = OpFunction %void None %3
-%5 = OpLabel
-%x = OpVariable %_ptr_Function__arr_float_uint_10 Function
-OpBranch %11
-%11 = OpLabel
-%33 = OpPhi %int %int_0 %5 %32 %14
-OpLoopMerge %13 %14 None
-OpBranch %15
-%15 = OpLabel
-%19 = OpSLessThan %bool %33 %int_10
-OpBranchConditional %19 %12 %13
-%12 = OpLabel
-%28 = OpConvertSToF %float %33
-%30 = OpAccessChain %_ptr_Function_float %x %33
-OpStore %30 %28
-OpBranch %14
-%14 = OpLabel
-%32 = OpIAdd %int %33 %int_1
-OpBranch %11
-%13 = OpLabel
-OpReturn
-OpFunctionEnd
-)";
-
-const std::string expected = R"(OpCapability Shader
-%1 = OpExtInstImport "GLSL.std.450"
-OpMemoryModel Logical GLSL450
-OpEntryPoint Fragment %main "main"
-OpExecutionMode %main OriginUpperLeft
-OpSource GLSL 440
-OpName %main "main"
-OpName %x "x"
-%void = OpTypeVoid
-%5 = OpTypeFunction %void
-%int = OpTypeInt 32 1
-%_ptr_Function_int = OpTypePointer Function %int
-%int_1 = OpConstant %int 1
-%int_0 = OpConstant %int 0
-%int_10 = OpConstant %int 10
-%bool = OpTypeBool
-%float = OpTypeFloat 32
-%uint = OpTypeInt 32 0
-%uint_10 = OpConstant %uint 10
-%_arr_float_uint_10 = OpTypeArray %float %uint_10
-%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10
-%_ptr_Function_float = OpTypePointer Function %float
-%main = OpFunction %void None %5
-%18 = OpLabel
-%x = OpVariable %_ptr_Function__arr_float_uint_10 Function
-OpBranch %19
-%19 = OpLabel
-%20 = OpPhi %int %int_0 %18 %37 %36
-OpLoopMerge %23 %36 None
-OpBranch %24
-%24 = OpLabel
-%25 = OpSLessThan %bool %20 %int_10
-OpBranchConditional %25 %26 %23
-%26 = OpLabel
-%27 = OpConvertSToF %float %20
-%28 = OpAccessChain %_ptr_Function_float %x %20
-OpStore %28 %27
-OpBranch %22
-%22 = OpLabel
-%21 = OpIAdd %int %20 %int_1
-OpBranch %29
-%29 = OpLabel
-OpBranch %31
-%31 = OpLabel
-%32 = OpSLessThan %bool %21 %int_10
-OpBranch %33
-%33 = OpLabel
-%34 = OpConvertSToF %float %21
-%35 = OpAccessChain %_ptr_Function_float %x %21
-OpStore %35 %34
-OpBranch %36
-%36 = OpLabel
-%37 = OpIAdd %int %21 %int_1
-OpBranch %19
-%23 = OpLabel
-OpReturn
-OpFunctionEnd
-)";
-  // clang-format on
-
-  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, expected, false);
-}
-
-/*
-Generated from the following GLSL
 #version 410 core
 void main() {
   float out_array[3];
@@ -2176,4 +2051,749 @@ OpFunctionEnd
   SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
 }
 
+// clang-format off
+// With opt::LocalMultiStoreElimPass
+static const std::string multiple_phi_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource GLSL 410
+               OpName %4 "main"
+               OpName %8 "foo("
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypeFunction %6
+         %10 = OpTypePointer Function %6
+         %12 = OpConstant %6 0
+         %14 = OpConstant %6 3
+         %22 = OpConstant %6 6
+         %23 = OpTypeBool
+         %31 = OpConstant %6 1
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+         %40 = OpFunctionCall %6 %8
+               OpReturn
+               OpFunctionEnd
+          %8 = OpFunction %6 None %7
+          %9 = OpLabel
+               OpBranch %16
+         %16 = OpLabel
+         %41 = OpPhi %6 %12 %9 %34 %19
+         %42 = OpPhi %6 %14 %9 %29 %19
+         %43 = OpPhi %6 %12 %9 %32 %19
+               OpLoopMerge %18 %19 Unroll
+               OpBranch %20
+         %20 = OpLabel
+         %24 = OpSLessThan %23 %43 %22
+               OpBranchConditional %24 %17 %18
+         %17 = OpLabel
+         %27 = OpIMul %6 %43 %41
+         %29 = OpIAdd %6 %42 %27
+               OpBranch %19
+         %19 = OpLabel
+         %32 = OpIAdd %6 %43 %31
+         %34 = OpISub %6 %41 %31
+               OpBranch %16
+         %18 = OpLabel
+         %37 = OpIAdd %6 %42 %41
+               OpReturnValue %37
+               OpFunctionEnd
+    )";
+// clang-format on
+
+TEST_F(PassClassTest, PartiallyUnrollResidualMultipleInductionVariables) {
+  // clang-format off
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "foo("
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypeFunction %6
+%8 = OpTypePointer Function %6
+%9 = OpConstant %6 0
+%10 = OpConstant %6 3
+%11 = OpConstant %6 6
+%12 = OpTypeBool
+%13 = OpConstant %6 1
+%82 = OpTypeInt 32 0
+%83 = OpConstant %82 2
+%2 = OpFunction %4 None %5
+%14 = OpLabel
+%15 = OpFunctionCall %6 %3
+OpReturn
+OpFunctionEnd
+%3 = OpFunction %6 None %7
+%16 = OpLabel
+OpBranch %17
+%17 = OpLabel
+%18 = OpPhi %6 %9 %16 %19 %20
+%21 = OpPhi %6 %10 %16 %22 %20
+%23 = OpPhi %6 %9 %16 %24 %20
+OpLoopMerge %31 %20 Unroll
+OpBranch %26
+%26 = OpLabel
+%27 = OpSLessThan %12 %23 %83
+OpBranchConditional %27 %28 %31
+%28 = OpLabel
+%29 = OpIMul %6 %23 %18
+%22 = OpIAdd %6 %21 %29
+OpBranch %20
+%20 = OpLabel
+%24 = OpIAdd %6 %23 %13
+%19 = OpISub %6 %18 %13
+OpBranch %17
+%31 = OpLabel
+OpBranch %32
+%32 = OpLabel
+%33 = OpPhi %6 %18 %31 %81 %79
+%34 = OpPhi %6 %21 %31 %78 %79
+%35 = OpPhi %6 %23 %31 %80 %79
+OpLoopMerge %44 %79 DontUnroll
+OpBranch %36
+%36 = OpLabel
+%37 = OpSLessThan %12 %35 %11
+OpBranchConditional %37 %38 %44
+%38 = OpLabel
+%39 = OpIMul %6 %35 %33
+%40 = OpIAdd %6 %34 %39
+OpBranch %41
+%41 = OpLabel
+%42 = OpIAdd %6 %35 %13
+%43 = OpISub %6 %33 %13
+OpBranch %46
+%46 = OpLabel
+OpBranch %50
+%50 = OpLabel
+%51 = OpSLessThan %12 %42 %11
+OpBranch %52
+%52 = OpLabel
+%53 = OpIMul %6 %42 %43
+%54 = OpIAdd %6 %40 %53
+OpBranch %55
+%55 = OpLabel
+%56 = OpIAdd %6 %42 %13
+%57 = OpISub %6 %43 %13
+OpBranch %58
+%58 = OpLabel
+OpBranch %62
+%62 = OpLabel
+%63 = OpSLessThan %12 %56 %11
+OpBranch %64
+%64 = OpLabel
+%65 = OpIMul %6 %56 %57
+%66 = OpIAdd %6 %54 %65
+OpBranch %67
+%67 = OpLabel
+%68 = OpIAdd %6 %56 %13
+%69 = OpISub %6 %57 %13
+OpBranch %70
+%70 = OpLabel
+OpBranch %74
+%74 = OpLabel
+%75 = OpSLessThan %12 %68 %11
+OpBranch %76
+%76 = OpLabel
+%77 = OpIMul %6 %68 %69
+%78 = OpIAdd %6 %66 %77
+OpBranch %79
+%79 = OpLabel
+%80 = OpIAdd %6 %68 %13
+%81 = OpISub %6 %69 %13
+OpBranch %32
+%44 = OpLabel
+%45 = OpIAdd %6 %34 %33
+OpReturnValue %45
+%25 = OpLabel
+%30 = OpIAdd %6 %34 %33
+OpReturnValue %30
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, multiple_phi_shader,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << multiple_phi_shader << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<4>>(multiple_phi_shader, output,
+                                                    false);
+}
+
+TEST_F(PassClassTest, PartiallyUnrollMultipleInductionVariables) {
+  // clang-format off
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "foo("
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypeFunction %6
+%8 = OpTypePointer Function %6
+%9 = OpConstant %6 0
+%10 = OpConstant %6 3
+%11 = OpConstant %6 6
+%12 = OpTypeBool
+%13 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%14 = OpLabel
+%15 = OpFunctionCall %6 %3
+OpReturn
+OpFunctionEnd
+%3 = OpFunction %6 None %7
+%16 = OpLabel
+OpBranch %17
+%17 = OpLabel
+%18 = OpPhi %6 %9 %16 %42 %40
+%21 = OpPhi %6 %10 %16 %39 %40
+%23 = OpPhi %6 %9 %16 %41 %40
+OpLoopMerge %25 %40 DontUnroll
+OpBranch %26
+%26 = OpLabel
+%27 = OpSLessThan %12 %23 %11
+OpBranchConditional %27 %28 %25
+%28 = OpLabel
+%29 = OpIMul %6 %23 %18
+%22 = OpIAdd %6 %21 %29
+OpBranch %20
+%20 = OpLabel
+%24 = OpIAdd %6 %23 %13
+%19 = OpISub %6 %18 %13
+OpBranch %31
+%31 = OpLabel
+OpBranch %35
+%35 = OpLabel
+%36 = OpSLessThan %12 %24 %11
+OpBranch %37
+%37 = OpLabel
+%38 = OpIMul %6 %24 %19
+%39 = OpIAdd %6 %22 %38
+OpBranch %40
+%40 = OpLabel
+%41 = OpIAdd %6 %24 %13
+%42 = OpISub %6 %19 %13
+OpBranch %17
+%25 = OpLabel
+%30 = OpIAdd %6 %21 %18
+OpReturnValue %30
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, multiple_phi_shader,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << multiple_phi_shader << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(multiple_phi_shader, output,
+                                                    false);
+}
+
+TEST_F(PassClassTest, FullyUnrollMultipleInductionVariables) {
+  // clang-format off
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "foo("
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypeFunction %6
+%8 = OpTypePointer Function %6
+%9 = OpConstant %6 0
+%10 = OpConstant %6 3
+%11 = OpConstant %6 6
+%12 = OpTypeBool
+%13 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%14 = OpLabel
+%15 = OpFunctionCall %6 %3
+OpReturn
+OpFunctionEnd
+%3 = OpFunction %6 None %7
+%16 = OpLabel
+OpBranch %17
+%17 = OpLabel
+OpBranch %26
+%26 = OpLabel
+%27 = OpSLessThan %12 %9 %11
+OpBranch %28
+%28 = OpLabel
+%29 = OpIMul %6 %9 %9
+%22 = OpIAdd %6 %10 %29
+OpBranch %20
+%20 = OpLabel
+%24 = OpIAdd %6 %9 %13
+%19 = OpISub %6 %9 %13
+OpBranch %31
+%31 = OpLabel
+OpBranch %35
+%35 = OpLabel
+%36 = OpSLessThan %12 %24 %11
+OpBranch %37
+%37 = OpLabel
+%38 = OpIMul %6 %24 %19
+%39 = OpIAdd %6 %22 %38
+OpBranch %40
+%40 = OpLabel
+%41 = OpIAdd %6 %24 %13
+%42 = OpISub %6 %19 %13
+OpBranch %43
+%43 = OpLabel
+OpBranch %47
+%47 = OpLabel
+%48 = OpSLessThan %12 %41 %11
+OpBranch %49
+%49 = OpLabel
+%50 = OpIMul %6 %41 %42
+%51 = OpIAdd %6 %39 %50
+OpBranch %52
+%52 = OpLabel
+%53 = OpIAdd %6 %41 %13
+%54 = OpISub %6 %42 %13
+OpBranch %55
+%55 = OpLabel
+OpBranch %59
+%59 = OpLabel
+%60 = OpSLessThan %12 %53 %11
+OpBranch %61
+%61 = OpLabel
+%62 = OpIMul %6 %53 %54
+%63 = OpIAdd %6 %51 %62
+OpBranch %64
+%64 = OpLabel
+%65 = OpIAdd %6 %53 %13
+%66 = OpISub %6 %54 %13
+OpBranch %67
+%67 = OpLabel
+OpBranch %71
+%71 = OpLabel
+%72 = OpSLessThan %12 %65 %11
+OpBranch %73
+%73 = OpLabel
+%74 = OpIMul %6 %65 %66
+%75 = OpIAdd %6 %63 %74
+OpBranch %76
+%76 = OpLabel
+%77 = OpIAdd %6 %65 %13
+%78 = OpISub %6 %66 %13
+OpBranch %79
+%79 = OpLabel
+OpBranch %83
+%83 = OpLabel
+%84 = OpSLessThan %12 %77 %11
+OpBranch %85
+%85 = OpLabel
+%86 = OpIMul %6 %77 %78
+%87 = OpIAdd %6 %75 %86
+OpBranch %88
+%88 = OpLabel
+%89 = OpIAdd %6 %77 %13
+%90 = OpISub %6 %78 %13
+OpBranch %25
+%25 = OpLabel
+%30 = OpIAdd %6 %87 %90
+OpReturnValue %30
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, multiple_phi_shader,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << multiple_phi_shader << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(multiple_phi_shader, output, false);
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main()
+{
+    int j = 0;
+    for (int i = 0; i <= 2; ++i)
+        ++j;
+
+    for (int i = 1; i >= 0; --i)
+        ++j;
+}
+*/
+TEST_F(PassClassTest, FullyUnrollEqualToOperations) {
+  // clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource GLSL 440
+               OpName %4 "main"
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Function %6
+          %9 = OpConstant %6 0
+         %17 = OpConstant %6 2
+         %18 = OpTypeBool
+         %21 = OpConstant %6 1
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpBranch %11
+         %11 = OpLabel
+         %37 = OpPhi %6 %9 %5 %22 %14
+         %38 = OpPhi %6 %9 %5 %24 %14
+               OpLoopMerge %13 %14 Unroll
+               OpBranch %15
+         %15 = OpLabel
+         %19 = OpSLessThanEqual %18 %38 %17
+               OpBranchConditional %19 %12 %13
+         %12 = OpLabel
+         %22 = OpIAdd %6 %37 %21
+               OpBranch %14
+         %14 = OpLabel
+         %24 = OpIAdd %6 %38 %21
+               OpBranch %11
+         %13 = OpLabel
+               OpBranch %26
+         %26 = OpLabel
+         %39 = OpPhi %6 %37 %13 %34 %29
+         %40 = OpPhi %6 %21 %13 %36 %29
+               OpLoopMerge %28 %29 Unroll
+               OpBranch %30
+         %30 = OpLabel
+         %32 = OpSGreaterThanEqual %18 %40 %9
+               OpBranchConditional %32 %27 %28
+         %27 = OpLabel
+         %34 = OpIAdd %6 %39 %21
+               OpBranch %29
+         %29 = OpLabel
+         %36 = OpISub %6 %40 %21
+               OpBranch %26
+         %28 = OpLabel
+               OpReturn
+               OpFunctionEnd
+    )";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+%3 = OpTypeVoid
+%4 = OpTypeFunction %3
+%5 = OpTypeInt 32 1
+%6 = OpTypePointer Function %5
+%7 = OpConstant %5 0
+%8 = OpConstant %5 2
+%9 = OpTypeBool
+%10 = OpConstant %5 1
+%2 = OpFunction %3 None %4
+%11 = OpLabel
+OpBranch %12
+%12 = OpLabel
+OpBranch %19
+%19 = OpLabel
+%20 = OpSLessThanEqual %9 %7 %8
+OpBranch %21
+%21 = OpLabel
+%14 = OpIAdd %5 %7 %10
+OpBranch %15
+%15 = OpLabel
+%17 = OpIAdd %5 %7 %10
+OpBranch %41
+%41 = OpLabel
+OpBranch %44
+%44 = OpLabel
+%45 = OpSLessThanEqual %9 %17 %8
+OpBranch %46
+%46 = OpLabel
+%47 = OpIAdd %5 %14 %10
+OpBranch %48
+%48 = OpLabel
+%49 = OpIAdd %5 %17 %10
+OpBranch %50
+%50 = OpLabel
+OpBranch %53
+%53 = OpLabel
+%54 = OpSLessThanEqual %9 %49 %8
+OpBranch %55
+%55 = OpLabel
+%56 = OpIAdd %5 %47 %10
+OpBranch %57
+%57 = OpLabel
+%58 = OpIAdd %5 %49 %10
+OpBranch %18
+%18 = OpLabel
+OpBranch %22
+%22 = OpLabel
+OpBranch %29
+%29 = OpLabel
+%30 = OpSGreaterThanEqual %9 %10 %7
+OpBranch %31
+%31 = OpLabel
+%24 = OpIAdd %5 %56 %10
+OpBranch %25
+%25 = OpLabel
+%27 = OpISub %5 %10 %10
+OpBranch %32
+%32 = OpLabel
+OpBranch %35
+%35 = OpLabel
+%36 = OpSGreaterThanEqual %9 %27 %7
+OpBranch %37
+%37 = OpLabel
+%38 = OpIAdd %5 %24 %10
+OpBranch %39
+%39 = OpLabel
+%40 = OpISub %5 %27 %10
+OpBranch %28
+%28 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << text << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+// clang-format off
+  // With opt::LocalMultiStoreElimPass
+  const std::string condition_in_header = R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %o
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 430
+               OpDecorate %o Location 0
+       %void = OpTypeVoid
+          %6 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+     %int_n2 = OpConstant %int -2
+      %int_2 = OpConstant %int 2
+       %bool = OpTypeBool
+      %float = OpTypeFloat 32
+%_ptr_Output_float = OpTypePointer Output %float
+          %o = OpVariable %_ptr_Output_float Output
+    %float_1 = OpConstant %float 1
+       %main = OpFunction %void None %6
+         %15 = OpLabel
+               OpBranch %16
+         %16 = OpLabel
+         %27 = OpPhi %int %int_n2 %15 %26 %18
+         %21 = OpSLessThanEqual %bool %27 %int_2
+               OpLoopMerge %17 %18 Unroll
+               OpBranchConditional %21 %22 %17
+         %22 = OpLabel
+         %23 = OpLoad %float %o
+         %24 = OpFAdd %float %23 %float_1
+               OpStore %o %24
+               OpBranch %18
+         %18 = OpLabel
+         %26 = OpIAdd %int %27 %int_2
+               OpBranch %16
+         %17 = OpLabel
+               OpReturn
+               OpFunctionEnd
+    )";
+//clang-format on
+
+
+TEST_F(PassClassTest, FullyUnrollConditionIsInHeaderBlock) {
+
+// clang-format off
+const std::string output =
+R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main" %2
+OpExecutionMode %1 OriginUpperLeft
+OpSource GLSL 430
+OpDecorate %2 Location 0
+%3 = OpTypeVoid
+%4 = OpTypeFunction %3
+%5 = OpTypeInt 32 1
+%6 = OpConstant %5 -2
+%7 = OpConstant %5 2
+%8 = OpTypeBool
+%9 = OpTypeFloat 32
+%10 = OpTypePointer Output %9
+%2 = OpVariable %10 Output
+%11 = OpConstant %9 1
+%1 = OpFunction %3 None %4
+%12 = OpLabel
+OpBranch %13
+%13 = OpLabel
+%17 = OpSLessThanEqual %8 %6 %7
+OpBranch %19
+%19 = OpLabel
+%20 = OpLoad %9 %2
+%21 = OpFAdd %9 %20 %11
+OpStore %2 %21
+OpBranch %16
+%16 = OpLabel
+%15 = OpIAdd %5 %6 %7
+OpBranch %22
+%22 = OpLabel
+%24 = OpSLessThanEqual %8 %15 %7
+OpBranch %25
+%25 = OpLabel
+%26 = OpLoad %9 %2
+%27 = OpFAdd %9 %26 %11
+OpStore %2 %27
+OpBranch %28
+%28 = OpLabel
+%29 = OpIAdd %5 %15 %7
+OpBranch %30
+%30 = OpLabel
+%32 = OpSLessThanEqual %8 %29 %7
+OpBranch %33
+%33 = OpLabel
+%34 = OpLoad %9 %2
+%35 = OpFAdd %9 %34 %11
+OpStore %2 %35
+OpBranch %36
+%36 = OpLabel
+%37 = OpIAdd %5 %29 %7
+OpBranch %18
+%18 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, condition_in_header,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << condition_in_header << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<opt::LoopUnroller>(condition_in_header, output, false);
+}
+
+TEST_F(PassClassTest, PartiallyUnrollResidualConditionIsInHeaderBlock) {
+  // clang-format off
+const std::string output =
+R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main" %2
+OpExecutionMode %1 OriginUpperLeft
+OpSource GLSL 430
+OpDecorate %2 Location 0
+%3 = OpTypeVoid
+%4 = OpTypeFunction %3
+%5 = OpTypeInt 32 1
+%6 = OpConstant %5 -2
+%7 = OpConstant %5 2
+%8 = OpTypeBool
+%9 = OpTypeFloat 32
+%10 = OpTypePointer Output %9
+%2 = OpVariable %10 Output
+%11 = OpConstant %9 1
+%40 = OpTypeInt 32 0
+%41 = OpConstant %40 1
+%1 = OpFunction %3 None %4
+%12 = OpLabel
+OpBranch %13
+%13 = OpLabel
+%14 = OpPhi %5 %6 %12 %15 %16
+%17 = OpSLessThanEqual %8 %14 %41
+OpLoopMerge %22 %16 Unroll
+OpBranchConditional %17 %19 %22
+%19 = OpLabel
+%20 = OpLoad %9 %2
+%21 = OpFAdd %9 %20 %11
+OpStore %2 %21
+OpBranch %16
+%16 = OpLabel
+%15 = OpIAdd %5 %14 %7
+OpBranch %13
+%22 = OpLabel
+OpBranch %23
+%23 = OpLabel
+%24 = OpPhi %5 %14 %22 %39 %38
+%25 = OpSLessThanEqual %8 %24 %7
+OpLoopMerge %31 %38 DontUnroll
+OpBranchConditional %25 %26 %31
+%26 = OpLabel
+%27 = OpLoad %9 %2
+%28 = OpFAdd %9 %27 %11
+OpStore %2 %28
+OpBranch %29
+%29 = OpLabel
+%30 = OpIAdd %5 %24 %7
+OpBranch %32
+%32 = OpLabel
+%34 = OpSLessThanEqual %8 %30 %7
+OpBranch %35
+%35 = OpLabel
+%36 = OpLoad %9 %2
+%37 = OpFAdd %9 %36 %11
+OpStore %2 %37
+OpBranch %38
+%38 = OpLabel
+%39 = OpIAdd %5 %30 %7
+OpBranch %23
+%31 = OpLabel
+OpReturn
+%18 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  // clang-format on
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, condition_in_header,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ir::Module* module = context->module();
+  EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+                             << condition_in_header << std::endl;
+
+  opt::LoopUnroller loop_unroller;
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(condition_in_header, output,
+                                                    false);
+}
+
 }  // namespace
index fb8b3a1..3f8dd88 100644 (file)
@@ -169,6 +169,12 @@ Options (in lexicographical order):
   --local-redundancy-elimination
                Looks for instructions in the same basic block that compute the
                same value, and deletes the redundant ones.
+  --loop-unroll
+               Fully unrolls loops marked with the Unroll flag
+  --loop-unroll-partial
+               Partially unrolls loops marked with the Unroll flag. Takes an
+               additional non-0 integer argument to set the unroll factor, or
+               how many times a loop body should be duplicated
   --merge-blocks
                Join two blocks into a single block if the second has the
                first as its only predecessor. Performed only on entry point
@@ -355,6 +361,21 @@ OptStatus ParseOconfigFlag(const char* prog_name, const char* opt_flag,
                     in_file, out_file, nullptr, &skip_validator);
 }
 
+OptStatus ParseLoopUnrollPartialArg(int argc, const char** argv, int argi,
+                                    Optimizer* optimizer) {
+  if (argi < argc) {
+    int factor = atoi(argv[argi]);
+    if (factor != 0) {
+      optimizer->RegisterPass(CreateLoopUnrollPass(false, factor));
+      return {OPT_CONTINUE, 0};
+    }
+  }
+  fprintf(stderr,
+          "error: --loop-unroll-partial must be followed by a non-0 "
+          "integer\n");
+  return {OPT_STOP, 1};
+}
+
 // Parses command-line flags. |argc| contains the number of command-line flags.
 // |argv| points to an array of strings holding the flags. |optimizer| is the
 // Optimizer instance used to optimize the program.
@@ -473,7 +494,13 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer,
       } else if (0 == strcmp(cur_arg, "--simplify-instructions")) {
         optimizer->RegisterPass(CreateSimplificationPass());
       } else if (0 == strcmp(cur_arg, "--loop-unroll")) {
-        optimizer->RegisterPass(CreateLoopFullyUnrollPass());
+        optimizer->RegisterPass(CreateLoopUnrollPass(true));
+      } else if (0 == strcmp(cur_arg, "--loop-unroll-partial")) {
+        OptStatus status =
+            ParseLoopUnrollPartialArg(argc, argv, ++argi, optimizer);
+        if (status.action != OPT_CONTINUE) {
+          return status;
+        }
       } else if (0 == strcmp(cur_arg, "--skip-validation")) {
         *skip_validator = true;
       } else if (0 == strcmp(cur_arg, "-O")) {