Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / pass / PermutationInsertionPass.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "PermutationInsertionPass.h"
18
19 #include <cassert>
20 #include <utility>
21 #include <unordered_map>
22
23 #include "backend/controlflow/Config.h"
24 #include "ir/Operand.h"
25 #include "ir/operation/LowerInfo.h"
26 #include "ir/Graph.h"
27 #include "backend/IConfig.h"
28 #include "util/logging.h"
29 #include <memory>
30 #include "ir/operation/Permute.h"
31
32 namespace onert
33 {
34 namespace compiler
35 {
36 namespace pass
37 {
38
39 void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Operand &object)
40 {
41   auto &&operand_li = _lowered_graph.getLowerInfo(index);
42   assert(operand_li);
43
44   // NOTE Later, constants also will have Def
45   // Ignore constants
46   if (operand_li->def_factors().size() == 0)
47   {
48     return;
49   }
50
51   std::list<ir::OperationIndex> permute_indexes;
52
53   // Build a map for all necessary type of operands
54   std::unordered_map<ir::operand::PermuteFactor, ir::OperandIndex> factor_to_index;
55   {
56     assert(operand_li->def_factors().size() == 1);
57     for (auto factor : operand_li->def_factors())
58     {
59       factor_to_index.emplace(factor, index);
60     }
61
62     auto insert_set = operand_li->use_factors() - operand_li->def_factors();
63     for (auto factor : insert_set)
64     {
65       const auto permute_operation_index = insertPermute(index, factor);
66       permute_indexes.push_back(permute_operation_index);
67       const auto &permute_operation = _graph.operations().at(permute_operation_index);
68       const auto permuted_operand_index = permute_operation.getOutputs().at(0);
69       factor_to_index.emplace(factor, permuted_operand_index);
70     }
71   }
72
73   // Update operations' input that uses this operand
74   {
75     std::list<ir::OperationIndex> remove_list;
76
77     auto uses = object.getUses();
78     for (auto use : uses)
79     {
80       // If permute operation, ignore it
81       if (std::find(permute_indexes.begin(), permute_indexes.end(), use) != permute_indexes.end())
82         continue;
83
84       auto &operation = _graph.operations().at(use);
85       assert(_lowered_graph.op_seqs().containsOperation(use));
86       auto op_seq_index = _lowered_graph.op_seqs().getOperation(use);
87       auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
88       assert(op_seq_li);
89       const auto op_seq_layout = op_seq_li->layout();
90       const backend::Backend *backend = op_seq_li->backend();
91       assert(backend);
92       auto use_node_inputs = operation.getInputs();
93       assert(use_node_inputs.contains(index));
94
95       auto new_index = factor_to_index.at({backend, op_seq_layout});
96       if (index != new_index)
97       {
98         // Update from op_seq
99         // Replace the same inputs of an OpSequence at once for the following reasons:
100         // 1. An OpSequence's inputs are the same inputs of first operation
101         // 2. An OpSequence may have inputs as the same operand (2 or more).
102         // 3. The same inputs of OpSequence have the same PermuteFactor.
103         _lowered_graph.op_seqs().at(op_seq_index).replaceInputs(index, new_index);
104
105         // Update from operation
106         // Replace the same inputs of an operation at once for the following reasons:
107         // No. 2 and 3 above
108         operation.replaceInputs(index, new_index);
109
110         // Update from operand
111         remove_list.push_back(
112             use); // Removal should be done in another loop since we are in the loop
113         _graph.operands().at(new_index).insertUse(use);
114       }
115     }
116
117     for (auto &operation : remove_list)
118     {
119       object.removeUse(operation);
120     }
121   }
122 }
123
124 ir::OperationIndex PermutationInsertionPass::insertPermute(const ir::OperandIndex &operand_index,
125                                                            const ir::operand::PermuteFactor &factor)
126 {
127   assert(!_graph.isBuildingPhase());
128
129   auto &operand = _graph.operands().at(operand_index);
130
131   // Generate output operand and permute operation
132   auto out_operand_index = _graph.addOperand(operand.shape(), operand.typeInfo());
133   // change model output if operand_index is model output index and the out operand is controlflow
134   // backend
135   auto &model_outputs = _graph.getOutputs();
136   const backend::Backend *cf_backend = compiler::BackendManager::get().getControlflow();
137   if (model_outputs.contains(operand_index) && factor.backend() == cf_backend)
138   {
139     model_outputs.replace(operand_index, out_operand_index);
140   }
141
142   // Find Permute information
143   auto input_factor = _lowered_graph.getLowerInfo(operand_index)->def_factors().getOnlyElement();
144   auto input_backend = input_factor.backend();
145   auto output_backend = factor.backend();
146   // NOTE Permute may not have specific layout because the layout of input and output may be
147   // different.
148   const auto permute_node_layout = ir::Layout::UNKNOWN;
149   // NOTE If one backend supports several layout, the backend must support Permute operation
150   const backend::Backend *permute_node_backend = compiler::BackendManager::get().getControlflow();
151   if (input_backend == output_backend)
152   {
153     permute_node_backend = input_backend;
154   }
155   const ir::operand::PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout};
156
157   // Update LowerInfo of input operand
158   auto operand_lower_info = _lowered_graph.getLowerInfo(operand_index);
159   operand_lower_info->removeUsePermuteFactor(factor);
160   operand_lower_info->addUsePermuteFactor(permute_node_factor);
161
162   // Update LowerInfo of output operand
163   auto out_operand_li = std::make_unique<ir::operand::LowerInfo>();
164
165   // The input and output factors of all nodes will be the same except Permute. So Tensor's
166   // allocators allocates memory using only the information of def permutation factor now.
167   // TODO Change param to permute_node_factor
168   out_operand_li->addDefPermuteFactor(factor);
169   out_operand_li->addUsePermuteFactor(factor);
170   _lowered_graph.setLowerInfo(out_operand_index, std::move(out_operand_li));
171
172   // Insert permute operation to the graph
173   const auto input_layout = input_factor.layout();
174   const auto output_layout = factor.layout();
175   using Permute = ir::operation::Permute;
176   const auto permute_type = [&]() {
177     if (input_layout == ir::Layout::NHWC && output_layout == ir::Layout::NCHW)
178     {
179       return Permute::Type::NHWC_TO_NCHW;
180     }
181     else if (input_layout == ir::Layout::NCHW && output_layout == ir::Layout::NHWC)
182     {
183       return Permute::Type::NCHW_TO_NHWC;
184     }
185     else
186     {
187       return Permute::Type::COPY;
188     }
189   }();
190   auto insert_node = std::make_unique<Permute>(operand_index, out_operand_index, permute_type);
191
192   auto node_index = _graph.operations().push(std::move(insert_node));
193   const auto &node = _graph.operations().at(node_index);
194
195   VERBOSE_F() << "Permute Op inserted, node index : " << node_index << std::endl;
196   VERBOSE_F() << "  - Input (original) Operand : " << operand_index << "("
197               << input_factor.backend()->config()->id() << ")" << std::endl;
198   VERBOSE_F() << "  - Output(inserted) Operand : " << out_operand_index << "("
199               << factor.backend()->config()->id() << ")" << std::endl;
200
201   // OpSequence
202   {
203     auto op_seq_index = _lowered_graph.op_seqs().emplace(node_index, permute_node_layout);
204     auto &op_seq = _lowered_graph.op_seqs().at(op_seq_index);
205     op_seq.setInputs(node.getInputs());
206     op_seq.setOutputs(node.getOutputs());
207     _lowered_graph.setLowerInfo(op_seq_index, std::make_unique<ir::operation::LowerInfo>(
208                                                   permute_node_backend, permute_node_layout));
209   }
210
211   // Update Use/Def info
212   {
213     _graph.operands().at(operand_index).insertUse(node_index);
214     _graph.operands().at(out_operand_index).setDef(node_index);
215   }
216   return node_index;
217 }
218 } // namespace pass
219 } // namespace compiler
220 } // namespace onert