2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "PermutationOperationPass.h"
19 #include "backend/Backend.h"
20 #include "backend/IConfig.h"
22 #include "util/logging.h"
33 void PermutationOperationPass::callback(const OperationIndex &, Operation &node)
38 // TODO Remove this. Expanding ranks of Operand is dangerous
39 void PermutationOperationPass::applyExpandRanks(const Operation &node)
41 const auto &output_ind = node.getOutputs().at(0);
42 const auto &output = _graph.operands().at(output_ind);
44 assert(output.getDef().valid());
45 const auto node_index = output.getDef();
46 const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
47 const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout();
48 const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout();
50 if (frontend_layout == backend_layout)
55 int32_t expanded_rank = 0;
56 for (const auto &index :
57 (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
59 expanded_rank = std::max(expanded_rank, _graph.operands().at(index).shape().rank());
61 if (expanded_rank < 4)
64 for (const auto &index :
65 (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
67 const auto &operand = _graph.operands().at(index);
68 if (operand.shape().rank() < expanded_rank)
70 if (operand.getUses().size() > 1)
71 throw std::runtime_error("PermutationOperationPass: not supported expanding rank of "
72 "operand used in more than one node");
73 // TODO remove const_cast later. For example, _ctx may need to be a non const variable or
74 // a node to extend shape may be inserted in front of this operation
75 const_cast<Shape &>(operand.shape()).extendRank(expanded_rank);
80 void PermutationOperationPass::changeToKeepLayout(const Operation &node)
82 const auto &output_ind = node.getOutputs().at(0);
83 const auto &output_obj = _graph.operands().at(output_ind);
85 assert(output_obj.getDef().valid());
86 const auto node_index = output_obj.getDef();
87 const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
89 const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout();
90 const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout();
92 if (frontend_layout == backend_layout)
97 // Permutation changing layout beyond 4-D is not supported yet
98 assert(output_obj.shape().rank() <= 4);
100 // Divide op_seq based on target operation
102 auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index);
103 auto &operations = _lowered_graph.graph().operations();
105 // Create new op_seq and move information from existing op_seq to new op_seq if target
106 // node is the end of op_seq
107 auto it = prev_op_seq.begin();
108 // Find iterator of target node in op_seq
109 while (*(it++) != node_index)
111 if (it != prev_op_seq.end())
113 const auto &target_op_idx = *it;
114 const auto &target_node = operations.at(target_op_idx);
115 const auto &next_op_seq_index =
116 _lowered_graph.op_seqs().emplace(target_op_idx, prev_op_seq.getLayout());
117 auto &next_op_seq = _lowered_graph.op_seqs().at(next_op_seq_index);
118 next_op_seq.setInputs(target_node.getInputs());
119 next_op_seq.setOutputs(target_node.getOutputs());
121 std::vector<OperationIndex> remove_list;
122 remove_list.emplace_back(target_op_idx);
123 while (++it != prev_op_seq.end())
125 next_op_seq.appendOperation(target_op_idx);
126 next_op_seq.setOutputs(target_node.getOutputs());
127 remove_list.emplace_back(target_op_idx);
130 prev_op_seq.setOutputs(node.getOutputs());
131 for (const auto &index : remove_list)
133 prev_op_seq.remove(index);
136 const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
137 _lowered_graph.setLowerInfo(
139 std::make_unique<ir::operation::LowerInfo>(op_seq_li->backend(), op_seq_li->layout()));
143 // Remove target operation from op_seq and insert the target operation to new op_seq
145 const auto backend = _lowered_graph.getLowerInfo(op_seq_index)->backend();
147 // Remove target operation from op_sequence
148 _lowered_graph.op_seqs().removeFromOpSequence(node_index);
150 if (!_lowered_graph.op_seqs().exist(op_seq_index))
152 // Remove lowerinfo for op_seq of target operation if the op_seq does not exist
153 _lowered_graph.removeLowerInfo(op_seq_index);
157 // Update op_seq of target operation if the op_seq exists
158 auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index);
159 const auto &last_node_idx = *(--prev_op_seq.end());
160 const auto &last_node = _lowered_graph.graph().operations().at(last_node_idx);
161 prev_op_seq.setOutputs(last_node.getOutputs());
164 // Create new op_seq and set information to the op_seq
165 auto new_op_seq_index = _lowered_graph.op_seqs().emplace(node_index, frontend_layout);
166 auto &new_op_seq = _lowered_graph.op_seqs().at(new_op_seq_index);
167 new_op_seq.setInputs(node.getInputs());
168 new_op_seq.setOutputs(node.getOutputs());
169 _lowered_graph.setLowerInfo(
170 new_op_seq_index, std::make_unique<ir::operation::LowerInfo>(backend, frontend_layout));
173 // Change PermuteFactors of operands of target node
175 const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
176 const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
177 const auto backend = op_seq_li->backend();
178 const operand::PermuteFactor removed_factor{backend, backend_layout};
179 const operand::PermuteFactor new_factor{backend, frontend_layout};
180 for (const auto &input : node.getInputs() | Remove::DUPLICATED | Remove::UNDEFINED)
182 bool canRemove = true;
183 for (const auto &use : _graph.operands().at(input).getUses())
185 if (use != node_index)
187 const auto &use_op_seq_index = _lowered_graph.op_seqs().getOperation(use);
188 auto use_op_seq_li = _lowered_graph.getLowerInfo(use_op_seq_index);
189 if (use_op_seq_li->backend() == backend && use_op_seq_li->layout() == backend_layout)
197 auto lower_info = _lowered_graph.getLowerInfo(input);
200 lower_info->removeUsePermuteFactor(removed_factor);
202 lower_info->addUsePermuteFactor(new_factor);
204 // Whether if node's input is an input of model or a constant
205 if (!_graph.operands().at(input).getDef().valid() &&
206 (lower_info->def_factors().size() == 1 &&
207 lower_info->def_factors().getOnlyElement() == removed_factor))
209 assert(_graph.getInputs().contains(input) || _graph.operands().at(input).isConstant());
210 lower_info->removeDefPermuteFactor(removed_factor);
211 lower_info->addDefPermuteFactor(new_factor);
215 for (const auto &output : node.getOutputs() | Remove::DUPLICATED)
217 auto lower_info = _lowered_graph.getLowerInfo(output);
218 lower_info->removeDefPermuteFactor(removed_factor);
219 lower_info->addDefPermuteFactor(new_factor);
221 // Whether if node's output is an output of model
222 if (_graph.operands().at(output).getUses().size() == 0)
224 assert(_graph.getOutputs().contains(output));
225 lower_info->removeUsePermuteFactor(removed_factor);
226 lower_info->addUsePermuteFactor(new_factor);
232 void PermutationOperationPass::visit(const ir::operation::BinaryArithmetic &node)
234 applyExpandRanks(node);
237 void PermutationOperationPass::visit(const ir::operation::Concat &node) { applyExpandRanks(node); }
239 void PermutationOperationPass::visit(const ir::operation::Comparison &node)
241 applyExpandRanks(node);
244 void PermutationOperationPass::visit(const ir::operation::ElementwiseBinary &node)
246 applyExpandRanks(node);
249 void PermutationOperationPass::visit(const ir::operation::ElementwiseUnary &node)
251 applyExpandRanks(node);
254 void PermutationOperationPass::visit(const ir::operation::FullyConnected &node)
256 const auto &input_ind = node.getInputs().at(ir::operation::FullyConnected::Input::INPUT);
257 const auto &input_obj = _graph.operands().at(input_ind);
258 const auto &input_shape = input_obj.shape();
260 if (input_shape.rank() >= 4)
262 changeToKeepLayout(node);
266 void PermutationOperationPass::visit(const ir::operation::Gather &node)
268 const auto &input_ind = node.getInputs().at(ir::operation::Gather::Input::INPUT);
269 const auto &input_obj = _graph.operands().at(input_ind);
270 const auto &input_shape = input_obj.shape();
272 const auto &output_ind = node.getOutputs().at(0);
273 const auto &output_obj = _graph.operands().at(output_ind);
274 const auto &output_shape = output_obj.shape();
276 if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
278 changeToKeepLayout(node);
282 void PermutationOperationPass::visit(const ir::operation::Pack &node)
284 const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
285 const auto &input_obj = _graph.operands().at(input_ind);
286 const auto &input_shape = input_obj.shape();
288 const auto &output_ind = node.getOutputs().at(0);
289 const auto &output_obj = _graph.operands().at(output_ind);
290 const auto &output_shape = output_obj.shape();
292 if (input_shape.rank() < 4 || output_shape.rank() >= 4)
294 changeToKeepLayout(node);
298 void PermutationOperationPass::visit(const ir::operation::PReLU &node) { applyExpandRanks(node); }
300 void PermutationOperationPass::visit(const ir::operation::Reshape &node)
302 const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
303 const auto &input_obj = _graph.operands().at(input_ind);
304 const auto &input_shape = input_obj.shape();
306 const auto &output_ind = node.getOutputs().at(0);
307 const auto &output_obj = _graph.operands().at(output_ind);
308 const auto &output_shape = output_obj.shape();
310 if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
312 changeToKeepLayout(node);
316 void PermutationOperationPass::visit(const ir::operation::SquaredDifference &node)
318 applyExpandRanks(node);
321 void PermutationOperationPass::visit(const ir::operation::Unpack &node)
323 const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
324 const auto &input_obj = _graph.operands().at(input_ind);
325 const auto &input_shape = input_obj.shape();
327 const auto &output_ind = node.getOutputs().at(0);
328 const auto &output_obj = _graph.operands().at(output_ind);
329 const auto &output_shape = output_obj.shape();
331 if (input_shape.rank() < 4 || output_shape.rank() >= 4)
333 changeToKeepLayout(node);
338 } // namespace compiler