6eb412cf1460b11301116404082699e82e715377
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / pass / PermutationOperationPass.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "PermutationOperationPass.h"
18
19 #include "backend/Backend.h"
20 #include "backend/IConfig.h"
21 #include "ir/Graph.h"
22 #include "util/logging.h"
23
24 namespace onert
25 {
26 namespace ir
27 {
28 namespace pass
29 {
30
31 void PermutationOperationPass::callback(const OperationIndex &, Operation &node)
32 {
33   node.accept(*this);
34 };
35
36 // TODO Remove this. Expanding ranks of Operand is dangerous
37 void PermutationOperationPass::applyExpandRanks(const Operation &node)
38 {
39   const auto &output_ind = node.getOutputs().at(0);
40   const auto &output = _graph.operands().at(output_ind);
41
42   assert(output.getDef().valid());
43   const auto node_index = output.getDef();
44   const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
45   const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout();
46   const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout();
47
48   if (frontend_layout == backend_layout)
49   {
50     return;
51   }
52
53   int32_t expanded_rank = 0;
54   for (const auto &index :
55        (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
56   {
57     expanded_rank = std::max(expanded_rank, _graph.operands().at(index).shape().rank());
58   }
59   if (expanded_rank < 4)
60     return;
61
62   for (const auto &index :
63        (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
64   {
65     const auto &operand = _graph.operands().at(index);
66     if (operand.shape().rank() < expanded_rank)
67     {
68       if (operand.getUses().size() > 1)
69         throw std::runtime_error("PermutationOperationPass: not supported expanding rank of "
70                                  "operand used in more than one node");
71       // TODO remove const_cast later. For example, _ctx may need to be a non const variable or
72       //      a node to extend shape may be inserted in front of this operation
73       const_cast<ir::Shape &>(operand.shape()).extendRank(expanded_rank);
74     }
75   }
76 }
77
78 void PermutationOperationPass::changeToKeepLayout(const Operation &node)
79 {
80   const auto &output_ind = node.getOutputs().at(0);
81   const auto &output_obj = _graph.operands().at(output_ind);
82
83   assert(output_obj.getDef().valid());
84   const auto node_index = output_obj.getDef();
85   const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
86
87   const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout();
88   const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout();
89
90   if (frontend_layout == backend_layout)
91   {
92     return;
93   }
94
95   // Permutation changing layout beyond 4-D is not supported yet
96   assert(output_obj.shape().rank() <= 4);
97
98   // Divide op_seq based on target operation
99   {
100     auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index);
101     auto &operations = _lowered_graph.graph().operations();
102
103     // Create new op_seq and move information from existing op_seq to new op_seq if target
104     // node is the end of op_seq
105     auto it = prev_op_seq.begin();
106     // Find iterator of target node in op_seq
107     while (*(it++) != node_index)
108       ;
109     if (it != prev_op_seq.end())
110     {
111       const auto &target_op_idx = *it;
112       const auto &target_node = operations.at(target_op_idx);
113       const auto &next_op_seq_index =
114           _lowered_graph.op_seqs().emplace(target_op_idx, prev_op_seq.getLayout());
115       auto &next_op_seq = _lowered_graph.op_seqs().at(next_op_seq_index);
116       next_op_seq.setInputs(target_node.getInputs());
117       next_op_seq.setOutputs(target_node.getOutputs());
118
119       std::vector<OperationIndex> remove_list;
120       remove_list.emplace_back(target_op_idx);
121       while (++it != prev_op_seq.end())
122       {
123         next_op_seq.appendOperation(target_op_idx);
124         next_op_seq.setOutputs(target_node.getOutputs());
125         remove_list.emplace_back(target_op_idx);
126       }
127
128       prev_op_seq.setOutputs(node.getOutputs());
129       for (const auto &index : remove_list)
130       {
131         prev_op_seq.remove(index);
132       }
133
134       const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
135       _lowered_graph.setLowerInfo(
136           next_op_seq_index,
137           std::make_unique<operation::LowerInfo>(op_seq_li->backend(), op_seq_li->layout()));
138     }
139   }
140
141   // Remove target operation from op_seq and insert the target operation to new op_seq
142   {
143     const auto backend = _lowered_graph.getLowerInfo(op_seq_index)->backend();
144
145     // Remove target operation from op_sequence
146     _lowered_graph.op_seqs().removeFromOpSequence(node_index);
147
148     if (!_lowered_graph.op_seqs().exist(op_seq_index))
149     {
150       // Remove lowerinfo for op_seq of target operation if the op_seq does not exist
151       _lowered_graph.removeLowerInfo(op_seq_index);
152     }
153     else
154     {
155       // Update op_seq of target operation if the op_seq exists
156       auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index);
157       const auto &last_node_idx = *(--prev_op_seq.end());
158       const auto &last_node = _lowered_graph.graph().operations().at(last_node_idx);
159       prev_op_seq.setOutputs(last_node.getOutputs());
160     }
161
162     // Create new op_seq and set information to the op_seq
163     auto new_op_seq_index = _lowered_graph.op_seqs().emplace(node_index, frontend_layout);
164     auto &new_op_seq = _lowered_graph.op_seqs().at(new_op_seq_index);
165     new_op_seq.setInputs(node.getInputs());
166     new_op_seq.setOutputs(node.getOutputs());
167     _lowered_graph.setLowerInfo(new_op_seq_index,
168                                 std::make_unique<operation::LowerInfo>(backend, frontend_layout));
169   }
170
171   // Change PermuteFactors of operands of target node
172   {
173     const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
174     const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
175     const auto backend = op_seq_li->backend();
176     const operand::PermuteFactor removed_factor{backend, backend_layout};
177     const operand::PermuteFactor new_factor{backend, frontend_layout};
178     for (const auto &input : node.getInputs() | Remove::DUPLICATED | ir::Remove::UNDEFINED)
179     {
180       bool canRemove = true;
181       for (const auto &use : _graph.operands().at(input).getUses())
182       {
183         if (use != node_index)
184         {
185           const auto &use_op_seq_index = _lowered_graph.op_seqs().getOperation(use);
186           auto use_op_seq_li = _lowered_graph.getLowerInfo(use_op_seq_index);
187           if (use_op_seq_li->backend() == backend && use_op_seq_li->layout() == backend_layout)
188           {
189             canRemove = false;
190             break;
191           }
192         }
193       }
194
195       auto lower_info = _lowered_graph.getLowerInfo(input);
196       if (canRemove)
197       {
198         lower_info->removeUsePermuteFactor(removed_factor);
199       }
200       lower_info->addUsePermuteFactor(new_factor);
201
202       // Whether if node's input is an input of model or a constant
203       if (!_graph.operands().at(input).getDef().valid() &&
204           (lower_info->def_factors().size() == 1 &&
205            lower_info->def_factors().getOnlyElement() == removed_factor))
206       {
207         assert(_graph.getInputs().contains(input) || _graph.operands().at(input).isConstant());
208         lower_info->removeDefPermuteFactor(removed_factor);
209         lower_info->addDefPermuteFactor(new_factor);
210       }
211     }
212
213     for (const auto &output : node.getOutputs() | Remove::DUPLICATED)
214     {
215       auto lower_info = _lowered_graph.getLowerInfo(output);
216       lower_info->removeDefPermuteFactor(removed_factor);
217       lower_info->addDefPermuteFactor(new_factor);
218
219       // Whether if node's output is an output of model
220       if (_graph.operands().at(output).getUses().size() == 0)
221       {
222         assert(_graph.getOutputs().contains(output));
223         lower_info->removeUsePermuteFactor(removed_factor);
224         lower_info->addUsePermuteFactor(new_factor);
225       }
226     }
227   }
228 }
229
230 void PermutationOperationPass::visit(const operation::Add &node) { applyExpandRanks(node); }
231
232 void PermutationOperationPass::visit(const operation::Concat &node) { applyExpandRanks(node); }
233
234 void PermutationOperationPass::visit(const operation::Comparison &node) { applyExpandRanks(node); }
235
236 void PermutationOperationPass::visit(const operation::Div &node) { applyExpandRanks(node); }
237
238 void PermutationOperationPass::visit(const operation::FullyConnected &node)
239 {
240   const auto &input_ind = node.getInputs().at(operation::FullyConnected::Input::INPUT);
241   const auto &input_obj = _graph.operands().at(input_ind);
242   const auto &input_shape = input_obj.shape();
243
244   if (input_shape.rank() >= 4)
245   {
246     changeToKeepLayout(node);
247   }
248 }
249
250 void PermutationOperationPass::visit(const operation::Gather &node)
251 {
252   const auto &input_ind = node.getInputs().at(operation::Gather::Input::INPUT);
253   const auto &input_obj = _graph.operands().at(input_ind);
254   const auto &input_shape = input_obj.shape();
255
256   const auto &output_ind = node.getOutputs().at(0);
257   const auto &output_obj = _graph.operands().at(output_ind);
258   const auto &output_shape = output_obj.shape();
259
260   if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
261   {
262     changeToKeepLayout(node);
263   }
264 }
265
266 void PermutationOperationPass::visit(const operation::LogicalAnd &node) { applyExpandRanks(node); }
267
268 void PermutationOperationPass::visit(const operation::LogicalNot &node) { applyExpandRanks(node); }
269
270 void PermutationOperationPass::visit(const operation::LogicalOr &node) { applyExpandRanks(node); }
271
272 void PermutationOperationPass::visit(const operation::Max &node) { applyExpandRanks(node); }
273
274 void PermutationOperationPass::visit(const operation::Min &node) { applyExpandRanks(node); }
275
276 void PermutationOperationPass::visit(const operation::Mul &node) { applyExpandRanks(node); }
277
278 void PermutationOperationPass::visit(const operation::Pack &node)
279 {
280   const auto &input_ind = node.getInputs().at(operation::Reshape::Input::INPUT);
281   const auto &input_obj = _graph.operands().at(input_ind);
282   const auto &input_shape = input_obj.shape();
283
284   const auto &output_ind = node.getOutputs().at(0);
285   const auto &output_obj = _graph.operands().at(output_ind);
286   const auto &output_shape = output_obj.shape();
287
288   if (input_shape.rank() < 4 || output_shape.rank() >= 4)
289   {
290     changeToKeepLayout(node);
291   }
292 }
293
294 void PermutationOperationPass::visit(const operation::PReLU &node) { applyExpandRanks(node); }
295
296 void PermutationOperationPass::visit(const operation::Reshape &node)
297 {
298   const auto &input_ind = node.getInputs().at(operation::Reshape::Input::INPUT);
299   const auto &input_obj = _graph.operands().at(input_ind);
300   const auto &input_shape = input_obj.shape();
301
302   const auto &output_ind = node.getOutputs().at(0);
303   const auto &output_obj = _graph.operands().at(output_ind);
304   const auto &output_shape = output_obj.shape();
305
306   if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
307   {
308     changeToKeepLayout(node);
309   }
310 }
311
312 void PermutationOperationPass::visit(const operation::SquaredDifference &node)
313 {
314   applyExpandRanks(node);
315 }
316
317 void PermutationOperationPass::visit(const operation::Sub &node) { applyExpandRanks(node); }
318
319 void PermutationOperationPass::visit(const operation::Unpack &node)
320 {
321   const auto &input_ind = node.getInputs().at(operation::Reshape::Input::INPUT);
322   const auto &input_obj = _graph.operands().at(input_ind);
323   const auto &input_shape = input_obj.shape();
324
325   const auto &output_ind = node.getOutputs().at(0);
326   const auto &output_obj = _graph.operands().at(output_ind);
327   const auto &output_shape = output_obj.shape();
328
329   if (input_shape.rank() < 4 || output_shape.rank() >= 4)
330   {
331     changeToKeepLayout(node);
332   }
333 }
334
335 } // namespace pass
336 } // namespace ir
337 } // namespace onert