2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "PermutationOperationPass.h"
19 #include "backend/Backend.h"
20 #include "backend/IConfig.h"
22 #include "util/logging.h"
33 void PermutationOperationPass::callback(const OperationIndex &, IOperation &node)
38 // TODO Remove this. Expanding ranks of Operand is dangerous
39 void PermutationOperationPass::applyExpandRanks(const Operation &node)
41 const auto &output_ind = node.getOutputs().at(0);
42 const auto &output = _graph.operands().at(output_ind);
44 assert(output.getDef().valid());
45 const auto node_index = output.getDef();
46 const auto frontend_layout = _graph.layout();
47 const auto backend_layout = _lowered_graph.lower_info().operation.getRawPtr(node_index)->layout();
49 if (frontend_layout == backend_layout)
54 int32_t expanded_rank = 0;
55 for (const auto &index :
56 (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
58 expanded_rank = std::max(expanded_rank, _graph.operands().at(index).shape().rank());
60 if (expanded_rank < 4)
63 for (const auto &index :
64 (node.getInputs() + node.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
66 const auto &operand = _graph.operands().at(index);
67 if (operand.shape().rank() < expanded_rank)
69 if (operand.getUses().size() > 1)
70 throw std::runtime_error("PermutationOperationPass: not supported expanding rank of "
71 "operand used in more than one node");
72 // TODO remove const_cast later. For example, _ctx may need to be a non const variable or
73 // a node to extend shape may be inserted in front of this operation
74 const_cast<Shape &>(operand.shape()).extendRank(expanded_rank);
79 void PermutationOperationPass::changeToKeepLayout(const Operation &node)
81 const auto &output_ind = node.getOutputs().at(0);
82 const auto &output_obj = _graph.operands().at(output_ind);
84 assert(output_obj.getDef().valid());
85 const auto node_index = output_obj.getDef();
87 auto &operation_li_map = _lowered_graph.lower_info().operation;
88 auto &operand_li_map = _lowered_graph.lower_info().operand;
89 const auto frontend_layout = _graph.layout();
90 const auto backend_layout = operation_li_map.getRawPtr(node_index)->layout();
92 if (frontend_layout == backend_layout)
97 // Permutation changing layout beyond 4-D is not supported yet
98 assert(output_obj.shape().rank() <= 4);
100 // Change PermuteFactors of operands and the operation of target node
102 const auto op_li = operation_li_map.getRawPtr(node_index);
103 const auto backend = op_li->backend();
105 operation_li_map.set(node_index,
106 std::make_unique<compiler::OperationLowerInfo>(backend, frontend_layout));
108 const PermuteFactor removed_factor{backend, backend_layout};
109 const PermuteFactor new_factor{backend, frontend_layout};
110 for (const auto &input : node.getInputs() | Remove::DUPLICATED | Remove::UNDEFINED)
112 // Check if it can be removed by checking if the operand is used by another operation and
113 // it uses the same backend and layout
114 bool canRemove = true;
115 for (const auto &use : _graph.operands().at(input).getUses())
117 if (use != node_index)
119 auto use_op_li = operation_li_map.getRawPtr(use);
120 if (use_op_li->backend() == backend && use_op_li->layout() == backend_layout)
128 auto input_li = operand_li_map.getRawPtr(input);
131 input_li->removeUsePermuteFactor(removed_factor);
133 input_li->addUsePermuteFactor(new_factor);
135 // Whether if node's input is an input of model or a constant
136 if (!_graph.operands().at(input).getDef().valid() &&
137 (input_li->def_factors().size() == 1 &&
138 input_li->def_factors().getOnlyElement() == removed_factor))
140 assert(_graph.getInputs().contains(input) || _graph.operands().at(input).isConstant());
141 input_li->removeDefPermuteFactor(removed_factor);
142 input_li->addDefPermuteFactor(new_factor);
146 for (const auto &output : node.getOutputs() | Remove::DUPLICATED | Remove::UNDEFINED)
148 auto lower_info = operand_li_map.getRawPtr(output);
149 lower_info->removeDefPermuteFactor(removed_factor);
150 lower_info->addDefPermuteFactor(new_factor);
152 // Whether if node's output is an output of model
153 if (_graph.operands().at(output).getUses().size() == 0)
155 assert(_graph.getOutputs().contains(output));
156 lower_info->removeUsePermuteFactor(removed_factor);
157 lower_info->addUsePermuteFactor(new_factor);
163 void PermutationOperationPass::visit(const ir::operation::BinaryArithmetic &node)
165 applyExpandRanks(node);
168 void PermutationOperationPass::visit(const ir::operation::Concat &node) { applyExpandRanks(node); }
170 void PermutationOperationPass::visit(const ir::operation::Comparison &node)
172 applyExpandRanks(node);
175 void PermutationOperationPass::visit(const ir::operation::ElementwiseBinary &node)
177 applyExpandRanks(node);
180 void PermutationOperationPass::visit(const ir::operation::ElementwiseUnary &node)
182 applyExpandRanks(node);
185 void PermutationOperationPass::visit(const ir::operation::FullyConnected &node)
187 const auto &input_ind = node.getInputs().at(ir::operation::FullyConnected::Input::INPUT);
188 const auto &input_obj = _graph.operands().at(input_ind);
189 const auto &input_shape = input_obj.shape();
191 if (input_shape.rank() >= 4)
193 changeToKeepLayout(node);
197 void PermutationOperationPass::visit(const ir::operation::Gather &node)
199 const auto &input_ind = node.getInputs().at(ir::operation::Gather::Input::INPUT);
200 const auto &input_obj = _graph.operands().at(input_ind);
201 const auto &input_shape = input_obj.shape();
203 const auto &output_ind = node.getOutputs().at(0);
204 const auto &output_obj = _graph.operands().at(output_ind);
205 const auto &output_shape = output_obj.shape();
207 if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
209 changeToKeepLayout(node);
213 void PermutationOperationPass::visit(const ir::operation::OneHot &node)
215 const auto &output_ind = node.getOutputs().at(0);
216 const auto &output_obj = _graph.operands().at(output_ind);
217 const auto &output_shape = output_obj.shape();
219 if (output_shape.rank() >= 4)
221 changeToKeepLayout(node);
225 void PermutationOperationPass::visit(const ir::operation::Pack &node)
227 const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
228 const auto &input_obj = _graph.operands().at(input_ind);
229 const auto &input_shape = input_obj.shape();
231 const auto &output_ind = node.getOutputs().at(0);
232 const auto &output_obj = _graph.operands().at(output_ind);
233 const auto &output_shape = output_obj.shape();
235 if (input_shape.rank() < 4 || output_shape.rank() >= 4)
237 changeToKeepLayout(node);
241 void PermutationOperationPass::visit(const ir::operation::PReLU &node) { applyExpandRanks(node); }
243 void PermutationOperationPass::visit(const ir::operation::Reshape &node)
245 const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
246 const auto &input_obj = _graph.operands().at(input_ind);
247 const auto &input_shape = input_obj.shape();
249 const auto &output_ind = node.getOutputs().at(0);
250 const auto &output_obj = _graph.operands().at(output_ind);
251 const auto &output_shape = output_obj.shape();
253 if (input_shape.rank() >= 4 || output_shape.rank() >= 4)
255 changeToKeepLayout(node);
259 void PermutationOperationPass::visit(const ir::operation::SquaredDifference &node)
261 applyExpandRanks(node);
264 void PermutationOperationPass::visit(const ir::operation::Unpack &node)
266 const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
267 const auto &input_obj = _graph.operands().at(input_ind);
268 const auto &input_shape = input_obj.shape();
270 const auto &output_ind = node.getOutputs().at(0);
271 const auto &output_obj = _graph.operands().at(output_ind);
272 const auto &output_shape = output_obj.shape();
274 if (input_shape.rank() < 4 || output_shape.rank() >= 4)
276 changeToKeepLayout(node);
281 } // namespace compiler