source/opt/local_single_store_elim_pass.cpp \
source/opt/local_ssa_elim_pass.cpp \
source/opt/loop_descriptor.cpp \
+ source/opt/loop_unroller.cpp \
source/opt/mem_pass.cpp \
source/opt/merge_return_pass.cpp \
source/opt/module.cpp \
// Creates a pass that simplifies instructions using the instruction folder.
Optimizer::PassToken CreateSimplificationPass();
+// Create loop unroller pass.
+// Creates a pass to fully unroll loops which have the "Unroll" loop control
+// mask set. The loops must meet a specific criteria in order to be unrolled
+// safely this criteria is checked before doing the unroll by the
+// LoopUtils::CanPerformUnroll method. Any loop that does not meet the criteria
+// won't be unrolled. See CanPerformUnroll LoopUtils.h for more information.
+Optimizer::PassToken CreateLoopFullyUnrollPass();
+
} // namespace spvtools
#endif // SPIRV_TOOLS_OPTIMIZER_HPP_
local_ssa_elim_pass.h
log.h
loop_descriptor.h
+ loop_unroller.h
loop_utils.h
make_unique.h
mem_pass.h
local_ssa_elim_pass.cpp
loop_descriptor.cpp
loop_utils.cpp
+ loop_unroller.cpp
mem_pass.cpp
merge_return_pass.cpp
module.cpp
return true;
}
+ // Applies the std::function |func| to all nodes in the dominator tree from
+ // |node| downwards. The boolean return from |func| is used to determine
+ // whether or not the children should also be traversed. Tree nodes are
+ // visited in a depth first pre-order.
+ void VisitChildrenIf(std::function<bool(DominatorTreeNode*)> func,
+ iterator node) {
+ if (func(&*node)) {
+ for (auto n : *node) {
+ VisitChildrenIf(func, n->df_begin());
+ }
+ }
+ }
+
// Returns the DominatorTreeNode associated with the basic block |bb|.
// If the |bb| is unknown to the dominator tree, it returns null.
inline DominatorTreeNode* GetTreeNode(ir::BasicBlock* bb) {
#define LIBSPIRV_OPT_IR_BUILDER_H_
#include "opt/basic_block.h"
+#include "opt/constants.h"
#include "opt/instruction.h"
#include "opt/ir_context.h"
-
namespace spvtools {
namespace opt {
return AddInstruction(std::move(select));
}
+ // Adds a signed int32 constant to the binary.
+ // The |value| parameter is the constant value to be added.
+ ir::Instruction* Add32BitSignedIntegerConstant(int32_t value) {
+ return Add32BitConstantInteger<int32_t>(value, true);
+ }
+
// Create a composite construct.
// |type| should be a composite type and the number of elements it has should
// match the size od |ids|.
GetContext()->TakeNextId(), ops));
return AddInstruction(std::move(construct));
}
+ // Adds an unsigned int32 constant to the binary.
+ // The |value| parameter is the constant value to be added.
+ ir::Instruction* Add32BitUnsignedIntegerConstant(uint32_t value) {
+ return Add32BitConstantInteger<uint32_t>(value, false);
+ }
+
+ // Adds either a signed or unsigned 32 bit integer constant to the binary
+ // depedning on the |sign|. If |sign| is true then the value is added as a
+ // signed constant otherwise as an unsigned constant. If |sign| is false the
+ // value must not be a negative number.
+ template <typename T>
+ ir::Instruction* Add32BitConstantInteger(T value, bool sign) {
+ // Assert that we are not trying to store a negative number in an unsigned
+ // type.
+ if (!sign)
+ assert(value > 0 &&
+ "Trying to add a signed integer with an unsigned type!");
+
+ // Get or create the integer type.
+ analysis::Integer int_type(32, sign);
+
+ // Even if the value is negative we need to pass the bit pattern as a
+ // uint32_t to GetConstant.
+ uint32_t word = value;
+
+ // Create the constant value.
+ const opt::analysis::Constant* constant =
+ GetContext()->get_constant_mgr()->GetConstant(&int_type, {word});
+
+ // Create the OpConstant instruction using the type and the value.
+ return GetContext()->get_constant_mgr()->GetDefiningInstruction(constant);
+ }
ir::Instruction* AddCompositeExtract(
uint32_t type, uint32_t id_of_composite,
#include <utility>
#include <vector>
+#include "constants.h"
#include "opt/cfg.h"
#include "opt/dominator_tree.h"
#include "opt/ir_builder.h"
namespace spvtools {
namespace ir {
+// Takes in a phi instruction |induction| and the loop |header| and returns the
+// step operation of the loop.
+ir::Instruction* Loop::GetInductionStepOperation(
+ const ir::Loop* loop, const ir::Instruction* induction) const {
+ // Induction must be a phi instruction.
+ assert(induction->opcode() == SpvOpPhi);
+
+ ir::Instruction* step = nullptr;
+
+ opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+ // Traverse the incoming operands of the phi instruction.
+ for (uint32_t operand_id = 1; operand_id < induction->NumInOperands();
+ operand_id += 2) {
+ // Incoming edge.
+ ir::BasicBlock* incoming_block =
+ context_->cfg()->block(induction->GetSingleWordInOperand(operand_id));
+
+ // Check if the block is dominated by header, and thus coming from within
+ // the loop.
+ if (loop->IsInsideLoop(incoming_block)) {
+ step = def_use_manager->GetDef(
+ induction->GetSingleWordInOperand(operand_id - 1));
+ break;
+ }
+ }
+
+ if (!step || !IsSupportedStepOp(step->opcode())) {
+ return nullptr;
+ }
+
+ return step;
+}
+
+// Returns true if the |step| operation is an induction variable step operation
+// which is currently handled.
+bool Loop::IsSupportedStepOp(SpvOp step) const {
+ switch (step) {
+ case SpvOp::SpvOpISub:
+ case SpvOp::SpvOpIAdd:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool Loop::IsSupportedCondition(SpvOp condition) const {
+ switch (condition) {
+ // <
+ case SpvOp::SpvOpULessThan:
+ case SpvOp::SpvOpSLessThan:
+ // >
+ case SpvOp::SpvOpUGreaterThan:
+ case SpvOp::SpvOpSGreaterThan:
+ return true;
+ default:
+ return false;
+ }
+}
+
+// Extract the initial value from the |induction| OpPhi instruction and store it
+// in |value|. If the function couldn't find the initial value of |induction|
+// return false.
+bool Loop::GetInductionInitValue(const ir::Loop* loop,
+ const ir::Instruction* induction,
+ int64_t* value) const {
+ ir::Instruction* constant_instruction = nullptr;
+ opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+ for (uint32_t operand_id = 0; operand_id < induction->NumInOperands();
+ operand_id += 2) {
+ ir::BasicBlock* bb = context_->cfg()->block(
+ induction->GetSingleWordInOperand(operand_id + 1));
+
+ if (!loop->IsInsideLoop(bb)) {
+ constant_instruction = def_use_manager->GetDef(
+ induction->GetSingleWordInOperand(operand_id));
+ }
+ }
+
+ if (!constant_instruction) return false;
+
+ const opt::analysis::Constant* constant =
+ context_->get_constant_mgr()->FindDeclaredConstant(
+ constant_instruction->result_id());
+ if (!constant) return false;
+
+ if (value) {
+ const opt::analysis::Integer* type =
+ constant->AsIntConstant()->type()->AsInteger();
+
+ if (type->IsSigned()) {
+ *value = constant->AsIntConstant()->GetS32BitValue();
+ } else {
+ *value = constant->AsIntConstant()->GetU32BitValue();
+ }
+ }
+
+ return true;
+}
+
Loop::Loop(IRContext* context, opt::DominatorAnalysis* dom_analysis,
BasicBlock* header, BasicBlock* continue_target,
BasicBlock* merge_target)
loop_continue_(continue_target),
loop_merge_(merge_target),
loop_preheader_(nullptr),
- parent_(nullptr) {
+ parent_(nullptr),
+ loop_is_marked_for_removal_(false) {
assert(context);
assert(dom_analysis);
loop_preheader_ = FindLoopPreheader(dom_analysis);
- AddBasicBlockToLoop(header);
- AddBasicBlockToLoop(continue_target);
}
BasicBlock* Loop::FindLoopPreheader(opt::DominatorAnalysis* dom_analysis) {
bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) {
assert(bb->GetParent() && "The basic block does not belong to a function");
-
opt::DominatorAnalysis* dom_analysis =
context_->GetDominatorAnalysis(bb->GetParent(), *context_->cfg());
if (!dom_analysis->Dominates(GetHeaderBlock(), bb)) return false;
void Loop::SetMergeBlock(BasicBlock* merge) {
#ifndef NDEBUG
assert(merge->GetParent() && "The basic block does not belong to a function");
- CFG& cfg = *merge->GetParent()->GetParent()->context()->cfg();
-
- for (uint32_t pred : cfg.preds(merge->id())) {
- assert(IsInsideLoop(pred) &&
- "A predecessor of the merge block does not belong to the loop");
- }
- assert(!IsInsideLoop(merge) && "The merge block is in the loop");
#endif // NDEBUG
+ assert(!IsInsideLoop(merge) && "The merge block is in the loop");
SetMergeBlockImpl(merge);
if (GetHeaderBlock()->GetLoopMergeInst()) {
void LoopDescriptor::PopulateList(const Function* f) {
IRContext* context = f->GetParent()->context();
+
opt::DominatorAnalysis* dom_analysis =
context->GetDominatorAnalysis(f, *context->cfg());
make_range(node.df_begin(), node.df_end())) {
// Check if we are in the loop.
if (dom_tree.Dominates(dom_merge_node, &loop_node)) continue;
- current_loop->AddBasicBlockToLoop(loop_node.bb_);
+ current_loop->AddBasicBlock(loop_node.bb_);
basic_block_to_loop_.insert(
std::make_pair(loop_node.bb_->id(), current_loop));
}
}
}
+ir::BasicBlock* Loop::FindConditionBlock() const {
+ const ir::Function& function = *loop_merge_->GetParent();
+ ir::BasicBlock* condition_block = nullptr;
+
+ const opt::DominatorAnalysis* dom_analysis =
+ context_->GetDominatorAnalysis(&function, *context_->cfg());
+ ir::BasicBlock* bb = dom_analysis->ImmediateDominator(loop_merge_);
+
+ if (!bb) return nullptr;
+
+ const ir::Instruction& branch = *bb->ctail();
+
+ // Make sure the branch is a conditional branch.
+ if (branch.opcode() != SpvOpBranchConditional) return nullptr;
+
+ // Make sure one of the two possible branches is to the merge block.
+ if (branch.GetSingleWordInOperand(1) == loop_merge_->id() ||
+ branch.GetSingleWordInOperand(2) == loop_merge_->id()) {
+ condition_block = bb;
+ }
+
+ return condition_block;
+}
+
+bool Loop::FindNumberOfIterations(const ir::Instruction* induction,
+ const ir::Instruction* branch_inst,
+ size_t* iterations_out,
+ int64_t* step_value_out,
+ int64_t* init_value_out) const {
+ // From the branch instruction find the branch condition.
+ opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+ // Condition instruction from the OpConditionalBranch.
+ ir::Instruction* condition =
+ def_use_manager->GetDef(branch_inst->GetSingleWordOperand(0));
+
+ assert(IsSupportedCondition(condition->opcode()));
+
+ // Get the constant manager from the ir context.
+ opt::analysis::ConstantManager* const_manager = context_->get_constant_mgr();
+
+ // Find the constant value used by the condition variable. Exit out if it
+ // isn't a constant int.
+ const opt::analysis::Constant* upper_bound =
+ const_manager->FindDeclaredConstant(condition->GetSingleWordOperand(3));
+ if (!upper_bound) return false;
+
+ // Must be integer because of the opcode on the condition.
+ int64_t condition_value = 0;
+
+ const opt::analysis::Integer* type =
+ upper_bound->AsIntConstant()->type()->AsInteger();
+
+ if (type->IsSigned()) {
+ condition_value = upper_bound->AsIntConstant()->GetS32BitValue();
+ } else {
+ condition_value = upper_bound->AsIntConstant()->GetU32BitValue();
+ }
+
+ // Find the instruction which is stepping through the loop.
+ ir::Instruction* step_inst = GetInductionStepOperation(this, induction);
+ if (!step_inst) return false;
+
+ // Find the constant value used by the condition variable.
+ const opt::analysis::Constant* step_constant =
+ const_manager->FindDeclaredConstant(step_inst->GetSingleWordOperand(3));
+ if (!step_constant) return false;
+
+ // Must be integer because of the opcode on the condition.
+ int64_t step_value = 0;
+
+ const opt::analysis::Integer* step_type =
+ step_constant->AsIntConstant()->type()->AsInteger();
+
+ if (step_type->IsSigned()) {
+ step_value = step_constant->AsIntConstant()->GetS32BitValue();
+ } else {
+ step_value = step_constant->AsIntConstant()->GetU32BitValue();
+ }
+
+ // If this is a subtraction step we should negate the step value.
+ if (step_inst->opcode() == SpvOp::SpvOpISub) {
+ step_value = -step_value;
+ }
+
+ // Find the inital value of the loop and make sure it is a constant integer.
+ int64_t init_value = 0;
+ if (!GetInductionInitValue(this, induction, &init_value)) return false;
+
+ // If iterations is non null then store the value in that.
+ if (iterations_out) {
+ int64_t num_itrs = GetIterations(condition->opcode(), condition_value,
+ init_value, step_value);
+
+ // If the loop body will not be reached return false.
+ if (num_itrs <= 0) {
+ return false;
+ }
+ assert(static_cast<size_t>(num_itrs) <= std::numeric_limits<size_t>::max());
+ *iterations_out = static_cast<size_t>(num_itrs);
+ }
+
+ if (step_value_out) {
+ *step_value_out = step_value;
+ }
+
+ if (init_value_out) {
+ *init_value_out = init_value;
+ }
+
+ return true;
+}
+
+// We retrieve the number of iterations using the following formula, diff /
+// |step_value| where diff is calculated differently according to the
+// |condition| and uses the |condition_value| and |init_value|. If diff /
+// |step_value| is NOT cleanly divisable then we add one to the sum.
+int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value,
+ int64_t init_value, int64_t step_value) const {
+ int64_t diff = 0;
+
+ // Take the abs of - step values.
+ step_value = llabs(step_value);
+
+ switch (condition) {
+ case SpvOp::SpvOpSLessThan:
+ case SpvOp::SpvOpULessThan: {
+ diff = condition_value - init_value;
+ break;
+ }
+ case SpvOp::SpvOpSGreaterThan:
+ case SpvOp::SpvOpUGreaterThan: {
+ diff = init_value - condition_value;
+ break;
+ }
+ default:
+ assert(false &&
+ "Could not retrieve number of iterations from the loop condition. "
+ "Condition is not supported.");
+ }
+
+ int64_t result = diff / step_value;
+
+ if (diff % step_value != 0) {
+ result += 1;
+ }
+ return result;
+}
+
+ir::Instruction* Loop::FindInductionVariable(
+ const ir::BasicBlock* condition_block) const {
+ // Find the branch instruction.
+ const ir::Instruction& branch_inst = *condition_block->ctail();
+
+ ir::Instruction* induction = nullptr;
+ // Verify that the branch instruction is a conditional branch.
+ if (branch_inst.opcode() == SpvOp::SpvOpBranchConditional) {
+ // From the branch instruction find the branch condition.
+ opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+ // Find the instruction representing the condition used in the conditional
+ // branch.
+ ir::Instruction* condition =
+ def_use_manager->GetDef(branch_inst.GetSingleWordOperand(0));
+
+ // Ensure that the condition is a less than operation.
+ if (condition && IsSupportedCondition(condition->opcode())) {
+ // The left hand side operand of the operation.
+ ir::Instruction* variable_inst =
+ def_use_manager->GetDef(condition->GetSingleWordOperand(2));
+
+ // Make sure the variable instruction used is a phi.
+ if (!variable_inst || variable_inst->opcode() != SpvOpPhi) return nullptr;
+
+ // Make sure the phi instruction only has two incoming blocks. Each
+ // incoming block will be represented by two in operands in the phi
+ // instruction, the value and the block which that value came from. We
+ // assume the cannocalised phi will have two incoming values, one from the
+ // preheader and one from the continue block.
+ size_t max_supported_operands = 4;
+ if (variable_inst->NumInOperands() == max_supported_operands) {
+ // The operand index of the first incoming block label.
+ uint32_t operand_label_1 = 1;
+
+ // The operand index of the second incoming block label.
+ uint32_t operand_label_2 = 3;
+
+ // Make sure one of them is the preheader.
+ if (variable_inst->GetSingleWordInOperand(operand_label_1) !=
+ loop_preheader_->id() &&
+ variable_inst->GetSingleWordInOperand(operand_label_2) !=
+ loop_preheader_->id()) {
+ return nullptr;
+ }
+
+ // And make sure that the other is the latch block.
+ if (variable_inst->GetSingleWordInOperand(operand_label_1) !=
+ loop_continue_->id() &&
+ variable_inst->GetSingleWordInOperand(operand_label_2) !=
+ loop_continue_->id()) {
+ return nullptr;
+ }
+ } else {
+ return nullptr;
+ }
+
+ if (!FindNumberOfIterations(variable_inst, &branch_inst, nullptr))
+ return nullptr;
+ induction = variable_inst;
+ }
+ }
+
+ return induction;
+}
+
+// Add and remove loops which have been marked for addition and removal to
+// maintain the state of the loop descriptor class.
+void LoopDescriptor::PostModificationCleanup() {
+ LoopContainerType loops_to_remove_;
+ for (ir::Loop* loop : loops_) {
+ if (loop->IsMarkedForRemoval()) {
+ loops_to_remove_.push_back(loop);
+ if (loop->HasParent()) {
+ loop->GetParent()->RemoveChildLoop(loop);
+ }
+ }
+ }
+
+ for (ir::Loop* loop : loops_to_remove_) {
+ loops_.erase(std::find(loops_.begin(), loops_.end(), loop));
+ }
+
+ for (auto& pair : loops_to_add_) {
+ ir::Loop* parent = pair.first;
+ ir::Loop* loop = pair.second;
+
+ if (parent) {
+ loop->SetParent(nullptr);
+ parent->AddNestedLoop(loop);
+
+ for (uint32_t block_id : loop->GetBlocks()) {
+ parent->AddBasicBlock(block_id);
+ }
+ }
+
+ loops_.emplace_back(loop);
+ }
+
+ loops_to_add_.clear();
+}
+
void LoopDescriptor::ClearLoops() {
for (Loop* loop : loops_) {
delete loop;
}
loops_.clear();
}
-
} // namespace ir
} // namespace spvtools
#include <vector>
#include "opt/basic_block.h"
+#include "opt/module.h"
#include "opt/tree_iterator.h"
namespace spvtools {
loop_continue_(nullptr),
loop_merge_(nullptr),
loop_preheader_(nullptr),
- parent_(nullptr) {}
+ parent_(nullptr),
+ loop_is_marked_for_removal_(false) {}
Loop(IRContext* context, opt::DominatorAnalysis* analysis, BasicBlock* header,
BasicBlock* continue_target, BasicBlock* merge_target);
return lvl;
}
+ inline size_t NumImmediateChildren() const { return nested_loops_.size(); }
+
// Adds |nested| as a nested loop of this loop. Automatically register |this|
// as the parent of |nested|.
inline void AddNestedLoop(Loop* nested) {
// Returns true if the instruction |inst| is inside this loop.
bool IsInsideLoop(Instruction* inst) const;
+ // Adds the Basic Block |bb| to this loop and its parents.
+ void AddBasicBlock(const BasicBlock* bb) { AddBasicBlock(bb->id()); }
+
+ // Adds the Basic Block with |id| to this loop and its parents.
+ void AddBasicBlock(uint32_t id) {
+ for (Loop* loop = this; loop != nullptr; loop = loop->parent_) {
+ loop_basic_blocks_.insert(id);
+ }
+ }
+
+ // Removes all the basic blocks from the set of basic blocks within the loop.
+ // This does not affect any of the stored pointers to the header, preheader,
+ // merge, or continue blocks.
+ void ClearBlocks() { loop_basic_blocks_.clear(); }
+
// Adds the Basic Block |bb| this loop and its parents.
void AddBasicBlockToLoop(const BasicBlock* bb) {
assert(IsBasicBlockInLoopSlow(bb) &&
AddBasicBlock(bb);
}
- // Adds the Basic Block |bb| this loop and its parents.
- void AddBasicBlock(const BasicBlock* bb) {
- for (Loop* loop = this; loop != nullptr; loop = loop->parent_) {
- loop_basic_blocks_.insert(bb->id());
+ // This function uses the |condition| to find the induction variable within
+ // the loop. This only works if the loop is bound by a single condition and a
+ // single induction variable.
+ ir::Instruction* FindInductionVariable(const ir::BasicBlock* condition) const;
+
+ // Returns the number of iterations within a loop when given the |induction|
+ // variable and the loop |condition| check. It stores the found number of
+ // iterations in the output parameter |iterations| and optionally, the step
+ // value in |step_value| and the initial value of the induction variable in
+ // |init_value|.
+ bool FindNumberOfIterations(const ir::Instruction* induction,
+ const ir::Instruction* condition,
+ size_t* iterations,
+ int64_t* step_amount = nullptr,
+ int64_t* init_value = nullptr) const;
+
+ // Returns the value of the OpLoopMerge control operand as a bool. Loop
+ // control can be None(0), Unroll(1), or DontUnroll(2). This function returns
+ // true if it is set to Unroll.
+ inline bool HasUnrollLoopControl() const {
+ assert(loop_header_);
+ if (!loop_header_->GetLoopMergeInst()) return false;
+
+ return loop_header_->GetLoopMergeInst()->GetSingleWordOperand(2) == 1;
+ }
+
+ // Finds the conditional block with a branch to the merge and continue blocks
+ // within the loop body.
+ ir::BasicBlock* FindConditionBlock() const;
+
+ // Remove the child loop form this loop.
+ inline void RemoveChildLoop(Loop* loop) {
+ nested_loops_.erase(
+ std::find(nested_loops_.begin(), nested_loops_.end(), loop));
+ loop->SetParent(nullptr);
+ }
+
+ // Mark this loop to be removed later by a call to
+ // LoopDescriptor::PostModificationCleanup.
+ inline void MarkLoopForRemoval() { loop_is_marked_for_removal_ = true; }
+
+ // Returns whether or not this loop has been marked for removal.
+ inline bool IsMarkedForRemoval() const { return loop_is_marked_for_removal_; }
+
+ // Returns true if all nested loops have been marked for removal.
+ inline bool AreAllChildrenMarkedForRemoval() const {
+ for (const Loop* child : nested_loops_) {
+ if (!child->IsMarkedForRemoval()) {
+ return false;
+ }
}
+ return true;
}
// Sets the parent loop of this loop, that is, a loop which contains this loop
// loop
bool AreAllOperandsOutsideLoop(IRContext* context, Instruction* inst);
+ // Extract the initial value from the |induction| variable and store it in
+ // |value|. If the function couldn't find the initial value of |induction|
+ // return false.
+ bool GetInductionInitValue(const ir::Loop* loop,
+ const ir::Instruction* induction,
+ int64_t* value) const;
+
+ // Takes in a phi instruction |induction| and the loop |header| and returns
+ // the step operation of the loop.
+ ir::Instruction* GetInductionStepOperation(
+ const ir::Loop* loop, const ir::Instruction* induction) const;
+
+ // Returns true if we can deduce the number of loop iterations in the step
+ // operation |step|. IsSupportedCondition must also be true for the condition
+ // instruction.
+ bool IsSupportedStepOp(SpvOp step) const;
+
+ // Returns true if we can deduce the number of loop iterations in the
+ // condition operation |condition|. IsSupportedStepOp must also be true for
+ // the step instruction.
+ bool IsSupportedCondition(SpvOp condition) const;
+
private:
IRContext* context_;
// The block which marks the start of the loop.
// Sets |merge| as the loop merge block. No checks are performed here.
inline void SetMergeBlockImpl(BasicBlock* merge) { loop_merge_ = merge; }
+ // Each differnt loop |condition| affects how we calculate the number of
+ // iterations using the |condition_value|, |init_value|, and |step_values| of
+ // the induction variable. This method will return the number of iterations in
+ // a loop with those values for a given |condition|.
+ int64_t GetIterations(SpvOp condition, int64_t condition_value,
+ int64_t init_value, int64_t step_value) const;
+
+ // This is to allow for loops to be removed mid iteration without invalidating
+ // the iterators.
+ bool loop_is_marked_for_removal_;
+
// This is only to allow LoopDescriptor::dummy_top_loop_ to add top level
// loops as child.
friend class LoopDescriptor;
basic_block_to_loop_[bb_id] = loop;
}
+ // Mark the loop |loop_to_add| as needing to be added when the user calls
+ // PostModificationCleanup. |parent| may be null.
+ inline void AddLoop(ir::Loop* loop_to_add, ir::Loop* parent) {
+ loops_to_add_.emplace_back(std::make_pair(parent, loop_to_add));
+ }
+
+ // Should be called to preserve the LoopAnalysis after loops have been marked
+ // for addition with AddLoop or MarkLoopForRemoval.
+ void PostModificationCleanup();
+
private:
// TODO(dneto): This should be a vector of unique_ptr. But VisualStudio 2013
// is unable to compile it.
using LoopContainerType = std::vector<Loop*>;
+ using LoopsToAddContainerType = std::vector<std::pair<Loop*, Loop*>>;
// Creates loop descriptors for the function |f|.
void PopulateList(const Function* f);
// A list of all the loops in the function. This variable owns the Loop
// objects.
LoopContainerType loops_;
+
// Dummy root: this "loop" is only there to help iterators creation.
Loop dummy_top_loop_;
+
std::unordered_map<uint32_t, Loop*> basic_block_to_loop_;
+
+ // List of the loops marked for addition when PostModificationCleanup is
+ // called.
+ LoopsToAddContainerType loops_to_add_;
};
} // namespace ir
--- /dev/null
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "opt/loop_unroller.h"
+#include <map>
+#include <memory>
+#include <utility>
+#include "opt/ir_builder.h"
+#include "opt/loop_utils.h"
+
+// Implements loop util unrolling functionality for fully and partially
+// unrolling loops. Given a factor it will duplicate the loop that many times,
+// appending each one to the end of the old loop and removing backedges, to
+// create a new unrolled loop.
+//
+// 1 - User calls LoopUtils::FullyUnroll or LoopUtils::PartiallyUnroll with a
+// loop they wish to unroll. LoopUtils::CanPerformUnroll is used to
+// validate that a given loop can be unrolled. That method (along with the
+// constructor of loop) checks that the IR is in the expected canonicalised
+// format.
+//
+// 2 - The LoopUtils methods create a LoopUnrollerUtilsImpl object to actually
+// perform the unrolling. This implements helper methods to copy the loop basic
+// blocks and remap the ids of instructions used inside them.
+//
+// 3 - The core of LoopUnrollerUtilsImpl is the Unroll method, this method
+// actually performs the loop duplication. It does this by creating a
+// LoopUnrollState object and then copying the loop as given by the factor
+// parameter. The LoopUnrollState object retains the state of the unroller
+// between the loop body copies as each iteration needs information on the last
+// to adjust the phi induction variable, adjust the OpLoopMerge instruction in
+// the main loop header, and change the previous continue block to point to the
+// new header and the new continue block to the main loop header.
+//
+// 4 - If the loop is to be fully unrolled then it is simply closed after step
+// 3, with the OpLoopMerge being deleted, the backedge removed, and the
+// condition blocks folded.
+//
+// 5 - If it is being partially unrolled: if the unrolling factor leaves the
+// loop with an even number of bodies with respect to the number of loop
+// iterations then step 3 is all that is needed. If it is uneven then we need to
+// duplicate the loop completely and unroll the duplicated loop to cover the
+// residual part and adjust the first loop to cover only the "even" part. For
+// instance if you request an unroll factor of 3 on a loop with 10 iterations
+// then copying the body three times would leave you with three bodies in the
+// loop
+// where the loop still iterates over each 4 times. So we make two loops one
+// iterating once then a second loop of three iterating 3 times.
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+// This utility class encapsulates some of the state we need to maintain between
+// loop unrolls. Specifically it maintains key blocks and the induction variable
+// in the current loop duplication step and the blocks from the previous one.
+// This is because each step of the unroll needs to use data from both the
+// preceding step and the original loop.
+struct LoopUnrollState {
+ LoopUnrollState()
+ : previous_phi_(nullptr),
+ previous_continue_block_(nullptr),
+ previous_condition_block_(nullptr),
+ new_phi(nullptr),
+ new_continue_block(nullptr),
+ new_condition_block(nullptr),
+ new_header_block(nullptr) {}
+
+ // Initialize from the loop descriptor class.
+ LoopUnrollState(ir::Instruction* induction, ir::BasicBlock* continue_block,
+ ir::BasicBlock* condition)
+ : previous_phi_(induction),
+ previous_continue_block_(continue_block),
+ previous_condition_block_(condition),
+ new_phi(nullptr),
+ new_continue_block(nullptr),
+ new_condition_block(nullptr),
+ new_header_block(nullptr) {}
+
+ // Swap the state so that the new nodes are now the previous nodes.
+ void NextIterationState() {
+ previous_phi_ = new_phi;
+ previous_continue_block_ = new_continue_block;
+ previous_condition_block_ = new_condition_block;
+
+ // Clear new nodes.
+ new_phi = nullptr;
+ new_continue_block = nullptr;
+ new_condition_block = nullptr;
+ new_header_block = nullptr;
+
+ // Clear new block/instruction maps.
+ new_blocks.clear();
+ new_inst.clear();
+ }
+
+ // The induction variable from the immediately preceding loop body.
+ ir::Instruction* previous_phi_;
+
+ // The previous continue block. The backedge will be removed from this and
+ // added to the new continue block.
+ ir::BasicBlock* previous_continue_block_;
+
+ // The previous condition block. This may be folded to flatten the loop.
+ ir::BasicBlock* previous_condition_block_;
+
+ // The new induction variable.
+ ir::Instruction* new_phi;
+
+ // The new continue block.
+ ir::BasicBlock* new_continue_block;
+
+ // The new condition block.
+ ir::BasicBlock* new_condition_block;
+
+ // The new header block.
+ ir::BasicBlock* new_header_block;
+
+ // A mapping of new block ids to the original blocks which they were copied
+ // from.
+ std::unordered_map<uint32_t, ir::BasicBlock*> new_blocks;
+
+ // A mapping of new instruction ids to the instruction ids from which they
+ // were copied.
+ std::unordered_map<uint32_t, uint32_t> new_inst;
+};
+
+// This class implements the actual unrolling. It uses a LoopUnrollState to
+// maintain the state of the unrolling inbetween steps.
+class LoopUnrollerUtilsImpl {
+ public:
+ using BasicBlockListTy = std::vector<std::unique_ptr<ir::BasicBlock>>;
+
+ LoopUnrollerUtilsImpl(ir::IRContext* c, ir::Function* function)
+ : context_(c),
+ function_(*function),
+ loop_condition_block_(nullptr),
+ loop_induction_variable_(nullptr),
+ number_of_loop_iterations_(0),
+ loop_step_value_(0),
+ loop_init_value_(0) {}
+
+ // Unroll the |loop| by given |factor| by copying the whole body |factor|
+ // times. The resulting basicblock structure will remain a loop.
+ void PartiallyUnroll(ir::Loop*, size_t factor);
+
+ // If partially unrolling the |loop| would leave the loop with too many bodies
+ // for its number of iterations then this method should be used. This method
+ // will duplicate the |loop| completely, making the duplicated loop the
+ // successor of the original's merge block. The original loop will have its
+ // condition changed to loop over the residual part and the duplicate will be
+ // partially unrolled. The resulting structure will be two loops.
+ void PartiallyUnrollResidualFactor(ir::Loop* loop, size_t factor);
+
+ // Fully unroll the |loop| by copying the full body by the total number of
+ // loop iterations, folding all conditions, and removing the backedge from the
+ // continue block to the header.
+ void FullyUnroll(ir::Loop* loop);
+
+ // Get the ID of the variable in the |phi| paired with |label|.
+ uint32_t GetPhiDefID(const ir::Instruction* phi, uint32_t label) const;
+
+ // Close the loop by removing the OpLoopMerge from the |loop| header block and
+ // making the backedge point to the merge block.
+ void CloseUnrolledLoop(ir::Loop* loop);
+
+ // Remove the OpConditionalBranch instruction inside |conditional_block| used
+ // to branch to either exit or continue the loop and replace it with an
+ // unconditional OpBranch to block |new_target|.
+ void FoldConditionBlock(ir::BasicBlock* condtion_block, uint32_t new_target);
+
+ // Add all blocks_to_add_ to function_ at the |insert_point|.
+ void AddBlocksToFunction(const ir::BasicBlock* insert_point);
+
+ // Duplicates the |old_loop|, cloning each body and remaping the ids without
+ // removing instructions or changing relative structure. Result will be stored
+ // in |new_loop|.
+ void DuplicateLoop(ir::Loop* old_loop, ir::Loop* new_loop);
+
+ inline size_t GetLoopIterationCount() const {
+ return number_of_loop_iterations_;
+ }
+
+ // Extracts the initial state information from the |loop|.
+ void Init(ir::Loop* loop);
+
+ private:
+ // Remap all the in |basic_block| to new IDs and keep the mapping of new ids
+ // to old
+ // ids. |loop| is used to identify special loop blocks (header, continue,
+ // ect).
+ void AssignNewResultIds(ir::BasicBlock* basic_block);
+
+ // Using the map built by AssignNewResultIds, for each instruction in
+ // |basic_block| use
+ // that map to substitute the IDs used by instructions (in the operands) with
+ // the new ids.
+ void RemapOperands(ir::BasicBlock* basic_block);
+
+ // Copy the whole body of the loop, all blocks dominated by the |loop| header
+ // and not dominated by the |loop| merge. The copied body will be linked to by
+ // the old |loop| continue block and the new body will link to the |loop|
+ // header via the new continue block. |eliminate_conditions| is used to decide
+ // whether or not to fold all the condition blocks other than the last one.
+ void CopyBody(ir::Loop* loop, bool eliminate_conditions);
+
+ // Copy a given |block_to_copy| in the |loop| and record the mapping of the
+ // old/new ids. |preserve_instructions| determines whether or not the method
+ // will modify (other than result_id) instructions which are copied.
+ void CopyBasicBlock(ir::Loop* loop, const ir::BasicBlock* block_to_copy,
+ bool preserve_instructions);
+
+ // The actual implementation of the unroll step. Unrolls |loop| by given
+ // |factor| by copying the body by |factor| times. Also propagates the
+ // induction variable value throughout the copies.
+ void Unroll(ir::Loop* loop, size_t factor);
+
+ // Fills the loop_blocks_inorder_ field with the ordered list of basic blocks
+ // as computed by the method ComputeLoopOrderedBlocks.
+ void ComputeLoopOrderedBlocks(ir::Loop* loop);
+
+ // Adds the blocks_to_add_ to both the |loop| and to the parent of |loop| if
+ // the parent exists.
+ void AddBlocksToLoop(ir::Loop* loop) const;
+
+ // A pointer to the IRContext. Used to add/remove instructions and for usedef
+ // chains.
+ ir::IRContext* context_;
+
+ // A reference the function the loop is within.
+ ir::Function& function_;
+
+ // A list of basic blocks to be added to the loop at the end of an unroll
+ // step.
+ BasicBlockListTy blocks_to_add_;
+
+ // List of instructions which are now dead and can be removed.
+ std::vector<ir::Instruction*> dead_instructions_;
+
+ // Maintains the current state of the transform between calls to unroll.
+ LoopUnrollState state_;
+
+ // An ordered list containing the loop basic blocks.
+ std::vector<ir::BasicBlock*> loop_blocks_inorder_;
+
+ // The block containing the condition check which contains a conditional
+ // branch to the merge and continue block.
+ ir::BasicBlock* loop_condition_block_;
+
+ // The induction variable of the loop.
+ ir::Instruction* loop_induction_variable_;
+
+ // The number of loop iterations that the loop would preform pre-unroll.
+ size_t number_of_loop_iterations_;
+
+ // The amount that the loop steps each iteration.
+ int64_t loop_step_value_;
+
+ // The value the loop starts stepping from.
+ int64_t loop_init_value_;
+};
+
+/*
+ * Static helper functions.
+ */
+
+// Retrieve the index of the OpPhi instruction |phi| which corresponds to the
+// incoming |block| id.
+static uint32_t GetPhiIndexFromLabel(const ir::BasicBlock* block,
+ const ir::Instruction* phi) {
+ for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
+ if (block->id() == phi->GetSingleWordInOperand(i)) {
+ return i;
+ }
+ }
+ assert(false && "Could not find operand in instruction.");
+ return 0;
+}
+
+void LoopUnrollerUtilsImpl::Init(ir::Loop* loop) {
+ loop_condition_block_ = loop->FindConditionBlock();
+
+ // When we reinit the second loop during PartiallyUnrollResidualFactor we need
+ // to use the cached value from the duplicate step as the dominator tree
+ // basded solution, loop->FindConditionBlock, requires all the nodes to be
+ // connected up with the correct branches. They won't be at this point.
+ if (!loop_condition_block_) {
+ loop_condition_block_ = state_.new_condition_block;
+ }
+ assert(loop_condition_block_);
+
+ loop_induction_variable_ = loop->FindInductionVariable(loop_condition_block_);
+ assert(loop_induction_variable_);
+
+ bool found = loop->FindNumberOfIterations(
+ loop_induction_variable_, &*loop_condition_block_->ctail(),
+ &number_of_loop_iterations_, &loop_step_value_, &loop_init_value_);
+ (void)found; // To silence unused variable warning on release builds.
+ assert(found);
+ ComputeLoopOrderedBlocks(loop);
+}
+
+// This function is used to partially unroll the loop when the factor provided
+// would normally lead to an illegal optimization. Instead of just unrolling the
+// loop it creates two loops and unrolls one and adjusts the condition on the
+// other. The end result being that the new loop pair iterates over the correct
+// number of bodies.
+void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(ir::Loop* loop,
+ size_t factor) {
+ // Create a new merge block for the first loop.
+ std::unique_ptr<ir::Instruction> new_label{new ir::Instruction(
+ context_, SpvOp::SpvOpLabel, 0, context_->TakeNextId(), {})};
+ std::unique_ptr<ir::BasicBlock> new_exit_bb{
+ new ir::BasicBlock(std::move(new_label))};
+
+ // Save the id of the block before we move it.
+ uint32_t new_merge_id = new_exit_bb->id();
+
+ // Add the block the list of blocks to add, we want this merge block to be
+ // right at the start of the new blocks.
+ blocks_to_add_.push_back(std::move(new_exit_bb));
+ ir::BasicBlock* new_exit_bb_raw = blocks_to_add_[0].get();
+ ir::Instruction& original_conditional_branch = *loop_condition_block_->tail();
+
+ // Duplicate the loop, providing access to the blocks of both loops.
+ // This is a naked new due to the VS2013 requirement of not having unique
+ // pointers in vectors, as it will be inserted into a vector with
+ // loop_descriptor.AddLoop.
+ ir::Loop* new_loop = new ir::Loop(*loop);
+
+ // Clear the basic blocks of the new loop.
+ new_loop->ClearBlocks();
+
+ DuplicateLoop(loop, new_loop);
+
+ // Add the blocks to the function.
+ AddBlocksToFunction(loop->GetMergeBlock());
+ blocks_to_add_.clear();
+
+ InstructionBuilder builder{context_, new_exit_bb_raw};
+ // Make the first loop branch to the second.
+ builder.AddBranch(new_loop->GetHeaderBlock()->id());
+
+ loop_condition_block_ = state_.new_condition_block;
+ loop_induction_variable_ = state_.new_phi;
+
+ // Unroll the new loop by the factor with the usual -1 to account for the
+ // existing block iteration.
+ Unroll(new_loop, factor);
+
+ // We need to account for the initial body when calculating the remainder.
+ int64_t remainder = loop_init_value_ +
+ (number_of_loop_iterations_ % factor) * loop_step_value_;
+
+ assert(remainder > std::numeric_limits<int32_t>::min() &&
+ remainder < std::numeric_limits<int32_t>::max());
+
+ ir::Instruction* new_constant = nullptr;
+
+ // If the remainder is negative then we add a signed constant, otherwise just
+ // add an unsigned constant.
+ if (remainder < 0) {
+ new_constant =
+ builder.Add32BitSignedIntegerConstant(static_cast<int32_t>(remainder));
+ } else {
+ new_constant = builder.Add32BitUnsignedIntegerConstant(
+ static_cast<int32_t>(remainder));
+ }
+
+ uint32_t constant_id = new_constant->result_id();
+
+ // Add the merge block to the back of the binary.
+ blocks_to_add_.push_back(
+ std::unique_ptr<ir::BasicBlock>(new_loop->GetMergeBlock()));
+
+ AddBlocksToLoop(new_loop);
+ // Add the blocks to the function.
+ AddBlocksToFunction(loop->GetMergeBlock());
+
+ // Reset the usedef analysis.
+ context_->InvalidateAnalysesExceptFor(
+ ir::IRContext::Analysis::kAnalysisLoopAnalysis);
+ opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+ // Update the condition check.
+ ir::Instruction* condition_check = def_use_manager->GetDef(
+ original_conditional_branch.GetSingleWordOperand(0));
+
+ // This should have been checked by the LoopUtils::CanPerformUnroll function
+ // before entering this.
+ assert(condition_check->opcode() == SpvOpSLessThan);
+ condition_check->SetInOperand(1, {constant_id});
+
+ // Update the next phi node. The phi will have a constant value coming in from
+ // the preheader block. For the duplicated loop we need to update the constant
+ // to be the amount of iterations covered by the first loop and the incoming
+ // block to be the first loops new merge block.
+ uint32_t phi_incoming_index =
+ GetPhiIndexFromLabel(loop->GetPreHeaderBlock(), loop_induction_variable_);
+ loop_induction_variable_->SetInOperand(phi_incoming_index - 1, {constant_id});
+ loop_induction_variable_->SetInOperand(phi_incoming_index, {new_merge_id});
+
+ context_->InvalidateAnalysesExceptFor(
+ ir::IRContext::Analysis::kAnalysisLoopAnalysis);
+
+ context_->ReplaceAllUsesWith(loop->GetMergeBlock()->id(), new_merge_id);
+
+ ir::LoopDescriptor& loop_descriptor =
+ *context_->GetLoopDescriptor(&function_);
+
+ loop_descriptor.AddLoop(new_loop, loop->GetParent());
+}
+
+// Duplicate the |loop| body |factor| number of times while keeping the loop
+// backedge intact.
+void LoopUnrollerUtilsImpl::PartiallyUnroll(ir::Loop* loop, size_t factor) {
+ Unroll(loop, factor);
+ AddBlocksToLoop(loop);
+ AddBlocksToFunction(loop->GetMergeBlock());
+}
+
+// Duplicate the |loop| body |factor| number of times while keeping the loop
+// backedge intact.
+void LoopUnrollerUtilsImpl::Unroll(ir::Loop* loop, size_t factor) {
+ state_ = LoopUnrollState{loop_induction_variable_, loop->GetLatchBlock(),
+ loop_condition_block_};
+ for (size_t i = 0; i < factor - 1; ++i) {
+ CopyBody(loop, true);
+ }
+
+ uint32_t phi_index = GetPhiIndexFromLabel(state_.previous_continue_block_,
+ state_.previous_phi_);
+ uint32_t phi_variable =
+ state_.previous_phi_->GetSingleWordInOperand(phi_index - 1);
+ uint32_t phi_label = state_.previous_phi_->GetSingleWordInOperand(phi_index);
+
+ ir::Instruction* original_phi = loop_induction_variable_;
+
+ // SetInOperands are offset by two.
+ original_phi->SetInOperand(phi_index - 1, {phi_variable});
+ original_phi->SetInOperand(phi_index, {phi_label});
+}
+
+// Fully unroll the loop by partially unrolling it by the number of loop
+// iterations minus one for the body already accounted for.
+void LoopUnrollerUtilsImpl::FullyUnroll(ir::Loop* loop) {
+ // We unroll the loop by number of iterations in the loop.
+ Unroll(loop, number_of_loop_iterations_);
+
+ // The first condition block is preserved until now so it can be copied.
+ FoldConditionBlock(loop_condition_block_, 1);
+
+ // Delete the OpLoopMerge and remove the backedge to the header.
+ CloseUnrolledLoop(loop);
+
+ // Mark the loop for later deletion. This allows us to preserve the loop
+ // iterators but still disregard dead loops.
+ loop->MarkLoopForRemoval();
+
+ // If the loop has a parent add the new blocks to the parent.
+ if (loop->GetParent()) {
+ AddBlocksToLoop(loop->GetParent());
+ }
+
+ // Add the blocks to the function.
+ AddBlocksToFunction(loop->GetMergeBlock());
+
+ // Invalidate all analyses.
+ context_->InvalidateAnalysesExceptFor(
+ ir::IRContext::Analysis::kAnalysisLoopAnalysis);
+}
+
+// Copy a given basic block, give it a new result_id, and store the new block
+// and the id mapping in the state. |preserve_instructions| is used to determine
+// whether or not this function should edit instructions other than the
+// |result_id|.
+void LoopUnrollerUtilsImpl::CopyBasicBlock(ir::Loop* loop,
+ const ir::BasicBlock* itr,
+ bool preserve_instructions) {
+ // Clone the block exactly, including the IDs.
+ ir::BasicBlock* basic_block = itr->Clone(context_);
+
+ basic_block->SetParent(itr->GetParent());
+
+ // Assign each result a new unique ID and keep a mapping of the old ids to
+ // the new ones.
+ AssignNewResultIds(basic_block);
+
+ // If this is the continue block we are copying.
+ if (itr == loop->GetLatchBlock()) {
+ // Make the OpLoopMerge point to this block for the continue.
+ if (!preserve_instructions) {
+ ir::Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst();
+ merge_inst->SetInOperand(1, {basic_block->id()});
+ }
+
+ state_.new_continue_block = basic_block;
+ }
+
+ // If this is the header block we are copying.
+ if (itr == loop->GetHeaderBlock()) {
+ state_.new_header_block = basic_block;
+
+ if (!preserve_instructions) {
+ // Remove the loop merge instruction if it exists.
+ ir::Instruction* merge_inst = basic_block->GetLoopMergeInst();
+ if (merge_inst) context_->KillInst(merge_inst);
+ }
+ }
+
+ // If this is the condition block we are copying.
+ if (itr == loop_condition_block_) {
+ state_.new_condition_block = basic_block;
+ }
+
+ // Add this block to the list of blocks to add to the function at the end of
+ // the unrolling process.
+ blocks_to_add_.push_back(std::unique_ptr<ir::BasicBlock>(basic_block));
+
+ // Keep tracking the old block via a map.
+ state_.new_blocks[itr->id()] = basic_block;
+}
+
+void LoopUnrollerUtilsImpl::CopyBody(ir::Loop* loop,
+ bool eliminate_conditions) {
+ // Copy each basic block in the loop, give them new ids, and save state
+ // information.
+ for (const ir::BasicBlock* itr : loop_blocks_inorder_) {
+ CopyBasicBlock(loop, itr, false);
+ }
+
+ // Set the previous continue block to point to the new header.
+ ir::Instruction& continue_branch = *state_.previous_continue_block_->tail();
+ continue_branch.SetInOperand(0, {state_.new_header_block->id()});
+
+ // As the algorithm copies the original loop blocks exactly, the tail of the
+ // latch block on iterations after the first one will be a branch to the new
+ // header and not the actual loop header. The last continue block in the loop
+ // should always be a backedge to the global header.
+ ir::Instruction& new_continue_branch = *state_.new_continue_block->tail();
+ new_continue_branch.SetInOperand(0, {loop->GetHeaderBlock()->id()});
+
+ // Update references to the old phi node with the actual variable.
+ const ir::Instruction* induction = loop_induction_variable_;
+ state_.new_inst[induction->result_id()] =
+ GetPhiDefID(state_.previous_phi_, state_.previous_continue_block_->id());
+
+ if (eliminate_conditions &&
+ state_.new_condition_block != loop_condition_block_) {
+ FoldConditionBlock(state_.new_condition_block, 1);
+ }
+
+ // Only reference to the header block is the backedge in the latch block,
+ // don't change this.
+ state_.new_inst[loop->GetHeaderBlock()->id()] = loop->GetHeaderBlock()->id();
+
+ for (auto& pair : state_.new_blocks) {
+ RemapOperands(pair.second);
+ }
+
+ dead_instructions_.push_back(state_.new_phi);
+
+ // Swap the state so the new is now the previous.
+ state_.NextIterationState();
+}
+
+uint32_t LoopUnrollerUtilsImpl::GetPhiDefID(const ir::Instruction* phi,
+ uint32_t label) const {
+ for (uint32_t operand = 3; operand < phi->NumOperands(); operand += 2) {
+ if (phi->GetSingleWordOperand(operand) == label) {
+ return phi->GetSingleWordOperand(operand - 1);
+ }
+ }
+
+ return 0;
+}
+
+void LoopUnrollerUtilsImpl::FoldConditionBlock(ir::BasicBlock* condition_block,
+ uint32_t operand_label) {
+ // Remove the old conditional branch to the merge and continue blocks.
+ ir::Instruction& old_branch = *condition_block->tail();
+ uint32_t new_target = old_branch.GetSingleWordOperand(operand_label);
+ context_->KillInst(&old_branch);
+
+ // Add the new unconditional branch to the merge block.
+ InstructionBuilder builder{context_, condition_block};
+ builder.AddBranch(new_target);
+}
+
+void LoopUnrollerUtilsImpl::CloseUnrolledLoop(ir::Loop* loop) {
+ // Remove the OpLoopMerge instruction from the function.
+ ir::Instruction* merge_inst = loop->GetHeaderBlock()->GetLoopMergeInst();
+ context_->KillInst(merge_inst);
+
+ // Remove the final backedge to the header and make it point instead to the
+ // merge block.
+ state_.previous_continue_block_->tail()->SetInOperand(
+ 0, {loop->GetMergeBlock()->id()});
+
+ // Remove the induction variable as the phi will now be invalid. Replace all
+ // uses with the constant initializer value (all uses of the phi will be in
+ // the first iteration with the subsequent phis already having been removed.
+ uint32_t initalizer_id =
+ GetPhiDefID(loop_induction_variable_, loop->GetPreHeaderBlock()->id());
+ context_->ReplaceAllUsesWith(loop_induction_variable_->result_id(),
+ initalizer_id);
+
+ // Remove the now unused phi.
+ context_->KillInst(loop_induction_variable_);
+}
+
+// Uses the first loop to create a copy of the loop with new IDs.
+void LoopUnrollerUtilsImpl::DuplicateLoop(ir::Loop* old_loop,
+ ir::Loop* new_loop) {
+ std::vector<ir::BasicBlock*> new_block_order;
+
+ // Copy every block in the old loop.
+ for (const ir::BasicBlock* itr : loop_blocks_inorder_) {
+ CopyBasicBlock(old_loop, itr, true);
+ new_block_order.push_back(blocks_to_add_.back().get());
+ }
+
+ ir::BasicBlock* new_merge = old_loop->GetMergeBlock()->Clone(context_);
+ new_merge->SetParent(old_loop->GetMergeBlock()->GetParent());
+ AssignNewResultIds(new_merge);
+ state_.new_blocks[old_loop->GetMergeBlock()->id()] = new_merge;
+ for (auto& pair : state_.new_blocks) {
+ RemapOperands(pair.second);
+ }
+
+ loop_blocks_inorder_ = std::move(new_block_order);
+
+ AddBlocksToLoop(new_loop);
+
+ new_loop->SetHeaderBlock(state_.new_header_block);
+ new_loop->SetLatchBlock(state_.new_continue_block);
+ new_loop->SetMergeBlock(new_merge);
+}
+
+void LoopUnrollerUtilsImpl::AddBlocksToFunction(
+ const ir::BasicBlock* insert_point) {
+ for (ir::Instruction* inst : dead_instructions_) {
+ context_->KillInst(inst);
+ }
+
+ for (auto basic_block_iterator = function_.begin();
+ basic_block_iterator != function_.end(); ++basic_block_iterator) {
+ if (basic_block_iterator->id() == insert_point->id()) {
+ basic_block_iterator.InsertBefore(&blocks_to_add_);
+ return;
+ }
+ }
+
+ assert(
+ false &&
+ "Could not add basic blocks to function as insert point was not found.");
+}
+
+// Assign all result_ids in |basic_block| instructions to new IDs and preserve
+// the mapping of new ids to old ones.
+void LoopUnrollerUtilsImpl::AssignNewResultIds(ir::BasicBlock* basic_block) {
+ // Label instructions aren't covered by normal traversal of the
+ // instructions.
+ uint32_t new_label_id = context_->TakeNextId();
+
+ // Assign a new id to the label.
+ state_.new_inst[basic_block->GetLabelInst()->result_id()] = new_label_id;
+ basic_block->GetLabelInst()->SetResultId(new_label_id);
+
+ for (ir::Instruction& inst : *basic_block) {
+ uint32_t old_id = inst.result_id();
+
+ // Ignore stores etc.
+ if (old_id == 0) {
+ continue;
+ }
+
+ // Give the instruction a new id.
+ inst.SetResultId(context_->TakeNextId());
+
+ // Save the mapping of old_id -> new_id.
+ state_.new_inst[old_id] = inst.result_id();
+
+ // Check if this instruction is the induction variable.
+ if (loop_induction_variable_->result_id() == old_id) {
+ // Save a pointer to the new copy of it.
+ state_.new_phi = &inst;
+ }
+ }
+}
+
+// For all instructions in |basic_block| check if the operands used are from a
+// copied instruction and if so swap out the operand for the copy of it.
+void LoopUnrollerUtilsImpl::RemapOperands(ir::BasicBlock* basic_block) {
+ for (ir::Instruction& inst : *basic_block) {
+ auto remap_operands_to_new_ids = [this](uint32_t* id) {
+ auto itr = state_.new_inst.find(*id);
+ if (itr != state_.new_inst.end()) {
+ *id = itr->second;
+ }
+ };
+
+ inst.ForEachInId(remap_operands_to_new_ids);
+ }
+}
+
+// Generate the ordered list of basic blocks in the |loop| and cache it for
+// later use.
+void LoopUnrollerUtilsImpl::ComputeLoopOrderedBlocks(ir::Loop* loop) {
+ loop_blocks_inorder_.clear();
+
+ opt::DominatorAnalysis* analysis =
+ context_->GetDominatorAnalysis(&function_, *context_->cfg());
+ opt::DominatorTree& tree = analysis->GetDomTree();
+
+ // Starting at the loop header BasicBlock, traverse the dominator tree until
+ // we reach the merge block and add every node we traverse to the set of
+ // blocks
+ // which we consider to be the loop.
+ auto begin_itr = tree.GetTreeNode(loop->GetHeaderBlock())->df_begin();
+ const ir::BasicBlock* merge = loop->GetMergeBlock();
+ auto func = [merge, &tree, this](DominatorTreeNode* node) {
+ if (!tree.Dominates(merge->id(), node->id())) {
+ this->loop_blocks_inorder_.push_back(node->bb_);
+ return true;
+ }
+ return false;
+ };
+
+ tree.VisitChildrenIf(func, begin_itr);
+}
+
+// Adds the blocks_to_add_ to both the loop and to the parent.
+void LoopUnrollerUtilsImpl::AddBlocksToLoop(ir::Loop* loop) const {
+ // Add the blocks to this loop.
+ for (auto& block_itr : blocks_to_add_) {
+ loop->AddBasicBlock(block_itr.get());
+ }
+
+ // Add the blocks to the parent as well.
+ if (loop->GetParent()) AddBlocksToLoop(loop->GetParent());
+}
+
+/*
+ * End LoopUtilsImpl.
+ */
+
+} // namespace
+
+/*
+ *
+ * Begin Utils.
+ *
+ * */
+
+bool LoopUtils::CanPerformUnroll() {
+ // The loop is expected to be in structured order.
+ if (!loop_->GetHeaderBlock()->GetMergeInst()) {
+ return false;
+ }
+
+ // Find check the loop has a condition we can find and evaluate.
+ const ir::BasicBlock* condition = loop_->FindConditionBlock();
+ if (!condition) return false;
+
+ // Check that we can find and process the induction variable.
+ const ir::Instruction* induction = loop_->FindInductionVariable(condition);
+ if (!induction || induction->opcode() != SpvOpPhi) return false;
+
+ // Check that we can find the number of loop iterations.
+ if (!loop_->FindNumberOfIterations(induction, &*condition->ctail(), nullptr))
+ return false;
+
+ // Make sure the continue block is a unconditional branch to the header
+ // block.
+ const ir::Instruction& branch = *loop_->GetLatchBlock()->ctail();
+ bool branching_assumption =
+ branch.opcode() == SpvOpBranch &&
+ branch.GetSingleWordInOperand(0) == loop_->GetHeaderBlock()->id();
+ if (!branching_assumption) {
+ return false;
+ }
+
+ // Make sure the induction is the only phi instruction we have in the loop
+ // header. Other optimizations have been seen to leave dead phi nodes in the
+ // header so we also check that the phi is used.
+ for (const ir::Instruction& inst : *loop_->GetHeaderBlock()) {
+ if (inst.opcode() == SpvOpPhi &&
+ inst.result_id() != induction->result_id()) {
+ return false;
+ }
+ }
+
+ // Ban breaks within the loop.
+ const std::vector<uint32_t>& merge_block_preds =
+ context_->cfg()->preds(loop_->GetMergeBlock()->id());
+ if (merge_block_preds.size() != 1) {
+ return false;
+ }
+
+ // Ban continues within the loop.
+ const std::vector<uint32_t>& continue_block_preds =
+ context_->cfg()->preds(loop_->GetLatchBlock()->id());
+ if (continue_block_preds.size() != 1) {
+ return false;
+ }
+
+ // Ban returns in the loop.
+ // Iterate over all the blocks within the loop and check that none of them
+ // exit the loop.
+ for (uint32_t label_id : loop_->GetBlocks()) {
+ const ir::BasicBlock* block = context_->cfg()->block(label_id);
+ if (block->ctail()->opcode() == SpvOp::SpvOpKill ||
+ block->ctail()->opcode() == SpvOp::SpvOpReturn ||
+ block->ctail()->opcode() == SpvOp::SpvOpReturnValue) {
+ return false;
+ }
+ }
+ // Can only unroll inner loops.
+ if (!loop_->AreAllChildrenMarkedForRemoval()) {
+ return false;
+ }
+
+ for (uint32_t block_id : loop_->GetBlocks()) {
+ opt::analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
+
+ ir::BasicBlock& bb = *context_->cfg()->block(block_id);
+ // For every instruction in the block.
+ for (ir::Instruction& inst : bb) {
+ if (inst.result_id() == 0) continue;
+
+ auto is_used_outside_loop = [this,
+ def_use_manager](ir::Instruction* user) {
+
+ if (!loop_->IsInsideLoop(user)) {
+ // Some optimization passes have been seen to leave dead phis in the
+ // IR so we check that if a phi is used outside of the loop that the
+ // user is not dead.
+ if (!(user->opcode() == SpvOpPhi &&
+ def_use_manager->NumUsers(user) == 0))
+ return false;
+ }
+ return true;
+ };
+
+ if (!def_use_manager->WhileEachUser(&inst, is_used_outside_loop)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+bool LoopUtils::PartiallyUnroll(size_t factor) {
+ if (factor == 1 || !CanPerformUnroll()) return false;
+
+ // Create the unroller utility.
+ LoopUnrollerUtilsImpl unroller{context_,
+ loop_->GetHeaderBlock()->GetParent()};
+ unroller.Init(loop_);
+
+ // If the unrolling factor is larger than or the same size as the loop just
+ // fully unroll the loop.
+ if (factor >= unroller.GetLoopIterationCount()) {
+ unroller.FullyUnroll(loop_);
+ return true;
+ }
+
+ // If the loop unrolling factor is an residual number of iterations we need to
+ // let run the loop for the residual part then let it branch into the unrolled
+ // remaining part. We add one when calucating the remainder to take into
+ // account the one iteration already in the loop.
+ if (unroller.GetLoopIterationCount() % factor != 0) {
+ unroller.PartiallyUnrollResidualFactor(loop_, factor);
+ } else {
+ unroller.PartiallyUnroll(loop_, factor);
+ }
+
+ return true;
+}
+
+bool LoopUtils::FullyUnroll() {
+ if (!CanPerformUnroll()) return false;
+
+ LoopUnrollerUtilsImpl unroller{context_,
+ loop_->GetHeaderBlock()->GetParent()};
+
+ unroller.Init(loop_);
+ unroller.FullyUnroll(loop_);
+
+ return true;
+}
+
+void LoopUtils::Finalize() {
+ // Clean up the loop descriptor to preserve the analysis.
+
+ ir::LoopDescriptor* LD = context_->GetLoopDescriptor(&function_);
+ LD->PostModificationCleanup();
+}
+
+/*
+ *
+ * Begin Pass.
+ *
+ */
+
+Pass::Status LoopUnroller::Process(ir::IRContext* c) {
+ context_ = c;
+ bool changed = false;
+ for (ir::Function& f : *c->module()) {
+ ir::LoopDescriptor* LD = context_->GetLoopDescriptor(&f);
+ for (ir::Loop& loop : *LD) {
+ LoopUtils loop_utils{c, &loop};
+ if (!loop.HasUnrollLoopControl() || !loop_utils.CanPerformUnroll()) {
+ continue;
+ }
+
+ loop_utils.FullyUnroll();
+ changed = true;
+ }
+ LD->PostModificationCleanup();
+ }
+
+ return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+} // namespace opt
+} // namespace spvtools
--- /dev/null
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_LOOP_UNROLLER_H_
+#define SOURCE_OPT_LOOP_UNROLLER_H_
+#include "opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+class LoopUnroller : public Pass {
+ public:
+ LoopUnroller() : Pass() {}
+
+ const char* name() const override { return "Loop unroller"; }
+
+ Status Process(ir::IRContext* context) override;
+
+ private:
+ ir::IRContext* context_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // SOURCE_OPT_LOOP_UNROLLER_H_
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef LIBSPIRV_OPT_LOOP_UTILS_H_
-#define LIBSPIRV_OPT_LOOP_UTILS_H_
+#ifndef SOURCE_OPT_LOOP_UTILS_H_
+#define SOURCE_OPT_LOOP_UTILS_H_
+#include <list>
+#include <memory>
+#include <vector>
+#include "opt/loop_descriptor.h"
namespace spvtools {
namespace opt {
-// Set of basic loop transformation.
+// LoopUtils is used to encapsulte loop optimizations and from the passes which
+// use them. Any pass which needs a loop optimization should do it through this
+// or through a pass which is using this.
class LoopUtils {
public:
LoopUtils(ir::IRContext* context, ir::Loop* loop)
- : context_(context), loop_(loop) {}
+ : context_(context),
+ loop_(loop),
+ function_(*loop_->GetHeaderBlock()->GetParent()) {}
// The converts the current loop to loop closed SSA form.
// In the loop closed SSA, all loop exiting values go through a dedicated Phi
// Preserves: CFG, def/use and instruction to block mapping.
void CreateLoopDedicatedExits();
+ // Perfom a partial unroll of |loop| by given |factor|. This will copy the
+ // body of the loop |factor| times. So a |factor| of one would give a new loop
+ // with the original body plus one unrolled copy body.
+ bool PartiallyUnroll(size_t factor);
+
+ // Fully unroll |loop|.
+ bool FullyUnroll();
+
+ // This function validates that |loop| meets the assumptions made by the
+ // implementation of the loop unroller. As the implementation accommodates
+ // more types of loops this function can reduce its checks.
+ //
+ // The conditions checked to ensure the loop can be unrolled are as follows:
+ // 1. That the loop is in structured order.
+ // 2. That the condinue block is a branch to the header.
+ // 3. That the only phi used in the loop is the induction variable.
+ // TODO(stephen@codeplay.com): This is a temporary mesure, after the loop is
+ // converted into LCSAA form and has a single entry and exit we can rewrite
+ // the other phis.
+ // 4. That this is an inner most loop, or that loops contained within this
+ // loop have already been fully unrolled.
+ // 5. That each instruction in the loop is only used within the loop.
+ // (Related to the above phi condition).
+ bool CanPerformUnroll();
+
+ // Maintains the loop descriptor object after the unroll functions have been
+ // called, otherwise the analysis should be invalidated.
+ void Finalize();
+
private:
ir::IRContext* context_;
ir::Loop* loop_;
+ ir::Function& function_;
};
} // namespace opt
} // namespace spvtools
-#endif // LIBSPIRV_OPT_LOOP_UTILS_H_
+#endif // SOURCE_OPT_LOOP_UTILS_H_
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::SimplificationPass>());
}
+
+Optimizer::PassToken CreateLoopFullyUnrollPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::LoopUnroller>());
+}
} // namespace spvtools
#include "local_single_block_elim_pass.h"
#include "local_single_store_elim_pass.h"
#include "local_ssa_elim_pass.h"
+#include "loop_unroller.h"
#include "merge_return_pass.h"
#include "null_pass.h"
#include "private_to_local_pass.h"
#include "strip_debug_info_pass.h"
#include "unify_const_pass.h"
#include "workaround1209.h"
-
#endif // LIBSPIRV_OPT_PASSES_H_
hoist_without_preheader.cpp
LIBS SPIRV-Tools-opt
)
+
+add_spvtools_unittest(TARGET loop_unroll_simple
+ SRCS ../function_utils.h
+ unroll_simple.cpp
+ LIBS SPIRV-Tools-opt
+)
+
+add_spvtools_unittest(TARGET loop_unroll_assumtion_checks
+ SRCS ../function_utils.h
+ unroll_assumptions.cpp
+ LIBS SPIRV-Tools-opt
+)
+
--- /dev/null
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+
+#include "../assembly_builder.h"
+#include "../function_utils.h"
+#include "../pass_fixture.h"
+#include "../pass_utils.h"
+#include "opt/loop_unroller.h"
+#include "opt/loop_utils.h"
+#include "opt/pass.h"
+
+namespace {
+
+using namespace spvtools;
+using ::testing::UnorderedElementsAre;
+
+template <int factor>
+class PartialUnrollerTestPass : public opt::Pass {
+ public:
+ PartialUnrollerTestPass() : Pass() {}
+
+ const char* name() const override { return "Loop unroller"; }
+
+ Status Process(ir::IRContext* context) override {
+ bool changed = false;
+ for (ir::Function& f : *context->module()) {
+ ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(&f);
+ for (auto& loop : loop_descriptor) {
+ opt::LoopUtils loop_utils{context, &loop};
+ if (loop_utils.PartiallyUnroll(factor)) {
+ changed = true;
+ }
+ }
+ }
+
+ if (changed) return Pass::Status::SuccessWithChange;
+ return Pass::Status::SuccessWithoutChange;
+ }
+};
+
+using PassClassTest = PassTest<::testing::Test>;
+
+/*
+Generated from the following GLSL
+#version 410 core
+layout(location = 0) flat in int in_upper_bound;
+void main() {
+ for (int i = ; i < in_upper_bound; ++i) {
+ x[i] = 1.0f;
+ }
+}
+*/
+TEST_F(PassClassTest, CheckUpperBound) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "in_upper_bound"
+OpName %4 "x"
+OpDecorate %3 Flat
+OpDecorate %3 Location 0
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%7 = OpTypeInt 32 1
+%8 = OpTypePointer Function %7
+%9 = OpConstant %7 0
+%10 = OpTypePointer Input %7
+%3 = OpVariable %10 Input
+%11 = OpTypeBool
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpConstant %12 1
+%18 = OpTypePointer Function %12
+%19 = OpConstant %7 1
+%2 = OpFunction %5 None %6
+%20 = OpLabel
+%4 = OpVariable %16 Function
+OpBranch %21
+%21 = OpLabel
+%22 = OpPhi %7 %9 %20 %23 %24
+OpLoopMerge %25 %24 None
+OpBranch %26
+%26 = OpLabel
+%27 = OpLoad %7 %3
+%28 = OpSLessThan %11 %22 %27
+OpBranchConditional %28 %29 %25
+%29 = OpLabel
+%30 = OpAccessChain %18 %4 %22
+OpStore %30 %17
+OpBranch %24
+%24 = OpLabel
+%23 = OpIAdd %7 %22 %19
+OpBranch %21
+%25 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+ // Make sure the pass doesn't run
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+ SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+ SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+ float out_array[10];
+ int i = 0;
+ for (int i = 0; i < 10; ++i) {
+ out_array[i] = i;
+ }
+ out_array[9] = i*10;
+}
+*/
+TEST_F(PassClassTest, InductionUsedOutsideOfLoop) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 10
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 10
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 1
+%18 = OpConstant %6 9
+%2 = OpFunction %4 None %5
+%19 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %20
+%20 = OpLabel
+%21 = OpPhi %6 %8 %19 %22 %23
+OpLoopMerge %24 %23 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %10 %21 %9
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%28 = OpConvertSToF %11 %21
+%29 = OpAccessChain %16 %3 %21
+OpStore %29 %28
+OpBranch %23
+%23 = OpLabel
+%22 = OpIAdd %6 %21 %17
+OpBranch %20
+%24 = OpLabel
+%30 = OpIMul %6 %21 %9
+%31 = OpConvertSToF %11 %30
+%32 = OpAccessChain %16 %3 %18
+OpStore %32 %31
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+ // Make sure the pass doesn't run
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+ SinglePassRunAndCheck<PartialUnrollerTestPass<1>>(text, text, false);
+ SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+ float out_array[10];
+ for (uint i = 0; i < 2; i++) {
+ for (float x = 0; x < 5; ++x) {
+ out_array[x + i*5] = i;
+ }
+ }
+}
+*/
+TEST_F(PassClassTest, UnrollNestedLoopsInvalid) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 0
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 2
+%10 = OpTypeBool
+%11 = OpTypeInt 32 1
+%12 = OpTypePointer Function %11
+%13 = OpConstant %11 0
+%14 = OpConstant %11 5
+%15 = OpTypeFloat 32
+%16 = OpConstant %6 10
+%17 = OpTypeArray %15 %16
+%18 = OpTypePointer Function %17
+%19 = OpConstant %6 5
+%20 = OpTypePointer Function %15
+%21 = OpConstant %11 1
+%22 = OpUndef %11
+%2 = OpFunction %4 None %5
+%23 = OpLabel
+%3 = OpVariable %18 Function
+OpBranch %24
+%24 = OpLabel
+%25 = OpPhi %6 %8 %23 %26 %27
+%28 = OpPhi %11 %22 %23 %29 %27
+OpLoopMerge %30 %27 None
+OpBranch %31
+%31 = OpLabel
+%32 = OpULessThan %10 %25 %9
+OpBranchConditional %32 %33 %30
+%33 = OpLabel
+OpBranch %34
+%34 = OpLabel
+%29 = OpPhi %11 %13 %33 %35 %36
+OpLoopMerge %37 %36 None
+OpBranch %38
+%38 = OpLabel
+%39 = OpSLessThan %10 %29 %14
+OpBranchConditional %39 %40 %37
+%40 = OpLabel
+%41 = OpBitcast %6 %29
+%42 = OpIMul %6 %25 %19
+%43 = OpIAdd %6 %41 %42
+%44 = OpConvertUToF %15 %25
+%45 = OpAccessChain %20 %3 %43
+OpStore %45 %44
+OpBranch %36
+%36 = OpLabel
+%35 = OpIAdd %11 %29 %21
+OpBranch %34
+%37 = OpLabel
+OpBranch %27
+%27 = OpLabel
+%26 = OpIAdd %6 %25 %21
+OpBranch %24
+%30 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+}
+
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main(){
+ float x[10];
+ int ind = 0;
+ for (int i = 0; i < 10; i++) {
+ ind = i;
+ x[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, MultiplePhiInHeader) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+OpName %3 "x"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 10
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 10
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%18 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %19
+%19 = OpLabel
+%20 = OpPhi %6 %8 %18 %21 %22
+%21 = OpPhi %6 %8 %18 %23 %22
+OpLoopMerge %24 %22 None
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %10 %21 %9
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%28 = OpConvertSToF %11 %21
+%29 = OpAccessChain %16 %3 %21
+OpStore %29 %28
+OpBranch %22
+%22 = OpLabel
+%23 = OpIAdd %6 %21 %17
+OpBranch %19
+%24 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main(){
+ float x[10];
+ for (int i = 0; i < 10; i++) {
+ if (i == 5) {
+ break;
+ }
+ x[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, BreakInBody) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+OpName %3 "x"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 10
+%10 = OpTypeBool
+%11 = OpConstant %6 5
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpTypePointer Function %12
+%18 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%19 = OpLabel
+%3 = OpVariable %16 Function
+OpBranch %20
+%20 = OpLabel
+%21 = OpPhi %6 %8 %19 %22 %23
+OpLoopMerge %24 %23 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %10 %21 %9
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%28 = OpIEqual %10 %21 %11
+OpSelectionMerge %29 None
+OpBranchConditional %28 %30 %29
+%30 = OpLabel
+OpBranch %24
+%29 = OpLabel
+%31 = OpConvertSToF %12 %21
+%32 = OpAccessChain %17 %3 %21
+OpStore %32 %31
+OpBranch %23
+%23 = OpLabel
+%22 = OpIAdd %6 %21 %18
+OpBranch %20
+%24 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main(){
+ float x[10];
+ for (int i = 0; i < 10; i++) {
+ if (i == 5) {
+ continue;
+ }
+ x[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, ContinueInBody) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+OpName %3 "x"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 10
+%10 = OpTypeBool
+%11 = OpConstant %6 5
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpTypePointer Function %12
+%18 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%19 = OpLabel
+%3 = OpVariable %16 Function
+OpBranch %20
+%20 = OpLabel
+%21 = OpPhi %6 %8 %19 %22 %23
+OpLoopMerge %24 %23 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %10 %21 %9
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%28 = OpIEqual %10 %21 %11
+OpSelectionMerge %29 None
+OpBranchConditional %28 %30 %29
+%30 = OpLabel
+OpBranch %23
+%29 = OpLabel
+%31 = OpConvertSToF %12 %21
+%32 = OpAccessChain %17 %3 %21
+OpStore %32 %31
+OpBranch %23
+%23 = OpLabel
+%22 = OpIAdd %6 %21 %18
+OpBranch %20
+%24 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main(){
+ float x[10];
+ for (int i = 0; i < 10; i++) {
+ if (i == 5) {
+ return;
+ }
+ x[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, ReturnInBody) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 440
+OpName %2 "main"
+OpName %3 "x"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 10
+%10 = OpTypeBool
+%11 = OpConstant %6 5
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpTypePointer Function %12
+%18 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%19 = OpLabel
+%3 = OpVariable %16 Function
+OpBranch %20
+%20 = OpLabel
+%21 = OpPhi %6 %8 %19 %22 %23
+OpLoopMerge %24 %23 Unroll
+OpBranch %25
+%25 = OpLabel
+%26 = OpSLessThan %10 %21 %9
+OpBranchConditional %26 %27 %24
+%27 = OpLabel
+%28 = OpIEqual %10 %21 %11
+OpSelectionMerge %29 None
+OpBranchConditional %28 %30 %29
+%30 = OpLabel
+OpReturn
+%29 = OpLabel
+%31 = OpConvertSToF %12 %21
+%32 = OpAccessChain %17 %3 %21
+OpStore %32 %31
+OpBranch %23
+%23 = OpLabel
+%22 = OpIAdd %6 %21 %18
+OpBranch %20
+%24 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, text, false);
+}
+
+} // namespace
--- /dev/null
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+
+#include "../assembly_builder.h"
+#include "../function_utils.h"
+#include "../pass_fixture.h"
+#include "../pass_utils.h"
+#include "opt/loop_unroller.h"
+#include "opt/loop_utils.h"
+#include "opt/pass.h"
+
+namespace {
+
+using namespace spvtools;
+using ::testing::UnorderedElementsAre;
+
+using PassClassTest = PassTest<::testing::Test>;
+
+/*
+Generated from the following GLSL
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+ float x[4];
+ for (int i = 0; i < 4; ++i) {
+ x[i] = 1.0f;
+ }
+}
+*/
+TEST_F(PassClassTest, SimpleFullyUnrollTest) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main" %3
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource GLSL 330
+ OpName %2 "main"
+ OpName %5 "x"
+ OpName %3 "c"
+ OpDecorate %3 Location 0
+ %6 = OpTypeVoid
+ %7 = OpTypeFunction %6
+ %8 = OpTypeInt 32 1
+ %9 = OpTypePointer Function %8
+ %10 = OpConstant %8 0
+ %11 = OpConstant %8 4
+ %12 = OpTypeBool
+ %13 = OpTypeFloat 32
+ %14 = OpTypeInt 32 0
+ %15 = OpConstant %14 4
+ %16 = OpTypeArray %13 %15
+ %17 = OpTypePointer Function %16
+ %18 = OpConstant %13 1
+ %19 = OpTypePointer Function %13
+ %20 = OpConstant %8 1
+ %21 = OpTypeVector %13 4
+ %22 = OpTypePointer Output %21
+ %3 = OpVariable %22 Output
+ %2 = OpFunction %6 None %7
+ %23 = OpLabel
+ %5 = OpVariable %17 Function
+ OpBranch %24
+ %24 = OpLabel
+ %35 = OpPhi %8 %10 %23 %34 %26
+ OpLoopMerge %25 %26 Unroll
+ OpBranch %27
+ %27 = OpLabel
+ %29 = OpSLessThan %12 %35 %11
+ OpBranchConditional %29 %30 %25
+ %30 = OpLabel
+ %32 = OpAccessChain %19 %5 %35
+ OpStore %32 %18
+ OpBranch %26
+ %26 = OpLabel
+ %34 = OpIAdd %8 %35 %20
+ OpBranch %24
+ %25 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 330
+OpName %2 "main"
+OpName %4 "x"
+OpName %3 "c"
+OpDecorate %3 Location 0
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%7 = OpTypeInt 32 1
+%8 = OpTypePointer Function %7
+%9 = OpConstant %7 0
+%10 = OpConstant %7 4
+%11 = OpTypeBool
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 4
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpConstant %12 1
+%18 = OpTypePointer Function %12
+%19 = OpConstant %7 1
+%20 = OpTypeVector %12 4
+%21 = OpTypePointer Output %20
+%3 = OpVariable %21 Output
+%2 = OpFunction %5 None %6
+%22 = OpLabel
+%4 = OpVariable %16 Function
+OpBranch %23
+%23 = OpLabel
+OpBranch %28
+%28 = OpLabel
+%29 = OpSLessThan %11 %9 %10
+OpBranch %30
+%30 = OpLabel
+%31 = OpAccessChain %18 %4 %9
+OpStore %31 %17
+OpBranch %26
+%26 = OpLabel
+%25 = OpIAdd %7 %9 %19
+OpBranch %32
+%32 = OpLabel
+OpBranch %34
+%34 = OpLabel
+%35 = OpSLessThan %11 %25 %10
+OpBranch %36
+%36 = OpLabel
+%37 = OpAccessChain %18 %4 %25
+OpStore %37 %17
+OpBranch %38
+%38 = OpLabel
+%39 = OpIAdd %7 %25 %19
+OpBranch %40
+%40 = OpLabel
+OpBranch %42
+%42 = OpLabel
+%43 = OpSLessThan %11 %39 %10
+OpBranch %44
+%44 = OpLabel
+%45 = OpAccessChain %18 %4 %39
+OpStore %45 %17
+OpBranch %46
+%46 = OpLabel
+%47 = OpIAdd %7 %39 %19
+OpBranch %48
+%48 = OpLabel
+OpBranch %50
+%50 = OpLabel
+%51 = OpSLessThan %11 %47 %10
+OpBranch %52
+%52 = OpLabel
+%53 = OpAccessChain %18 %4 %47
+OpStore %53 %17
+OpBranch %54
+%54 = OpLabel
+%55 = OpIAdd %7 %47 %19
+OpBranch %27
+%27 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+template <int factor>
+class PartialUnrollerTestPass : public opt::Pass {
+ public:
+ PartialUnrollerTestPass() : Pass() {}
+
+ const char* name() const override { return "Loop unroller"; }
+
+ Status Process(ir::IRContext* context) override {
+ for (ir::Function& f : *context->module()) {
+ ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(&f);
+ for (auto& loop : loop_descriptor) {
+ opt::LoopUtils loop_utils{context, &loop};
+ loop_utils.PartiallyUnroll(factor);
+ }
+ }
+
+ return Pass::Status::SuccessWithChange;
+ }
+};
+
+/*
+Generated from the following GLSL
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+ float x[10];
+ for (int i = 0; i < 10; ++i) {
+ x[i] = 1.0f;
+ }
+}
+*/
+TEST_F(PassClassTest, SimplePartialUnroll) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main" %3
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource GLSL 330
+ OpName %2 "main"
+ OpName %5 "x"
+ OpName %3 "c"
+ OpDecorate %3 Location 0
+ %6 = OpTypeVoid
+ %7 = OpTypeFunction %6
+ %8 = OpTypeInt 32 1
+ %9 = OpTypePointer Function %8
+ %10 = OpConstant %8 0
+ %11 = OpConstant %8 10
+ %12 = OpTypeBool
+ %13 = OpTypeFloat 32
+ %14 = OpTypeInt 32 0
+ %15 = OpConstant %14 10
+ %16 = OpTypeArray %13 %15
+ %17 = OpTypePointer Function %16
+ %18 = OpConstant %13 1
+ %19 = OpTypePointer Function %13
+ %20 = OpConstant %8 1
+ %21 = OpTypeVector %13 4
+ %22 = OpTypePointer Output %21
+ %3 = OpVariable %22 Output
+ %2 = OpFunction %6 None %7
+ %23 = OpLabel
+ %5 = OpVariable %17 Function
+ OpBranch %24
+ %24 = OpLabel
+ %35 = OpPhi %8 %10 %23 %34 %26
+ OpLoopMerge %25 %26 Unroll
+ OpBranch %27
+ %27 = OpLabel
+ %29 = OpSLessThan %12 %35 %11
+ OpBranchConditional %29 %30 %25
+ %30 = OpLabel
+ %32 = OpAccessChain %19 %5 %35
+ OpStore %32 %18
+ OpBranch %26
+ %26 = OpLabel
+ %34 = OpIAdd %8 %35 %20
+ OpBranch %24
+ %25 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const std::string output = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 330
+OpName %2 "main"
+OpName %4 "x"
+OpName %3 "c"
+OpDecorate %3 Location 0
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%7 = OpTypeInt 32 1
+%8 = OpTypePointer Function %7
+%9 = OpConstant %7 0
+%10 = OpConstant %7 10
+%11 = OpTypeBool
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpConstant %12 1
+%18 = OpTypePointer Function %12
+%19 = OpConstant %7 1
+%20 = OpTypeVector %12 4
+%21 = OpTypePointer Output %20
+%3 = OpVariable %21 Output
+%2 = OpFunction %5 None %6
+%22 = OpLabel
+%4 = OpVariable %16 Function
+OpBranch %23
+%23 = OpLabel
+%24 = OpPhi %7 %9 %22 %39 %38
+OpLoopMerge %27 %38 Unroll
+OpBranch %28
+%28 = OpLabel
+%29 = OpSLessThan %11 %24 %10
+OpBranchConditional %29 %30 %27
+%30 = OpLabel
+%31 = OpAccessChain %18 %4 %24
+OpStore %31 %17
+OpBranch %26
+%26 = OpLabel
+%25 = OpIAdd %7 %24 %19
+OpBranch %32
+%32 = OpLabel
+OpBranch %34
+%34 = OpLabel
+%35 = OpSLessThan %11 %25 %10
+OpBranch %36
+%36 = OpLabel
+%37 = OpAccessChain %18 %4 %25
+OpStore %37 %17
+OpBranch %38
+%38 = OpLabel
+%39 = OpIAdd %7 %25 %19
+OpBranch %23
+%27 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, output, false);
+}
+
+/*
+Generated from the following GLSL
+#version 330 core
+layout(location = 0) out vec4 c;
+void main() {
+ float x[10];
+ for (int i = 0; i < 10; ++i) {
+ x[i] = 1.0f;
+ }
+}
+*/
+TEST_F(PassClassTest, SimpleUnevenPartialUnroll) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main" %3
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource GLSL 330
+ OpName %2 "main"
+ OpName %5 "x"
+ OpName %3 "c"
+ OpDecorate %3 Location 0
+ %6 = OpTypeVoid
+ %7 = OpTypeFunction %6
+ %8 = OpTypeInt 32 1
+ %9 = OpTypePointer Function %8
+ %10 = OpConstant %8 0
+ %11 = OpConstant %8 10
+ %12 = OpTypeBool
+ %13 = OpTypeFloat 32
+ %14 = OpTypeInt 32 0
+ %15 = OpConstant %14 10
+ %16 = OpTypeArray %13 %15
+ %17 = OpTypePointer Function %16
+ %18 = OpConstant %13 1
+ %19 = OpTypePointer Function %13
+ %20 = OpConstant %8 1
+ %21 = OpTypeVector %13 4
+ %22 = OpTypePointer Output %21
+ %3 = OpVariable %22 Output
+ %2 = OpFunction %6 None %7
+ %23 = OpLabel
+ %5 = OpVariable %17 Function
+ OpBranch %24
+ %24 = OpLabel
+ %35 = OpPhi %8 %10 %23 %34 %26
+ OpLoopMerge %25 %26 Unroll
+ OpBranch %27
+ %27 = OpLabel
+ %29 = OpSLessThan %12 %35 %11
+ OpBranchConditional %29 %30 %25
+ %30 = OpLabel
+ %32 = OpAccessChain %19 %5 %35
+ OpStore %32 %18
+ OpBranch %26
+ %26 = OpLabel
+ %34 = OpIAdd %8 %35 %20
+ OpBranch %24
+ %25 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 330
+OpName %2 "main"
+OpName %4 "x"
+OpName %3 "c"
+OpDecorate %3 Location 0
+%5 = OpTypeVoid
+%6 = OpTypeFunction %5
+%7 = OpTypeInt 32 1
+%8 = OpTypePointer Function %7
+%9 = OpConstant %7 0
+%10 = OpConstant %7 10
+%11 = OpTypeBool
+%12 = OpTypeFloat 32
+%13 = OpTypeInt 32 0
+%14 = OpConstant %13 10
+%15 = OpTypeArray %12 %14
+%16 = OpTypePointer Function %15
+%17 = OpConstant %12 1
+%18 = OpTypePointer Function %12
+%19 = OpConstant %7 1
+%20 = OpTypeVector %12 4
+%21 = OpTypePointer Output %20
+%3 = OpVariable %21 Output
+%58 = OpConstant %13 1
+%2 = OpFunction %5 None %6
+%22 = OpLabel
+%4 = OpVariable %16 Function
+OpBranch %23
+%23 = OpLabel
+%24 = OpPhi %7 %9 %22 %25 %26
+OpLoopMerge %32 %26 Unroll
+OpBranch %28
+%28 = OpLabel
+%29 = OpSLessThan %11 %24 %58
+OpBranchConditional %29 %30 %32
+%30 = OpLabel
+%31 = OpAccessChain %18 %4 %24
+OpStore %31 %17
+OpBranch %26
+%26 = OpLabel
+%25 = OpIAdd %7 %24 %19
+OpBranch %23
+%32 = OpLabel
+OpBranch %33
+%33 = OpLabel
+%34 = OpPhi %7 %58 %32 %57 %56
+OpLoopMerge %41 %56 Unroll
+OpBranch %35
+%35 = OpLabel
+%36 = OpSLessThan %11 %34 %10
+OpBranchConditional %36 %37 %41
+%37 = OpLabel
+%38 = OpAccessChain %18 %4 %34
+OpStore %38 %17
+OpBranch %39
+%39 = OpLabel
+%40 = OpIAdd %7 %34 %19
+OpBranch %42
+%42 = OpLabel
+OpBranch %44
+%44 = OpLabel
+%45 = OpSLessThan %11 %40 %10
+OpBranch %46
+%46 = OpLabel
+%47 = OpAccessChain %18 %4 %40
+OpStore %47 %17
+OpBranch %48
+%48 = OpLabel
+%49 = OpIAdd %7 %40 %19
+OpBranch %50
+%50 = OpLabel
+OpBranch %52
+%52 = OpLabel
+%53 = OpSLessThan %11 %49 %10
+OpBranch %54
+%54 = OpLabel
+%55 = OpAccessChain %18 %4 %49
+OpStore %55 %17
+OpBranch %56
+%56 = OpLabel
+%57 = OpIAdd %7 %49 %19
+OpBranch %33
+%41 = OpLabel
+OpReturn
+%27 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ // By unrolling by a factor that doesn't divide evenly into the number of loop
+ // iterations we perfom an additional transform when partially unrolling to
+ // account for the remainder.
+ SinglePassRunAndCheck<PartialUnrollerTestPass<3>>(text, output, false);
+}
+
+/* Generated from
+#version 410 core
+layout(location=0) flat in int upper_bound;
+void main() {
+ float x[10];
+ for (int i = 2; i < 8; i+=2) {
+ x[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, SimpleLoopIterationsCheck) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(
+OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %5 "x"
+OpName %3 "upper_bound"
+OpDecorate %3 Flat
+OpDecorate %3 Location 0
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%8 = OpTypeInt 32 1
+%9 = OpTypePointer Function %8
+%10 = OpConstant %8 2
+%11 = OpConstant %8 8
+%12 = OpTypeBool
+%13 = OpTypeFloat 32
+%14 = OpTypeInt 32 0
+%15 = OpConstant %14 10
+%16 = OpTypeArray %13 %15
+%17 = OpTypePointer Function %16
+%18 = OpTypePointer Function %13
+%19 = OpTypePointer Input %8
+%3 = OpVariable %19 Input
+%2 = OpFunction %6 None %7
+%20 = OpLabel
+%5 = OpVariable %17 Function
+OpBranch %21
+%21 = OpLabel
+%34 = OpPhi %8 %10 %20 %33 %23
+OpLoopMerge %22 %23 Unroll
+OpBranch %24
+%24 = OpLabel
+%26 = OpSLessThan %12 %34 %11
+OpBranchConditional %26 %27 %22
+%27 = OpLabel
+%30 = OpConvertSToF %13 %34
+%31 = OpAccessChain %18 %5 %34
+OpStore %31 %30
+OpBranch %23
+%23 = OpLabel
+%33 = OpIAdd %8 %34 %10
+OpBranch %21
+%22 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+
+ ir::Function* f = spvtest::GetFunction(module, 2);
+
+ ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+ EXPECT_EQ(loop_descriptor.NumLoops(), 1u);
+
+ ir::Loop& loop = loop_descriptor.GetLoopByIndex(0);
+
+ EXPECT_TRUE(loop.HasUnrollLoopControl());
+
+ ir::BasicBlock* condition = loop.FindConditionBlock();
+ EXPECT_EQ(condition->id(), 24u);
+
+ ir::Instruction* induction = loop.FindInductionVariable(condition);
+ EXPECT_EQ(induction->result_id(), 34u);
+
+ opt::LoopUtils loop_utils{context.get(), &loop};
+ EXPECT_TRUE(loop_utils.CanPerformUnroll());
+
+ size_t iterations = 0;
+ EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(),
+ &iterations));
+ EXPECT_EQ(iterations, 3u);
+}
+
+/* Generated from
+#version 410 core
+void main() {
+ float x[10];
+ for (int i = -1; i < 6; i+=3) {
+ x[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, SimpleLoopIterationsCheckSignedInit) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(
+OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %5 "x"
+OpName %3 "upper_bound"
+OpDecorate %3 Flat
+OpDecorate %3 Location 0
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%8 = OpTypeInt 32 1
+%9 = OpTypePointer Function %8
+%10 = OpConstant %8 -1
+%11 = OpConstant %8 6
+%12 = OpTypeBool
+%13 = OpTypeFloat 32
+%14 = OpTypeInt 32 0
+%15 = OpConstant %14 10
+%16 = OpTypeArray %13 %15
+%17 = OpTypePointer Function %16
+%18 = OpTypePointer Function %13
+%19 = OpConstant %8 3
+%20 = OpTypePointer Input %8
+%3 = OpVariable %20 Input
+%2 = OpFunction %6 None %7
+%21 = OpLabel
+%5 = OpVariable %17 Function
+OpBranch %22
+%22 = OpLabel
+%35 = OpPhi %8 %10 %21 %34 %24
+OpLoopMerge %23 %24 None
+OpBranch %25
+%25 = OpLabel
+%27 = OpSLessThan %12 %35 %11
+OpBranchConditional %27 %28 %23
+%28 = OpLabel
+%31 = OpConvertSToF %13 %35
+%32 = OpAccessChain %18 %5 %35
+OpStore %32 %31
+OpBranch %24
+%24 = OpLabel
+%34 = OpIAdd %8 %35 %19
+OpBranch %22
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+
+ ir::Function* f = spvtest::GetFunction(module, 2);
+
+ ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+
+ EXPECT_EQ(loop_descriptor.NumLoops(), 1u);
+
+ ir::Loop& loop = loop_descriptor.GetLoopByIndex(0);
+
+ EXPECT_FALSE(loop.HasUnrollLoopControl());
+
+ ir::BasicBlock* condition = loop.FindConditionBlock();
+ EXPECT_EQ(condition->id(), 25u);
+
+ ir::Instruction* induction = loop.FindInductionVariable(condition);
+ EXPECT_EQ(induction->result_id(), 35u);
+
+ opt::LoopUtils loop_utils{context.get(), &loop};
+ EXPECT_TRUE(loop_utils.CanPerformUnroll());
+
+ size_t iterations = 0;
+ EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(),
+ &iterations));
+ EXPECT_EQ(iterations, 3u);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+ float out_array[6];
+ for (uint i = 0; i < 2; i++) {
+ for (int x = 0; x < 3; ++x) {
+ out_array[x + i*3] = i;
+ }
+ }
+}
+*/
+TEST_F(PassClassTest, UnrollNestedLoops) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource GLSL 410
+ OpName %4 "main"
+ OpName %35 "out_array"
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 0
+ %7 = OpTypePointer Function %6
+ %9 = OpConstant %6 0
+ %16 = OpConstant %6 2
+ %17 = OpTypeBool
+ %19 = OpTypeInt 32 1
+ %20 = OpTypePointer Function %19
+ %22 = OpConstant %19 0
+ %29 = OpConstant %19 3
+ %31 = OpTypeFloat 32
+ %32 = OpConstant %6 6
+ %33 = OpTypeArray %31 %32
+ %34 = OpTypePointer Function %33
+ %39 = OpConstant %6 3
+ %44 = OpTypePointer Function %31
+ %47 = OpConstant %19 1
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %35 = OpVariable %34 Function
+ OpBranch %10
+ %10 = OpLabel
+ %51 = OpPhi %6 %9 %5 %50 %13
+ OpLoopMerge %12 %13 Unroll
+ OpBranch %14
+ %14 = OpLabel
+ %18 = OpULessThan %17 %51 %16
+ OpBranchConditional %18 %11 %12
+ %11 = OpLabel
+ OpBranch %23
+ %23 = OpLabel
+ %54 = OpPhi %19 %22 %11 %48 %26
+ OpLoopMerge %25 %26 Unroll
+ OpBranch %27
+ %27 = OpLabel
+ %30 = OpSLessThan %17 %54 %29
+ OpBranchConditional %30 %24 %25
+ %24 = OpLabel
+ %37 = OpBitcast %6 %54
+ %40 = OpIMul %6 %51 %39
+ %41 = OpIAdd %6 %37 %40
+ %43 = OpConvertUToF %31 %51
+ %45 = OpAccessChain %44 %35 %41
+ OpStore %45 %43
+ OpBranch %26
+ %26 = OpLabel
+ %48 = OpIAdd %19 %54 %47
+ OpBranch %23
+ %25 = OpLabel
+ OpBranch %13
+ %13 = OpLabel
+ %50 = OpIAdd %6 %51 %47
+ OpBranch %10
+ %12 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 0
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 2
+%10 = OpTypeBool
+%11 = OpTypeInt 32 1
+%12 = OpTypePointer Function %11
+%13 = OpConstant %11 0
+%14 = OpConstant %11 3
+%15 = OpTypeFloat 32
+%16 = OpConstant %6 6
+%17 = OpTypeArray %15 %16
+%18 = OpTypePointer Function %17
+%19 = OpConstant %6 3
+%20 = OpTypePointer Function %15
+%21 = OpConstant %11 1
+%2 = OpFunction %4 None %5
+%22 = OpLabel
+%3 = OpVariable %18 Function
+OpBranch %23
+%23 = OpLabel
+OpBranch %28
+%28 = OpLabel
+%29 = OpULessThan %10 %8 %9
+OpBranch %30
+%30 = OpLabel
+OpBranch %31
+%31 = OpLabel
+OpBranch %36
+%36 = OpLabel
+%37 = OpSLessThan %10 %13 %14
+OpBranch %38
+%38 = OpLabel
+%39 = OpBitcast %6 %13
+%40 = OpIMul %6 %8 %19
+%41 = OpIAdd %6 %39 %40
+%42 = OpConvertUToF %15 %8
+%43 = OpAccessChain %20 %3 %41
+OpStore %43 %42
+OpBranch %34
+%34 = OpLabel
+%33 = OpIAdd %11 %13 %21
+OpBranch %44
+%44 = OpLabel
+OpBranch %46
+%46 = OpLabel
+%47 = OpSLessThan %10 %33 %14
+OpBranch %48
+%48 = OpLabel
+%49 = OpBitcast %6 %33
+%50 = OpIMul %6 %8 %19
+%51 = OpIAdd %6 %49 %50
+%52 = OpConvertUToF %15 %8
+%53 = OpAccessChain %20 %3 %51
+OpStore %53 %52
+OpBranch %54
+%54 = OpLabel
+%55 = OpIAdd %11 %33 %21
+OpBranch %56
+%56 = OpLabel
+OpBranch %58
+%58 = OpLabel
+%59 = OpSLessThan %10 %55 %14
+OpBranch %60
+%60 = OpLabel
+%61 = OpBitcast %6 %55
+%62 = OpIMul %6 %8 %19
+%63 = OpIAdd %6 %61 %62
+%64 = OpConvertUToF %15 %8
+%65 = OpAccessChain %20 %3 %63
+OpStore %65 %64
+OpBranch %66
+%66 = OpLabel
+%67 = OpIAdd %11 %55 %21
+OpBranch %35
+%35 = OpLabel
+OpBranch %26
+%26 = OpLabel
+%25 = OpIAdd %6 %8 %21
+OpBranch %68
+%68 = OpLabel
+OpBranch %70
+%70 = OpLabel
+%71 = OpULessThan %10 %25 %9
+OpBranch %72
+%72 = OpLabel
+OpBranch %73
+%73 = OpLabel
+OpBranch %74
+%74 = OpLabel
+%75 = OpSLessThan %10 %13 %14
+OpBranch %76
+%76 = OpLabel
+%77 = OpBitcast %6 %13
+%78 = OpIMul %6 %25 %19
+%79 = OpIAdd %6 %77 %78
+%80 = OpConvertUToF %15 %25
+%81 = OpAccessChain %20 %3 %79
+OpStore %81 %80
+OpBranch %82
+%82 = OpLabel
+%83 = OpIAdd %11 %13 %21
+OpBranch %84
+%84 = OpLabel
+OpBranch %85
+%85 = OpLabel
+%86 = OpSLessThan %10 %83 %14
+OpBranch %87
+%87 = OpLabel
+%88 = OpBitcast %6 %83
+%89 = OpIMul %6 %25 %19
+%90 = OpIAdd %6 %88 %89
+%91 = OpConvertUToF %15 %25
+%92 = OpAccessChain %20 %3 %90
+OpStore %92 %91
+OpBranch %93
+%93 = OpLabel
+%94 = OpIAdd %11 %83 %21
+OpBranch %95
+%95 = OpLabel
+OpBranch %96
+%96 = OpLabel
+%97 = OpSLessThan %10 %94 %14
+OpBranch %98
+%98 = OpLabel
+%99 = OpBitcast %6 %94
+%100 = OpIMul %6 %25 %19
+%101 = OpIAdd %6 %99 %100
+%102 = OpConvertUToF %15 %25
+%103 = OpAccessChain %20 %3 %101
+OpStore %103 %102
+OpBranch %104
+%104 = OpLabel
+%105 = OpIAdd %11 %94 %21
+OpBranch %106
+%106 = OpLabel
+OpBranch %107
+%107 = OpLabel
+%108 = OpIAdd %6 %25 %21
+OpBranch %27
+%27 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+ float out_array[2];
+ for (int i = -3; i < -1; i++) {
+ out_array[3 + i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, NegativeConditionAndInit) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource GLSL 410
+ OpName %4 "main"
+ OpName %23 "out_array"
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 1
+ %7 = OpTypePointer Function %6
+ %9 = OpConstant %6 -3
+ %16 = OpConstant %6 -1
+ %17 = OpTypeBool
+ %19 = OpTypeInt 32 0
+ %20 = OpConstant %19 2
+ %21 = OpTypeArray %6 %20
+ %22 = OpTypePointer Function %21
+ %25 = OpConstant %6 3
+ %30 = OpConstant %6 1
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %23 = OpVariable %22 Function
+ OpBranch %10
+ %10 = OpLabel
+ %32 = OpPhi %6 %9 %5 %31 %13
+ OpLoopMerge %12 %13 Unroll
+ OpBranch %14
+ %14 = OpLabel
+ %18 = OpSLessThan %17 %32 %16
+ OpBranchConditional %18 %11 %12
+ %11 = OpLabel
+ %26 = OpIAdd %6 %32 %25
+ %28 = OpAccessChain %7 %23 %26
+ OpStore %28 %32
+ OpBranch %13
+ %13 = OpLabel
+ %31 = OpIAdd %6 %32 %30
+ OpBranch %10
+ %12 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+
+const std::string expected = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 -3
+%9 = OpConstant %6 -1
+%10 = OpTypeBool
+%11 = OpTypeInt 32 0
+%12 = OpConstant %11 2
+%13 = OpTypeArray %6 %12
+%14 = OpTypePointer Function %13
+%15 = OpConstant %6 3
+%16 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%17 = OpLabel
+%3 = OpVariable %14 Function
+OpBranch %18
+%18 = OpLabel
+OpBranch %23
+%23 = OpLabel
+%24 = OpSLessThan %10 %8 %9
+OpBranch %25
+%25 = OpLabel
+%26 = OpIAdd %6 %8 %15
+%27 = OpAccessChain %7 %3 %26
+OpStore %27 %8
+OpBranch %21
+%21 = OpLabel
+%20 = OpIAdd %6 %8 %16
+OpBranch %28
+%28 = OpLabel
+OpBranch %30
+%30 = OpLabel
+%31 = OpSLessThan %10 %20 %9
+OpBranch %32
+%32 = OpLabel
+%33 = OpIAdd %6 %20 %15
+%34 = OpAccessChain %7 %3 %33
+OpStore %34 %20
+OpBranch %35
+%35 = OpLabel
+%36 = OpIAdd %6 %20 %16
+OpBranch %22
+%22 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ // SinglePassRunAndCheck<opt::LoopUnroller>(text, expected, false);
+
+ ir::Function* f = spvtest::GetFunction(module, 4);
+
+ ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+ EXPECT_EQ(loop_descriptor.NumLoops(), 1u);
+
+ ir::Loop& loop = loop_descriptor.GetLoopByIndex(0);
+
+ EXPECT_TRUE(loop.HasUnrollLoopControl());
+
+ ir::BasicBlock* condition = loop.FindConditionBlock();
+ EXPECT_EQ(condition->id(), 14u);
+
+ ir::Instruction* induction = loop.FindInductionVariable(condition);
+ EXPECT_EQ(induction->result_id(), 32u);
+
+ opt::LoopUtils loop_utils{context.get(), &loop};
+ EXPECT_TRUE(loop_utils.CanPerformUnroll());
+
+ size_t iterations = 0;
+ EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(),
+ &iterations));
+ EXPECT_EQ(iterations, 2u);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, expected, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+ float out_array[9];
+ for (int i = -10; i < -1; i++) {
+ out_array[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, NegativeConditionAndInitResidualUnroll) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource GLSL 410
+ OpName %4 "main"
+ OpName %23 "out_array"
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 1
+ %7 = OpTypePointer Function %6
+ %9 = OpConstant %6 -10
+ %16 = OpConstant %6 -1
+ %17 = OpTypeBool
+ %19 = OpTypeInt 32 0
+ %20 = OpConstant %19 9
+ %21 = OpTypeArray %6 %20
+ %22 = OpTypePointer Function %21
+ %25 = OpConstant %6 10
+ %30 = OpConstant %6 1
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %23 = OpVariable %22 Function
+ OpBranch %10
+ %10 = OpLabel
+ %32 = OpPhi %6 %9 %5 %31 %13
+ OpLoopMerge %12 %13 Unroll
+ OpBranch %14
+ %14 = OpLabel
+ %18 = OpSLessThan %17 %32 %16
+ OpBranchConditional %18 %11 %12
+ %11 = OpLabel
+ %26 = OpIAdd %6 %32 %25
+ %28 = OpAccessChain %7 %23 %26
+ OpStore %28 %32
+ OpBranch %13
+ %13 = OpLabel
+ %31 = OpIAdd %6 %32 %30
+ OpBranch %10
+ %12 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+
+const std::string expected = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 -10
+%9 = OpConstant %6 -1
+%10 = OpTypeBool
+%11 = OpTypeInt 32 0
+%12 = OpConstant %11 9
+%13 = OpTypeArray %6 %12
+%14 = OpTypePointer Function %13
+%15 = OpConstant %6 10
+%16 = OpConstant %6 1
+%48 = OpConstant %6 -9
+%2 = OpFunction %4 None %5
+%17 = OpLabel
+%3 = OpVariable %14 Function
+OpBranch %18
+%18 = OpLabel
+%19 = OpPhi %6 %8 %17 %20 %21
+OpLoopMerge %28 %21 Unroll
+OpBranch %23
+%23 = OpLabel
+%24 = OpSLessThan %10 %19 %48
+OpBranchConditional %24 %25 %28
+%25 = OpLabel
+%26 = OpIAdd %6 %19 %15
+%27 = OpAccessChain %7 %3 %26
+OpStore %27 %19
+OpBranch %21
+%21 = OpLabel
+%20 = OpIAdd %6 %19 %16
+OpBranch %18
+%28 = OpLabel
+OpBranch %29
+%29 = OpLabel
+%30 = OpPhi %6 %48 %28 %47 %46
+OpLoopMerge %38 %46 Unroll
+OpBranch %31
+%31 = OpLabel
+%32 = OpSLessThan %10 %30 %9
+OpBranchConditional %32 %33 %38
+%33 = OpLabel
+%34 = OpIAdd %6 %30 %15
+%35 = OpAccessChain %7 %3 %34
+OpStore %35 %30
+OpBranch %36
+%36 = OpLabel
+%37 = OpIAdd %6 %30 %16
+OpBranch %39
+%39 = OpLabel
+OpBranch %41
+%41 = OpLabel
+%42 = OpSLessThan %10 %37 %9
+OpBranch %43
+%43 = OpLabel
+%44 = OpIAdd %6 %37 %15
+%45 = OpAccessChain %7 %3 %44
+OpStore %45 %37
+OpBranch %46
+%46 = OpLabel
+%47 = OpIAdd %6 %37 %16
+OpBranch %29
+%38 = OpLabel
+OpReturn
+%22 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+ ir::Function* f = spvtest::GetFunction(module, 4);
+
+ ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+ EXPECT_EQ(loop_descriptor.NumLoops(), 1u);
+
+ ir::Loop& loop = loop_descriptor.GetLoopByIndex(0);
+
+ EXPECT_TRUE(loop.HasUnrollLoopControl());
+
+ ir::BasicBlock* condition = loop.FindConditionBlock();
+ EXPECT_EQ(condition->id(), 14u);
+
+ ir::Instruction* induction = loop.FindInductionVariable(condition);
+ EXPECT_EQ(induction->result_id(), 32u);
+
+ opt::LoopUtils loop_utils{context.get(), &loop};
+ EXPECT_TRUE(loop_utils.CanPerformUnroll());
+
+ size_t iterations = 0;
+ EXPECT_TRUE(loop.FindNumberOfIterations(induction, &*condition->ctail(),
+ &iterations));
+ EXPECT_EQ(iterations, 9u);
+ SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, expected, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+ float out_array[10];
+ for (uint i = 0; i < 2; i++) {
+ for (int x = 0; x < 5; ++x) {
+ out_array[x + i*5] = i;
+ }
+ }
+}
+*/
+TEST_F(PassClassTest, UnrollNestedLoopsValidateDescriptor) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource GLSL 410
+ OpName %4 "main"
+ OpName %35 "out_array"
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 0
+ %7 = OpTypePointer Function %6
+ %9 = OpConstant %6 0
+ %16 = OpConstant %6 2
+ %17 = OpTypeBool
+ %19 = OpTypeInt 32 1
+ %20 = OpTypePointer Function %19
+ %22 = OpConstant %19 0
+ %29 = OpConstant %19 5
+ %31 = OpTypeFloat 32
+ %32 = OpConstant %6 10
+ %33 = OpTypeArray %31 %32
+ %34 = OpTypePointer Function %33
+ %39 = OpConstant %6 5
+ %44 = OpTypePointer Function %31
+ %47 = OpConstant %19 1
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %35 = OpVariable %34 Function
+ OpBranch %10
+ %10 = OpLabel
+ %51 = OpPhi %6 %9 %5 %50 %13
+ OpLoopMerge %12 %13 Unroll
+ OpBranch %14
+ %14 = OpLabel
+ %18 = OpULessThan %17 %51 %16
+ OpBranchConditional %18 %11 %12
+ %11 = OpLabel
+ OpBranch %23
+ %23 = OpLabel
+ %54 = OpPhi %19 %22 %11 %48 %26
+ OpLoopMerge %25 %26 Unroll
+ OpBranch %27
+ %27 = OpLabel
+ %30 = OpSLessThan %17 %54 %29
+ OpBranchConditional %30 %24 %25
+ %24 = OpLabel
+ %37 = OpBitcast %6 %54
+ %40 = OpIMul %6 %51 %39
+ %41 = OpIAdd %6 %37 %40
+ %43 = OpConvertUToF %31 %51
+ %45 = OpAccessChain %44 %35 %41
+ OpStore %45 %43
+ OpBranch %26
+ %26 = OpLabel
+ %48 = OpIAdd %19 %54 %47
+ OpBranch %23
+ %25 = OpLabel
+ OpBranch %13
+ %13 = OpLabel
+ %50 = OpIAdd %6 %51 %47
+ OpBranch %10
+ %12 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ // clang-format on
+
+ { // Test fully unroll
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+ ir::Function* f = spvtest::GetFunction(module, 4);
+ ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+ EXPECT_EQ(loop_descriptor.NumLoops(), 2u);
+
+ ir::Loop& outer_loop = loop_descriptor.GetLoopByIndex(1);
+
+ EXPECT_TRUE(outer_loop.HasUnrollLoopControl());
+
+ ir::Loop& inner_loop = loop_descriptor.GetLoopByIndex(0);
+
+ EXPECT_TRUE(inner_loop.HasUnrollLoopControl());
+
+ EXPECT_EQ(outer_loop.GetBlocks().size(), 9u);
+
+ EXPECT_EQ(inner_loop.GetBlocks().size(), 4u);
+ EXPECT_EQ(outer_loop.NumImmediateChildren(), 1u);
+ EXPECT_EQ(inner_loop.NumImmediateChildren(), 0u);
+
+ {
+ opt::LoopUtils loop_utils{context.get(), &inner_loop};
+ loop_utils.FullyUnroll();
+ loop_utils.Finalize();
+ }
+
+ EXPECT_EQ(loop_descriptor.NumLoops(), 1u);
+ EXPECT_EQ(outer_loop.GetBlocks().size(), 25u);
+ EXPECT_EQ(outer_loop.NumImmediateChildren(), 0u);
+ {
+ opt::LoopUtils loop_utils{context.get(), &outer_loop};
+ loop_utils.FullyUnroll();
+ loop_utils.Finalize();
+ }
+ EXPECT_EQ(loop_descriptor.NumLoops(), 0u);
+ }
+
+ { // Test partially unroll
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+
+ ir::Function* f = spvtest::GetFunction(module, 4);
+ ir::LoopDescriptor& loop_descriptor = *context->GetLoopDescriptor(f);
+ EXPECT_EQ(loop_descriptor.NumLoops(), 2u);
+
+ ir::Loop& outer_loop = loop_descriptor.GetLoopByIndex(1);
+
+ EXPECT_TRUE(outer_loop.HasUnrollLoopControl());
+
+ ir::Loop& inner_loop = loop_descriptor.GetLoopByIndex(0);
+
+ EXPECT_TRUE(inner_loop.HasUnrollLoopControl());
+
+ EXPECT_EQ(outer_loop.GetBlocks().size(), 9u);
+
+ EXPECT_EQ(inner_loop.GetBlocks().size(), 4u);
+
+ EXPECT_EQ(outer_loop.NumImmediateChildren(), 1u);
+ EXPECT_EQ(inner_loop.NumImmediateChildren(), 0u);
+
+ opt::LoopUtils loop_utils{context.get(), &inner_loop};
+ loop_utils.PartiallyUnroll(2);
+ loop_utils.Finalize();
+
+ // The number of loops should actually grow.
+ EXPECT_EQ(loop_descriptor.NumLoops(), 3u);
+ EXPECT_EQ(outer_loop.GetBlocks().size(), 19u);
+ EXPECT_EQ(outer_loop.NumImmediateChildren(), 2u);
+ }
+}
+
+/*
+Generated from the following GLSL
+#version 440 core
+void main(){
+ float x[10];
+ int i = 1;
+ i = 0;
+ for (; i < 10; i++) {
+ x[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, UnrollWithInductionOutsideHeader) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 440
+OpName %main "main"
+OpName %x "x"
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%int_1 = OpConstant %int 1
+%int_0 = OpConstant %int 0
+%int_10 = OpConstant %int 10
+%bool = OpTypeBool
+%float = OpTypeFloat 32
+%uint = OpTypeInt 32 0
+%uint_10 = OpConstant %uint 10
+%_arr_float_uint_10 = OpTypeArray %float %uint_10
+%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10
+%_ptr_Function_float = OpTypePointer Function %float
+%main = OpFunction %void None %3
+%5 = OpLabel
+%x = OpVariable %_ptr_Function__arr_float_uint_10 Function
+OpBranch %11
+%11 = OpLabel
+%33 = OpPhi %int %int_0 %5 %32 %14
+OpLoopMerge %13 %14 None
+OpBranch %15
+%15 = OpLabel
+%19 = OpSLessThan %bool %33 %int_10
+OpBranchConditional %19 %12 %13
+%12 = OpLabel
+%28 = OpConvertSToF %float %33
+%30 = OpAccessChain %_ptr_Function_float %x %33
+OpStore %30 %28
+OpBranch %14
+%14 = OpLabel
+%32 = OpIAdd %int %33 %int_1
+OpBranch %11
+%13 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+const std::string expected = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 440
+OpName %main "main"
+OpName %x "x"
+%void = OpTypeVoid
+%5 = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%int_1 = OpConstant %int 1
+%int_0 = OpConstant %int 0
+%int_10 = OpConstant %int 10
+%bool = OpTypeBool
+%float = OpTypeFloat 32
+%uint = OpTypeInt 32 0
+%uint_10 = OpConstant %uint 10
+%_arr_float_uint_10 = OpTypeArray %float %uint_10
+%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10
+%_ptr_Function_float = OpTypePointer Function %float
+%main = OpFunction %void None %5
+%18 = OpLabel
+%x = OpVariable %_ptr_Function__arr_float_uint_10 Function
+OpBranch %19
+%19 = OpLabel
+%20 = OpPhi %int %int_0 %18 %37 %36
+OpLoopMerge %23 %36 None
+OpBranch %24
+%24 = OpLabel
+%25 = OpSLessThan %bool %20 %int_10
+OpBranchConditional %25 %26 %23
+%26 = OpLabel
+%27 = OpConvertSToF %float %20
+%28 = OpAccessChain %_ptr_Function_float %x %20
+OpStore %28 %27
+OpBranch %22
+%22 = OpLabel
+%21 = OpIAdd %int %20 %int_1
+OpBranch %29
+%29 = OpLabel
+OpBranch %31
+%31 = OpLabel
+%32 = OpSLessThan %bool %21 %int_10
+OpBranch %33
+%33 = OpLabel
+%34 = OpConvertSToF %float %21
+%35 = OpAccessChain %_ptr_Function_float %x %21
+OpStore %35 %34
+OpBranch %36
+%36 = OpLabel
+%37 = OpIAdd %int %21 %int_1
+OpBranch %19
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+
+ SinglePassRunAndCheck<PartialUnrollerTestPass<2>>(text, expected, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+ float out_array[3];
+ for (int i = 3; i > 0; --i) {
+ out_array[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, FullyUnrollNegativeStepLoopTest) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource GLSL 410
+ OpName %4 "main"
+ OpName %24 "out_array"
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 1
+ %7 = OpTypePointer Function %6
+ %9 = OpConstant %6 3
+ %16 = OpConstant %6 0
+ %17 = OpTypeBool
+ %19 = OpTypeFloat 32
+ %20 = OpTypeInt 32 0
+ %21 = OpConstant %20 3
+ %22 = OpTypeArray %19 %21
+ %23 = OpTypePointer Function %22
+ %28 = OpTypePointer Function %19
+ %31 = OpConstant %6 1
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %24 = OpVariable %23 Function
+ OpBranch %10
+ %10 = OpLabel
+ %33 = OpPhi %6 %9 %5 %32 %13
+ OpLoopMerge %12 %13 Unroll
+ OpBranch %14
+ %14 = OpLabel
+ %18 = OpSGreaterThan %17 %33 %16
+ OpBranchConditional %18 %11 %12
+ %11 = OpLabel
+ %27 = OpConvertSToF %19 %33
+ %29 = OpAccessChain %28 %24 %33
+ OpStore %29 %27
+ OpBranch %13
+ %13 = OpLabel
+ %32 = OpISub %6 %33 %31
+ OpBranch %10
+ %12 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 3
+%9 = OpConstant %6 0
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 3
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 1
+%2 = OpFunction %4 None %5
+%18 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %19
+%19 = OpLabel
+OpBranch %24
+%24 = OpLabel
+%25 = OpSGreaterThan %10 %8 %9
+OpBranch %26
+%26 = OpLabel
+%27 = OpConvertSToF %11 %8
+%28 = OpAccessChain %16 %3 %8
+OpStore %28 %27
+OpBranch %22
+%22 = OpLabel
+%21 = OpISub %6 %8 %17
+OpBranch %29
+%29 = OpLabel
+OpBranch %31
+%31 = OpLabel
+%32 = OpSGreaterThan %10 %21 %9
+OpBranch %33
+%33 = OpLabel
+%34 = OpConvertSToF %11 %21
+%35 = OpAccessChain %16 %3 %21
+OpStore %35 %34
+OpBranch %36
+%36 = OpLabel
+%37 = OpISub %6 %21 %17
+OpBranch %38
+%38 = OpLabel
+OpBranch %40
+%40 = OpLabel
+%41 = OpSGreaterThan %10 %37 %9
+OpBranch %42
+%42 = OpLabel
+%43 = OpConvertSToF %11 %37
+%44 = OpAccessChain %16 %3 %37
+OpStore %44 %43
+OpBranch %45
+%45 = OpLabel
+%46 = OpISub %6 %37 %17
+OpBranch %23
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+ float out_array[3];
+ for (int i = 9; i > 0; i-=3) {
+ out_array[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, FullyUnrollNegativeNonOneStepLoop) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource GLSL 410
+ OpName %4 "main"
+ OpName %24 "out_array"
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 1
+ %7 = OpTypePointer Function %6
+ %9 = OpConstant %6 9
+ %16 = OpConstant %6 0
+ %17 = OpTypeBool
+ %19 = OpTypeFloat 32
+ %20 = OpTypeInt 32 0
+ %21 = OpConstant %20 3
+ %22 = OpTypeArray %19 %21
+ %23 = OpTypePointer Function %22
+ %28 = OpTypePointer Function %19
+ %30 = OpConstant %6 3
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %24 = OpVariable %23 Function
+ OpBranch %10
+ %10 = OpLabel
+ %33 = OpPhi %6 %9 %5 %32 %13
+ OpLoopMerge %12 %13 Unroll
+ OpBranch %14
+ %14 = OpLabel
+ %18 = OpSGreaterThan %17 %33 %16
+ OpBranchConditional %18 %11 %12
+ %11 = OpLabel
+ %27 = OpConvertSToF %19 %33
+ %29 = OpAccessChain %28 %24 %33
+ OpStore %29 %27
+ OpBranch %13
+ %13 = OpLabel
+ %32 = OpISub %6 %33 %30
+ OpBranch %10
+ %12 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 9
+%9 = OpConstant %6 0
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 3
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 3
+%2 = OpFunction %4 None %5
+%18 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %19
+%19 = OpLabel
+OpBranch %24
+%24 = OpLabel
+%25 = OpSGreaterThan %10 %8 %9
+OpBranch %26
+%26 = OpLabel
+%27 = OpConvertSToF %11 %8
+%28 = OpAccessChain %16 %3 %8
+OpStore %28 %27
+OpBranch %22
+%22 = OpLabel
+%21 = OpISub %6 %8 %17
+OpBranch %29
+%29 = OpLabel
+OpBranch %31
+%31 = OpLabel
+%32 = OpSGreaterThan %10 %21 %9
+OpBranch %33
+%33 = OpLabel
+%34 = OpConvertSToF %11 %21
+%35 = OpAccessChain %16 %3 %21
+OpStore %35 %34
+OpBranch %36
+%36 = OpLabel
+%37 = OpISub %6 %21 %17
+OpBranch %38
+%38 = OpLabel
+OpBranch %40
+%40 = OpLabel
+%41 = OpSGreaterThan %10 %37 %9
+OpBranch %42
+%42 = OpLabel
+%43 = OpConvertSToF %11 %37
+%44 = OpAccessChain %16 %3 %37
+OpStore %44 %43
+OpBranch %45
+%45 = OpLabel
+%46 = OpISub %6 %37 %17
+OpBranch %23
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+ float out_array[3];
+ for (int i = 0; i < 7; i+=3) {
+ out_array[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, FullyUnrollNonDivisibleStepLoop) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+OpSource GLSL 410
+OpName %4 "main"
+OpName %24 "out_array"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%9 = OpConstant %6 0
+%16 = OpConstant %6 7
+%17 = OpTypeBool
+%19 = OpTypeFloat 32
+%20 = OpTypeInt 32 0
+%21 = OpConstant %20 3
+%22 = OpTypeArray %19 %21
+%23 = OpTypePointer Function %22
+%28 = OpTypePointer Function %19
+%30 = OpConstant %6 3
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%24 = OpVariable %23 Function
+OpBranch %10
+%10 = OpLabel
+%33 = OpPhi %6 %9 %5 %32 %13
+OpLoopMerge %12 %13 Unroll
+OpBranch %14
+%14 = OpLabel
+%18 = OpSLessThan %17 %33 %16
+OpBranchConditional %18 %11 %12
+%11 = OpLabel
+%27 = OpConvertSToF %19 %33
+%29 = OpAccessChain %28 %24 %33
+OpStore %29 %27
+OpBranch %13
+%13 = OpLabel
+%32 = OpIAdd %6 %33 %30
+OpBranch %10
+%12 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 0
+%9 = OpConstant %6 7
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 3
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 3
+%2 = OpFunction %4 None %5
+%18 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %19
+%19 = OpLabel
+OpBranch %24
+%24 = OpLabel
+%25 = OpSLessThan %10 %8 %9
+OpBranch %26
+%26 = OpLabel
+%27 = OpConvertSToF %11 %8
+%28 = OpAccessChain %16 %3 %8
+OpStore %28 %27
+OpBranch %22
+%22 = OpLabel
+%21 = OpIAdd %6 %8 %17
+OpBranch %29
+%29 = OpLabel
+OpBranch %31
+%31 = OpLabel
+%32 = OpSLessThan %10 %21 %9
+OpBranch %33
+%33 = OpLabel
+%34 = OpConvertSToF %11 %21
+%35 = OpAccessChain %16 %3 %21
+OpStore %35 %34
+OpBranch %36
+%36 = OpLabel
+%37 = OpIAdd %6 %21 %17
+OpBranch %38
+%38 = OpLabel
+OpBranch %40
+%40 = OpLabel
+%41 = OpSLessThan %10 %37 %9
+OpBranch %42
+%42 = OpLabel
+%43 = OpConvertSToF %11 %37
+%44 = OpAccessChain %16 %3 %37
+OpStore %44 %43
+OpBranch %45
+%45 = OpLabel
+%46 = OpIAdd %6 %37 %17
+OpBranch %23
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+/*
+Generated from the following GLSL
+#version 410 core
+void main() {
+ float out_array[4];
+ for (int i = 11; i > 0; i-=3) {
+ out_array[i] = i;
+ }
+}
+*/
+TEST_F(PassClassTest, FullyUnrollNegativeNonDivisibleStepLoop) {
+ // clang-format off
+ // With opt::LocalMultiStoreElimPass
+ const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+OpSource GLSL 410
+OpName %4 "main"
+OpName %24 "out_array"
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%9 = OpConstant %6 11
+%16 = OpConstant %6 0
+%17 = OpTypeBool
+%19 = OpTypeFloat 32
+%20 = OpTypeInt 32 0
+%21 = OpConstant %20 4
+%22 = OpTypeArray %19 %21
+%23 = OpTypePointer Function %22
+%28 = OpTypePointer Function %19
+%30 = OpConstant %6 3
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%24 = OpVariable %23 Function
+OpBranch %10
+%10 = OpLabel
+%33 = OpPhi %6 %9 %5 %32 %13
+OpLoopMerge %12 %13 Unroll
+OpBranch %14
+%14 = OpLabel
+%18 = OpSGreaterThan %17 %33 %16
+OpBranchConditional %18 %11 %12
+%11 = OpLabel
+%27 = OpConvertSToF %19 %33
+%29 = OpAccessChain %28 %24 %33
+OpStore %29 %27
+OpBranch %13
+%13 = OpLabel
+%32 = OpISub %6 %33 %30
+OpBranch %10
+%12 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+const std::string output =
+R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+OpSource GLSL 410
+OpName %2 "main"
+OpName %3 "out_array"
+%4 = OpTypeVoid
+%5 = OpTypeFunction %4
+%6 = OpTypeInt 32 1
+%7 = OpTypePointer Function %6
+%8 = OpConstant %6 11
+%9 = OpConstant %6 0
+%10 = OpTypeBool
+%11 = OpTypeFloat 32
+%12 = OpTypeInt 32 0
+%13 = OpConstant %12 4
+%14 = OpTypeArray %11 %13
+%15 = OpTypePointer Function %14
+%16 = OpTypePointer Function %11
+%17 = OpConstant %6 3
+%2 = OpFunction %4 None %5
+%18 = OpLabel
+%3 = OpVariable %15 Function
+OpBranch %19
+%19 = OpLabel
+OpBranch %24
+%24 = OpLabel
+%25 = OpSGreaterThan %10 %8 %9
+OpBranch %26
+%26 = OpLabel
+%27 = OpConvertSToF %11 %8
+%28 = OpAccessChain %16 %3 %8
+OpStore %28 %27
+OpBranch %22
+%22 = OpLabel
+%21 = OpISub %6 %8 %17
+OpBranch %29
+%29 = OpLabel
+OpBranch %31
+%31 = OpLabel
+%32 = OpSGreaterThan %10 %21 %9
+OpBranch %33
+%33 = OpLabel
+%34 = OpConvertSToF %11 %21
+%35 = OpAccessChain %16 %3 %21
+OpStore %35 %34
+OpBranch %36
+%36 = OpLabel
+%37 = OpISub %6 %21 %17
+OpBranch %38
+%38 = OpLabel
+OpBranch %40
+%40 = OpLabel
+%41 = OpSGreaterThan %10 %37 %9
+OpBranch %42
+%42 = OpLabel
+%43 = OpConvertSToF %11 %37
+%44 = OpAccessChain %16 %3 %37
+OpStore %44 %43
+OpBranch %45
+%45 = OpLabel
+%46 = OpISub %6 %37 %17
+OpBranch %47
+%47 = OpLabel
+OpBranch %49
+%49 = OpLabel
+%50 = OpSGreaterThan %10 %46 %9
+OpBranch %51
+%51 = OpLabel
+%52 = OpConvertSToF %11 %46
+%53 = OpAccessChain %16 %3 %46
+OpStore %53 %52
+OpBranch %54
+%54 = OpLabel
+%55 = OpISub %6 %46 %17
+OpBranch %23
+%23 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for ushader:\n"
+ << text << std::endl;
+
+ opt::LoopUnroller loop_unroller;
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<opt::LoopUnroller>(text, output, false);
+}
+
+} // namespace
optimizer->RegisterPass(CreateReplaceInvalidOpcodePass());
} else if (0 == strcmp(cur_arg, "--simplify-instructions")) {
optimizer->RegisterPass(CreateSimplificationPass());
+ } else if (0 == strcmp(cur_arg, "--loop-unroll")) {
+ optimizer->RegisterPass(CreateLoopFullyUnrollPass());
} else if (0 == strcmp(cur_arg, "--skip-validation")) {
*skip_validator = true;
} else if (0 == strcmp(cur_arg, "-O")) {