Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / pass / PermutationOperationPass.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "PermutationOperationPass.h"
18
19 #include "backend/Backend.h"
20 #include "backend/IConfig.h"
21 #include "ir/Graph.h"
22 #include "util/logging.h"
23
24 namespace onert
25 {
26 namespace compiler
27 {
28 namespace pass
29 {
30
31 using namespace ir;
32
33 void PermutationOperationPass::callback(const OperationIndex &, Operation &node)
34 {
35   node.accept(*this);
36 };
37
38 // TODO Remove this. Expanding ranks of Operand is dangerous
39 void PermutationOperationPass::applyExpandRanks(const Operation &node)
40 {
41   const auto &output_ind = node.getOutputs().at(0);
42   const auto &output = _graph.operands().at(output_ind);
43
44   assert(output.getDef().valid());
45   const auto node_index = output.getDef();
46   const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
47   const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout();
48   const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout();
49
50   if (frontend_layout == backend_layout)
51   {
52     return;
53   }
54
55   int32_t expanded_rank = 0;
56   for (const auto &index :
57        (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
58   {
59     expanded_rank = std::max(expanded_rank, _graph.operands().at(index).shape().rank());
60   }
61   if (expanded_rank < 4)
62     return;
63
64   for (const auto &index :
65        (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
66   {
67     const auto &operand = _graph.operands().at(index);
68     if (operand.shape().rank() < expanded_rank)
69     {
70       if (operand.getUses().size() > 1)
71         throw std::runtime_error("PermutationOperationPass: not supported expanding rank of "
72                                  "operand used in more than one node");
73       // TODO remove const_cast later. For example, _ctx may need to be a non const variable or
74       //      a node to extend shape may be inserted in front of this operation
75       const_cast<Shape &>(operand.shape()).extendRank(expanded_rank);
76     }
77   }
78 }
79
80 void PermutationOperationPass::changeToKeepLayout(const Operation &node)
81 {
82   const auto &output_ind = node.getOutputs().at(0);
83   const auto &output_obj = _graph.operands().at(output_ind);
84
85   assert(output_obj.getDef().valid());
86   const auto node_index = output_obj.getDef();
87   const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
88
89   const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout();
90   const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout();
91
92   if (frontend_layout == backend_layout)
93   {
94     return;
95   }
96
97   // Permutation changing layout beyond 4-D is not supported yet
98   assert(output_obj.shape().rank() <= 4);
99
100   // Divide op_seq based on target operation
101   {
102     auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index);
103     auto &operations = _lowered_graph.graph().operations();
104
105     // Create new op_seq and move information from existing op_seq to new op_seq if target
106     // node is the end of op_seq
107     auto it = prev_op_seq.begin();
108     // Find iterator of target node in op_seq
109     while (*(it++) != node_index)
110       ;
111     if (it != prev_op_seq.end())
112     {
113       const auto &target_op_idx = *it;
114       const auto &target_node = operations.at(target_op_idx);
115       const auto &next_op_seq_index =
116           _lowered_graph.op_seqs().emplace(target_op_idx, prev_op_seq.getLayout());
117       auto &next_op_seq = _lowered_graph.op_seqs().at(next_op_seq_index);
118       next_op_seq.setInputs(target_node.getInputs());
119       next_op_seq.setOutputs(target_node.getOutputs());
120
121       std::vector<OperationIndex> remove_list;
122       remove_list.emplace_back(target_op_idx);
123       while (++it != prev_op_seq.end())
124       {
125         next_op_seq.appendOperation(target_op_idx);
126         next_op_seq.setOutputs(target_node.getOutputs());
127         remove_list.emplace_back(target_op_idx);
128       }
129
130       prev_op_seq.setOutputs(node.getOutputs());
131       for (const auto &index : remove_list)
132       {
133         prev_op_seq.remove(index);
134       }
135
136       const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
137       _lowered_graph.setLowerInfo(
138           next_op_seq_index,
139           std::make_unique<ir::operation::LowerInfo>(op_seq_li->backend(), op_seq_li->layout()));
140     }
141   }
142
143   // Remove target operation from op_seq and insert the target operation to new op_seq
144   {
145     const auto backend = _lowered_graph.getLowerInfo(op_seq_index)->backend();
146
147     // Remove target operation from op_sequence
148     _lowered_graph.op_seqs().removeFromOpSequence(node_index);
149
150     if (!_lowered_graph.op_seqs().exist(op_seq_index))
151     {
152       // Remove lowerinfo for op_seq of target operation if the op_seq does not exist
153       _lowered_graph.removeLowerInfo(op_seq_index);
154     }
155     else
156     {
157       // Update op_seq of target operation if the op_seq exists
158       auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index);
159       const auto &last_node_idx = *(--prev_op_seq.end());
160       const auto &last_node = _lowered_graph.graph().operations().at(last_node_idx);
161       prev_op_seq.setOutputs(last_node.getOutputs());
162     }
163
164     // Create new op_seq and set information to the op_seq
165     auto new_op_seq_index = _lowered_graph.op_seqs().emplace(node_index, frontend_layout);
166     auto &new_op_seq = _lowered_graph.op_seqs().at(new_op_seq_index);
167     new_op_seq.setInputs(node.getInputs());
168     new_op_seq.setOutputs(node.getOutputs());
169     _lowered_graph.setLowerInfo(
170         new_op_seq_index, std::make_unique<ir::operation::LowerInfo>(backend, frontend_layout));
171   }
172
173   // Change PermuteFactors of operands of target node
174   {
175     const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
176     const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
177     const auto backend = op_seq_li->backend();
178     const operand::PermuteFactor removed_factor{backend, backend_layout};
179     const operand::PermuteFactor new_factor{backend, frontend_layout};
180     for (const auto &input : node.getInputs() | Remove::DUPLICATED | Remove::UNDEFINED)
181     {
182       bool canRemove = true;
183       for (const auto &use : _graph.operands().at(input).getUses())
184       {
185         if (use != node_index)
186         {
187           const auto &use_op_seq_index = _lowered_graph.op_seqs().getOperation(use);
188           auto use_op_seq_li = _lowered_graph.getLowerInfo(use_op_seq_index);
189           if (use_op_seq_li->backend() == backend && use_op_seq_li->layout() == backend_layout)
190           {
191             canRemove = false;
192             break;
193           }
194         }
195       }
196
197       auto lower_info = _lowered_graph.getLowerInfo(input);
198       if (canRemove)
199       {
200         lower_info->removeUsePermuteFactor(removed_factor);
201       }
202       lower_info->addUsePermuteFactor(new_factor);
203
204       // Whether if node's input is an input of model or a constant
205       if (!_graph.operands().at(input).getDef().valid() &&
206           (lower_info->def_factors().size() == 1 &&
207            lower_info->def_factors().getOnlyElement() == removed_factor))
208       {
209         assert(_graph.getInputs().contains(input) || _graph.operands().at(input).isConstant());
210         lower_info->removeDefPermuteFactor(removed_factor);
211         lower_info->addDefPermuteFactor(new_factor);
212       }
213     }
214
215     for (const auto &output : node.getOutputs() | Remove::DUPLICATED)
216     {
217       auto lower_info = _lowered_graph.getLowerInfo(output);
218       lower_info->removeDefPermuteFactor(removed_factor);
219       lower_info->addDefPermuteFactor(new_factor);
220
221       // Whether if node's output is an output of model
222       if (_graph.operands().at(output).getUses().size() == 0)
223       {
224         assert(_graph.getOutputs().contains(output));
225         lower_info->removeUsePermuteFactor(removed_factor);
226         lower_info->addUsePermuteFactor(new_factor);
227       }
228     }
229   }
230 }
231
232 void PermutationOperationPass::visit(const ir::operation::BinaryArithmetic &node)
233 {
234   applyExpandRanks(node);
235 }
236
237 void PermutationOperationPass::visit(const ir::operation::Concat &node) { applyExpandRanks(node); }
238
239 void PermutationOperationPass::visit(const ir::operation::Comparison &node)
240 {
241   applyExpandRanks(node);
242 }
243
244 void PermutationOperationPass::visit(const ir::operation::ElementwiseBinary &node)
245 {
246   applyExpandRanks(node);
247 }
248
249 void PermutationOperationPass::visit(const ir::operation::ElementwiseUnary &node)
250 {
251   applyExpandRanks(node);
252 }
253
254 void PermutationOperationPass::visit(const ir::operation::FullyConnected &node)
255 {
256   const auto &input_ind = node.getInputs().at(ir::operation::FullyConnected::Input::INPUT);
257   const auto &input_obj = _graph.operands().at(input_ind);
258   const auto &input_shape = input_obj.shape();
259
260   if (input_shape.rank() >= 4)
261   {
262     changeToKeepLayout(node);
263   }
264 }
265
266 void PermutationOperationPass::visit(const ir::operation::Gather &node)
267 {
268   const auto &input_ind = node.getInputs().at(ir::operation::Gather::Input::INPUT);
269   const auto &input_obj = _graph.operands().at(input_ind);
270   const auto &input_shape = input_obj.shape();
271
272   const auto &output_ind = node.getOutputs().at(0);
273   const auto &output_obj = _graph.operands().at(output_ind);
274   const auto &output_shape = output_obj.shape();
275
276   if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
277   {
278     changeToKeepLayout(node);
279   }
280 }
281
282 void PermutationOperationPass::visit(const ir::operation::Pack &node)
283 {
284   const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
285   const auto &input_obj = _graph.operands().at(input_ind);
286   const auto &input_shape = input_obj.shape();
287
288   const auto &output_ind = node.getOutputs().at(0);
289   const auto &output_obj = _graph.operands().at(output_ind);
290   const auto &output_shape = output_obj.shape();
291
292   if (input_shape.rank() < 4 || output_shape.rank() >= 4)
293   {
294     changeToKeepLayout(node);
295   }
296 }
297
298 void PermutationOperationPass::visit(const ir::operation::PReLU &node) { applyExpandRanks(node); }
299
300 void PermutationOperationPass::visit(const ir::operation::Reshape &node)
301 {
302   const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
303   const auto &input_obj = _graph.operands().at(input_ind);
304   const auto &input_shape = input_obj.shape();
305
306   const auto &output_ind = node.getOutputs().at(0);
307   const auto &output_obj = _graph.operands().at(output_ind);
308   const auto &output_shape = output_obj.shape();
309
310   if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
311   {
312     changeToKeepLayout(node);
313   }
314 }
315
316 void PermutationOperationPass::visit(const ir::operation::SquaredDifference &node)
317 {
318   applyExpandRanks(node);
319 }
320
321 void PermutationOperationPass::visit(const ir::operation::Unpack &node)
322 {
323   const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
324   const auto &input_obj = _graph.operands().at(input_ind);
325   const auto &input_shape = input_obj.shape();
326
327   const auto &output_ind = node.getOutputs().at(0);
328   const auto &output_obj = _graph.operands().at(output_ind);
329   const auto &output_shape = output_obj.shape();
330
331   if (input_shape.rank() < 4 || output_shape.rank() >= 4)
332   {
333     changeToKeepLayout(node);
334   }
335 }
336
337 } // namespace pass
338 } // namespace compiler
339 } // namespace onert