Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / pass / PermutationOperationPass.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "PermutationOperationPass.h"
18
19 #include "backend/Backend.h"
20 #include "backend/IConfig.h"
21 #include "ir/Graph.h"
22 #include "util/logging.h"
23
24 namespace onert
25 {
26 namespace compiler
27 {
28 namespace pass
29 {
30
31 using namespace ir;
32
33 void PermutationOperationPass::callback(const OperationIndex &, IOperation &node)
34 {
35   node.accept(*this);
36 }
37
38 // TODO Remove this. Expanding ranks of Operand is dangerous
39 void PermutationOperationPass::applyExpandRanks(const Operation &node)
40 {
41   const auto &output_ind = node.getOutputs().at(0);
42   const auto &output = _graph.operands().at(output_ind);
43
44   assert(output.getDef().valid());
45   const auto node_index = output.getDef();
46   const auto frontend_layout = _graph.layout();
47   const auto backend_layout = _lowered_graph.lower_info().operation.getRawPtr(node_index)->layout();
48
49   if (frontend_layout == backend_layout)
50   {
51     return;
52   }
53
54   int32_t expanded_rank = 0;
55   for (const auto &index :
56        (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
57   {
58     expanded_rank = std::max(expanded_rank, _graph.operands().at(index).shape().rank());
59   }
60   if (expanded_rank < 4)
61     return;
62
63   for (const auto &index :
64        (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
65   {
66     const auto &operand = _graph.operands().at(index);
67     if (operand.shape().rank() < expanded_rank)
68     {
69       if (operand.getUses().size() > 1)
70         throw std::runtime_error("PermutationOperationPass: not supported expanding rank of "
71                                  "operand used in more than one node");
72       // TODO remove const_cast later. For example, _ctx may need to be a non const variable or
73       //      a node to extend shape may be inserted in front of this operation
74       const_cast<Shape &>(operand.shape()).extendRank(expanded_rank);
75     }
76   }
77 }
78
79 void PermutationOperationPass::changeToKeepLayout(const Operation &node)
80 {
81   const auto &output_ind = node.getOutputs().at(0);
82   const auto &output_obj = _graph.operands().at(output_ind);
83
84   assert(output_obj.getDef().valid());
85   const auto node_index = output_obj.getDef();
86
87   auto &operation_li_map = _lowered_graph.lower_info().operation;
88   auto &operand_li_map = _lowered_graph.lower_info().operand;
89   const auto frontend_layout = _graph.layout();
90   const auto backend_layout = operation_li_map.getRawPtr(node_index)->layout();
91
92   if (frontend_layout == backend_layout)
93   {
94     return;
95   }
96
97   // Permutation changing layout beyond 4-D is not supported yet
98   assert(output_obj.shape().rank() <= 4);
99
100   // Change PermuteFactors of operands and the operation of target node
101   {
102     const auto op_li = operation_li_map.getRawPtr(node_index);
103     const auto backend = op_li->backend();
104
105     operation_li_map.set(node_index,
106                          std::make_unique<compiler::OperationLowerInfo>(backend, frontend_layout));
107
108     const PermuteFactor removed_factor{backend, backend_layout};
109     const PermuteFactor new_factor{backend, frontend_layout};
110     for (const auto &input : node.getInputs() | Remove::DUPLICATED | Remove::UNDEFINED)
111     {
112       // Check if it can be removed by checking if the operand is used by another operation and
113       // it uses the same backend and layout
114       bool canRemove = true;
115       for (const auto &use : _graph.operands().at(input).getUses())
116       {
117         if (use != node_index)
118         {
119           auto use_op_li = operation_li_map.getRawPtr(use);
120           if (use_op_li->backend() == backend && use_op_li->layout() == backend_layout)
121           {
122             canRemove = false;
123             break;
124           }
125         }
126       }
127
128       auto input_li = operand_li_map.getRawPtr(input);
129       if (canRemove)
130       {
131         input_li->removeUsePermuteFactor(removed_factor);
132       }
133       input_li->addUsePermuteFactor(new_factor);
134
135       // Whether if node's input is an input of model or a constant
136       if (!_graph.operands().at(input).getDef().valid() &&
137           (input_li->def_factors().size() == 1 &&
138            input_li->def_factors().getOnlyElement() == removed_factor))
139       {
140         assert(_graph.getInputs().contains(input) || _graph.operands().at(input).isConstant());
141         input_li->removeDefPermuteFactor(removed_factor);
142         input_li->addDefPermuteFactor(new_factor);
143       }
144     }
145
146     for (const auto &output : node.getOutputs() | Remove::DUPLICATED | Remove::UNDEFINED)
147     {
148       auto lower_info = operand_li_map.getRawPtr(output);
149       lower_info->removeDefPermuteFactor(removed_factor);
150       lower_info->addDefPermuteFactor(new_factor);
151
152       // Whether if node's output is an output of model
153       if (_graph.operands().at(output).getUses().size() == 0)
154       {
155         assert(_graph.getOutputs().contains(output));
156         lower_info->removeUsePermuteFactor(removed_factor);
157         lower_info->addUsePermuteFactor(new_factor);
158       }
159     }
160   }
161 }
162
163 void PermutationOperationPass::visit(const ir::operation::BinaryArithmetic &node)
164 {
165   applyExpandRanks(node);
166 }
167
168 void PermutationOperationPass::visit(const ir::operation::Concat &node) { applyExpandRanks(node); }
169
170 void PermutationOperationPass::visit(const ir::operation::Comparison &node)
171 {
172   applyExpandRanks(node);
173 }
174
175 void PermutationOperationPass::visit(const ir::operation::ElementwiseBinary &node)
176 {
177   applyExpandRanks(node);
178 }
179
180 void PermutationOperationPass::visit(const ir::operation::ElementwiseUnary &node)
181 {
182   applyExpandRanks(node);
183 }
184
185 void PermutationOperationPass::visit(const ir::operation::FullyConnected &node)
186 {
187   const auto &input_ind = node.getInputs().at(ir::operation::FullyConnected::Input::INPUT);
188   const auto &input_obj = _graph.operands().at(input_ind);
189   const auto &input_shape = input_obj.shape();
190
191   if (input_shape.rank() >= 4)
192   {
193     changeToKeepLayout(node);
194   }
195 }
196
197 void PermutationOperationPass::visit(const ir::operation::Gather &node)
198 {
199   const auto &input_ind = node.getInputs().at(ir::operation::Gather::Input::INPUT);
200   const auto &input_obj = _graph.operands().at(input_ind);
201   const auto &input_shape = input_obj.shape();
202
203   const auto &output_ind = node.getOutputs().at(0);
204   const auto &output_obj = _graph.operands().at(output_ind);
205   const auto &output_shape = output_obj.shape();
206
207   if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
208   {
209     changeToKeepLayout(node);
210   }
211 }
212
213 void PermutationOperationPass::visit(const ir::operation::OneHot &node)
214 {
215   const auto &output_ind = node.getOutputs().at(0);
216   const auto &output_obj = _graph.operands().at(output_ind);
217   const auto &output_shape = output_obj.shape();
218
219   if (output_shape.rank() >= 4)
220   {
221     changeToKeepLayout(node);
222   }
223 }
224
225 void PermutationOperationPass::visit(const ir::operation::Pack &node)
226 {
227   const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
228   const auto &input_obj = _graph.operands().at(input_ind);
229   const auto &input_shape = input_obj.shape();
230
231   const auto &output_ind = node.getOutputs().at(0);
232   const auto &output_obj = _graph.operands().at(output_ind);
233   const auto &output_shape = output_obj.shape();
234
235   if (input_shape.rank() < 4 || output_shape.rank() >= 4)
236   {
237     changeToKeepLayout(node);
238   }
239 }
240
241 void PermutationOperationPass::visit(const ir::operation::PReLU &node) { applyExpandRanks(node); }
242
243 void PermutationOperationPass::visit(const ir::operation::Reshape &node)
244 {
245   const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
246   const auto &input_obj = _graph.operands().at(input_ind);
247   const auto &input_shape = input_obj.shape();
248
249   const auto &output_ind = node.getOutputs().at(0);
250   const auto &output_obj = _graph.operands().at(output_ind);
251   const auto &output_shape = output_obj.shape();
252
253   if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
254   {
255     changeToKeepLayout(node);
256   }
257 }
258
259 void PermutationOperationPass::visit(const ir::operation::SquaredDifference &node)
260 {
261   applyExpandRanks(node);
262 }
263
264 void PermutationOperationPass::visit(const ir::operation::Unpack &node)
265 {
266   const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
267   const auto &input_obj = _graph.operands().at(input_ind);
268   const auto &input_shape = input_obj.shape();
269
270   const auto &output_ind = node.getOutputs().at(0);
271   const auto &output_obj = _graph.operands().at(output_ind);
272   const auto &output_shape = output_obj.shape();
273
274   if (input_shape.rank() < 4 || output_shape.rank() >= 4)
275   {
276     changeToKeepLayout(node);
277   }
278 }
279
280 } // namespace pass
281 } // namespace compiler
282 } // namespace onert