Initial implementation of merge return pass.
[platform/upstream/SPIRV-Tools.git] / source / opt / merge_return_pass.cpp
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "merge_return_pass.h"
16
17 #include "instruction.h"
18 #include "ir_context.h"
19
20 namespace spvtools {
21 namespace opt {
22
23 Pass::Status MergeReturnPass::Process(ir::IRContext* irContext) {
24   InitializeProcessing(irContext);
25
26   // TODO (alanbaker): Support structured control flow. Bail out in the
27   // meantime.
28   if (get_module()->HasCapability(SpvCapabilityShader))
29     return Status::SuccessWithoutChange;
30
31   bool modified = false;
32   for (auto& function : *get_module()) {
33     std::vector<ir::BasicBlock*> returnBlocks = CollectReturnBlocks(&function);
34     modified |= MergeReturnBlocks(&function, returnBlocks);
35   }
36
37   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
38 }
39
40 std::vector<ir::BasicBlock*> MergeReturnPass::CollectReturnBlocks(
41     ir::Function* function) {
42   std::vector<ir::BasicBlock*> returnBlocks;
43   for (auto& block : *function) {
44     ir::Instruction& terminator = *block.tail();
45     if (terminator.opcode() == SpvOpReturn ||
46         terminator.opcode() == SpvOpReturnValue) {
47       returnBlocks.push_back(&block);
48     }
49   }
50
51   return returnBlocks;
52 }
53
54 bool MergeReturnPass::MergeReturnBlocks(
55     ir::Function* function, const std::vector<ir::BasicBlock*>& returnBlocks) {
56   if (returnBlocks.size() <= 1) {
57     // No work to do.
58     return false;
59   }
60
61   // Create a label for the new return block
62   std::unique_ptr<ir::Instruction> returnLabel(
63       new ir::Instruction(SpvOpLabel, 0u, TakeNextId(), {}));
64   uint32_t returnId = returnLabel->result_id();
65
66   // Create the new basic block
67   std::unique_ptr<ir::BasicBlock> returnBlock(
68       new ir::BasicBlock(std::move(returnLabel)));
69   function->AddBasicBlock(std::move(returnBlock));
70   ir::Function::iterator retBlockIter = --function->end();
71
72   // Create the PHI for the merged block (if necessary)
73   // Create new return
74   std::vector<ir::Operand> phiOps;
75   for (auto block : returnBlocks) {
76     if (block->tail()->opcode() == SpvOpReturnValue) {
77       phiOps.push_back(
78           {SPV_OPERAND_TYPE_ID, {block->tail()->GetSingleWordInOperand(0u)}});
79       phiOps.push_back({SPV_OPERAND_TYPE_ID, {block->id()}});
80     }
81   }
82
83   if (!phiOps.empty()) {
84     // Need a PHI node to select the correct return value.
85     uint32_t phiResultId = TakeNextId();
86     uint32_t phiTypeId = function->type_id();
87     std::unique_ptr<ir::Instruction> phiInst(
88         new ir::Instruction(SpvOpPhi, phiTypeId, phiResultId, phiOps));
89     retBlockIter->AddInstruction(std::move(phiInst));
90     ir::BasicBlock::iterator phiIter = retBlockIter->tail();
91
92     std::unique_ptr<ir::Instruction> returnInst(new ir::Instruction(
93         SpvOpReturnValue, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {phiResultId}}}));
94     retBlockIter->AddInstruction(std::move(returnInst));
95     ir::BasicBlock::iterator ret = retBlockIter->tail();
96
97     get_def_use_mgr()->AnalyzeInstDefUse(&*phiIter);
98     get_def_use_mgr()->AnalyzeInstDef(&*ret);
99   } else {
100     std::unique_ptr<ir::Instruction> returnInst(
101         new ir::Instruction(SpvOpReturn));
102     retBlockIter->AddInstruction(std::move(returnInst));
103   }
104
105   // Replace returns with branches
106   for (auto block : returnBlocks) {
107     context()->KillInst(&*block->tail());
108     block->tail()->SetOpcode(SpvOpBranch);
109     block->tail()->ReplaceOperands({{SPV_OPERAND_TYPE_ID, {returnId}}});
110     get_def_use_mgr()->AnalyzeInstUse(&*block->tail());
111     get_def_use_mgr()->AnalyzeInstUse(block->GetLabelInst());
112   }
113
114   get_def_use_mgr()->AnalyzeInstDefUse(retBlockIter->GetLabelInst());
115
116   return true;
117 }
118
119 }  // namespace opt
120 }  // namespace spvtools