From dd8400e15077e640043b096d7634976d72fe2f8a Mon Sep 17 00:00:00 2001 From: Stephen McGroarty Date: Wed, 14 Feb 2018 17:03:12 +0000 Subject: [PATCH] Initial support for loop unrolling. This patch adds initial support for loop unrolling in the form of a series of utility classes which perform the unrolling. The pass can be run with the command spirv-opt --loop-unroll. This will unroll loops within the module which have the unroll hint set. The unroller imposes a number of requirements on the loops it can unroll. These are documented in the comments for the LoopUtils::CanPerformUnroll method in loop_utils.h. Some of the restrictions will be lifted in future patches. --- Android.mk | 1 + include/spirv-tools/optimizer.hpp | 8 + source/opt/CMakeLists.txt | 2 + source/opt/dominator_tree.h | 13 + source/opt/ir_builder.h | 40 +- source/opt/loop_descriptor.cpp | 371 ++- source/opt/loop_descriptor.h | 126 +- source/opt/loop_unroller.cpp | 939 +++++++ source/opt/loop_unroller.h | 37 + source/opt/loop_utils.h | 48 +- source/opt/optimizer.cpp | 5 + source/opt/passes.h | 2 +- test/opt/loop_optimizations/CMakeLists.txt | 13 + .../loop_optimizations/unroll_assumptions.cpp | 627 +++++ test/opt/loop_optimizations/unroll_simple.cpp | 2179 +++++++++++++++++ tools/opt/opt.cpp | 2 + 16 files changed, 4388 insertions(+), 25 deletions(-) create mode 100644 source/opt/loop_unroller.cpp create mode 100644 source/opt/loop_unroller.h create mode 100644 test/opt/loop_optimizations/unroll_assumptions.cpp create mode 100644 test/opt/loop_optimizations/unroll_simple.cpp diff --git a/Android.mk b/Android.mk index 8f5e55b8..4429ad18 100644 --- a/Android.mk +++ b/Android.mk @@ -100,6 +100,7 @@ SPVTOOLS_OPT_SRC_FILES := \ source/opt/local_single_store_elim_pass.cpp \ source/opt/local_ssa_elim_pass.cpp \ source/opt/loop_descriptor.cpp \ + source/opt/loop_unroller.cpp \ source/opt/mem_pass.cpp \ source/opt/merge_return_pass.cpp \ source/opt/module.cpp \ diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index e738b506..9f3b3601 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -512,6 +512,14 @@ Optimizer::PassToken CreateReplaceInvalidOpcodePass(); // Creates a pass that simplifies instructions using the instruction folder. Optimizer::PassToken CreateSimplificationPass(); +// Create loop unroller pass. +// Creates a pass to fully unroll loops which have the "Unroll" loop control +// mask set. The loops must meet a specific criteria in order to be unrolled +// safely this criteria is checked before doing the unroll by the +// LoopUtils::CanPerformUnroll method. Any loop that does not meet the criteria +// won't be unrolled. See CanPerformUnroll LoopUtils.h for more information. +Optimizer::PassToken CreateLoopFullyUnrollPass(); + } // namespace spvtools #endif // SPIRV_TOOLS_OPTIMIZER_HPP_ diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index b3c6ebe0..01948515 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -58,6 +58,7 @@ add_library(SPIRV-Tools-opt local_ssa_elim_pass.h log.h loop_descriptor.h + loop_unroller.h loop_utils.h make_unique.h mem_pass.h @@ -130,6 +131,7 @@ add_library(SPIRV-Tools-opt local_ssa_elim_pass.cpp loop_descriptor.cpp loop_utils.cpp + loop_unroller.cpp mem_pass.cpp merge_return_pass.cpp module.cpp diff --git a/source/opt/dominator_tree.h b/source/opt/dominator_tree.h index 0be49514..5221eea1 100644 --- a/source/opt/dominator_tree.h +++ b/source/opt/dominator_tree.h @@ -224,6 +224,19 @@ class DominatorTree { return true; } + // Applies the std::function |func| to all nodes in the dominator tree from + // |node| downwards. The boolean return from |func| is used to determine + // whether or not the children should also be traversed. Tree nodes are + // visited in a depth first pre-order. + void VisitChildrenIf(std::function func, + iterator node) { + if (func(&*node)) { + for (auto n : *node) { + VisitChildrenIf(func, n->df_begin()); + } + } + } + // Returns the DominatorTreeNode associated with the basic block |bb|. // If the |bb| is unknown to the dominator tree, it returns null. inline DominatorTreeNode* GetTreeNode(ir::BasicBlock* bb) { diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h index 9c153deb..3b75ec6d 100644 --- a/source/opt/ir_builder.h +++ b/source/opt/ir_builder.h @@ -16,9 +16,9 @@ #define LIBSPIRV_OPT_IR_BUILDER_H_ #include "opt/basic_block.h" +#include "opt/constants.h" #include "opt/instruction.h" #include "opt/ir_context.h" - namespace spvtools { namespace opt { @@ -136,6 +136,12 @@ class InstructionBuilder { return AddInstruction(std::move(select)); } + // Adds a signed int32 constant to the binary. + // The |value| parameter is the constant value to be added. + ir::Instruction* Add32BitSignedIntegerConstant(int32_t value) { + return Add32BitConstantInteger(value, true); + } + // Create a composite construct. // |type| should be a composite type and the number of elements it has should // match the size od |ids|. @@ -151,6 +157,38 @@ class InstructionBuilder { GetContext()->TakeNextId(), ops)); return AddInstruction(std::move(construct)); } + // Adds an unsigned int32 constant to the binary. + // The |value| parameter is the constant value to be added. + ir::Instruction* Add32BitUnsignedIntegerConstant(uint32_t value) { + return Add32BitConstantInteger(value, false); + } + + // Adds either a signed or unsigned 32 bit integer constant to the binary + // depedning on the |sign|. If |sign| is true then the value is added as a + // signed constant otherwise as an unsigned constant. If |sign| is false the + // value must not be a negative number. + template + ir::Instruction* Add32BitConstantInteger(T value, bool sign) { + // Assert that we are not trying to store a negative number in an unsigned + // type. + if (!sign) + assert(value > 0 && + "Trying to add a signed integer with an unsigned type!"); + + // Get or create the integer type. + analysis::Integer int_type(32, sign); + + // Even if the value is negative we need to pass the bit pattern as a + // uint32_t to GetConstant. + uint32_t word = value; + + // Create the constant value. + const opt::analysis::Constant* constant = + GetContext()->get_constant_mgr()->GetConstant(&int_type, {word}); + + // Create the OpConstant instruction using the type and the value. + return GetContext()->get_constant_mgr()->GetDefiningInstruction(constant); + } ir::Instruction* AddCompositeExtract( uint32_t type, uint32_t id_of_composite, diff --git a/source/opt/loop_descriptor.cpp b/source/opt/loop_descriptor.cpp index a2b95462..32765be0 100644 --- a/source/opt/loop_descriptor.cpp +++ b/source/opt/loop_descriptor.cpp @@ -18,6 +18,7 @@ #include #include +#include "constants.h" #include "opt/cfg.h" #include "opt/dominator_tree.h" #include "opt/ir_builder.h" @@ -29,6 +30,107 @@ namespace spvtools { namespace ir { +// Takes in a phi instruction |induction| and the loop |header| and returns the +// step operation of the loop. +ir::Instruction* Loop::GetInductionStepOperation( + const ir::Loop* loop, const ir::Instruction* induction) const { + // Induction must be a phi instruction. + assert(induction->opcode() == SpvOpPhi); + + ir::Instruction* step = nullptr; + + opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + + // Traverse the incoming operands of the phi instruction. + for (uint32_t operand_id = 1; operand_id < induction->NumInOperands(); + operand_id += 2) { + // Incoming edge. + ir::BasicBlock* incoming_block = + context_->cfg()->block(induction->GetSingleWordInOperand(operand_id)); + + // Check if the block is dominated by header, and thus coming from within + // the loop. + if (loop->IsInsideLoop(incoming_block)) { + step = def_use_manager->GetDef( + induction->GetSingleWordInOperand(operand_id - 1)); + break; + } + } + + if (!step || !IsSupportedStepOp(step->opcode())) { + return nullptr; + } + + return step; +} + +// Returns true if the |step| operation is an induction variable step operation +// which is currently handled. +bool Loop::IsSupportedStepOp(SpvOp step) const { + switch (step) { + case SpvOp::SpvOpISub: + case SpvOp::SpvOpIAdd: + return true; + default: + return false; + } +} + +bool Loop::IsSupportedCondition(SpvOp condition) const { + switch (condition) { + // < + case SpvOp::SpvOpULessThan: + case SpvOp::SpvOpSLessThan: + // > + case SpvOp::SpvOpUGreaterThan: + case SpvOp::SpvOpSGreaterThan: + return true; + default: + return false; + } +} + +// Extract the initial value from the |induction| OpPhi instruction and store it +// in |value|. If the function couldn't find the initial value of |induction| +// return false. +bool Loop::GetInductionInitValue(const ir::Loop* loop, + const ir::Instruction* induction, + int64_t* value) const { + ir::Instruction* constant_instruction = nullptr; + opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + + for (uint32_t operand_id = 0; operand_id < induction->NumInOperands(); + operand_id += 2) { + ir::BasicBlock* bb = context_->cfg()->block( + induction->GetSingleWordInOperand(operand_id + 1)); + + if (!loop->IsInsideLoop(bb)) { + constant_instruction = def_use_manager->GetDef( + induction->GetSingleWordInOperand(operand_id)); + } + } + + if (!constant_instruction) return false; + + const opt::analysis::Constant* constant = + context_->get_constant_mgr()->FindDeclaredConstant( + constant_instruction->result_id()); + if (!constant) return false; + + if (value) { + const opt::analysis::Integer* type = + constant->AsIntConstant()->type()->AsInteger(); + + if (type->IsSigned()) { + *value = constant->AsIntConstant()->GetS32BitValue(); + } else { + *value = constant->AsIntConstant()->GetU32BitValue(); + } + } + + return true; +} + Loop::Loop(IRContext* context, opt::DominatorAnalysis* dom_analysis, BasicBlock* header, BasicBlock* continue_target, BasicBlock* merge_target) @@ -37,12 +139,11 @@ Loop::Loop(IRContext* context, opt::DominatorAnalysis* dom_analysis, loop_continue_(continue_target), loop_merge_(merge_target), loop_preheader_(nullptr), - parent_(nullptr) { + parent_(nullptr), + loop_is_marked_for_removal_(false) { assert(context); assert(dom_analysis); loop_preheader_ = FindLoopPreheader(dom_analysis); - AddBasicBlockToLoop(header); - AddBasicBlockToLoop(continue_target); } BasicBlock* Loop::FindLoopPreheader(opt::DominatorAnalysis* dom_analysis) { @@ -92,7 +193,6 @@ bool Loop::IsInsideLoop(Instruction* inst) const { bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) { assert(bb->GetParent() && "The basic block does not belong to a function"); - opt::DominatorAnalysis* dom_analysis = context_->GetDominatorAnalysis(bb->GetParent(), *context_->cfg()); if (!dom_analysis->Dominates(GetHeaderBlock(), bb)) return false; @@ -219,14 +319,8 @@ void Loop::SetLatchBlock(BasicBlock* latch) { void Loop::SetMergeBlock(BasicBlock* merge) { #ifndef NDEBUG assert(merge->GetParent() && "The basic block does not belong to a function"); - CFG& cfg = *merge->GetParent()->GetParent()->context()->cfg(); - - for (uint32_t pred : cfg.preds(merge->id())) { - assert(IsInsideLoop(pred) && - "A predecessor of the merge block does not belong to the loop"); - } - assert(!IsInsideLoop(merge) && "The merge block is in the loop"); #endif // NDEBUG + assert(!IsInsideLoop(merge) && "The merge block is in the loop"); SetMergeBlockImpl(merge); if (GetHeaderBlock()->GetLoopMergeInst()) { @@ -327,6 +421,7 @@ LoopDescriptor::~LoopDescriptor() { ClearLoops(); } void LoopDescriptor::PopulateList(const Function* f) { IRContext* context = f->GetParent()->context(); + opt::DominatorAnalysis* dom_analysis = context->GetDominatorAnalysis(f, *context->cfg()); @@ -384,7 +479,7 @@ void LoopDescriptor::PopulateList(const Function* f) { make_range(node.df_begin(), node.df_end())) { // Check if we are in the loop. if (dom_tree.Dominates(dom_merge_node, &loop_node)) continue; - current_loop->AddBasicBlockToLoop(loop_node.bb_); + current_loop->AddBasicBlock(loop_node.bb_); basic_block_to_loop_.insert( std::make_pair(loop_node.bb_->id(), current_loop)); } @@ -395,12 +490,262 @@ void LoopDescriptor::PopulateList(const Function* f) { } } +ir::BasicBlock* Loop::FindConditionBlock() const { + const ir::Function& function = *loop_merge_->GetParent(); + ir::BasicBlock* condition_block = nullptr; + + const opt::DominatorAnalysis* dom_analysis = + context_->GetDominatorAnalysis(&function, *context_->cfg()); + ir::BasicBlock* bb = dom_analysis->ImmediateDominator(loop_merge_); + + if (!bb) return nullptr; + + const ir::Instruction& branch = *bb->ctail(); + + // Make sure the branch is a conditional branch. + if (branch.opcode() != SpvOpBranchConditional) return nullptr; + + // Make sure one of the two possible branches is to the merge block. + if (branch.GetSingleWordInOperand(1) == loop_merge_->id() || + branch.GetSingleWordInOperand(2) == loop_merge_->id()) { + condition_block = bb; + } + + return condition_block; +} + +bool Loop::FindNumberOfIterations(const ir::Instruction* induction, + const ir::Instruction* branch_inst, + size_t* iterations_out, + int64_t* step_value_out, + int64_t* init_value_out) const { + // From the branch instruction find the branch condition. + opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + + // Condition instruction from the OpConditionalBranch. + ir::Instruction* condition = + def_use_manager->GetDef(branch_inst->GetSingleWordOperand(0)); + + assert(IsSupportedCondition(condition->opcode())); + + // Get the constant manager from the ir context. + opt::analysis::ConstantManager* const_manager = context_->get_constant_mgr(); + + // Find the constant value used by the condition variable. Exit out if it + // isn't a constant int. + const opt::analysis::Constant* upper_bound = + const_manager->FindDeclaredConstant(condition->GetSingleWordOperand(3)); + if (!upper_bound) return false; + + // Must be integer because of the opcode on the condition. + int64_t condition_value = 0; + + const opt::analysis::Integer* type = + upper_bound->AsIntConstant()->type()->AsInteger(); + + if (type->IsSigned()) { + condition_value = upper_bound->AsIntConstant()->GetS32BitValue(); + } else { + condition_value = upper_bound->AsIntConstant()->GetU32BitValue(); + } + + // Find the instruction which is stepping through the loop. + ir::Instruction* step_inst = GetInductionStepOperation(this, induction); + if (!step_inst) return false; + + // Find the constant value used by the condition variable. + const opt::analysis::Constant* step_constant = + const_manager->FindDeclaredConstant(step_inst->GetSingleWordOperand(3)); + if (!step_constant) return false; + + // Must be integer because of the opcode on the condition. + int64_t step_value = 0; + + const opt::analysis::Integer* step_type = + step_constant->AsIntConstant()->type()->AsInteger(); + + if (step_type->IsSigned()) { + step_value = step_constant->AsIntConstant()->GetS32BitValue(); + } else { + step_value = step_constant->AsIntConstant()->GetU32BitValue(); + } + + // If this is a subtraction step we should negate the step value. + if (step_inst->opcode() == SpvOp::SpvOpISub) { + step_value = -step_value; + } + + // Find the inital value of the loop and make sure it is a constant integer. + int64_t init_value = 0; + if (!GetInductionInitValue(this, induction, &init_value)) return false; + + // If iterations is non null then store the value in that. + if (iterations_out) { + int64_t num_itrs = GetIterations(condition->opcode(), condition_value, + init_value, step_value); + + // If the loop body will not be reached return false. + if (num_itrs <= 0) { + return false; + } + assert(static_cast(num_itrs) <= std::numeric_limits::max()); + *iterations_out = static_cast(num_itrs); + } + + if (step_value_out) { + *step_value_out = step_value; + } + + if (init_value_out) { + *init_value_out = init_value; + } + + return true; +} + +// We retrieve the number of iterations using the following formula, diff / +// |step_value| where diff is calculated differently according to the +// |condition| and uses the |condition_value| and |init_value|. If diff / +// |step_value| is NOT cleanly divisable then we add one to the sum. +int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value, + int64_t init_value, int64_t step_value) const { + int64_t diff = 0; + + // Take the abs of - step values. + step_value = llabs(step_value); + + switch (condition) { + case SpvOp::SpvOpSLessThan: + case SpvOp::SpvOpULessThan: { + diff = condition_value - init_value; + break; + } + case SpvOp::SpvOpSGreaterThan: + case SpvOp::SpvOpUGreaterThan: { + diff = init_value - condition_value; + break; + } + default: + assert(false && + "Could not retrieve number of iterations from the loop condition. " + "Condition is not supported."); + } + + int64_t result = diff / step_value; + + if (diff % step_value != 0) { + result += 1; + } + return result; +} + +ir::Instruction* Loop::FindInductionVariable( + const ir::BasicBlock* condition_block) const { + // Find the branch instruction. + const ir::Instruction& branch_inst = *condition_block->ctail(); + + ir::Instruction* induction = nullptr; + // Verify that the branch instruction is a conditional branch. + if (branch_inst.opcode() == SpvOp::SpvOpBranchConditional) { + // From the branch instruction find the branch condition. + opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + + // Find the instruction representing the condition used in the conditional + // branch. + ir::Instruction* condition = + def_use_manager->GetDef(branch_inst.GetSingleWordOperand(0)); + + // Ensure that the condition is a less than operation. + if (condition && IsSupportedCondition(condition->opcode())) { + // The left hand side operand of the operation. + ir::Instruction* variable_inst = + def_use_manager->GetDef(condition->GetSingleWordOperand(2)); + + // Make sure the variable instruction used is a phi. + if (!variable_inst || variable_inst->opcode() != SpvOpPhi) return nullptr; + + // Make sure the phi instruction only has two incoming blocks. Each + // incoming block will be represented by two in operands in the phi + // instruction, the value and the block which that value came from. We + // assume the cannocalised phi will have two incoming values, one from the + // preheader and one from the continue block. + size_t max_supported_operands = 4; + if (variable_inst->NumInOperands() == max_supported_operands) { + // The operand index of the first incoming block label. + uint32_t operand_label_1 = 1; + + // The operand index of the second incoming block label. + uint32_t operand_label_2 = 3; + + // Make sure one of them is the preheader. + if (variable_inst->GetSingleWordInOperand(operand_label_1) != + loop_preheader_->id() && + variable_inst->GetSingleWordInOperand(operand_label_2) != + loop_preheader_->id()) { + return nullptr; + } + + // And make sure that the other is the latch block. + if (variable_inst->GetSingleWordInOperand(operand_label_1) != + loop_continue_->id() && + variable_inst->GetSingleWordInOperand(operand_label_2) != + loop_continue_->id()) { + return nullptr; + } + } else { + return nullptr; + } + + if (!FindNumberOfIterations(variable_inst, &branch_inst, nullptr)) + return nullptr; + induction = variable_inst; + } + } + + return induction; +} + +// Add and remove loops which have been marked for addition and removal to +// maintain the state of the loop descriptor class. +void LoopDescriptor::PostModificationCleanup() { + LoopContainerType loops_to_remove_; + for (ir::Loop* loop : loops_) { + if (loop->IsMarkedForRemoval()) { + loops_to_remove_.push_back(loop); + if (loop->HasParent()) { + loop->GetParent()->RemoveChildLoop(loop); + } + } + } + + for (ir::Loop* loop : loops_to_remove_) { + loops_.erase(std::find(loops_.begin(), loops_.end(), loop)); + } + + for (auto& pair : loops_to_add_) { + ir::Loop* parent = pair.first; + ir::Loop* loop = pair.second; + + if (parent) { + loop->SetParent(nullptr); + parent->AddNestedLoop(loop); + + for (uint32_t block_id : loop->GetBlocks()) { + parent->AddBasicBlock(block_id); + } + } + + loops_.emplace_back(loop); + } + + loops_to_add_.clear(); +} + void LoopDescriptor::ClearLoops() { for (Loop* loop : loops_) { delete loop; } loops_.clear(); } - } // namespace ir } // namespace spvtools diff --git a/source/opt/loop_descriptor.h b/source/opt/loop_descriptor.h index b4651d75..ef5f19f1 100644 --- a/source/opt/loop_descriptor.h +++ b/source/opt/loop_descriptor.h @@ -24,6 +24,7 @@ #include #include "opt/basic_block.h" +#include "opt/module.h" #include "opt/tree_iterator.h" namespace spvtools { @@ -52,7 +53,8 @@ class Loop { loop_continue_(nullptr), loop_merge_(nullptr), loop_preheader_(nullptr), - parent_(nullptr) {} + parent_(nullptr), + loop_is_marked_for_removal_(false) {} Loop(IRContext* context, opt::DominatorAnalysis* analysis, BasicBlock* header, BasicBlock* continue_target, BasicBlock* merge_target); @@ -144,6 +146,8 @@ class Loop { return lvl; } + inline size_t NumImmediateChildren() const { return nested_loops_.size(); } + // Adds |nested| as a nested loop of this loop. Automatically register |this| // as the parent of |nested|. inline void AddNestedLoop(Loop* nested) { @@ -180,6 +184,21 @@ class Loop { // Returns true if the instruction |inst| is inside this loop. bool IsInsideLoop(Instruction* inst) const; + // Adds the Basic Block |bb| to this loop and its parents. + void AddBasicBlock(const BasicBlock* bb) { AddBasicBlock(bb->id()); } + + // Adds the Basic Block with |id| to this loop and its parents. + void AddBasicBlock(uint32_t id) { + for (Loop* loop = this; loop != nullptr; loop = loop->parent_) { + loop_basic_blocks_.insert(id); + } + } + + // Removes all the basic blocks from the set of basic blocks within the loop. + // This does not affect any of the stored pointers to the header, preheader, + // merge, or continue blocks. + void ClearBlocks() { loop_basic_blocks_.clear(); } + // Adds the Basic Block |bb| this loop and its parents. void AddBasicBlockToLoop(const BasicBlock* bb) { assert(IsBasicBlockInLoopSlow(bb) && @@ -188,11 +207,58 @@ class Loop { AddBasicBlock(bb); } - // Adds the Basic Block |bb| this loop and its parents. - void AddBasicBlock(const BasicBlock* bb) { - for (Loop* loop = this; loop != nullptr; loop = loop->parent_) { - loop_basic_blocks_.insert(bb->id()); + // This function uses the |condition| to find the induction variable within + // the loop. This only works if the loop is bound by a single condition and a + // single induction variable. + ir::Instruction* FindInductionVariable(const ir::BasicBlock* condition) const; + + // Returns the number of iterations within a loop when given the |induction| + // variable and the loop |condition| check. It stores the found number of + // iterations in the output parameter |iterations| and optionally, the step + // value in |step_value| and the initial value of the induction variable in + // |init_value|. + bool FindNumberOfIterations(const ir::Instruction* induction, + const ir::Instruction* condition, + size_t* iterations, + int64_t* step_amount = nullptr, + int64_t* init_value = nullptr) const; + + // Returns the value of the OpLoopMerge control operand as a bool. Loop + // control can be None(0), Unroll(1), or DontUnroll(2). This function returns + // true if it is set to Unroll. + inline bool HasUnrollLoopControl() const { + assert(loop_header_); + if (!loop_header_->GetLoopMergeInst()) return false; + + return loop_header_->GetLoopMergeInst()->GetSingleWordOperand(2) == 1; + } + + // Finds the conditional block with a branch to the merge and continue blocks + // within the loop body. + ir::BasicBlock* FindConditionBlock() const; + + // Remove the child loop form this loop. + inline void RemoveChildLoop(Loop* loop) { + nested_loops_.erase( + std::find(nested_loops_.begin(), nested_loops_.end(), loop)); + loop->SetParent(nullptr); + } + + // Mark this loop to be removed later by a call to + // LoopDescriptor::PostModificationCleanup. + inline void MarkLoopForRemoval() { loop_is_marked_for_removal_ = true; } + + // Returns whether or not this loop has been marked for removal. + inline bool IsMarkedForRemoval() const { return loop_is_marked_for_removal_; } + + // Returns true if all nested loops have been marked for removal. + inline bool AreAllChildrenMarkedForRemoval() const { + for (const Loop* child : nested_loops_) { + if (!child->IsMarkedForRemoval()) { + return false; + } } + return true; } // Sets the parent loop of this loop, that is, a loop which contains this loop @@ -206,6 +272,28 @@ class Loop { // loop bool AreAllOperandsOutsideLoop(IRContext* context, Instruction* inst); + // Extract the initial value from the |induction| variable and store it in + // |value|. If the function couldn't find the initial value of |induction| + // return false. + bool GetInductionInitValue(const ir::Loop* loop, + const ir::Instruction* induction, + int64_t* value) const; + + // Takes in a phi instruction |induction| and the loop |header| and returns + // the step operation of the loop. + ir::Instruction* GetInductionStepOperation( + const ir::Loop* loop, const ir::Instruction* induction) const; + + // Returns true if we can deduce the number of loop iterations in the step + // operation |step|. IsSupportedCondition must also be true for the condition + // instruction. + bool IsSupportedStepOp(SpvOp step) const; + + // Returns true if we can deduce the number of loop iterations in the + // condition operation |condition|. IsSupportedStepOp must also be true for + // the step instruction. + bool IsSupportedCondition(SpvOp condition) const; + private: IRContext* context_; // The block which marks the start of the loop. @@ -244,6 +332,17 @@ class Loop { // Sets |merge| as the loop merge block. No checks are performed here. inline void SetMergeBlockImpl(BasicBlock* merge) { loop_merge_ = merge; } + // Each differnt loop |condition| affects how we calculate the number of + // iterations using the |condition_value|, |init_value|, and |step_values| of + // the induction variable. This method will return the number of iterations in + // a loop with those values for a given |condition|. + int64_t GetIterations(SpvOp condition, int64_t condition_value, + int64_t init_value, int64_t step_value) const; + + // This is to allow for loops to be removed mid iteration without invalidating + // the iterators. + bool loop_is_marked_for_removal_; + // This is only to allow LoopDescriptor::dummy_top_loop_ to add top level // loops as child. friend class LoopDescriptor; @@ -317,10 +416,21 @@ class LoopDescriptor { basic_block_to_loop_[bb_id] = loop; } + // Mark the loop |loop_to_add| as needing to be added when the user calls + // PostModificationCleanup. |parent| may be null. + inline void AddLoop(ir::Loop* loop_to_add, ir::Loop* parent) { + loops_to_add_.emplace_back(std::make_pair(parent, loop_to_add)); + } + + // Should be called to preserve the LoopAnalysis after loops have been marked + // for addition with AddLoop or MarkLoopForRemoval. + void PostModificationCleanup(); + private: // TODO(dneto): This should be a vector of unique_ptr. But VisualStudio 2013 // is unable to compile it. using LoopContainerType = std::vector; + using LoopsToAddContainerType = std::vector>; // Creates loop descriptors for the function |f|. void PopulateList(const Function* f); @@ -338,9 +448,15 @@ class LoopDescriptor { // A list of all the loops in the function. This variable owns the Loop // objects. LoopContainerType loops_; + // Dummy root: this "loop" is only there to help iterators creation. Loop dummy_top_loop_; + std::unordered_map basic_block_to_loop_; + + // List of the loops marked for addition when PostModificationCleanup is + // called. + LoopsToAddContainerType loops_to_add_; }; } // namespace ir diff --git a/source/opt/loop_unroller.cpp b/source/opt/loop_unroller.cpp new file mode 100644 index 00000000..16031e29 --- /dev/null +++ b/source/opt/loop_unroller.cpp @@ -0,0 +1,939 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "opt/loop_unroller.h" +#include +#include +#include +#include "opt/ir_builder.h" +#include "opt/loop_utils.h" + +// Implements loop util unrolling functionality for fully and partially +// unrolling loops. Given a factor it will duplicate the loop that many times, +// appending each one to the end of the old loop and removing backedges, to +// create a new unrolled loop. +// +// 1 - User calls LoopUtils::FullyUnroll or LoopUtils::PartiallyUnroll with a +// loop they wish to unroll. LoopUtils::CanPerformUnroll is used to +// validate that a given loop can be unrolled. That method (along with the +// constructor of loop) checks that the IR is in the expected canonicalised +// format. +// +// 2 - The LoopUtils methods create a LoopUnrollerUtilsImpl object to actually +// perform the unrolling. This implements helper methods to copy the loop basic +// blocks and remap the ids of instructions used inside them. +// +// 3 - The core of LoopUnrollerUtilsImpl is the Unroll method, this method +// actually performs the loop duplication. It does this by creating a +// LoopUnrollState object and then copying the loop as given by the factor +// parameter. The LoopUnrollState object retains the state of the unroller +// between the loop body copies as each iteration needs information on the last +// to adjust the phi induction variable, adjust the OpLoopMerge instruction in +// the main loop header, and change the previous continue block to point to the +// new header and the new continue block to the main loop header. +// +// 4 - If the loop is to be fully unrolled then it is simply closed after step +// 3, with the OpLoopMerge being deleted, the backedge removed, and the +// condition blocks folded. +// +// 5 - If it is being partially unrolled: if the unrolling factor leaves the +// loop with an even number of bodies with respect to the number of loop +// iterations then step 3 is all that is needed. If it is uneven then we need to +// duplicate the loop completely and unroll the duplicated loop to cover the +// residual part and adjust the first loop to cover only the "even" part. For +// instance if you request an unroll factor of 3 on a loop with 10 iterations +// then copying the body three times would leave you with three bodies in the +// loop +// where the loop still iterates over each 4 times. So we make two loops one +// iterating once then a second loop of three iterating 3 times. + +namespace spvtools { +namespace opt { +namespace { + +// This utility class encapsulates some of the state we need to maintain between +// loop unrolls. Specifically it maintains key blocks and the induction variable +// in the current loop duplication step and the blocks from the previous one. +// This is because each step of the unroll needs to use data from both the +// preceding step and the original loop. +struct LoopUnrollState { + LoopUnrollState() + : previous_phi_(nullptr), + previous_continue_block_(nullptr), + previous_condition_block_(nullptr), + new_phi(nullptr), + new_continue_block(nullptr), + new_condition_block(nullptr), + new_header_block(nullptr) {} + + // Initialize from the loop descriptor class. + LoopUnrollState(ir::Instruction* induction, ir::BasicBlock* continue_block, + ir::BasicBlock* condition) + : previous_phi_(induction), + previous_continue_block_(continue_block), + previous_condition_block_(condition), + new_phi(nullptr), + new_continue_block(nullptr), + new_condition_block(nullptr), + new_header_block(nullptr) {} + + // Swap the state so that the new nodes are now the previous nodes. + void NextIterationState() { + previous_phi_ = new_phi; + previous_continue_block_ = new_continue_block; + previous_condition_block_ = new_condition_block; + + // Clear new nodes. + new_phi = nullptr; + new_continue_block = nullptr; + new_condition_block = nullptr; + new_header_block = nullptr; + + // Clear new block/instruction maps. + new_blocks.clear(); + new_inst.clear(); + } + + // The induction variable from the immediately preceding loop body. + ir::Instruction* previous_phi_; + + // The previous continue block. The backedge will be removed from this and + // added to the new continue block. + ir::BasicBlock* previous_continue_block_; + + // The previous condition block. This may be folded to flatten the loop. + ir::BasicBlock* previous_condition_block_; + + // The new induction variable. + ir::Instruction* new_phi; + + // The new continue block. + ir::BasicBlock* new_continue_block; + + // The new condition block. + ir::BasicBlock* new_condition_block; + + // The new header block. + ir::BasicBlock* new_header_block; + + // A mapping of new block ids to the original blocks which they were copied + // from. + std::unordered_map new_blocks; + + // A mapping of new instruction ids to the instruction ids from which they + // were copied. + std::unordered_map new_inst; +}; + +// This class implements the actual unrolling. It uses a LoopUnrollState to +// maintain the state of the unrolling inbetween steps. +class LoopUnrollerUtilsImpl { + public: + using BasicBlockListTy = std::vector>; + + LoopUnrollerUtilsImpl(ir::IRContext* c, ir::Function* function) + : context_(c), + function_(*function), + loop_condition_block_(nullptr), + loop_induction_variable_(nullptr), + number_of_loop_iterations_(0), + loop_step_value_(0), + loop_init_value_(0) {} + + // Unroll the |loop| by given |factor| by copying the whole body |factor| + // times. The resulting basicblock structure will remain a loop. + void PartiallyUnroll(ir::Loop*, size_t factor); + + // If partially unrolling the |loop| would leave the loop with too many bodies + // for its number of iterations then this method should be used. This method + // will duplicate the |loop| completely, making the duplicated loop the + // successor of the original's merge block. The original loop will have its + // condition changed to loop over the residual part and the duplicate will be + // partially unrolled. The resulting structure will be two loops. + void PartiallyUnrollResidualFactor(ir::Loop* loop, size_t factor); + + // Fully unroll the |loop| by copying the full body by the total number of + // loop iterations, folding all conditions, and removing the backedge from the + // continue block to the header. + void FullyUnroll(ir::Loop* loop); + + // Get the ID of the variable in the |phi| paired with |label|. + uint32_t GetPhiDefID(const ir::Instruction* phi, uint32_t label) const; + + // Close the loop by removing the OpLoopMerge from the |loop| header block and + // making the backedge point to the merge block. + void CloseUnrolledLoop(ir::Loop* loop); + + // Remove the OpConditionalBranch instruction inside |conditional_block| used + // to branch to either exit or continue the loop and replace it with an + // unconditional OpBranch to block |new_target|. + void FoldConditionBlock(ir::BasicBlock* condtion_block, uint32_t new_target); + + // Add all blocks_to_add_ to function_ at the |insert_point|. + void AddBlocksToFunction(const ir::BasicBlock* insert_point); + + // Duplicates the |old_loop|, cloning each body and remaping the ids without + // removing instructions or changing relative structure. Result will be stored + // in |new_loop|. + void DuplicateLoop(ir::Loop* old_loop, ir::Loop* new_loop); + + inline size_t GetLoopIterationCount() const { + return number_of_loop_iterations_; + } + + // Extracts the initial state information from the |loop|. + void Init(ir::Loop* loop); + + private: + // Remap all the in |basic_block| to new IDs and keep the mapping of new ids + // to old + // ids. |loop| is used to identify special loop blocks (header, continue, + // ect). + void AssignNewResultIds(ir::BasicBlock* basic_block); + + // Using the map built by AssignNewResultIds, for each instruction in + // |basic_block| use + // that map to substitute the IDs used by instructions (in the operands) with + // the new ids. + void RemapOperands(ir::BasicBlock* basic_block); + + // Copy the whole body of the loop, all blocks dominated by the |loop| header + // and not dominated by the |loop| merge. The copied body will be linked to by + // the old |loop| continue block and the new body will link to the |loop| + // header via the new continue block. |eliminate_conditions| is used to decide + // whether or not to fold all the condition blocks other than the last one. + void CopyBody(ir::Loop* loop, bool eliminate_conditions); + + // Copy a given |block_to_copy| in the |loop| and record the mapping of the + // old/new ids. |preserve_instructions| determines whether or not the method + // will modify (other than result_id) instructions which are copied. + void CopyBasicBlock(ir::Loop* loop, const ir::BasicBlock* block_to_copy, + bool preserve_instructions); + + // The actual implementation of the unroll step. Unrolls |loop| by given + // |factor| by copying the body by |factor| times. Also propagates the + // induction variable value throughout the copies. + void Unroll(ir::Loop* loop, size_t factor); + + // Fills the loop_blocks_inorder_ field with the ordered list of basic blocks + // as computed by the method ComputeLoopOrderedBlocks. + void ComputeLoopOrderedBlocks(ir::Loop* loop); + + // Adds the blocks_to_add_ to both the |loop| and to the parent of |loop| if + // the parent exists. + void AddBlocksToLoop(ir::Loop* loop) const; + + // A pointer to the IRContext. Used to add/remove instructions and for usedef + // chains. + ir::IRContext* context_; + + // A reference the function the loop is within. + ir::Function& function_; + + // A list of basic blocks to be added to the loop at the end of an unroll + // step. + BasicBlockListTy blocks_to_add_; + + // List of instructions which are now dead and can be removed. + std::vector dead_instructions_; + + // Maintains the current state of the transform between calls to unroll. + LoopUnrollState state_; + + // An ordered list containing the loop basic blocks. + std::vector loop_blocks_inorder_; + + // The block containing the condition check which contains a conditional + // branch to the merge and continue block. + ir::BasicBlock* loop_condition_block_; + + // The induction variable of the loop. + ir::Instruction* loop_induction_variable_; + + // The number of loop iterations that the loop would preform pre-unroll. + size_t number_of_loop_iterations_; + + // The amount that the loop steps each iteration. + int64_t loop_step_value_; + + // The value the loop starts stepping from. + int64_t loop_init_value_; +}; + +/* + * Static helper functions. + */ + +// Retrieve the index of the OpPhi instruction |phi| which corresponds to the +// incoming |block| id. +static uint32_t GetPhiIndexFromLabel(const ir::BasicBlock* block, + const ir::Instruction* phi) { + for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) { + if (block->id() == phi->GetSingleWordInOperand(i)) { + return i; + } + } + assert(false && "Could not find operand in instruction."); + return 0; +} + +void LoopUnrollerUtilsImpl::Init(ir::Loop* loop) { + loop_condition_block_ = loop->FindConditionBlock(); + + // When we reinit the second loop during PartiallyUnrollResidualFactor we need + // to use the cached value from the duplicate step as the dominator tree + // basded solution, loop->FindConditionBlock, requires all the nodes to be + // connected up with the correct branches. They won't be at this point. + if (!loop_condition_block_) { + loop_condition_block_ = state_.new_condition_block; + } + assert(loop_condition_block_); + + loop_induction_variable_ = loop->FindInductionVariable(loop_condition_block_); + assert(loop_induction_variable_); + + bool found = loop->FindNumberOfIterations( + loop_induction_variable_, &*loop_condition_block_->ctail(), + &number_of_loop_iterations_, &loop_step_value_, &loop_init_value_); + (void)found; // To silence unused variable warning on release builds. + assert(found); + ComputeLoopOrderedBlocks(loop); +} + +// This function is used to partially unroll the loop when the factor provided +// would normally lead to an illegal optimization. Instead of just unrolling the +// loop it creates two loops and unrolls one and adjusts the condition on the +// other. The end result being that the new loop pair iterates over the correct +// number of bodies. +void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop, + size_t factor) { + // Create a new merge block for the first loop. + std::unique_ptr new_label{new ir::Instruction( + context_, SpvOp::SpvOpLabel, 0, context_->TakeNextId(), {})}; + std::unique_ptr new_exit_bb{ + new ir::BasicBlock(std::move(new_label))}; + + // Save the id of the block before we move it. + uint32_t new_merge_id = new_exit_bb->id(); + + // Add the block the list of blocks to add, we want this merge block to be + // right at the start of the new blocks. + blocks_to_add_.push_back(std::move(new_exit_bb)); + ir::BasicBlock* new_exit_bb_raw = blocks_to_add_[0].get(); + ir::Instruction& original_conditional_branch = *loop_condition_block_->tail(); + + // Duplicate the loop, providing access to the blocks of both loops. + // This is a naked new due to the VS2013 requirement of not having unique + // pointers in vectors, as it will be inserted into a vector with + // loop_descriptor.AddLoop. + ir::Loop* new_loop = new ir::Loop(*loop); + + // Clear the basic blocks of the new loop. + new_loop->ClearBlocks(); + + DuplicateLoop(loop, new_loop); + + // Add the blocks to the function. + AddBlocksToFunction(loop->GetMergeBlock()); + blocks_to_add_.clear(); + + InstructionBuilder builder{context_, new_exit_bb_raw}; + // Make the first loop branch to the second. + builder.AddBranch(new_loop->GetHeaderBlock()->id()); + + loop_condition_block_ = state_.new_condition_block; + loop_induction_variable_ = state_.new_phi; + + // Unroll the new loop by the factor with the usual -1 to account for the + // existing block iteration. + Unroll(new_loop, factor); + + // We need to account for the initial body when calculating the remainder. + int64_t remainder = loop_init_value_ + + (number_of_loop_iterations_ % factor) * loop_step_value_; + + assert(remainder > std::numeric_limits::min() && + remainder < std::numeric_limits::max()); + + ir::Instruction* new_constant = nullptr; + + // If the remainder is negative then we add a signed constant, otherwise just + // add an unsigned constant. + if (remainder < 0) { + new_constant = + builder.Add32BitSignedIntegerConstant(static_cast(remainder)); + } else { + new_constant = builder.Add32BitUnsignedIntegerConstant( + static_cast(remainder)); + } + + uint32_t constant_id = new_constant->result_id(); + + // Add the merge block to the back of the binary. + blocks_to_add_.push_back( + std::unique_ptr(new_loop->GetMergeBlock())); + + AddBlocksToLoop(new_loop); + // Add the blocks to the function. + AddBlocksToFunction(loop->GetMergeBlock()); + + // Reset the usedef analysis. + context_->InvalidateAnalysesExceptFor( + ir::IRContext::Analysis::kAnalysisLoopAnalysis); + opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + + // Update the condition check. + ir::Instruction* condition_check = def_use_manager->GetDef( + original_conditional_branch.GetSingleWordOperand(0)); + + // This should have been checked by the LoopUtils::CanPerformUnroll function + // before entering this. + assert(condition_check->opcode() == SpvOpSLessThan); + condition_check->SetInOperand(1, {constant_id}); + + // Update the next phi node. The phi will have a constant value coming in from + // the preheader block. For the duplicated loop we need to update the constant + // to be the amount of iterations covered by the first loop and the incoming + // block to be the first loops new merge block. + uint32_t phi_incoming_index = + GetPhiIndexFromLabel(loop->GetPreHeaderBlock(), loop_induction_variable_); + loop_induction_variable_->SetInOperand(phi_incoming_index - 1, {constant_id}); + loop_induction_variable_->SetInOperand(phi_incoming_index, {new_merge_id}); + + context_->InvalidateAnalysesExceptFor( + ir::IRContext::Analysis::kAnalysisLoopAnalysis); + + context_->ReplaceAllUsesWith(loop->GetMergeBlock()->id(), new_merge_id); + + ir::LoopDescriptor& loop_descriptor = + *context_->GetLoopDescriptor(&function_); + + loop_descriptor.AddLoop(new_loop, loop->GetParent()); +} + +// Duplicate the |loop| body |factor| number of times while keeping the loop +// backedge intact. +void LoopUnrollerUtilsImpl::PartiallyUnroll(ir::Loop* loop, size_t factor) { + Unroll(loop, factor); + AddBlocksToLoop(loop); + AddBlocksToFunction(loop->GetMergeBlock()); +} + +// Duplicate the |loop| body |factor| number of times while keeping the loop +// backedge intact. +void LoopUnrollerUtilsImpl::Unroll(ir::Loop* loop, size_t factor) { + state_ = LoopUnrollState{loop_induction_variable_, loop->GetLatchBlock(), + loop_condition_block_}; + for (size_t i = 0; i < factor - 1; ++i) { + CopyBody(loop, true); + } + + uint32_t phi_index = GetPhiIndexFromLabel(state_.previous_continue_block_, + state_.previous_phi_); + uint32_t phi_variable = + state_.previous_phi_->GetSingleWordInOperand(phi_index - 1); + uint32_t phi_label = state_.previous_phi_->GetSingleWordInOperand(phi_index); + + ir::Instruction* original_phi = loop_induction_variable_; + + // SetInOperands are offset by two. + original_phi->SetInOperand(phi_index - 1, {phi_variable}); + original_phi->SetInOperand(phi_index, {phi_label}); +} + +// Fully unroll the loop by partially unrolling it by the number of loop +// iterations minus one for the body already accounted for. +void LoopUnrollerUtilsImpl::FullyUnroll(ir::Loop* loop) { + // We unroll the loop by number of iterations in the loop. + Unroll(loop, number_of_loop_iterations_); + + // The first condition block is preserved until now so it can be copied. + FoldConditionBlock(loop_condition_block_, 1); + + // Delete the OpLoopMerge and remove the backedge to the header. + CloseUnrolledLoop(loop); + + // Mark the loop for later deletion. This allows us to preserve the loop + // iterators but still disregard dead loops. + loop->MarkLoopForRemoval(); + + // If the loop has a parent add the new blocks to the parent. + if (loop->GetParent()) { + AddBlocksToLoop(loop->GetParent()); + } + + // Add the blocks to the function. + AddBlocksToFunction(loop->GetMergeBlock()); + + // Invalidate all analyses. + context_->InvalidateAnalysesExceptFor( + ir::IRContext::Analysis::kAnalysisLoopAnalysis); +} + +// Copy a given basic block, give it a new result_id, and store the new block +// and the id mapping in the state. |preserve_instructions| is used to determine +// whether or not this function should edit instructions other than the +// |result_id|. +void LoopUnrollerUtilsImpl::CopyBasicBlock(ir::Loop* loop, + const ir::BasicBlock* itr, + bool preserve_instructions) { + // Clone the block exactly, including the IDs. + ir::BasicBlock* basic_block = itr->Clone(context_); + + basic_block->SetParent(itr->GetParent()); + + // Assign each result a new unique ID and keep a mapping of the old ids to + // the new ones. + AssignNewResultIds(basic_block); + + // If this is the continue block we are copying. + if (itr == loop->GetLatchBlock()) { + // Make the OpLoopMerge point to this block for the continue. + if (!preserve_instructions) { + ir::Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst(); + merge_inst->SetInOperand(1, {basic_block->id()}); + } + + state_.new_continue_block = basic_block; + } + + // If this is the header block we are copying. + if (itr == loop->GetHeaderBlock()) { + state_.new_header_block = basic_block; + + if (!preserve_instructions) { + // Remove the loop merge instruction if it exists. + ir::Instruction* merge_inst = basic_block->GetLoopMergeInst(); + if (merge_inst) context_->KillInst(merge_inst); + } + } + + // If this is the condition block we are copying. + if (itr == loop_condition_block_) { + state_.new_condition_block = basic_block; + } + + // Add this block to the list of blocks to add to the function at the end of + // the unrolling process. + blocks_to_add_.push_back(std::unique_ptr(basic_block)); + + // Keep tracking the old block via a map. + state_.new_blocks[itr->id()] = basic_block; +} + +void LoopUnrollerUtilsImpl::CopyBody(ir::Loop* loop, + bool eliminate_conditions) { + // Copy each basic block in the loop, give them new ids, and save state + // information. + for (const ir::BasicBlock* itr : loop_blocks_inorder_) { + CopyBasicBlock(loop, itr, false); + } + + // Set the previous continue block to point to the new header. + ir::Instruction& continue_branch = *state_.previous_continue_block_->tail(); + continue_branch.SetInOperand(0, {state_.new_header_block->id()}); + + // As the algorithm copies the original loop blocks exactly, the tail of the + // latch block on iterations after the first one will be a branch to the new + // header and not the actual loop header. The last continue block in the loop + // should always be a backedge to the global header. + ir::Instruction& new_continue_branch = *state_.new_continue_block->tail(); + new_continue_branch.SetInOperand(0, {loop->GetHeaderBlock()->id()}); + + // Update references to the old phi node with the actual variable. + const ir::Instruction* induction = loop_induction_variable_; + state_.new_inst[induction->result_id()] = + GetPhiDefID(state_.previous_phi_, state_.previous_continue_block_->id()); + + if (eliminate_conditions && + state_.new_condition_block != loop_condition_block_) { + FoldConditionBlock(state_.new_condition_block, 1); + } + + // Only reference to the header block is the backedge in the latch block, + // don't change this. + state_.new_inst[loop->GetHeaderBlock()->id()] = loop->GetHeaderBlock()->id(); + + for (auto& pair : state_.new_blocks) { + RemapOperands(pair.second); + } + + dead_instructions_.push_back(state_.new_phi); + + // Swap the state so the new is now the previous. + state_.NextIterationState(); +} + +uint32_t LoopUnrollerUtilsImpl::GetPhiDefID(const ir::Instruction* phi, + uint32_t label) const { + for (uint32_t operand = 3; operand < phi->NumOperands(); operand += 2) { + if (phi->GetSingleWordOperand(operand) == label) { + return phi->GetSingleWordOperand(operand - 1); + } + } + + return 0; +} + +void LoopUnrollerUtilsImpl::FoldConditionBlock(ir::BasicBlock* condition_block, + uint32_t operand_label) { + // Remove the old conditional branch to the merge and continue blocks. + ir::Instruction& old_branch = *condition_block->tail(); + uint32_t new_target = old_branch.GetSingleWordOperand(operand_label); + context_->KillInst(&old_branch); + + // Add the new unconditional branch to the merge block. + InstructionBuilder builder{context_, condition_block}; + builder.AddBranch(new_target); +} + +void LoopUnrollerUtilsImpl::CloseUnrolledLoop(ir::Loop* loop) { + // Remove the OpLoopMerge instruction from the function. + ir::Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst(); + context_->KillInst(merge_inst); + + // Remove the final backedge to the header and make it point instead to the + // merge block. + state_.previous_continue_block_->tail()->SetInOperand( + 0, {loop->GetMergeBlock()->id()}); + + // Remove the induction variable as the phi will now be invalid. Replace all + // uses with the constant initializer value (all uses of the phi will be in + // the first iteration with the subsequent phis already having been removed. + uint32_t initalizer_id = + GetPhiDefID(loop_induction_variable_, loop->GetPreHeaderBlock()->id()); + context_->ReplaceAllUsesWith(loop_induction_variable_->result_id(), + initalizer_id); + + // Remove the now unused phi. + context_->KillInst(loop_induction_variable_); +} + +// Uses the first loop to create a copy of the loop with new IDs. +void LoopUnrollerUtilsImpl::DuplicateLoop(ir::Loop* old_loop, + ir::Loop* new_loop) { + std::vector new_block_order; + + // Copy every block in the old loop. + for (const ir::BasicBlock* itr : loop_blocks_inorder_) { + CopyBasicBlock(old_loop, itr, true); + new_block_order.push_back(blocks_to_add_.back().get()); + } + + ir::BasicBlock* new_merge = old_loop->GetMergeBlock()->Clone(context_); + new_merge->SetParent(old_loop->GetMergeBlock()->GetParent()); + AssignNewResultIds(new_merge); + state_.new_blocks[old_loop->GetMergeBlock()->id()] = new_merge; + for (auto& pair : state_.new_blocks) { + RemapOperands(pair.second); + } + + loop_blocks_inorder_ = std::move(new_block_order); + + AddBlocksToLoop(new_loop); + + new_loop->SetHeaderBlock(state_.new_header_block); + new_loop->SetLatchBlock(state_.new_continue_block); + new_loop->SetMergeBlock(new_merge); +} + +void LoopUnrollerUtilsImpl::AddBlocksToFunction( + const ir::BasicBlock* insert_point) { + for (ir::Instruction* inst : dead_instructions_) { + context_->KillInst(inst); + } + + for (auto basic_block_iterator = function_.begin(); + basic_block_iterator != function_.end(); ++basic_block_iterator) { + if (basic_block_iterator->id() == insert_point->id()) { + basic_block_iterator.InsertBefore(&blocks_to_add_); + return; + } + } + + assert( + false && + "Could not add basic blocks to function as insert point was not found."); +} + +// Assign all result_ids in |basic_block| instructions to new IDs and preserve +// the mapping of new ids to old ones. +void LoopUnrollerUtilsImpl::AssignNewResultIds(ir::BasicBlock* basic_block) { + // Label instructions aren't covered by normal traversal of the + // instructions. + uint32_t new_label_id = context_->TakeNextId(); + + // Assign a new id to the label. + state_.new_inst[basic_block->GetLabelInst()->result_id()] = new_label_id; + basic_block->GetLabelInst()->SetResultId(new_label_id); + + for (ir::Instruction& inst : *basic_block) { + uint32_t old_id = inst.result_id(); + + // Ignore stores etc. + if (old_id == 0) { + continue; + } + + // Give the instruction a new id. + inst.SetResultId(context_->TakeNextId()); + + // Save the mapping of old_id -> new_id. + state_.new_inst[old_id] = inst.result_id(); + + // Check if this instruction is the induction variable. + if (loop_induction_variable_->result_id() == old_id) { + // Save a pointer to the new copy of it. + state_.new_phi = &inst; + } + } +} + +// For all instructions in |basic_block| check if the operands used are from a +// copied instruction and if so swap out the operand for the copy of it. +void LoopUnrollerUtilsImpl::RemapOperands(ir::BasicBlock* basic_block) { + for (ir::Instruction& inst : *basic_block) { + auto remap_operands_to_new_ids = [this](uint32_t* id) { + auto itr = state_.new_inst.find(*id); + if (itr != state_.new_inst.end()) { + *id = itr->second; + } + }; + + inst.ForEachInId(remap_operands_to_new_ids); + } +} + +// Generate the ordered list of basic blocks in the |loop| and cache it for +// later use. +void LoopUnrollerUtilsImpl::ComputeLoopOrderedBlocks(ir::Loop* loop) { + loop_blocks_inorder_.clear(); + + opt::DominatorAnalysis* analysis = + context_->GetDominatorAnalysis(&function_, *context_->cfg()); + opt::DominatorTree& tree = analysis->GetDomTree(); + + // Starting at the loop header BasicBlock, traverse the dominator tree until + // we reach the merge block and add every node we traverse to the set of + // blocks + // which we consider to be the loop. + auto begin_itr = tree.GetTreeNode(loop->GetHeaderBlock())->df_begin(); + const ir::BasicBlock* merge = loop->GetMergeBlock(); + auto func = [merge, &tree, this](DominatorTreeNode* node) { + if (!tree.Dominates(merge->id(), node->id())) { + this->loop_blocks_inorder_.push_back(node->bb_); + return true; + } + return false; + }; + + tree.VisitChildrenIf(func, begin_itr); +} + +// Adds the blocks_to_add_ to both the loop and to the parent. +void LoopUnrollerUtilsImpl::AddBlocksToLoop(ir::Loop* loop) const { + // Add the blocks to this loop. + for (auto& block_itr : blocks_to_add_) { + loop->AddBasicBlock(block_itr.get()); + } + + // Add the blocks to the parent as well. + if (loop->GetParent()) AddBlocksToLoop(loop->GetParent()); +} + +/* + * End LoopUtilsImpl. + */ + +} // namespace + +/* + * + * Begin Utils. + * + * */ + +bool LoopUtils::CanPerformUnroll() { + // The loop is expected to be in structured order. + if (!loop_->GetHeaderBlock()->GetMergeInst()) { + return false; + } + + // Find check the loop has a condition we can find and evaluate. + const ir::BasicBlock* condition = loop_->FindConditionBlock(); + if (!condition) return false; + + // Check that we can find and process the induction variable. + const ir::Instruction* induction = loop_->FindInductionVariable(condition); + if (!induction || induction->opcode() != SpvOpPhi) return false; + + // Check that we can find the number of loop iterations. + if (!loop_->FindNumberOfIterations(induction, &*condition->ctail(), nullptr)) + return false; + + // Make sure the continue block is a unconditional branch to the header + // block. + const ir::Instruction& branch = *loop_->GetLatchBlock()->ctail(); + bool branching_assumption = + branch.opcode() == SpvOpBranch && + branch.GetSingleWordInOperand(0) == loop_->GetHeaderBlock()->id(); + if (!branching_assumption) { + return false; + } + + // Make sure the induction is the only phi instruction we have in the loop + // header. Other optimizations have been seen to leave dead phi nodes in the + // header so we also check that the phi is used. + for (const ir::Instruction& inst : *loop_->GetHeaderBlock()) { + if (inst.opcode() == SpvOpPhi && + inst.result_id() != induction->result_id()) { + return false; + } + } + + // Ban breaks within the loop. + const std::vector& merge_block_preds = + context_->cfg()->preds(loop_->GetMergeBlock()->id()); + if (merge_block_preds.size() != 1) { + return false; + } + + // Ban continues within the loop. + const std::vector& continue_block_preds = + context_->cfg()->preds(loop_->GetLatchBlock()->id()); + if (continue_block_preds.size() != 1) { + return false; + } + + // Ban returns in the loop. + // Iterate over all the blocks within the loop and check that none of them + // exit the loop. + for (uint32_t label_id : loop_->GetBlocks()) { + const ir::BasicBlock* block = context_->cfg()->block(label_id); + if (block->ctail()->opcode() == SpvOp::SpvOpKill || + block->ctail()->opcode() == SpvOp::SpvOpReturn || + block->ctail()->opcode() == SpvOp::SpvOpReturnValue) { + return false; + } + } + // Can only unroll inner loops. + if (!loop_->AreAllChildrenMarkedForRemoval()) { + return false; + } + + for (uint32_t block_id : loop_->GetBlocks()) { + opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr(); + + ir::BasicBlock& bb = *context_->cfg()->block(block_id); + // For every instruction in the block. + for (ir::Instruction& inst : bb) { + if (inst.result_id() == 0) continue; + + auto is_used_outside_loop = [this, + def_use_manager](ir::Instruction* user) { + + if (!loop_->IsInsideLoop(user)) { + // Some optimization passes have been seen to leave dead phis in the + // IR so we check that if a phi is used outside of the loop that the + // user is not dead. + if (!(user->opcode() == SpvOpPhi && + def_use_manager->NumUsers(user) == 0)) + return false; + } + return true; + }; + + if (!def_use_manager->WhileEachUser(&inst, is_used_outside_loop)) { + return false; + } + } + } + return true; +} + +bool LoopUtils::PartiallyUnroll(size_t factor) { + if (factor == 1 || !CanPerformUnroll()) return false; + + // Create the unroller utility. + LoopUnrollerUtilsImpl unroller{context_, + loop_->GetHeaderBlock()->GetParent()}; + unroller.Init(loop_); + + // If the unrolling factor is larger than or the same size as the loop just + // fully unroll the loop. + if (factor >= unroller.GetLoopIterationCount()) { + unroller.FullyUnroll(loop_); + return true; + } + + // If the loop unrolling factor is an residual number of iterations we need to + // let run the loop for the residual part then let it branch into the unrolled + // remaining part. We add one when calucating the remainder to take into + // account the one iteration already in the loop. + if (unroller.GetLoopIterationCount() % factor != 0) { + unroller.PartiallyUnrollResidualFactor(loop_, factor); + } else { + unroller.PartiallyUnroll(loop_, factor); + } + + return true; +} + +bool LoopUtils::FullyUnroll() { + if (!CanPerformUnroll()) return false; + + LoopUnrollerUtilsImpl unroller{context_, + loop_->GetHeaderBlock()->GetParent()}; + + unroller.Init(loop_); + unroller.FullyUnroll(loop_); + + return true; +} + +void LoopUtils::Finalize() { + // Clean up the loop descriptor to preserve the analysis. + + ir::LoopDescriptor* LD = context_->GetLoopDescriptor(&function_); + LD->PostModificationCleanup(); +} + +/* + * + * Begin Pass. + * + */ + +Pass::Status LoopUnroller::Process(ir::IRContext* c) { + context_ = c; + bool changed = false; + for (ir::Function& f : *c->module()) { + ir::LoopDescriptor* LD = context_->GetLoopDescriptor(&f); + for (ir::Loop& loop : *LD) { + LoopUtils loop_utils{c, &loop}; + if (!loop.HasUnrollLoopControl() || !loop_utils.CanPerformUnroll()) { + continue; + } + + loop_utils.FullyUnroll(); + changed = true; + } + LD->PostModificationCleanup(); + } + + return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/loop_unroller.h b/source/opt/loop_unroller.h new file mode 100644 index 00000000..38dfa347 --- /dev/null +++ b/source/opt/loop_unroller.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_LOOP_UNROLLER_H_ +#define SOURCE_OPT_LOOP_UNROLLER_H_ +#include "opt/pass.h" + +namespace spvtools { +namespace opt { + +class LoopUnroller : public Pass { + public: + LoopUnroller() : Pass() {} + + const char* name() const override { return "Loop unroller"; } + + Status Process(ir::IRContext* context) override; + + private: + ir::IRContext* context_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_LOOP_UNROLLER_H_ diff --git a/source/opt/loop_utils.h b/source/opt/loop_utils.h index 65a54312..3eb20de0 100644 --- a/source/opt/loop_utils.h +++ b/source/opt/loop_utils.h @@ -12,8 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef LIBSPIRV_OPT_LOOP_UTILS_H_ -#define LIBSPIRV_OPT_LOOP_UTILS_H_ +#ifndef SOURCE_OPT_LOOP_UTILS_H_ +#define SOURCE_OPT_LOOP_UTILS_H_ +#include +#include +#include +#include "opt/loop_descriptor.h" namespace spvtools { @@ -24,11 +28,15 @@ class IRContext; namespace opt { -// Set of basic loop transformation. +// LoopUtils is used to encapsulte loop optimizations and from the passes which +// use them. Any pass which needs a loop optimization should do it through this +// or through a pass which is using this. class LoopUtils { public: LoopUtils(ir::IRContext* context, ir::Loop* loop) - : context_(context), loop_(loop) {} + : context_(context), + loop_(loop), + function_(*loop_->GetHeaderBlock()->GetParent()) {} // The converts the current loop to loop closed SSA form. // In the loop closed SSA, all loop exiting values go through a dedicated Phi @@ -64,12 +72,42 @@ class LoopUtils { // Preserves: CFG, def/use and instruction to block mapping. void CreateLoopDedicatedExits(); + // Perfom a partial unroll of |loop| by given |factor|. This will copy the + // body of the loop |factor| times. So a |factor| of one would give a new loop + // with the original body plus one unrolled copy body. + bool PartiallyUnroll(size_t factor); + + // Fully unroll |loop|. + bool FullyUnroll(); + + // This function validates that |loop| meets the assumptions made by the + // implementation of the loop unroller. As the implementation accommodates + // more types of loops this function can reduce its checks. + // + // The conditions checked to ensure the loop can be unrolled are as follows: + // 1. That the loop is in structured order. + // 2. That the condinue block is a branch to the header. + // 3. That the only phi used in the loop is the induction variable. + // TODO(stephen@codeplay.com): This is a temporary mesure, after the loop is + // converted into LCSAA form and has a single entry and exit we can rewrite + // the other phis. + // 4. That this is an inner most loop, or that loops contained within this + // loop have already been fully unrolled. + // 5. That each instruction in the loop is only used within the loop. + // (Related to the above phi condition). + bool CanPerformUnroll(); + + // Maintains the loop descriptor object after the unroll functions have been + // called, otherwise the analysis should be invalidated. + void Finalize(); + private: ir::IRContext* context_; ir::Loop* loop_; + ir::Function& function_; }; } // namespace opt } // namespace spvtools -#endif // LIBSPIRV_OPT_LOOP_UTILS_H_ +#endif // SOURCE_OPT_LOOP_UTILS_H_ diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 9f36f19f..8aed7bc1 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -389,4 +389,9 @@ Optimizer::PassToken CreateSimplificationPass() { return MakeUnique( MakeUnique()); } + +Optimizer::PassToken CreateLoopFullyUnrollPass() { + return MakeUnique( + MakeUnique()); +} } // namespace spvtools diff --git a/source/opt/passes.h b/source/opt/passes.h index 890e16e5..9fb98aaf 100644 --- a/source/opt/passes.h +++ b/source/opt/passes.h @@ -41,6 +41,7 @@ #include "local_single_block_elim_pass.h" #include "local_single_store_elim_pass.h" #include "local_ssa_elim_pass.h" +#include "loop_unroller.h" #include "merge_return_pass.h" #include "null_pass.h" #include "private_to_local_pass.h" @@ -53,5 +54,4 @@ #include "strip_debug_info_pass.h" #include "unify_const_pass.h" #include "workaround1209.h" - #endif // LIBSPIRV_OPT_PASSES_H_ diff --git a/test/opt/loop_optimizations/CMakeLists.txt b/test/opt/loop_optimizations/CMakeLists.txt index 947f5c54..a9cb4991 100644 --- a/test/opt/loop_optimizations/CMakeLists.txt +++ b/test/opt/loop_optimizations/CMakeLists.txt @@ -66,3 +66,16 @@ add_spvtools_unittest(TARGET licm_hoist_no_preheader hoist_without_preheader.cpp LIBS SPIRV-Tools-opt ) + +add_spvtools_unittest(TARGET loop_unroll_simple + SRCS ../function_utils.h + unroll_simple.cpp + LIBS SPIRV-Tools-opt +) + +add_spvtools_unittest(TARGET loop_unroll_assumtion_checks + SRCS ../function_utils.h + unroll_assumptions.cpp + LIBS SPIRV-Tools-opt +) + diff --git a/test/opt/loop_optimizations/unroll_assumptions.cpp b/test/opt/loop_optimizations/unroll_assumptions.cpp new file mode 100644 index 00000000..e3ff1ee4 --- /dev/null +++ b/test/opt/loop_optimizations/unroll_assumptions.cpp @@ -0,0 +1,627 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +#include "../assembly_builder.h" +#include "../function_utils.h" +#include "../pass_fixture.h" +#include "../pass_utils.h" +#include "opt/loop_unroller.h" +#include "opt/loop_utils.h" +#include "opt/pass.h" + +namespace { + +using namespace spvtools; +using ::testing::UnorderedElementsAre; + +template +class PartialUnrollerTestPass : public opt::Pass { + public: + PartialUnrollerTestPass() : Pass() {} + + const char* name() const override { return "Loop unroller"; } + + Status Process(ir::IRContext* context) override { + bool changed = false; + for (ir::Function& f : *context->module()) { + ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(&f); + for (auto& loop : loop_descriptor) { + opt::LoopUtils loop_utils{context, &loop}; + if (loop_utils.PartiallyUnroll(factor)) { + changed = true; + } + } + } + + if (changed) return Pass::Status::SuccessWithChange; + return Pass::Status::SuccessWithoutChange; + } +}; + +using PassClassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL +#version 410 core +layout(location = 0) flat in int in_upper_bound; +void main() { + for (int i = ; i < in_upper_bound; ++i) { + x[i] = 1.0f; + } +} +*/ +TEST_F(PassClassTest, CheckUpperBound) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "in_upper_bound" +OpName %4 "x" +OpDecorate %3 Flat +OpDecorate %3 Location 0 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpTypeInt 32 1 +%8 = OpTypePointer Function %7 +%9 = OpConstant %7 0 +%10 = OpTypePointer Input %7 +%3 = OpVariable %10 Input +%11 = OpTypeBool +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpConstant %12 1 +%18 = OpTypePointer Function %12 +%19 = OpConstant %7 1 +%2 = OpFunction %5 None %6 +%20 = OpLabel +%4 = OpVariable %16 Function +OpBranch %21 +%21 = OpLabel +%22 = OpPhi %7 %9 %20 %23 %24 +OpLoopMerge %25 %24 None +OpBranch %26 +%26 = OpLabel +%27 = OpLoad %7 %3 +%28 = OpSLessThan %11 %22 %27 +OpBranchConditional %28 %29 %25 +%29 = OpLabel +%30 = OpAccessChain %18 %4 %22 +OpStore %30 %17 +OpBranch %24 +%24 = OpLabel +%23 = OpIAdd %7 %22 %19 +OpBranch %21 +%25 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[10]; + int i = 0; + for (int i = 0; i < 10; ++i) { + out_array[i] = i; + } + out_array[9] = i*10; +} +*/ +TEST_F(PassClassTest, InductionUsedOutsideOfLoop) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 10 +%10 = OpTypeBool +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 10 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %6 1 +%18 = OpConstant %6 9 +%2 = OpFunction %4 None %5 +%19 = OpLabel +%3 = OpVariable %15 Function +OpBranch %20 +%20 = OpLabel +%21 = OpPhi %6 %8 %19 %22 %23 +OpLoopMerge %24 %23 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %10 %21 %9 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%28 = OpConvertSToF %11 %21 +%29 = OpAccessChain %16 %3 %21 +OpStore %29 %28 +OpBranch %23 +%23 = OpLabel +%22 = OpIAdd %6 %21 %17 +OpBranch %20 +%24 = OpLabel +%30 = OpIMul %6 %21 %9 +%31 = OpConvertSToF %11 %30 +%32 = OpAccessChain %16 %3 %18 +OpStore %32 %31 +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + // Make sure the pass doesn't run + SinglePassRunAndCheck(text, text, false); + SinglePassRunAndCheck>(text, text, false); + SinglePassRunAndCheck>(text, text, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[10]; + for (uint i = 0; i < 2; i++) { + for (float x = 0; x < 5; ++x) { + out_array[x + i*5] = i; + } + } +} +*/ +TEST_F(PassClassTest, UnrollNestedLoopsInvalid) { + // clang-format off + // With opt::LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 0 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 2 +%10 = OpTypeBool +%11 = OpTypeInt 32 1 +%12 = OpTypePointer Function %11 +%13 = OpConstant %11 0 +%14 = OpConstant %11 5 +%15 = OpTypeFloat 32 +%16 = OpConstant %6 10 +%17 = OpTypeArray %15 %16 +%18 = OpTypePointer Function %17 +%19 = OpConstant %6 5 +%20 = OpTypePointer Function %15 +%21 = OpConstant %11 1 +%22 = OpUndef %11 +%2 = OpFunction %4 None %5 +%23 = OpLabel +%3 = OpVariable %18 Function +OpBranch %24 +%24 = OpLabel +%25 = OpPhi %6 %8 %23 %26 %27 +%28 = OpPhi %11 %22 %23 %29 %27 +OpLoopMerge %30 %27 None +OpBranch %31 +%31 = OpLabel +%32 = OpULessThan %10 %25 %9 +OpBranchConditional %32 %33 %30 +%33 = OpLabel +OpBranch %34 +%34 = OpLabel +%29 = OpPhi %11 %13 %33 %35 %36 +OpLoopMerge %37 %36 None +OpBranch %38 +%38 = OpLabel +%39 = OpSLessThan %10 %29 %14 +OpBranchConditional %39 %40 %37 +%40 = OpLabel +%41 = OpBitcast %6 %29 +%42 = OpIMul %6 %25 %19 +%43 = OpIAdd %6 %41 %42 +%44 = OpConvertUToF %15 %25 +%45 = OpAccessChain %20 %3 %43 +OpStore %45 %44 +OpBranch %36 +%36 = OpLabel +%35 = OpIAdd %11 %29 %21 +OpBranch %34 +%37 = OpLabel +OpBranch %27 +%27 = OpLabel +%26 = OpIAdd %6 %25 %21 +OpBranch %24 +%30 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, text, false); +} + + +/* +Generated from the following GLSL +#version 440 core +void main(){ + float x[10]; + int ind = 0; + for (int i = 0; i < 10; i++) { + ind = i; + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, MultiplePhiInHeader) { + // clang-format off + // With opt::LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +OpName %3 "x" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 10 +%10 = OpTypeBool +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 10 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%18 = OpLabel +%3 = OpVariable %15 Function +OpBranch %19 +%19 = OpLabel +%20 = OpPhi %6 %8 %18 %21 %22 +%21 = OpPhi %6 %8 %18 %23 %22 +OpLoopMerge %24 %22 None +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %10 %21 %9 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%28 = OpConvertSToF %11 %21 +%29 = OpAccessChain %16 %3 %21 +OpStore %29 %28 +OpBranch %22 +%22 = OpLabel +%23 = OpIAdd %6 %21 %17 +OpBranch %19 +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, text, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main(){ + float x[10]; + for (int i = 0; i < 10; i++) { + if (i == 5) { + break; + } + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, BreakInBody) { + // clang-format off + // With opt::LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +OpName %3 "x" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 10 +%10 = OpTypeBool +%11 = OpConstant %6 5 +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpTypePointer Function %12 +%18 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%19 = OpLabel +%3 = OpVariable %16 Function +OpBranch %20 +%20 = OpLabel +%21 = OpPhi %6 %8 %19 %22 %23 +OpLoopMerge %24 %23 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %10 %21 %9 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%28 = OpIEqual %10 %21 %11 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +OpBranch %24 +%29 = OpLabel +%31 = OpConvertSToF %12 %21 +%32 = OpAccessChain %17 %3 %21 +OpStore %32 %31 +OpBranch %23 +%23 = OpLabel +%22 = OpIAdd %6 %21 %18 +OpBranch %20 +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, text, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main(){ + float x[10]; + for (int i = 0; i < 10; i++) { + if (i == 5) { + continue; + } + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, ContinueInBody) { + // clang-format off + // With opt::LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +OpName %3 "x" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 10 +%10 = OpTypeBool +%11 = OpConstant %6 5 +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpTypePointer Function %12 +%18 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%19 = OpLabel +%3 = OpVariable %16 Function +OpBranch %20 +%20 = OpLabel +%21 = OpPhi %6 %8 %19 %22 %23 +OpLoopMerge %24 %23 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %10 %21 %9 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%28 = OpIEqual %10 %21 %11 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +OpBranch %23 +%29 = OpLabel +%31 = OpConvertSToF %12 %21 +%32 = OpAccessChain %17 %3 %21 +OpStore %32 %31 +OpBranch %23 +%23 = OpLabel +%22 = OpIAdd %6 %21 %18 +OpBranch %20 +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, text, false); +} + +/* +Generated from the following GLSL +#version 440 core +void main(){ + float x[10]; + for (int i = 0; i < 10; i++) { + if (i == 5) { + return; + } + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, ReturnInBody) { + // clang-format off + // With opt::LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 440 +OpName %2 "main" +OpName %3 "x" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 10 +%10 = OpTypeBool +%11 = OpConstant %6 5 +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpTypePointer Function %12 +%18 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%19 = OpLabel +%3 = OpVariable %16 Function +OpBranch %20 +%20 = OpLabel +%21 = OpPhi %6 %8 %19 %22 %23 +OpLoopMerge %24 %23 Unroll +OpBranch %25 +%25 = OpLabel +%26 = OpSLessThan %10 %21 %9 +OpBranchConditional %26 %27 %24 +%27 = OpLabel +%28 = OpIEqual %10 %21 %11 +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %29 +%30 = OpLabel +OpReturn +%29 = OpLabel +%31 = OpConvertSToF %12 %21 +%32 = OpAccessChain %17 %3 %21 +OpStore %32 %31 +OpBranch %23 +%23 = OpLabel +%22 = OpIAdd %6 %21 %18 +OpBranch %20 +%24 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, text, false); +} + +} // namespace diff --git a/test/opt/loop_optimizations/unroll_simple.cpp b/test/opt/loop_optimizations/unroll_simple.cpp new file mode 100644 index 00000000..59d2f95b --- /dev/null +++ b/test/opt/loop_optimizations/unroll_simple.cpp @@ -0,0 +1,2179 @@ +// Copyright (c) 2018 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +#include "../assembly_builder.h" +#include "../function_utils.h" +#include "../pass_fixture.h" +#include "../pass_utils.h" +#include "opt/loop_unroller.h" +#include "opt/loop_utils.h" +#include "opt/pass.h" + +namespace { + +using namespace spvtools; +using ::testing::UnorderedElementsAre; + +using PassClassTest = PassTest<::testing::Test>; + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + float x[4]; + for (int i = 0; i < 4; ++i) { + x[i] = 1.0f; + } +} +*/ +TEST_F(PassClassTest, SimpleFullyUnrollTest) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %5 "x" + OpName %3 "c" + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 4 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 4 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Function %16 + %18 = OpConstant %13 1 + %19 = OpTypePointer Function %13 + %20 = OpConstant %8 1 + %21 = OpTypeVector %13 4 + %22 = OpTypePointer Output %21 + %3 = OpVariable %22 Output + %2 = OpFunction %6 None %7 + %23 = OpLabel + %5 = OpVariable %17 Function + OpBranch %24 + %24 = OpLabel + %35 = OpPhi %8 %10 %23 %34 %26 + OpLoopMerge %25 %26 Unroll + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %12 %35 %11 + OpBranchConditional %29 %30 %25 + %30 = OpLabel + %32 = OpAccessChain %19 %5 %35 + OpStore %32 %18 + OpBranch %26 + %26 = OpLabel + %34 = OpIAdd %8 %35 %20 + OpBranch %24 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + +const std::string output = +R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 330 +OpName %2 "main" +OpName %4 "x" +OpName %3 "c" +OpDecorate %3 Location 0 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpTypeInt 32 1 +%8 = OpTypePointer Function %7 +%9 = OpConstant %7 0 +%10 = OpConstant %7 4 +%11 = OpTypeBool +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 4 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpConstant %12 1 +%18 = OpTypePointer Function %12 +%19 = OpConstant %7 1 +%20 = OpTypeVector %12 4 +%21 = OpTypePointer Output %20 +%3 = OpVariable %21 Output +%2 = OpFunction %5 None %6 +%22 = OpLabel +%4 = OpVariable %16 Function +OpBranch %23 +%23 = OpLabel +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %11 %9 %10 +OpBranch %30 +%30 = OpLabel +%31 = OpAccessChain %18 %4 %9 +OpStore %31 %17 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %7 %9 %19 +OpBranch %32 +%32 = OpLabel +OpBranch %34 +%34 = OpLabel +%35 = OpSLessThan %11 %25 %10 +OpBranch %36 +%36 = OpLabel +%37 = OpAccessChain %18 %4 %25 +OpStore %37 %17 +OpBranch %38 +%38 = OpLabel +%39 = OpIAdd %7 %25 %19 +OpBranch %40 +%40 = OpLabel +OpBranch %42 +%42 = OpLabel +%43 = OpSLessThan %11 %39 %10 +OpBranch %44 +%44 = OpLabel +%45 = OpAccessChain %18 %4 %39 +OpStore %45 %17 +OpBranch %46 +%46 = OpLabel +%47 = OpIAdd %7 %39 %19 +OpBranch %48 +%48 = OpLabel +OpBranch %50 +%50 = OpLabel +%51 = OpSLessThan %11 %47 %10 +OpBranch %52 +%52 = OpLabel +%53 = OpAccessChain %18 %4 %47 +OpStore %53 %17 +OpBranch %54 +%54 = OpLabel +%55 = OpIAdd %7 %47 %19 +OpBranch %27 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +template +class PartialUnrollerTestPass : public opt::Pass { + public: + PartialUnrollerTestPass() : Pass() {} + + const char* name() const override { return "Loop unroller"; } + + Status Process(ir::IRContext* context) override { + for (ir::Function& f : *context->module()) { + ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(&f); + for (auto& loop : loop_descriptor) { + opt::LoopUtils loop_utils{context, &loop}; + loop_utils.PartiallyUnroll(factor); + } + } + + return Pass::Status::SuccessWithChange; + } +}; + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + float x[10]; + for (int i = 0; i < 10; ++i) { + x[i] = 1.0f; + } +} +*/ +TEST_F(PassClassTest, SimplePartialUnroll) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %5 "x" + OpName %3 "c" + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 10 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Function %16 + %18 = OpConstant %13 1 + %19 = OpTypePointer Function %13 + %20 = OpConstant %8 1 + %21 = OpTypeVector %13 4 + %22 = OpTypePointer Output %21 + %3 = OpVariable %22 Output + %2 = OpFunction %6 None %7 + %23 = OpLabel + %5 = OpVariable %17 Function + OpBranch %24 + %24 = OpLabel + %35 = OpPhi %8 %10 %23 %34 %26 + OpLoopMerge %25 %26 Unroll + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %12 %35 %11 + OpBranchConditional %29 %30 %25 + %30 = OpLabel + %32 = OpAccessChain %19 %5 %35 + OpStore %32 %18 + OpBranch %26 + %26 = OpLabel + %34 = OpIAdd %8 %35 %20 + OpBranch %24 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const std::string output = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 330 +OpName %2 "main" +OpName %4 "x" +OpName %3 "c" +OpDecorate %3 Location 0 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpTypeInt 32 1 +%8 = OpTypePointer Function %7 +%9 = OpConstant %7 0 +%10 = OpConstant %7 10 +%11 = OpTypeBool +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpConstant %12 1 +%18 = OpTypePointer Function %12 +%19 = OpConstant %7 1 +%20 = OpTypeVector %12 4 +%21 = OpTypePointer Output %20 +%3 = OpVariable %21 Output +%2 = OpFunction %5 None %6 +%22 = OpLabel +%4 = OpVariable %16 Function +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %7 %9 %22 %39 %38 +OpLoopMerge %27 %38 Unroll +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %11 %24 %10 +OpBranchConditional %29 %30 %27 +%30 = OpLabel +%31 = OpAccessChain %18 %4 %24 +OpStore %31 %17 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %7 %24 %19 +OpBranch %32 +%32 = OpLabel +OpBranch %34 +%34 = OpLabel +%35 = OpSLessThan %11 %25 %10 +OpBranch %36 +%36 = OpLabel +%37 = OpAccessChain %18 %4 %25 +OpStore %37 %17 +OpBranch %38 +%38 = OpLabel +%39 = OpIAdd %7 %25 %19 +OpBranch %23 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck>(text, output, false); +} + +/* +Generated from the following GLSL +#version 330 core +layout(location = 0) out vec4 c; +void main() { + float x[10]; + for (int i = 0; i < 10; ++i) { + x[i] = 1.0f; + } +} +*/ +TEST_F(PassClassTest, SimpleUnevenPartialUnroll) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" %3 + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 330 + OpName %2 "main" + OpName %5 "x" + OpName %3 "c" + OpDecorate %3 Location 0 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %10 = OpConstant %8 0 + %11 = OpConstant %8 10 + %12 = OpTypeBool + %13 = OpTypeFloat 32 + %14 = OpTypeInt 32 0 + %15 = OpConstant %14 10 + %16 = OpTypeArray %13 %15 + %17 = OpTypePointer Function %16 + %18 = OpConstant %13 1 + %19 = OpTypePointer Function %13 + %20 = OpConstant %8 1 + %21 = OpTypeVector %13 4 + %22 = OpTypePointer Output %21 + %3 = OpVariable %22 Output + %2 = OpFunction %6 None %7 + %23 = OpLabel + %5 = OpVariable %17 Function + OpBranch %24 + %24 = OpLabel + %35 = OpPhi %8 %10 %23 %34 %26 + OpLoopMerge %25 %26 Unroll + OpBranch %27 + %27 = OpLabel + %29 = OpSLessThan %12 %35 %11 + OpBranchConditional %29 %30 %25 + %30 = OpLabel + %32 = OpAccessChain %19 %5 %35 + OpStore %32 %18 + OpBranch %26 + %26 = OpLabel + %34 = OpIAdd %8 %35 %20 + OpBranch %24 + %25 = OpLabel + OpReturn + OpFunctionEnd + )"; + +const std::string output = +R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 330 +OpName %2 "main" +OpName %4 "x" +OpName %3 "c" +OpDecorate %3 Location 0 +%5 = OpTypeVoid +%6 = OpTypeFunction %5 +%7 = OpTypeInt 32 1 +%8 = OpTypePointer Function %7 +%9 = OpConstant %7 0 +%10 = OpConstant %7 10 +%11 = OpTypeBool +%12 = OpTypeFloat 32 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 10 +%15 = OpTypeArray %12 %14 +%16 = OpTypePointer Function %15 +%17 = OpConstant %12 1 +%18 = OpTypePointer Function %12 +%19 = OpConstant %7 1 +%20 = OpTypeVector %12 4 +%21 = OpTypePointer Output %20 +%3 = OpVariable %21 Output +%58 = OpConstant %13 1 +%2 = OpFunction %5 None %6 +%22 = OpLabel +%4 = OpVariable %16 Function +OpBranch %23 +%23 = OpLabel +%24 = OpPhi %7 %9 %22 %25 %26 +OpLoopMerge %32 %26 Unroll +OpBranch %28 +%28 = OpLabel +%29 = OpSLessThan %11 %24 %58 +OpBranchConditional %29 %30 %32 +%30 = OpLabel +%31 = OpAccessChain %18 %4 %24 +OpStore %31 %17 +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %7 %24 %19 +OpBranch %23 +%32 = OpLabel +OpBranch %33 +%33 = OpLabel +%34 = OpPhi %7 %58 %32 %57 %56 +OpLoopMerge %41 %56 Unroll +OpBranch %35 +%35 = OpLabel +%36 = OpSLessThan %11 %34 %10 +OpBranchConditional %36 %37 %41 +%37 = OpLabel +%38 = OpAccessChain %18 %4 %34 +OpStore %38 %17 +OpBranch %39 +%39 = OpLabel +%40 = OpIAdd %7 %34 %19 +OpBranch %42 +%42 = OpLabel +OpBranch %44 +%44 = OpLabel +%45 = OpSLessThan %11 %40 %10 +OpBranch %46 +%46 = OpLabel +%47 = OpAccessChain %18 %4 %40 +OpStore %47 %17 +OpBranch %48 +%48 = OpLabel +%49 = OpIAdd %7 %40 %19 +OpBranch %50 +%50 = OpLabel +OpBranch %52 +%52 = OpLabel +%53 = OpSLessThan %11 %49 %10 +OpBranch %54 +%54 = OpLabel +%55 = OpAccessChain %18 %4 %49 +OpStore %55 %17 +OpBranch %56 +%56 = OpLabel +%57 = OpIAdd %7 %49 %19 +OpBranch %33 +%41 = OpLabel +OpReturn +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + // By unrolling by a factor that doesn't divide evenly into the number of loop + // iterations we perfom an additional transform when partially unrolling to + // account for the remainder. + SinglePassRunAndCheck>(text, output, false); +} + +/* Generated from +#version 410 core +layout(location=0) flat in int upper_bound; +void main() { + float x[10]; + for (int i = 2; i < 8; i+=2) { + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, SimpleLoopIterationsCheck) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %5 "x" +OpName %3 "upper_bound" +OpDecorate %3 Flat +OpDecorate %3 Location 0 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 2 +%11 = OpConstant %8 8 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpTypePointer Input %8 +%3 = OpVariable %19 Input +%2 = OpFunction %6 None %7 +%20 = OpLabel +%5 = OpVariable %17 Function +OpBranch %21 +%21 = OpLabel +%34 = OpPhi %8 %10 %20 %33 %23 +OpLoopMerge %22 %23 Unroll +OpBranch %24 +%24 = OpLabel +%26 = OpSLessThan %12 %34 %11 +OpBranchConditional %26 %27 %22 +%27 = OpLabel +%30 = OpConvertSToF %13 %34 +%31 = OpAccessChain %18 %5 %34 +OpStore %31 %30 +OpBranch %23 +%23 = OpLabel +%33 = OpIAdd %8 %34 %10 +OpBranch %21 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + ir::Function* f = spvtest::GetFunction(module, 2); + + ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + EXPECT_EQ(loop_descriptor.NumLoops(), 1u); + + ir::Loop& loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_TRUE(loop.HasUnrollLoopControl()); + + ir::BasicBlock* condition = loop.FindConditionBlock(); + EXPECT_EQ(condition->id(), 24u); + + ir::Instruction* induction = loop.FindInductionVariable(condition); + EXPECT_EQ(induction->result_id(), 34u); + + opt::LoopUtils loop_utils{context.get(), &loop}; + EXPECT_TRUE(loop_utils.CanPerformUnroll()); + + size_t iterations = 0; + EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(), + &iterations)); + EXPECT_EQ(iterations, 3u); +} + +/* Generated from +#version 410 core +void main() { + float x[10]; + for (int i = -1; i < 6; i+=3) { + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, SimpleLoopIterationsCheckSignedInit) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"( +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" %3 +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %5 "x" +OpName %3 "upper_bound" +OpDecorate %3 Flat +OpDecorate %3 Location 0 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 1 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 -1 +%11 = OpConstant %8 6 +%12 = OpTypeBool +%13 = OpTypeFloat 32 +%14 = OpTypeInt 32 0 +%15 = OpConstant %14 10 +%16 = OpTypeArray %13 %15 +%17 = OpTypePointer Function %16 +%18 = OpTypePointer Function %13 +%19 = OpConstant %8 3 +%20 = OpTypePointer Input %8 +%3 = OpVariable %20 Input +%2 = OpFunction %6 None %7 +%21 = OpLabel +%5 = OpVariable %17 Function +OpBranch %22 +%22 = OpLabel +%35 = OpPhi %8 %10 %21 %34 %24 +OpLoopMerge %23 %24 None +OpBranch %25 +%25 = OpLabel +%27 = OpSLessThan %12 %35 %11 +OpBranchConditional %27 %28 %23 +%28 = OpLabel +%31 = OpConvertSToF %13 %35 +%32 = OpAccessChain %18 %5 %35 +OpStore %32 %31 +OpBranch %24 +%24 = OpLabel +%34 = OpIAdd %8 %35 %19 +OpBranch %22 +%23 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + ir::Function* f = spvtest::GetFunction(module, 2); + + ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + + EXPECT_EQ(loop_descriptor.NumLoops(), 1u); + + ir::Loop& loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_FALSE(loop.HasUnrollLoopControl()); + + ir::BasicBlock* condition = loop.FindConditionBlock(); + EXPECT_EQ(condition->id(), 25u); + + ir::Instruction* induction = loop.FindInductionVariable(condition); + EXPECT_EQ(induction->result_id(), 35u); + + opt::LoopUtils loop_utils{context.get(), &loop}; + EXPECT_TRUE(loop_utils.CanPerformUnroll()); + + size_t iterations = 0; + EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(), + &iterations)); + EXPECT_EQ(iterations, 3u); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[6]; + for (uint i = 0; i < 2; i++) { + for (int x = 0; x < 3; ++x) { + out_array[x + i*3] = i; + } + } +} +*/ +TEST_F(PassClassTest, UnrollNestedLoops) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %35 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 0 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 2 + %17 = OpTypeBool + %19 = OpTypeInt 32 1 + %20 = OpTypePointer Function %19 + %22 = OpConstant %19 0 + %29 = OpConstant %19 3 + %31 = OpTypeFloat 32 + %32 = OpConstant %6 6 + %33 = OpTypeArray %31 %32 + %34 = OpTypePointer Function %33 + %39 = OpConstant %6 3 + %44 = OpTypePointer Function %31 + %47 = OpConstant %19 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %35 = OpVariable %34 Function + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %50 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpULessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %23 + %23 = OpLabel + %54 = OpPhi %19 %22 %11 %48 %26 + OpLoopMerge %25 %26 Unroll + OpBranch %27 + %27 = OpLabel + %30 = OpSLessThan %17 %54 %29 + OpBranchConditional %30 %24 %25 + %24 = OpLabel + %37 = OpBitcast %6 %54 + %40 = OpIMul %6 %51 %39 + %41 = OpIAdd %6 %37 %40 + %43 = OpConvertUToF %31 %51 + %45 = OpAccessChain %44 %35 %41 + OpStore %45 %43 + OpBranch %26 + %26 = OpLabel + %48 = OpIAdd %19 %54 %47 + OpBranch %23 + %25 = OpLabel + OpBranch %13 + %13 = OpLabel + %50 = OpIAdd %6 %51 %47 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + +const std::string output = +R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 0 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 2 +%10 = OpTypeBool +%11 = OpTypeInt 32 1 +%12 = OpTypePointer Function %11 +%13 = OpConstant %11 0 +%14 = OpConstant %11 3 +%15 = OpTypeFloat 32 +%16 = OpConstant %6 6 +%17 = OpTypeArray %15 %16 +%18 = OpTypePointer Function %17 +%19 = OpConstant %6 3 +%20 = OpTypePointer Function %15 +%21 = OpConstant %11 1 +%2 = OpFunction %4 None %5 +%22 = OpLabel +%3 = OpVariable %18 Function +OpBranch %23 +%23 = OpLabel +OpBranch %28 +%28 = OpLabel +%29 = OpULessThan %10 %8 %9 +OpBranch %30 +%30 = OpLabel +OpBranch %31 +%31 = OpLabel +OpBranch %36 +%36 = OpLabel +%37 = OpSLessThan %10 %13 %14 +OpBranch %38 +%38 = OpLabel +%39 = OpBitcast %6 %13 +%40 = OpIMul %6 %8 %19 +%41 = OpIAdd %6 %39 %40 +%42 = OpConvertUToF %15 %8 +%43 = OpAccessChain %20 %3 %41 +OpStore %43 %42 +OpBranch %34 +%34 = OpLabel +%33 = OpIAdd %11 %13 %21 +OpBranch %44 +%44 = OpLabel +OpBranch %46 +%46 = OpLabel +%47 = OpSLessThan %10 %33 %14 +OpBranch %48 +%48 = OpLabel +%49 = OpBitcast %6 %33 +%50 = OpIMul %6 %8 %19 +%51 = OpIAdd %6 %49 %50 +%52 = OpConvertUToF %15 %8 +%53 = OpAccessChain %20 %3 %51 +OpStore %53 %52 +OpBranch %54 +%54 = OpLabel +%55 = OpIAdd %11 %33 %21 +OpBranch %56 +%56 = OpLabel +OpBranch %58 +%58 = OpLabel +%59 = OpSLessThan %10 %55 %14 +OpBranch %60 +%60 = OpLabel +%61 = OpBitcast %6 %55 +%62 = OpIMul %6 %8 %19 +%63 = OpIAdd %6 %61 %62 +%64 = OpConvertUToF %15 %8 +%65 = OpAccessChain %20 %3 %63 +OpStore %65 %64 +OpBranch %66 +%66 = OpLabel +%67 = OpIAdd %11 %55 %21 +OpBranch %35 +%35 = OpLabel +OpBranch %26 +%26 = OpLabel +%25 = OpIAdd %6 %8 %21 +OpBranch %68 +%68 = OpLabel +OpBranch %70 +%70 = OpLabel +%71 = OpULessThan %10 %25 %9 +OpBranch %72 +%72 = OpLabel +OpBranch %73 +%73 = OpLabel +OpBranch %74 +%74 = OpLabel +%75 = OpSLessThan %10 %13 %14 +OpBranch %76 +%76 = OpLabel +%77 = OpBitcast %6 %13 +%78 = OpIMul %6 %25 %19 +%79 = OpIAdd %6 %77 %78 +%80 = OpConvertUToF %15 %25 +%81 = OpAccessChain %20 %3 %79 +OpStore %81 %80 +OpBranch %82 +%82 = OpLabel +%83 = OpIAdd %11 %13 %21 +OpBranch %84 +%84 = OpLabel +OpBranch %85 +%85 = OpLabel +%86 = OpSLessThan %10 %83 %14 +OpBranch %87 +%87 = OpLabel +%88 = OpBitcast %6 %83 +%89 = OpIMul %6 %25 %19 +%90 = OpIAdd %6 %88 %89 +%91 = OpConvertUToF %15 %25 +%92 = OpAccessChain %20 %3 %90 +OpStore %92 %91 +OpBranch %93 +%93 = OpLabel +%94 = OpIAdd %11 %83 %21 +OpBranch %95 +%95 = OpLabel +OpBranch %96 +%96 = OpLabel +%97 = OpSLessThan %10 %94 %14 +OpBranch %98 +%98 = OpLabel +%99 = OpBitcast %6 %94 +%100 = OpIMul %6 %25 %19 +%101 = OpIAdd %6 %99 %100 +%102 = OpConvertUToF %15 %25 +%103 = OpAccessChain %20 %3 %101 +OpStore %103 %102 +OpBranch %104 +%104 = OpLabel +%105 = OpIAdd %11 %94 %21 +OpBranch %106 +%106 = OpLabel +OpBranch %107 +%107 = OpLabel +%108 = OpIAdd %6 %25 %21 +OpBranch %27 +%27 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[2]; + for (int i = -3; i < -1; i++) { + out_array[3 + i] = i; + } +} +*/ +TEST_F(PassClassTest, NegativeConditionAndInit) { + // clang-format off + // With opt::LocalMultiStoreElimPass +const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %23 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 -3 + %16 = OpConstant %6 -1 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 2 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %6 3 + %30 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %23 = OpVariable %22 Function + OpBranch %10 + %10 = OpLabel + %32 = OpPhi %6 %9 %5 %31 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %32 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIAdd %6 %32 %25 + %28 = OpAccessChain %7 %23 %26 + OpStore %28 %32 + OpBranch %13 + %13 = OpLabel + %31 = OpIAdd %6 %32 %30 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 -3 +%9 = OpConstant %6 -1 +%10 = OpTypeBool +%11 = OpTypeInt 32 0 +%12 = OpConstant %11 2 +%13 = OpTypeArray %6 %12 +%14 = OpTypePointer Function %13 +%15 = OpConstant %6 3 +%16 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%17 = OpLabel +%3 = OpVariable %14 Function +OpBranch %18 +%18 = OpLabel +OpBranch %23 +%23 = OpLabel +%24 = OpSLessThan %10 %8 %9 +OpBranch %25 +%25 = OpLabel +%26 = OpIAdd %6 %8 %15 +%27 = OpAccessChain %7 %3 %26 +OpStore %27 %8 +OpBranch %21 +%21 = OpLabel +%20 = OpIAdd %6 %8 %16 +OpBranch %28 +%28 = OpLabel +OpBranch %30 +%30 = OpLabel +%31 = OpSLessThan %10 %20 %9 +OpBranch %32 +%32 = OpLabel +%33 = OpIAdd %6 %20 %15 +%34 = OpAccessChain %7 %3 %33 +OpStore %34 %20 +OpBranch %35 +%35 = OpLabel +%36 = OpIAdd %6 %20 %16 +OpBranch %22 +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + // SinglePassRunAndCheck(text, expected, false); + + ir::Function* f = spvtest::GetFunction(module, 4); + + ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + EXPECT_EQ(loop_descriptor.NumLoops(), 1u); + + ir::Loop& loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_TRUE(loop.HasUnrollLoopControl()); + + ir::BasicBlock* condition = loop.FindConditionBlock(); + EXPECT_EQ(condition->id(), 14u); + + ir::Instruction* induction = loop.FindInductionVariable(condition); + EXPECT_EQ(induction->result_id(), 32u); + + opt::LoopUtils loop_utils{context.get(), &loop}; + EXPECT_TRUE(loop_utils.CanPerformUnroll()); + + size_t iterations = 0; + EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(), + &iterations)); + EXPECT_EQ(iterations, 2u); + SinglePassRunAndCheck(text, expected, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[9]; + for (int i = -10; i < -1; i++) { + out_array[i] = i; + } +} +*/ +TEST_F(PassClassTest, NegativeConditionAndInitResidualUnroll) { + // clang-format off + // With opt::LocalMultiStoreElimPass +const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %23 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 -10 + %16 = OpConstant %6 -1 + %17 = OpTypeBool + %19 = OpTypeInt 32 0 + %20 = OpConstant %19 9 + %21 = OpTypeArray %6 %20 + %22 = OpTypePointer Function %21 + %25 = OpConstant %6 10 + %30 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %23 = OpVariable %22 Function + OpBranch %10 + %10 = OpLabel + %32 = OpPhi %6 %9 %5 %31 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpSLessThan %17 %32 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %26 = OpIAdd %6 %32 %25 + %28 = OpAccessChain %7 %23 %26 + OpStore %28 %32 + OpBranch %13 + %13 = OpLabel + %31 = OpIAdd %6 %32 %30 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 -10 +%9 = OpConstant %6 -1 +%10 = OpTypeBool +%11 = OpTypeInt 32 0 +%12 = OpConstant %11 9 +%13 = OpTypeArray %6 %12 +%14 = OpTypePointer Function %13 +%15 = OpConstant %6 10 +%16 = OpConstant %6 1 +%48 = OpConstant %6 -9 +%2 = OpFunction %4 None %5 +%17 = OpLabel +%3 = OpVariable %14 Function +OpBranch %18 +%18 = OpLabel +%19 = OpPhi %6 %8 %17 %20 %21 +OpLoopMerge %28 %21 Unroll +OpBranch %23 +%23 = OpLabel +%24 = OpSLessThan %10 %19 %48 +OpBranchConditional %24 %25 %28 +%25 = OpLabel +%26 = OpIAdd %6 %19 %15 +%27 = OpAccessChain %7 %3 %26 +OpStore %27 %19 +OpBranch %21 +%21 = OpLabel +%20 = OpIAdd %6 %19 %16 +OpBranch %18 +%28 = OpLabel +OpBranch %29 +%29 = OpLabel +%30 = OpPhi %6 %48 %28 %47 %46 +OpLoopMerge %38 %46 Unroll +OpBranch %31 +%31 = OpLabel +%32 = OpSLessThan %10 %30 %9 +OpBranchConditional %32 %33 %38 +%33 = OpLabel +%34 = OpIAdd %6 %30 %15 +%35 = OpAccessChain %7 %3 %34 +OpStore %35 %30 +OpBranch %36 +%36 = OpLabel +%37 = OpIAdd %6 %30 %16 +OpBranch %39 +%39 = OpLabel +OpBranch %41 +%41 = OpLabel +%42 = OpSLessThan %10 %37 %9 +OpBranch %43 +%43 = OpLabel +%44 = OpIAdd %6 %37 %15 +%45 = OpAccessChain %7 %3 %44 +OpStore %45 %37 +OpBranch %46 +%46 = OpLabel +%47 = OpIAdd %6 %37 %16 +OpBranch %29 +%38 = OpLabel +OpReturn +%22 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + ir::Function* f = spvtest::GetFunction(module, 4); + + ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + EXPECT_EQ(loop_descriptor.NumLoops(), 1u); + + ir::Loop& loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_TRUE(loop.HasUnrollLoopControl()); + + ir::BasicBlock* condition = loop.FindConditionBlock(); + EXPECT_EQ(condition->id(), 14u); + + ir::Instruction* induction = loop.FindInductionVariable(condition); + EXPECT_EQ(induction->result_id(), 32u); + + opt::LoopUtils loop_utils{context.get(), &loop}; + EXPECT_TRUE(loop_utils.CanPerformUnroll()); + + size_t iterations = 0; + EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(), + &iterations)); + EXPECT_EQ(iterations, 9u); + SinglePassRunAndCheck>(text, expected, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[10]; + for (uint i = 0; i < 2; i++) { + for (int x = 0; x < 5; ++x) { + out_array[x + i*5] = i; + } + } +} +*/ +TEST_F(PassClassTest, UnrollNestedLoopsValidateDescriptor) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %35 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 0 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 0 + %16 = OpConstant %6 2 + %17 = OpTypeBool + %19 = OpTypeInt 32 1 + %20 = OpTypePointer Function %19 + %22 = OpConstant %19 0 + %29 = OpConstant %19 5 + %31 = OpTypeFloat 32 + %32 = OpConstant %6 10 + %33 = OpTypeArray %31 %32 + %34 = OpTypePointer Function %33 + %39 = OpConstant %6 5 + %44 = OpTypePointer Function %31 + %47 = OpConstant %19 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %35 = OpVariable %34 Function + OpBranch %10 + %10 = OpLabel + %51 = OpPhi %6 %9 %5 %50 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpULessThan %17 %51 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + OpBranch %23 + %23 = OpLabel + %54 = OpPhi %19 %22 %11 %48 %26 + OpLoopMerge %25 %26 Unroll + OpBranch %27 + %27 = OpLabel + %30 = OpSLessThan %17 %54 %29 + OpBranchConditional %30 %24 %25 + %24 = OpLabel + %37 = OpBitcast %6 %54 + %40 = OpIMul %6 %51 %39 + %41 = OpIAdd %6 %37 %40 + %43 = OpConvertUToF %31 %51 + %45 = OpAccessChain %44 %35 %41 + OpStore %45 %43 + OpBranch %26 + %26 = OpLabel + %48 = OpIAdd %19 %54 %47 + OpBranch %23 + %25 = OpLabel + OpBranch %13 + %13 = OpLabel + %50 = OpIAdd %6 %51 %47 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + + // clang-format on + + { // Test fully unroll + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + ir::Function* f = spvtest::GetFunction(module, 4); + ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + EXPECT_EQ(loop_descriptor.NumLoops(), 2u); + + ir::Loop& outer_loop = loop_descriptor.GetLoopByIndex(1); + + EXPECT_TRUE(outer_loop.HasUnrollLoopControl()); + + ir::Loop& inner_loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_TRUE(inner_loop.HasUnrollLoopControl()); + + EXPECT_EQ(outer_loop.GetBlocks().size(), 9u); + + EXPECT_EQ(inner_loop.GetBlocks().size(), 4u); + EXPECT_EQ(outer_loop.NumImmediateChildren(), 1u); + EXPECT_EQ(inner_loop.NumImmediateChildren(), 0u); + + { + opt::LoopUtils loop_utils{context.get(), &inner_loop}; + loop_utils.FullyUnroll(); + loop_utils.Finalize(); + } + + EXPECT_EQ(loop_descriptor.NumLoops(), 1u); + EXPECT_EQ(outer_loop.GetBlocks().size(), 25u); + EXPECT_EQ(outer_loop.NumImmediateChildren(), 0u); + { + opt::LoopUtils loop_utils{context.get(), &outer_loop}; + loop_utils.FullyUnroll(); + loop_utils.Finalize(); + } + EXPECT_EQ(loop_descriptor.NumLoops(), 0u); + } + + { // Test partially unroll + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + + ir::Function* f = spvtest::GetFunction(module, 4); + ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f); + EXPECT_EQ(loop_descriptor.NumLoops(), 2u); + + ir::Loop& outer_loop = loop_descriptor.GetLoopByIndex(1); + + EXPECT_TRUE(outer_loop.HasUnrollLoopControl()); + + ir::Loop& inner_loop = loop_descriptor.GetLoopByIndex(0); + + EXPECT_TRUE(inner_loop.HasUnrollLoopControl()); + + EXPECT_EQ(outer_loop.GetBlocks().size(), 9u); + + EXPECT_EQ(inner_loop.GetBlocks().size(), 4u); + + EXPECT_EQ(outer_loop.NumImmediateChildren(), 1u); + EXPECT_EQ(inner_loop.NumImmediateChildren(), 0u); + + opt::LoopUtils loop_utils{context.get(), &inner_loop}; + loop_utils.PartiallyUnroll(2); + loop_utils.Finalize(); + + // The number of loops should actually grow. + EXPECT_EQ(loop_descriptor.NumLoops(), 3u); + EXPECT_EQ(outer_loop.GetBlocks().size(), 19u); + EXPECT_EQ(outer_loop.NumImmediateChildren(), 2u); + } +} + +/* +Generated from the following GLSL +#version 440 core +void main(){ + float x[10]; + int i = 1; + i = 0; + for (; i < 10; i++) { + x[i] = i; + } +} +*/ +TEST_F(PassClassTest, UnrollWithInductionOutsideHeader) { + // clang-format off + // With opt::LocalMultiStoreElimPass +const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +OpName %x "x" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_1 = OpConstant %int 1 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%float = OpTypeFloat 32 +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10 +%_ptr_Function_float = OpTypePointer Function %float +%main = OpFunction %void None %3 +%5 = OpLabel +%x = OpVariable %_ptr_Function__arr_float_uint_10 Function +OpBranch %11 +%11 = OpLabel +%33 = OpPhi %int %int_0 %5 %32 %14 +OpLoopMerge %13 %14 None +OpBranch %15 +%15 = OpLabel +%19 = OpSLessThan %bool %33 %int_10 +OpBranchConditional %19 %12 %13 +%12 = OpLabel +%28 = OpConvertSToF %float %33 +%30 = OpAccessChain %_ptr_Function_float %x %33 +OpStore %30 %28 +OpBranch %14 +%14 = OpLabel +%32 = OpIAdd %int %33 %int_1 +OpBranch %11 +%13 = OpLabel +OpReturn +OpFunctionEnd +)"; + +const std::string expected = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 440 +OpName %main "main" +OpName %x "x" +%void = OpTypeVoid +%5 = OpTypeFunction %void +%int = OpTypeInt 32 1 +%_ptr_Function_int = OpTypePointer Function %int +%int_1 = OpConstant %int 1 +%int_0 = OpConstant %int 0 +%int_10 = OpConstant %int 10 +%bool = OpTypeBool +%float = OpTypeFloat 32 +%uint = OpTypeInt 32 0 +%uint_10 = OpConstant %uint 10 +%_arr_float_uint_10 = OpTypeArray %float %uint_10 +%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10 +%_ptr_Function_float = OpTypePointer Function %float +%main = OpFunction %void None %5 +%18 = OpLabel +%x = OpVariable %_ptr_Function__arr_float_uint_10 Function +OpBranch %19 +%19 = OpLabel +%20 = OpPhi %int %int_0 %18 %37 %36 +OpLoopMerge %23 %36 None +OpBranch %24 +%24 = OpLabel +%25 = OpSLessThan %bool %20 %int_10 +OpBranchConditional %25 %26 %23 +%26 = OpLabel +%27 = OpConvertSToF %float %20 +%28 = OpAccessChain %_ptr_Function_float %x %20 +OpStore %28 %27 +OpBranch %22 +%22 = OpLabel +%21 = OpIAdd %int %20 %int_1 +OpBranch %29 +%29 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpSLessThan %bool %21 %int_10 +OpBranch %33 +%33 = OpLabel +%34 = OpConvertSToF %float %21 +%35 = OpAccessChain %_ptr_Function_float %x %21 +OpStore %35 %34 +OpBranch %36 +%36 = OpLabel +%37 = OpIAdd %int %21 %int_1 +OpBranch %19 +%23 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + + SinglePassRunAndCheck>(text, expected, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[3]; + for (int i = 3; i > 0; --i) { + out_array[i] = i; + } +} +*/ +TEST_F(PassClassTest, FullyUnrollNegativeStepLoopTest) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %24 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 3 + %16 = OpConstant %6 0 + %17 = OpTypeBool + %19 = OpTypeFloat 32 + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 3 + %22 = OpTypeArray %19 %21 + %23 = OpTypePointer Function %22 + %28 = OpTypePointer Function %19 + %31 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %24 = OpVariable %23 Function + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %32 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpSGreaterThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpConvertSToF %19 %33 + %29 = OpAccessChain %28 %24 %33 + OpStore %29 %27 + OpBranch %13 + %13 = OpLabel + %32 = OpISub %6 %33 %31 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + +const std::string output = +R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 3 +%9 = OpConstant %6 0 +%10 = OpTypeBool +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 3 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %6 1 +%2 = OpFunction %4 None %5 +%18 = OpLabel +%3 = OpVariable %15 Function +OpBranch %19 +%19 = OpLabel +OpBranch %24 +%24 = OpLabel +%25 = OpSGreaterThan %10 %8 %9 +OpBranch %26 +%26 = OpLabel +%27 = OpConvertSToF %11 %8 +%28 = OpAccessChain %16 %3 %8 +OpStore %28 %27 +OpBranch %22 +%22 = OpLabel +%21 = OpISub %6 %8 %17 +OpBranch %29 +%29 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpSGreaterThan %10 %21 %9 +OpBranch %33 +%33 = OpLabel +%34 = OpConvertSToF %11 %21 +%35 = OpAccessChain %16 %3 %21 +OpStore %35 %34 +OpBranch %36 +%36 = OpLabel +%37 = OpISub %6 %21 %17 +OpBranch %38 +%38 = OpLabel +OpBranch %40 +%40 = OpLabel +%41 = OpSGreaterThan %10 %37 %9 +OpBranch %42 +%42 = OpLabel +%43 = OpConvertSToF %11 %37 +%44 = OpAccessChain %16 %3 %37 +OpStore %44 %43 +OpBranch %45 +%45 = OpLabel +%46 = OpISub %6 %37 %17 +OpBranch %23 +%23 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[3]; + for (int i = 9; i > 0; i-=3) { + out_array[i] = i; + } +} +*/ +TEST_F(PassClassTest, FullyUnrollNegativeNonOneStepLoop) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource GLSL 410 + OpName %4 "main" + OpName %24 "out_array" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %9 = OpConstant %6 9 + %16 = OpConstant %6 0 + %17 = OpTypeBool + %19 = OpTypeFloat 32 + %20 = OpTypeInt 32 0 + %21 = OpConstant %20 3 + %22 = OpTypeArray %19 %21 + %23 = OpTypePointer Function %22 + %28 = OpTypePointer Function %19 + %30 = OpConstant %6 3 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %24 = OpVariable %23 Function + OpBranch %10 + %10 = OpLabel + %33 = OpPhi %6 %9 %5 %32 %13 + OpLoopMerge %12 %13 Unroll + OpBranch %14 + %14 = OpLabel + %18 = OpSGreaterThan %17 %33 %16 + OpBranchConditional %18 %11 %12 + %11 = OpLabel + %27 = OpConvertSToF %19 %33 + %29 = OpAccessChain %28 %24 %33 + OpStore %29 %27 + OpBranch %13 + %13 = OpLabel + %32 = OpISub %6 %33 %30 + OpBranch %10 + %12 = OpLabel + OpReturn + OpFunctionEnd + )"; + +const std::string output = +R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 9 +%9 = OpConstant %6 0 +%10 = OpTypeBool +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 3 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %6 3 +%2 = OpFunction %4 None %5 +%18 = OpLabel +%3 = OpVariable %15 Function +OpBranch %19 +%19 = OpLabel +OpBranch %24 +%24 = OpLabel +%25 = OpSGreaterThan %10 %8 %9 +OpBranch %26 +%26 = OpLabel +%27 = OpConvertSToF %11 %8 +%28 = OpAccessChain %16 %3 %8 +OpStore %28 %27 +OpBranch %22 +%22 = OpLabel +%21 = OpISub %6 %8 %17 +OpBranch %29 +%29 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpSGreaterThan %10 %21 %9 +OpBranch %33 +%33 = OpLabel +%34 = OpConvertSToF %11 %21 +%35 = OpAccessChain %16 %3 %21 +OpStore %35 %34 +OpBranch %36 +%36 = OpLabel +%37 = OpISub %6 %21 %17 +OpBranch %38 +%38 = OpLabel +OpBranch %40 +%40 = OpLabel +%41 = OpSGreaterThan %10 %37 %9 +OpBranch %42 +%42 = OpLabel +%43 = OpConvertSToF %11 %37 +%44 = OpAccessChain %16 %3 %37 +OpStore %44 %43 +OpBranch %45 +%45 = OpLabel +%46 = OpISub %6 %37 %17 +OpBranch %23 +%23 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[3]; + for (int i = 0; i < 7; i+=3) { + out_array[i] = i; + } +} +*/ +TEST_F(PassClassTest, FullyUnrollNonDivisibleStepLoop) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %4 "main" +OpExecutionMode %4 OriginUpperLeft +OpSource GLSL 410 +OpName %4 "main" +OpName %24 "out_array" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%9 = OpConstant %6 0 +%16 = OpConstant %6 7 +%17 = OpTypeBool +%19 = OpTypeFloat 32 +%20 = OpTypeInt 32 0 +%21 = OpConstant %20 3 +%22 = OpTypeArray %19 %21 +%23 = OpTypePointer Function %22 +%28 = OpTypePointer Function %19 +%30 = OpConstant %6 3 +%4 = OpFunction %2 None %3 +%5 = OpLabel +%24 = OpVariable %23 Function +OpBranch %10 +%10 = OpLabel +%33 = OpPhi %6 %9 %5 %32 %13 +OpLoopMerge %12 %13 Unroll +OpBranch %14 +%14 = OpLabel +%18 = OpSLessThan %17 %33 %16 +OpBranchConditional %18 %11 %12 +%11 = OpLabel +%27 = OpConvertSToF %19 %33 +%29 = OpAccessChain %28 %24 %33 +OpStore %29 %27 +OpBranch %13 +%13 = OpLabel +%32 = OpIAdd %6 %33 %30 +OpBranch %10 +%12 = OpLabel +OpReturn +OpFunctionEnd +)"; + +const std::string output = +R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 0 +%9 = OpConstant %6 7 +%10 = OpTypeBool +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 3 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %6 3 +%2 = OpFunction %4 None %5 +%18 = OpLabel +%3 = OpVariable %15 Function +OpBranch %19 +%19 = OpLabel +OpBranch %24 +%24 = OpLabel +%25 = OpSLessThan %10 %8 %9 +OpBranch %26 +%26 = OpLabel +%27 = OpConvertSToF %11 %8 +%28 = OpAccessChain %16 %3 %8 +OpStore %28 %27 +OpBranch %22 +%22 = OpLabel +%21 = OpIAdd %6 %8 %17 +OpBranch %29 +%29 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpSLessThan %10 %21 %9 +OpBranch %33 +%33 = OpLabel +%34 = OpConvertSToF %11 %21 +%35 = OpAccessChain %16 %3 %21 +OpStore %35 %34 +OpBranch %36 +%36 = OpLabel +%37 = OpIAdd %6 %21 %17 +OpBranch %38 +%38 = OpLabel +OpBranch %40 +%40 = OpLabel +%41 = OpSLessThan %10 %37 %9 +OpBranch %42 +%42 = OpLabel +%43 = OpConvertSToF %11 %37 +%44 = OpAccessChain %16 %3 %37 +OpStore %44 %43 +OpBranch %45 +%45 = OpLabel +%46 = OpIAdd %6 %37 %17 +OpBranch %23 +%23 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +/* +Generated from the following GLSL +#version 410 core +void main() { + float out_array[4]; + for (int i = 11; i > 0; i-=3) { + out_array[i] = i; + } +} +*/ +TEST_F(PassClassTest, FullyUnrollNegativeNonDivisibleStepLoop) { + // clang-format off + // With opt::LocalMultiStoreElimPass + const std::string text = R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %4 "main" +OpExecutionMode %4 OriginUpperLeft +OpSource GLSL 410 +OpName %4 "main" +OpName %24 "out_array" +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%9 = OpConstant %6 11 +%16 = OpConstant %6 0 +%17 = OpTypeBool +%19 = OpTypeFloat 32 +%20 = OpTypeInt 32 0 +%21 = OpConstant %20 4 +%22 = OpTypeArray %19 %21 +%23 = OpTypePointer Function %22 +%28 = OpTypePointer Function %19 +%30 = OpConstant %6 3 +%4 = OpFunction %2 None %3 +%5 = OpLabel +%24 = OpVariable %23 Function +OpBranch %10 +%10 = OpLabel +%33 = OpPhi %6 %9 %5 %32 %13 +OpLoopMerge %12 %13 Unroll +OpBranch %14 +%14 = OpLabel +%18 = OpSGreaterThan %17 %33 %16 +OpBranchConditional %18 %11 %12 +%11 = OpLabel +%27 = OpConvertSToF %19 %33 +%29 = OpAccessChain %28 %24 %33 +OpStore %29 %27 +OpBranch %13 +%13 = OpLabel +%32 = OpISub %6 %33 %30 +OpBranch %10 +%12 = OpLabel +OpReturn +OpFunctionEnd +)"; + +const std::string output = +R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %2 "main" +OpExecutionMode %2 OriginUpperLeft +OpSource GLSL 410 +OpName %2 "main" +OpName %3 "out_array" +%4 = OpTypeVoid +%5 = OpTypeFunction %4 +%6 = OpTypeInt 32 1 +%7 = OpTypePointer Function %6 +%8 = OpConstant %6 11 +%9 = OpConstant %6 0 +%10 = OpTypeBool +%11 = OpTypeFloat 32 +%12 = OpTypeInt 32 0 +%13 = OpConstant %12 4 +%14 = OpTypeArray %11 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer Function %11 +%17 = OpConstant %6 3 +%2 = OpFunction %4 None %5 +%18 = OpLabel +%3 = OpVariable %15 Function +OpBranch %19 +%19 = OpLabel +OpBranch %24 +%24 = OpLabel +%25 = OpSGreaterThan %10 %8 %9 +OpBranch %26 +%26 = OpLabel +%27 = OpConvertSToF %11 %8 +%28 = OpAccessChain %16 %3 %8 +OpStore %28 %27 +OpBranch %22 +%22 = OpLabel +%21 = OpISub %6 %8 %17 +OpBranch %29 +%29 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpSGreaterThan %10 %21 %9 +OpBranch %33 +%33 = OpLabel +%34 = OpConvertSToF %11 %21 +%35 = OpAccessChain %16 %3 %21 +OpStore %35 %34 +OpBranch %36 +%36 = OpLabel +%37 = OpISub %6 %21 %17 +OpBranch %38 +%38 = OpLabel +OpBranch %40 +%40 = OpLabel +%41 = OpSGreaterThan %10 %37 %9 +OpBranch %42 +%42 = OpLabel +%43 = OpConvertSToF %11 %37 +%44 = OpAccessChain %16 %3 %37 +OpStore %44 %43 +OpBranch %45 +%45 = OpLabel +%46 = OpISub %6 %37 %17 +OpBranch %47 +%47 = OpLabel +OpBranch %49 +%49 = OpLabel +%50 = OpSGreaterThan %10 %46 %9 +OpBranch %51 +%51 = OpLabel +%52 = OpConvertSToF %11 %46 +%53 = OpAccessChain %16 %3 %46 +OpStore %53 %52 +OpBranch %54 +%54 = OpLabel +%55 = OpISub %6 %46 %17 +OpBranch %23 +%23 = OpLabel +OpReturn +OpFunctionEnd +)"; + // clang-format on + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ir::Module* module = context->module(); + EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n" + << text << std::endl; + + opt::LoopUnroller loop_unroller; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + SinglePassRunAndCheck(text, output, false); +} + +} // namespace diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp index 4525e30d..fb8b3a1d 100644 --- a/tools/opt/opt.cpp +++ b/tools/opt/opt.cpp @@ -472,6 +472,8 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer, optimizer->RegisterPass(CreateReplaceInvalidOpcodePass()); } else if (0 == strcmp(cur_arg, "--simplify-instructions")) { optimizer->RegisterPass(CreateSimplificationPass()); + } else if (0 == strcmp(cur_arg, "--loop-unroll")) { + optimizer->RegisterPass(CreateLoopFullyUnrollPass()); } else if (0 == strcmp(cur_arg, "--skip-validation")) { *skip_validator = true; } else if (0 == strcmp(cur_arg, "-O")) { -- 2.34.1