2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "PermutationOperationPass.h"
19 #include "backend/Backend.h"
20 #include "backend/IConfig.h"
22 #include "util/logging.h"
31 void PermutationOperationPass::callback(const OperationIndex &, Operation &node)
36 // TODO Remove this. Expanding ranks of Operand is dangerous
37 void PermutationOperationPass::applyExpandRanks(const Operation &node)
39 const auto &output_ind = node.getOutputs().at(0);
40 const auto &output = _graph.operands().at(output_ind);
42 assert(output.getDef().valid());
43 const auto node_index = output.getDef();
44 const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
45 const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout();
46 const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout();
48 if (frontend_layout == backend_layout)
53 int32_t expanded_rank = 0;
54 for (const auto &index :
55 (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
57 expanded_rank = std::max(expanded_rank, _graph.operands().at(index).shape().rank());
59 if (expanded_rank < 4)
62 for (const auto &index :
63 (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
65 const auto &operand = _graph.operands().at(index);
66 if (operand.shape().rank() < expanded_rank)
68 if (operand.getUses().size() > 1)
69 throw std::runtime_error("PermutationOperationPass: not supported expanding rank of "
70 "operand used in more than one node");
71 // TODO remove const_cast later. For example, _ctx may need to be a non const variable or
72 // a node to extend shape may be inserted in front of this operation
73 const_cast<ir::Shape &>(operand.shape()).extendRank(expanded_rank);
78 void PermutationOperationPass::changeToKeepLayout(const Operation &node)
80 const auto &output_ind = node.getOutputs().at(0);
81 const auto &output_obj = _graph.operands().at(output_ind);
83 assert(output_obj.getDef().valid());
84 const auto node_index = output_obj.getDef();
85 const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
87 const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout();
88 const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout();
90 if (frontend_layout == backend_layout)
95 // Permutation changing layout beyond 4-D is not supported yet
96 assert(output_obj.shape().rank() <= 4);
98 // Divide op_seq based on target operation
100 auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index);
101 auto &operations = _lowered_graph.graph().operations();
103 // Create new op_seq and move information from existing op_seq to new op_seq if target
104 // node is the end of op_seq
105 auto it = prev_op_seq.begin();
106 // Find iterator of target node in op_seq
107 while (*(it++) != node_index)
109 if (it != prev_op_seq.end())
111 const auto &target_op_idx = *it;
112 const auto &target_node = operations.at(target_op_idx);
113 const auto &next_op_seq_index =
114 _lowered_graph.op_seqs().emplace(target_op_idx, prev_op_seq.getLayout());
115 auto &next_op_seq = _lowered_graph.op_seqs().at(next_op_seq_index);
116 next_op_seq.setInputs(target_node.getInputs());
117 next_op_seq.setOutputs(target_node.getOutputs());
119 std::vector<OperationIndex> remove_list;
120 remove_list.emplace_back(target_op_idx);
121 while (++it != prev_op_seq.end())
123 next_op_seq.appendOperation(target_op_idx);
124 next_op_seq.setOutputs(target_node.getOutputs());
125 remove_list.emplace_back(target_op_idx);
128 prev_op_seq.setOutputs(node.getOutputs());
129 for (const auto &index : remove_list)
131 prev_op_seq.remove(index);
134 const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
135 _lowered_graph.setLowerInfo(
137 std::make_unique<operation::LowerInfo>(op_seq_li->backend(), op_seq_li->layout()));
141 // Remove target operation from op_seq and insert the target operation to new op_seq
143 const auto backend = _lowered_graph.getLowerInfo(op_seq_index)->backend();
145 // Remove target operation from op_sequence
146 _lowered_graph.op_seqs().removeFromOpSequence(node_index);
148 if (!_lowered_graph.op_seqs().exist(op_seq_index))
150 // Remove lowerinfo for op_seq of target operation if the op_seq does not exist
151 _lowered_graph.removeLowerInfo(op_seq_index);
155 // Update op_seq of target operation if the op_seq exists
156 auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index);
157 const auto &last_node_idx = *(--prev_op_seq.end());
158 const auto &last_node = _lowered_graph.graph().operations().at(last_node_idx);
159 prev_op_seq.setOutputs(last_node.getOutputs());
162 // Create new op_seq and set information to the op_seq
163 auto new_op_seq_index = _lowered_graph.op_seqs().emplace(node_index, frontend_layout);
164 auto &new_op_seq = _lowered_graph.op_seqs().at(new_op_seq_index);
165 new_op_seq.setInputs(node.getInputs());
166 new_op_seq.setOutputs(node.getOutputs());
167 _lowered_graph.setLowerInfo(new_op_seq_index,
168 std::make_unique<operation::LowerInfo>(backend, frontend_layout));
171 // Change PermuteFactors of operands of target node
173 const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
174 const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
175 const auto backend = op_seq_li->backend();
176 const operand::PermuteFactor removed_factor{backend, backend_layout};
177 const operand::PermuteFactor new_factor{backend, frontend_layout};
178 for (const auto &input : node.getInputs() | Remove::DUPLICATED | ir::Remove::UNDEFINED)
180 bool canRemove = true;
181 for (const auto &use : _graph.operands().at(input).getUses())
183 if (use != node_index)
185 const auto &use_op_seq_index = _lowered_graph.op_seqs().getOperation(use);
186 auto use_op_seq_li = _lowered_graph.getLowerInfo(use_op_seq_index);
187 if (use_op_seq_li->backend() == backend && use_op_seq_li->layout() == backend_layout)
195 auto lower_info = _lowered_graph.getLowerInfo(input);
198 lower_info->removeUsePermuteFactor(removed_factor);
200 lower_info->addUsePermuteFactor(new_factor);
202 // Whether if node's input is an input of model or a constant
203 if (!_graph.operands().at(input).getDef().valid() &&
204 (lower_info->def_factors().size() == 1 &&
205 lower_info->def_factors().getOnlyElement() == removed_factor))
207 assert(_graph.getInputs().contains(input) || _graph.operands().at(input).isConstant());
208 lower_info->removeDefPermuteFactor(removed_factor);
209 lower_info->addDefPermuteFactor(new_factor);
213 for (const auto &output : node.getOutputs() | Remove::DUPLICATED)
215 auto lower_info = _lowered_graph.getLowerInfo(output);
216 lower_info->removeDefPermuteFactor(removed_factor);
217 lower_info->addDefPermuteFactor(new_factor);
219 // Whether if node's output is an output of model
220 if (_graph.operands().at(output).getUses().size() == 0)
222 assert(_graph.getOutputs().contains(output));
223 lower_info->removeUsePermuteFactor(removed_factor);
224 lower_info->addUsePermuteFactor(new_factor);
230 void PermutationOperationPass::visit(const operation::Add &node) { applyExpandRanks(node); }
232 void PermutationOperationPass::visit(const operation::Concat &node) { applyExpandRanks(node); }
234 void PermutationOperationPass::visit(const operation::Comparison &node) { applyExpandRanks(node); }
236 void PermutationOperationPass::visit(const operation::Div &node) { applyExpandRanks(node); }
238 void PermutationOperationPass::visit(const operation::FullyConnected &node)
240 const auto &input_ind = node.getInputs().at(operation::FullyConnected::Input::INPUT);
241 const auto &input_obj = _graph.operands().at(input_ind);
242 const auto &input_shape = input_obj.shape();
244 if (input_shape.rank() >= 4)
246 changeToKeepLayout(node);
250 void PermutationOperationPass::visit(const operation::Gather &node)
252 const auto &input_ind = node.getInputs().at(operation::Gather::Input::INPUT);
253 const auto &input_obj = _graph.operands().at(input_ind);
254 const auto &input_shape = input_obj.shape();
256 const auto &output_ind = node.getOutputs().at(0);
257 const auto &output_obj = _graph.operands().at(output_ind);
258 const auto &output_shape = output_obj.shape();
260 if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
262 changeToKeepLayout(node);
266 void PermutationOperationPass::visit(const operation::LogicalAnd &node) { applyExpandRanks(node); }
268 void PermutationOperationPass::visit(const operation::LogicalNot &node) { applyExpandRanks(node); }
270 void PermutationOperationPass::visit(const operation::LogicalOr &node) { applyExpandRanks(node); }
272 void PermutationOperationPass::visit(const operation::Max &node) { applyExpandRanks(node); }
274 void PermutationOperationPass::visit(const operation::Min &node) { applyExpandRanks(node); }
276 void PermutationOperationPass::visit(const operation::Mul &node) { applyExpandRanks(node); }
278 void PermutationOperationPass::visit(const operation::Pack &node)
280 const auto &input_ind = node.getInputs().at(operation::Reshape::Input::INPUT);
281 const auto &input_obj = _graph.operands().at(input_ind);
282 const auto &input_shape = input_obj.shape();
284 const auto &output_ind = node.getOutputs().at(0);
285 const auto &output_obj = _graph.operands().at(output_ind);
286 const auto &output_shape = output_obj.shape();
288 if (input_shape.rank() < 4 || output_shape.rank() >= 4)
290 changeToKeepLayout(node);
294 void PermutationOperationPass::visit(const operation::PReLU &node) { applyExpandRanks(node); }
296 void PermutationOperationPass::visit(const operation::Reshape &node)
298 const auto &input_ind = node.getInputs().at(operation::Reshape::Input::INPUT);
299 const auto &input_obj = _graph.operands().at(input_ind);
300 const auto &input_shape = input_obj.shape();
302 const auto &output_ind = node.getOutputs().at(0);
303 const auto &output_obj = _graph.operands().at(output_ind);
304 const auto &output_shape = output_obj.shape();
306 if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
308 changeToKeepLayout(node);
312 void PermutationOperationPass::visit(const operation::SquaredDifference &node)
314 applyExpandRanks(node);
317 void PermutationOperationPass::visit(const operation::Sub &node) { applyExpandRanks(node); }
319 void PermutationOperationPass::visit(const operation::Unpack &node)
321 const auto &input_ind = node.getInputs().at(operation::Reshape::Input::INPUT);
322 const auto &input_obj = _graph.operands().at(input_ind);
323 const auto &input_shape = input_obj.shape();
325 const auto &output_ind = node.getOutputs().at(0);
326 const auto &output_obj = _graph.operands().at(output_ind);
327 const auto &output_shape = output_obj.shape();
329 if (input_shape.rank() < 4 || output_shape.rank() >= 4)
331 changeToKeepLayout(node);