2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "PermutationInsertionPass.h"
20 #include "../../backend/builtin/Config.h"
22 #include "compiler/OperationLowerInfo.h"
23 #include "ir/operation/Permute.h"
24 #include "util/logging.h"
28 #include <unordered_map>
38 void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Operand &object)
40 auto &operand_li_map = _lowered_graph.lower_info().operand;
41 auto &&operand_li = operand_li_map.getRawPtr(index);
44 // NOTE Later, constants also will have Def
46 if (operand_li->def_factors().size() == 0)
51 std::list<ir::OperationIndex> permute_indexes;
53 // Build a map for all necessary type of operands
54 std::unordered_map<PermuteFactor, ir::OperandIndex> factor_to_index;
56 assert(operand_li->def_factors().size() == 1);
57 for (auto &&factor : operand_li->def_factors())
59 factor_to_index.emplace(factor, index);
62 auto insert_set = operand_li->use_factors() - operand_li->def_factors();
63 for (auto &&factor : insert_set)
65 const auto permute_operation_index = insertPermute(index, factor);
66 permute_indexes.push_back(permute_operation_index);
67 const auto &permute_operation = _graph.operations().at(permute_operation_index);
68 const auto permuted_operand_index = permute_operation.getOutputs().at(0);
69 factor_to_index.emplace(factor, permuted_operand_index);
73 // Update operations' input that uses this operand
75 std::list<ir::OperationIndex> remove_list;
77 auto uses = object.getUses();
78 for (auto &&use : uses)
80 // If permute operation, ignore it
81 if (std::find(permute_indexes.begin(), permute_indexes.end(), use) != permute_indexes.end())
84 auto &operation = _graph.operations().at(use);
85 auto op_li = _lowered_graph.lower_info().operation.getRawPtr(use);
87 const auto op_layout = op_li->layout();
88 const backend::Backend *backend = op_li->backend();
90 auto use_node_inputs = operation.getInputs();
91 assert(use_node_inputs.contains(index));
93 auto new_index = factor_to_index.at({backend, op_layout});
94 if (index != new_index)
96 // Update from operation
97 // Replace the same inputs of an operation at once for the following reasons:
99 operation.replaceInputs(index, new_index);
101 // Update from operand
102 remove_list.push_back(
103 use); // Removal should be done in another loop since we are in the loop
104 _graph.operands().at(new_index).insertUse(use);
108 for (const auto &operation_index : remove_list)
110 object.removeUse(operation_index);
115 ir::OperationIndex PermutationInsertionPass::insertPermute(const ir::OperandIndex &operand_index,
116 const PermuteFactor &factor)
118 auto &operand = _graph.operands().at(operand_index);
120 // Generate output operand and permute operation
121 auto out_operand_index = _graph.addOperand(operand.shape(), operand.typeInfo());
122 // change model output if operand_index is model output index and the out operand is builtin
124 auto &model_outputs = _graph.getOutputs();
125 const backend::Backend *builtin_backend = compiler::BackendManager::get().getBuiltin();
126 assert(builtin_backend->config()->id() == onert::backend::builtin::Config::ID);
128 if (model_outputs.contains(operand_index) && factor.backend() == builtin_backend)
130 model_outputs.replace(operand_index, out_operand_index);
133 auto &operand_li_map = _lowered_graph.lower_info().operand;
135 // Find Permute information
136 auto input_factor = operand_li_map.getRawPtr(operand_index)->def_factors().getOnlyElement();
137 auto input_backend = input_factor.backend();
138 auto output_backend = factor.backend();
139 // NOTE Permute may not have specific layout because the layout of input and output may be
141 const auto permute_node_layout = ir::Layout::UNKNOWN;
142 // NOTE If one backend supports several layout, the backend must support Permute operation
143 const backend::Backend *permute_node_backend = compiler::BackendManager::get().getBuiltin();
144 assert(permute_node_backend->config()->id() == onert::backend::builtin::Config::ID);
146 if (input_backend == output_backend)
148 permute_node_backend = input_backend;
150 const PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout};
152 // Update LowerInfo of input operand
153 auto operand_lower_info = operand_li_map.getRawPtr(operand_index);
154 operand_lower_info->removeUsePermuteFactor(factor);
155 operand_lower_info->addUsePermuteFactor(permute_node_factor);
157 // Update LowerInfo of output operand
158 auto out_operand_li = std::make_unique<compiler::OperandLowerInfo>();
160 // The input and output factors of all nodes will be the same except Permute. So Tensor's
161 // allocators allocates memory using only the information of def permutation factor now.
162 // TODO Change param to permute_node_factor
163 out_operand_li->addDefPermuteFactor(factor);
164 out_operand_li->addUsePermuteFactor(factor);
165 operand_li_map.set(out_operand_index, std::move(out_operand_li));
167 // Insert permute operation to the graph
168 const auto input_layout = input_factor.layout();
169 const auto output_layout = factor.layout();
170 using Permute = ir::operation::Permute;
171 const auto permute_type = [&]() {
172 if (input_layout == ir::Layout::NHWC && output_layout == ir::Layout::NCHW)
174 return Permute::Type::NHWC_TO_NCHW;
176 else if (input_layout == ir::Layout::NCHW && output_layout == ir::Layout::NHWC)
178 return Permute::Type::NCHW_TO_NHWC;
182 return Permute::Type::COPY;
185 auto insert_node = std::make_unique<Permute>(operand_index, out_operand_index, permute_type);
187 auto node_index = _graph.operations().push(std::move(insert_node));
189 VERBOSE_F() << "Permute Op inserted, node index : " << node_index << std::endl;
190 VERBOSE_F() << " - Input (original) Operand : " << operand_index << "("
191 << input_factor.backend()->config()->id() << ")" << std::endl;
192 VERBOSE_F() << " - Output(inserted) Operand : " << out_operand_index << "("
193 << factor.backend()->config()->id() << ")" << std::endl;
195 // Operation LowerInfo
197 auto &operation_li_map = _lowered_graph.lower_info().operation;
198 operation_li_map.set(node_index, std::make_unique<compiler::OperationLowerInfo>(
199 permute_node_backend, permute_node_layout));
202 // Update Use/Def info
204 _graph.operands().at(operand_index).insertUse(node_index);
205 _graph.operands().at(out_operand_index).setDef(node_index);
210 } // namespace compiler