Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / pass / PermutationInsertionPass.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "PermutationInsertionPass.h"
19
20 #include "../../backend/builtin/Config.h"
21
22 #include "compiler/OperationLowerInfo.h"
23 #include "ir/operation/Permute.h"
24 #include "util/logging.h"
25
26 #include <cassert>
27 #include <memory>
28 #include <unordered_map>
29 #include <utility>
30
31 namespace onert
32 {
33 namespace compiler
34 {
35 namespace pass
36 {
37
38 void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Operand &object)
39 {
40   auto &operand_li_map = _lowered_graph.lower_info().operand;
41   auto &&operand_li = operand_li_map.getRawPtr(index);
42   assert(operand_li);
43
44   // NOTE Later, constants also will have Def
45   // Ignore constants
46   if (operand_li->def_factors().size() == 0)
47   {
48     return;
49   }
50
51   std::list<ir::OperationIndex> permute_indexes;
52
53   // Build a map for all necessary type of operands
54   std::unordered_map<PermuteFactor, ir::OperandIndex> factor_to_index;
55   {
56     assert(operand_li->def_factors().size() == 1);
57     for (auto &&factor : operand_li->def_factors())
58     {
59       factor_to_index.emplace(factor, index);
60     }
61
62     auto insert_set = operand_li->use_factors() - operand_li->def_factors();
63     for (auto &&factor : insert_set)
64     {
65       const auto permute_operation_index = insertPermute(index, factor);
66       permute_indexes.push_back(permute_operation_index);
67       const auto &permute_operation = _graph.operations().at(permute_operation_index);
68       const auto permuted_operand_index = permute_operation.getOutputs().at(0);
69       factor_to_index.emplace(factor, permuted_operand_index);
70     }
71   }
72
73   // Update operations' input that uses this operand
74   {
75     std::list<ir::OperationIndex> remove_list;
76
77     auto uses = object.getUses();
78     for (auto &&use : uses)
79     {
80       // If permute operation, ignore it
81       if (std::find(permute_indexes.begin(), permute_indexes.end(), use) != permute_indexes.end())
82         continue;
83
84       auto &operation = _graph.operations().at(use);
85       auto op_li = _lowered_graph.lower_info().operation.getRawPtr(use);
86       assert(op_li);
87       const auto op_layout = op_li->layout();
88       const backend::Backend *backend = op_li->backend();
89       assert(backend);
90       auto use_node_inputs = operation.getInputs();
91       assert(use_node_inputs.contains(index));
92
93       auto new_index = factor_to_index.at({backend, op_layout});
94       if (index != new_index)
95       {
96         // Update from operation
97         // Replace the same inputs of an operation at once for the following reasons:
98         // No. 2 and 3 above
99         operation.replaceInputs(index, new_index);
100
101         // Update from operand
102         remove_list.push_back(
103           use); // Removal should be done in another loop since we are in the loop
104         _graph.operands().at(new_index).insertUse(use);
105       }
106     }
107
108     for (const auto &operation_index : remove_list)
109     {
110       object.removeUse(operation_index);
111     }
112   }
113 }
114
115 ir::OperationIndex PermutationInsertionPass::insertPermute(const ir::OperandIndex &operand_index,
116                                                            const PermuteFactor &factor)
117 {
118   auto &operand = _graph.operands().at(operand_index);
119
120   // Generate output operand and permute operation
121   auto out_operand_index = _graph.addOperand(operand.shape(), operand.typeInfo());
122   // change model output if operand_index is model output index and the out operand is builtin
123   // backend
124   auto &model_outputs = _graph.getOutputs();
125   const backend::Backend *builtin_backend = compiler::BackendManager::get().getBuiltin();
126   assert(builtin_backend->config()->id() == onert::backend::builtin::Config::ID);
127
128   if (model_outputs.contains(operand_index) && factor.backend() == builtin_backend)
129   {
130     model_outputs.replace(operand_index, out_operand_index);
131   }
132
133   auto &operand_li_map = _lowered_graph.lower_info().operand;
134
135   // Find Permute information
136   auto input_factor = operand_li_map.getRawPtr(operand_index)->def_factors().getOnlyElement();
137   auto input_backend = input_factor.backend();
138   auto output_backend = factor.backend();
139   // NOTE Permute may not have specific layout because the layout of input and output may be
140   // different.
141   const auto permute_node_layout = ir::Layout::UNKNOWN;
142   // NOTE If one backend supports several layout, the backend must support Permute operation
143   const backend::Backend *permute_node_backend = compiler::BackendManager::get().getBuiltin();
144   assert(permute_node_backend->config()->id() == onert::backend::builtin::Config::ID);
145
146   if (input_backend == output_backend)
147   {
148     permute_node_backend = input_backend;
149   }
150   const PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout};
151
152   // Update LowerInfo of input operand
153   auto operand_lower_info = operand_li_map.getRawPtr(operand_index);
154   operand_lower_info->removeUsePermuteFactor(factor);
155   operand_lower_info->addUsePermuteFactor(permute_node_factor);
156
157   // Update LowerInfo of output operand
158   auto out_operand_li = std::make_unique<compiler::OperandLowerInfo>();
159
160   // The input and output factors of all nodes will be the same except Permute. So Tensor's
161   // allocators allocates memory using only the information of def permutation factor now.
162   // TODO Change param to permute_node_factor
163   out_operand_li->addDefPermuteFactor(factor);
164   out_operand_li->addUsePermuteFactor(factor);
165   operand_li_map.set(out_operand_index, std::move(out_operand_li));
166
167   // Insert permute operation to the graph
168   const auto input_layout = input_factor.layout();
169   const auto output_layout = factor.layout();
170   using Permute = ir::operation::Permute;
171   const auto permute_type = [&]() {
172     if (input_layout == ir::Layout::NHWC && output_layout == ir::Layout::NCHW)
173     {
174       return Permute::Type::NHWC_TO_NCHW;
175     }
176     else if (input_layout == ir::Layout::NCHW && output_layout == ir::Layout::NHWC)
177     {
178       return Permute::Type::NCHW_TO_NHWC;
179     }
180     else
181     {
182       return Permute::Type::COPY;
183     }
184   }();
185   auto insert_node = std::make_unique<Permute>(operand_index, out_operand_index, permute_type);
186
187   auto node_index = _graph.operations().push(std::move(insert_node));
188
189   VERBOSE_F() << "Permute Op inserted, node index : " << node_index << std::endl;
190   VERBOSE_F() << "  - Input (original) Operand : " << operand_index << "("
191               << input_factor.backend()->config()->id() << ")" << std::endl;
192   VERBOSE_F() << "  - Output(inserted) Operand : " << out_operand_index << "("
193               << factor.backend()->config()->id() << ")" << std::endl;
194
195   // Operation LowerInfo
196   {
197     auto &operation_li_map = _lowered_graph.lower_info().operation;
198     operation_li_map.set(node_index, std::make_unique<compiler::OperationLowerInfo>(
199                                        permute_node_backend, permute_node_layout));
200   }
201
202   // Update Use/Def info
203   {
204     _graph.operands().at(operand_index).insertUse(node_index);
205     _graph.operands().at(out_operand_index).setDef(node_index);
206   }
207   return node_index;
208 }
209 } // namespace pass
210 } // namespace compiler
211 } // namespace onert