1 // Copyright (c) 2017 Google Inc.
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include "opt/loop_descriptor.h"
17 #include <type_traits>
22 #include "opt/dominator_tree.h"
23 #include "opt/ir_builder.h"
24 #include "opt/ir_context.h"
25 #include "opt/iterator.h"
26 #include "opt/make_unique.h"
27 #include "opt/tree_iterator.h"
32 Loop::Loop(IRContext* context, opt::DominatorAnalysis* dom_analysis,
33 BasicBlock* header, BasicBlock* continue_target,
34 BasicBlock* merge_target)
35 : loop_header_(header),
36 loop_continue_(continue_target),
37 loop_merge_(merge_target),
38 loop_preheader_(nullptr),
42 loop_preheader_ = FindLoopPreheader(context, dom_analysis);
43 AddBasicBlockToLoop(header);
44 AddBasicBlockToLoop(continue_target);
47 BasicBlock* Loop::FindLoopPreheader(IRContext* ir_context,
48 opt::DominatorAnalysis* dom_analysis) {
49 CFG* cfg = ir_context->cfg();
50 opt::DominatorTree& dom_tree = dom_analysis->GetDomTree();
51 opt::DominatorTreeNode* header_node = dom_tree.GetTreeNode(loop_header_);
53 // The loop predecessor.
54 BasicBlock* loop_pred = nullptr;
56 auto header_pred = cfg->preds(loop_header_->id());
57 for (uint32_t p_id : header_pred) {
58 opt::DominatorTreeNode* node = dom_tree.GetTreeNode(p_id);
59 if (node && !dom_tree.Dominates(header_node, node)) {
60 // The predecessor is not part of the loop, so potential loop preheader.
61 if (loop_pred && node->bb_ != loop_pred) {
62 // If we saw 2 distinct predecessors that are outside the loop, we don't
63 // have a loop preheader.
66 loop_pred = node->bb_;
69 // Safe guard against invalid code, SPIR-V spec forbids loop with the entry
71 assert(loop_pred && "The header node is the entry block ?");
73 // So we have a unique basic block that can enter this loop.
74 // If this loop is the unique successor of this block, then it is a loop
76 bool is_preheader = true;
77 uint32_t loop_header_id = loop_header_->id();
78 const auto* const_loop_pred = loop_pred;
79 const_loop_pred->ForEachSuccessorLabel(
80 [&is_preheader, loop_header_id](const uint32_t id) {
81 if (id != loop_header_id) is_preheader = false;
83 if (is_preheader) return loop_pred;
87 bool Loop::IsInsideLoop(Instruction* inst) const {
88 const BasicBlock* parent_block = inst->context()->get_instr_block(inst);
89 if (!parent_block) return false;
90 return IsInsideLoop(parent_block);
93 bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) {
94 assert(bb->GetParent() && "The basic block does not belong to a function");
95 IRContext* context = bb->GetParent()->GetParent()->context();
97 opt::DominatorAnalysis* dom_analysis =
98 context->GetDominatorAnalysis(bb->GetParent(), *context->cfg());
99 if (!dom_analysis->Dominates(GetHeaderBlock(), bb)) return false;
101 opt::PostDominatorAnalysis* postdom_analysis =
102 context->GetPostDominatorAnalysis(bb->GetParent(), *context->cfg());
103 if (!postdom_analysis->Dominates(GetMergeBlock(), bb)) return false;
107 BasicBlock* Loop::GetOrCreatePreHeaderBlock(ir::IRContext* context) {
108 if (loop_preheader_) return loop_preheader_;
110 Function* fn = loop_header_->GetParent();
111 // Find the insertion point for the preheader.
112 Function::iterator header_it =
113 std::find_if(fn->begin(), fn->end(),
114 [this](BasicBlock& bb) { return &bb == loop_header_; });
115 assert(header_it != fn->end());
117 // Create the preheader basic block.
118 loop_preheader_ = &*header_it.InsertBefore(std::unique_ptr<ir::BasicBlock>(
119 new ir::BasicBlock(std::unique_ptr<ir::Instruction>(new ir::Instruction(
120 context, SpvOpLabel, 0, context->TakeNextId(), {})))));
121 loop_preheader_->SetParent(fn);
122 uint32_t loop_preheader_id = loop_preheader_->id();
124 // Redirect the branches and patch the phi:
125 // - For each phi instruction in the header:
126 // - If the header has only 1 out-of-loop incoming branch:
127 // - Change the incomning branch to be the preheader.
128 // - If the header has more than 1 out-of-loop incoming branch:
129 // - Create a new phi in the preheader, gathering all out-of-loops
131 // - Patch the header phi instruction to use the preheader phi
133 // - Redirect all edges coming from outside the loop to the preheader.
134 opt::InstructionBuilder builder(
135 context, loop_preheader_,
136 ir::IRContext::kAnalysisDefUse |
137 ir::IRContext::kAnalysisInstrToBlockMapping);
138 // Patch all the phi instructions.
139 loop_header_->ForEachPhiInst([&builder, context, this](Instruction* phi) {
140 std::vector<uint32_t> preheader_phi_ops;
141 std::vector<uint32_t> header_phi_ops;
142 for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
143 uint32_t def_id = phi->GetSingleWordInOperand(i);
144 uint32_t branch_id = phi->GetSingleWordInOperand(i + 1);
145 if (IsInsideLoop(branch_id)) {
146 header_phi_ops.push_back(def_id);
147 header_phi_ops.push_back(branch_id);
149 preheader_phi_ops.push_back(def_id);
150 preheader_phi_ops.push_back(branch_id);
154 Instruction* preheader_insn_def = nullptr;
155 // Create a phi instruction if and only if the preheader_phi_ops has more
157 if (preheader_phi_ops.size() > 2)
158 preheader_insn_def = builder.AddPhi(phi->type_id(), preheader_phi_ops);
161 context->get_def_use_mgr()->GetDef(preheader_phi_ops[0]);
162 // Build the new incoming edge.
163 header_phi_ops.push_back(preheader_insn_def->result_id());
164 header_phi_ops.push_back(loop_preheader_->id());
165 // Rewrite operands of the header's phi instruction.
167 for (; idx < header_phi_ops.size(); idx++)
168 phi->SetInOperand(idx, {header_phi_ops[idx]});
169 // Remove extra operands, from last to first (more efficient).
170 for (uint32_t j = phi->NumInOperands() - 1; j >= idx; j--)
171 phi->RemoveInOperand(j);
173 // Branch from the preheader to the header.
174 builder.AddBranch(loop_header_->id());
176 // Redirect all out of loop branches to the header to the preheader.
177 CFG* cfg = context->cfg();
178 cfg->RegisterBlock(loop_preheader_);
179 for (uint32_t pred_id : cfg->preds(loop_header_->id())) {
180 if (pred_id == loop_preheader_->id()) continue;
181 if (IsInsideLoop(pred_id)) continue;
182 BasicBlock* pred = cfg->block(pred_id);
183 pred->ForEachSuccessorLabel([this, loop_preheader_id](uint32_t* id) {
184 if (*id == loop_header_->id()) *id = loop_preheader_id;
186 cfg->AddEdge(pred_id, loop_preheader_id);
188 // Delete predecessors that are no longer predecessors of the loop header.
189 cfg->RemoveNonExistingEdges(loop_header_->id());
190 // Update the loop descriptors.
192 GetParent()->AddBasicBlock(loop_preheader_);
193 context->GetLoopDescriptor(fn)->SetBasicBlockToLoop(loop_preheader_->id(),
197 context->InvalidateAnalysesExceptFor(
198 builder.GetPreservedAnalysis() |
199 ir::IRContext::Analysis::kAnalysisLoopAnalysis |
200 ir::IRContext::kAnalysisCFG);
202 return loop_preheader_;
205 void Loop::SetLatchBlock(BasicBlock* latch) {
207 assert(latch->GetParent() && "The basic block does not belong to a function");
209 const auto* const_latch = latch;
210 const_latch->ForEachSuccessorLabel([this](uint32_t id) {
211 assert((!IsInsideLoop(id) || id == GetHeaderBlock()->id()) &&
212 "A predecessor of the continue block does not belong to the loop");
215 assert(IsInsideLoop(latch) && "The continue block is not in the loop");
217 SetLatchBlockImpl(latch);
220 void Loop::SetMergeBlock(BasicBlock* merge) {
222 assert(merge->GetParent() && "The basic block does not belong to a function");
223 CFG& cfg = *merge->GetParent()->GetParent()->context()->cfg();
225 for (uint32_t pred : cfg.preds(merge->id())) {
226 assert(IsInsideLoop(pred) &&
227 "A predecessor of the merge block does not belong to the loop");
230 assert(!IsInsideLoop(merge) && "The merge block is in the loop");
232 SetMergeBlockImpl(merge);
233 if (GetHeaderBlock()->GetLoopMergeInst()) {
234 UpdateLoopMergeInst();
238 void Loop::GetExitBlocks(IRContext* context,
239 std::unordered_set<uint32_t>* exit_blocks) const {
240 ir::CFG* cfg = context->cfg();
242 for (uint32_t bb_id : GetBlocks()) {
243 const spvtools::ir::BasicBlock* bb = cfg->block(bb_id);
244 bb->ForEachSuccessorLabel([exit_blocks, this](uint32_t succ) {
245 if (!IsInsideLoop(succ)) {
246 exit_blocks->insert(succ);
252 void Loop::GetMergingBlocks(
253 IRContext* context, std::unordered_set<uint32_t>* merging_blocks) const {
254 assert(GetMergeBlock() && "This loop is not structured");
255 ir::CFG* cfg = context->cfg();
257 std::stack<const ir::BasicBlock*> to_visit;
258 to_visit.push(GetMergeBlock());
259 while (!to_visit.empty()) {
260 const ir::BasicBlock* bb = to_visit.top();
262 merging_blocks->insert(bb->id());
263 for (uint32_t pred_id : cfg->preds(bb->id())) {
264 if (!IsInsideLoop(pred_id) && !merging_blocks->count(pred_id)) {
265 to_visit.push(cfg->block(pred_id));
271 bool Loop::IsLCSSA() const {
272 IRContext* context = GetHeaderBlock()->GetParent()->GetParent()->context();
273 ir::CFG* cfg = context->cfg();
274 opt::analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
276 std::unordered_set<uint32_t> exit_blocks;
277 GetExitBlocks(context, &exit_blocks);
279 for (uint32_t bb_id : GetBlocks()) {
280 for (Instruction& insn : *cfg->block(bb_id)) {
281 // All uses must be either:
283 // - In an exit block and in a phi instruction.
284 if (!def_use_mgr->WhileEachUser(
286 [&exit_blocks, context, this](ir::Instruction* use) -> bool {
287 BasicBlock* parent = context->get_instr_block(use);
288 assert(parent && "Invalid analysis");
289 if (IsInsideLoop(parent)) return true;
290 if (use->opcode() != SpvOpPhi) return false;
291 return exit_blocks.count(parent->id());
299 LoopDescriptor::LoopDescriptor(const Function* f) : loops_() {
303 LoopDescriptor::~LoopDescriptor() { ClearLoops(); }
305 void LoopDescriptor::PopulateList(const Function* f) {
306 IRContext* context = f->GetParent()->context();
308 opt::DominatorAnalysis* dom_analysis =
309 context->GetDominatorAnalysis(f, *context->cfg());
313 // Post-order traversal of the dominator tree to find all the OpLoopMerge
315 opt::DominatorTree& dom_tree = dom_analysis->GetDomTree();
316 for (opt::DominatorTreeNode& node :
317 ir::make_range(dom_tree.post_begin(), dom_tree.post_end())) {
318 Instruction* merge_inst = node.bb_->GetLoopMergeInst();
320 // The id of the merge basic block of this loop.
321 uint32_t merge_bb_id = merge_inst->GetSingleWordOperand(0);
323 // The id of the continue basic block of this loop.
324 uint32_t continue_bb_id = merge_inst->GetSingleWordOperand(1);
326 // The merge target of this loop.
327 BasicBlock* merge_bb = context->cfg()->block(merge_bb_id);
329 // The continue target of this loop.
330 BasicBlock* continue_bb = context->cfg()->block(continue_bb_id);
332 // The basic block containing the merge instruction.
333 BasicBlock* header_bb = context->get_instr_block(merge_inst);
335 // Add the loop to the list of all the loops in the function.
337 new Loop(context, dom_analysis, header_bb, continue_bb, merge_bb);
338 loops_.push_back(current_loop);
340 // We have a bottom-up construction, so if this loop has nested-loops,
341 // they are by construction at the tail of the loop list.
342 for (auto itr = loops_.rbegin() + 1; itr != loops_.rend(); ++itr) {
343 Loop* previous_loop = *itr;
345 // If the loop already has a parent, then it has been processed.
346 if (previous_loop->HasParent()) continue;
348 // If the current loop does not dominates the previous loop then it is
350 if (!dom_analysis->Dominates(header_bb,
351 previous_loop->GetHeaderBlock()))
353 // If the current loop merge dominates the previous loop then it is
355 if (dom_analysis->Dominates(merge_bb, previous_loop->GetHeaderBlock()))
358 current_loop->AddNestedLoop(previous_loop);
360 opt::DominatorTreeNode* dom_merge_node = dom_tree.GetTreeNode(merge_bb);
361 for (opt::DominatorTreeNode& loop_node :
362 make_range(node.df_begin(), node.df_end())) {
363 // Check if we are in the loop.
364 if (dom_tree.Dominates(dom_merge_node, &loop_node)) continue;
365 current_loop->AddBasicBlockToLoop(loop_node.bb_);
366 basic_block_to_loop_.insert(
367 std::make_pair(loop_node.bb_->id(), current_loop));
371 for (Loop* loop : loops_) {
372 if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop);
376 void LoopDescriptor::ClearLoops() {
377 for (Loop* loop : loops_) {
384 } // namespace spvtools