2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "PermutationInsertionPass.h"
21 #include <unordered_map>
23 #include "backend/controlflow/Config.h"
24 #include "ir/Operand.h"
25 #include "ir/operation/LowerInfo.h"
27 #include "backend/IConfig.h"
28 #include "util/logging.h"
30 #include "ir/operation/Permute.h"
39 void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Operand &object)
41 auto &&operand_li = _lowered_graph.getLowerInfo(index);
44 // NOTE Later, constants also will have Def
46 if (operand_li->def_factors().size() == 0)
51 std::list<ir::OperationIndex> permute_indexes;
53 // Build a map for all necessary type of operands
54 std::unordered_map<ir::operand::PermuteFactor, ir::OperandIndex> factor_to_index;
56 assert(operand_li->def_factors().size() == 1);
57 for (auto factor : operand_li->def_factors())
59 factor_to_index.emplace(factor, index);
62 auto insert_set = operand_li->use_factors() - operand_li->def_factors();
63 for (auto factor : insert_set)
65 const auto permute_operation_index = insertPermute(index, factor);
66 permute_indexes.push_back(permute_operation_index);
67 const auto &permute_operation = _graph.operations().at(permute_operation_index);
68 const auto permuted_operand_index = permute_operation.getOutputs().at(0);
69 factor_to_index.emplace(factor, permuted_operand_index);
73 // Update operations' input that uses this operand
75 std::list<ir::OperationIndex> remove_list;
77 auto uses = object.getUses();
80 // If permute operation, ignore it
81 if (std::find(permute_indexes.begin(), permute_indexes.end(), use) != permute_indexes.end())
84 auto &operation = _graph.operations().at(use);
85 assert(_lowered_graph.op_seqs().containsOperation(use));
86 auto op_seq_index = _lowered_graph.op_seqs().getOperation(use);
87 auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
89 const auto op_seq_layout = op_seq_li->layout();
90 const backend::Backend *backend = op_seq_li->backend();
92 auto use_node_inputs = operation.getInputs();
93 assert(use_node_inputs.contains(index));
95 auto new_index = factor_to_index.at({backend, op_seq_layout});
96 if (index != new_index)
99 // Replace the same inputs of an OpSequence at once for the following reasons:
100 // 1. An OpSequence's inputs are the same inputs of first operation
101 // 2. An OpSequence may have inputs as the same operand (2 or more).
102 // 3. The same inputs of OpSequence have the same PermuteFactor.
103 _lowered_graph.op_seqs().at(op_seq_index).replaceInputs(index, new_index);
105 // Update from operation
106 // Replace the same inputs of an operation at once for the following reasons:
108 operation.replaceInputs(index, new_index);
110 // Update from operand
111 remove_list.push_back(
112 use); // Removal should be done in another loop since we are in the loop
113 _graph.operands().at(new_index).insertUse(use);
117 for (auto &operation : remove_list)
119 object.removeUse(operation);
124 ir::OperationIndex PermutationInsertionPass::insertPermute(const ir::OperandIndex &operand_index,
125 const ir::operand::PermuteFactor &factor)
127 assert(!_graph.isBuildingPhase());
129 auto &operand = _graph.operands().at(operand_index);
131 // Generate output operand and permute operation
132 auto out_operand_index = _graph.addOperand(operand.shape(), operand.typeInfo());
133 // change model output if operand_index is model output index and the out operand is controlflow
135 auto &model_outputs = _graph.getOutputs();
136 const backend::Backend *cf_backend = compiler::BackendManager::get().getControlflow();
137 if (model_outputs.contains(operand_index) && factor.backend() == cf_backend)
139 model_outputs.replace(operand_index, out_operand_index);
142 // Find Permute information
143 auto input_factor = _lowered_graph.getLowerInfo(operand_index)->def_factors().getOnlyElement();
144 auto input_backend = input_factor.backend();
145 auto output_backend = factor.backend();
146 // NOTE Permute may not have specific layout because the layout of input and output may be
148 const auto permute_node_layout = ir::Layout::UNKNOWN;
149 // NOTE If one backend supports several layout, the backend must support Permute operation
150 const backend::Backend *permute_node_backend = compiler::BackendManager::get().getControlflow();
151 if (input_backend == output_backend)
153 permute_node_backend = input_backend;
155 const ir::operand::PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout};
157 // Update LowerInfo of input operand
158 auto operand_lower_info = _lowered_graph.getLowerInfo(operand_index);
159 operand_lower_info->removeUsePermuteFactor(factor);
160 operand_lower_info->addUsePermuteFactor(permute_node_factor);
162 // Update LowerInfo of output operand
163 auto out_operand_li = std::make_unique<ir::operand::LowerInfo>();
165 // The input and output factors of all nodes will be the same except Permute. So Tensor's
166 // allocators allocates memory using only the information of def permutation factor now.
167 // TODO Change param to permute_node_factor
168 out_operand_li->addDefPermuteFactor(factor);
169 out_operand_li->addUsePermuteFactor(factor);
170 _lowered_graph.setLowerInfo(out_operand_index, std::move(out_operand_li));
172 // Insert permute operation to the graph
173 const auto input_layout = input_factor.layout();
174 const auto output_layout = factor.layout();
175 using Permute = ir::operation::Permute;
176 const auto permute_type = [&]() {
177 if (input_layout == ir::Layout::NHWC && output_layout == ir::Layout::NCHW)
179 return Permute::Type::NHWC_TO_NCHW;
181 else if (input_layout == ir::Layout::NCHW && output_layout == ir::Layout::NHWC)
183 return Permute::Type::NCHW_TO_NHWC;
187 return Permute::Type::COPY;
190 auto insert_node = std::make_unique<Permute>(operand_index, out_operand_index, permute_type);
192 auto node_index = _graph.operations().push(std::move(insert_node));
193 const auto &node = _graph.operations().at(node_index);
195 VERBOSE_F() << "Permute Op inserted, node index : " << node_index << std::endl;
196 VERBOSE_F() << " - Input (original) Operand : " << operand_index << "("
197 << input_factor.backend()->config()->id() << ")" << std::endl;
198 VERBOSE_F() << " - Output(inserted) Operand : " << out_operand_index << "("
199 << factor.backend()->config()->id() << ")" << std::endl;
203 auto op_seq_index = _lowered_graph.op_seqs().emplace(node_index, permute_node_layout);
204 auto &op_seq = _lowered_graph.op_seqs().at(op_seq_index);
205 op_seq.setInputs(node.getInputs());
206 op_seq.setOutputs(node.getOutputs());
207 _lowered_graph.setLowerInfo(op_seq_index, std::make_unique<ir::operation::LowerInfo>(
208 permute_node_backend, permute_node_layout));
211 // Update Use/Def info
213 _graph.operands().at(operand_index).insertUse(node_index);
214 _graph.operands().at(out_operand_index).setDef(node_index);
219 } // namespace compiler