2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "PermutationInsertionPass.h"
21 #include <unordered_map>
23 #include "backend/controlflow/Config.h"
24 #include "ir/Operand.h"
25 #include "ir/operation/LowerInfo.h"
27 #include "backend/IConfig.h"
28 #include "util/logging.h"
30 #include "ir/operation/Permute.h"
39 void PermutationInsertionPass::callback(const OperandIndex &index, Operand &object)
41 auto &&operand_li = _lowered_graph.getLowerInfo(index);
44 // NOTE Later, constants also will have Def
46 if (operand_li->def_factors().size() == 0)
51 std::list<OperationIndex> permute_indexes;
53 // Build a map for all necessary type of operands
54 std::unordered_map<operand::PermuteFactor, OperandIndex> factor_to_index;
56 assert(operand_li->def_factors().size() == 1);
57 for (auto factor : operand_li->def_factors())
59 factor_to_index.emplace(factor, index);
62 auto insert_set = operand_li->use_factors() - operand_li->def_factors();
63 for (auto factor : insert_set)
65 const auto permute_operation_index = insertPermute(index, factor);
66 permute_indexes.push_back(permute_operation_index);
67 const auto &permute_operation = _graph.operations().at(permute_operation_index);
68 const auto permuted_operand_index = permute_operation.getOutputs().at(0);
69 factor_to_index.emplace(factor, permuted_operand_index);
73 // Update operations' input that uses this operand
75 std::list<OperationIndex> remove_list;
77 auto uses = object.getUses();
80 // If permute operation, ignore it
81 if (std::find(permute_indexes.begin(), permute_indexes.end(), use) != permute_indexes.end())
84 auto &operation = _graph.operations().at(use);
85 assert(_lowered_graph.op_seqs().containsOperation(use));
86 auto op_seq_index = _lowered_graph.op_seqs().getOperation(use);
87 auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
89 const auto op_seq_layout = op_seq_li->layout();
90 const backend::Backend *backend = op_seq_li->backend();
92 auto use_node_inputs = operation.getInputs();
93 assert(use_node_inputs.contains(index));
95 auto new_index = factor_to_index.at({backend, op_seq_layout});
96 if (index != new_index)
99 // Replace the same inputs of an OpSequence at once for the following reasons:
100 // 1. An OpSequence's inputs are the same inputs of first operation
101 // 2. An OpSequence may have inputs as the same operand (2 or more).
102 // 3. The same inputs of OpSequence have the same PermuteFactor.
103 _lowered_graph.op_seqs().at(op_seq_index).replaceInputs(index, new_index);
105 // Update from operation
106 // Replace the same inputs of an operation at once for the following reasons:
108 operation.replaceInputs(index, new_index);
110 // Update from operand
111 remove_list.push_back(
112 use); // Removal should be done in another loop since we are in the loop
113 _graph.operands().at(new_index).insertUse(use);
117 for (auto &operation : remove_list)
119 object.removeUse(operation);
124 OperationIndex PermutationInsertionPass::insertPermute(const OperandIndex &operand_index,
125 const operand::PermuteFactor &factor)
127 assert(!_graph.isBuildingPhase());
129 auto &operand = _graph.operands().at(operand_index);
131 // Generate output operand and permute operation
132 auto out_operand_index = _graph.addOperand(operand.shape(), operand.typeInfo());
133 // change model output if operand_index is model output index
134 auto &model_outputs = _graph.getOutputs();
135 if (model_outputs.contains(operand_index))
137 model_outputs.replace(operand_index, out_operand_index);
140 // Find Permute information
141 auto input_factor = _lowered_graph.getLowerInfo(operand_index)->def_factors().getOnlyElement();
142 auto input_backend = input_factor.backend();
143 auto output_backend = factor.backend();
144 // NOTE Permute may not have specific layout because the layout of input and output may be
146 const auto permute_node_layout = Layout::UNKNOWN;
147 // NOTE If one backend supports several layout, the backend must support Permute operation
148 const backend::Backend *permute_node_backend = compiler::BackendManager::get().getControlflow();
149 if (input_backend == output_backend)
151 permute_node_backend = input_backend;
153 const operand::PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout};
155 // Update LowerInfo of input operand
156 auto operand_lower_info = _lowered_graph.getLowerInfo(operand_index);
157 operand_lower_info->removeUsePermuteFactor(factor);
158 operand_lower_info->addUsePermuteFactor(permute_node_factor);
160 // Update LowerInfo of output operand
161 auto out_operand_li = std::make_unique<operand::LowerInfo>();
163 // The input and output factors of all nodes will be the same except Permute. So Tensor's
164 // allocators allocates memory using only the information of def permutation factor now.
165 // TODO Change param to permute_node_factor
166 out_operand_li->addDefPermuteFactor(factor);
167 out_operand_li->addUsePermuteFactor(factor);
168 _lowered_graph.setLowerInfo(out_operand_index, std::move(out_operand_li));
170 // Insert permute operation to the graph
171 const auto input_layout = input_factor.layout();
172 const auto output_layout = factor.layout();
173 using Permute = operation::Permute;
174 const auto permute_type = [&]() {
175 if (input_layout == Layout::NHWC && output_layout == Layout::NCHW)
177 return Permute::Type::NHWC_TO_NCHW;
179 else if (input_layout == Layout::NCHW && output_layout == Layout::NHWC)
181 return Permute::Type::NCHW_TO_NHWC;
185 return Permute::Type::COPY;
188 auto insert_node = std::make_unique<Permute>(operand_index, out_operand_index, permute_type);
190 auto node_index = _graph.operations().push(std::move(insert_node));
191 const auto &node = _graph.operations().at(node_index);
193 VERBOSE_F() << "Permute Op inserted, node index : " << node_index << std::endl;
194 VERBOSE_F() << " - Input (original) Operand : " << operand_index << std::endl;
195 VERBOSE_F() << " - Output(inserted) Operand : " << out_operand_index << std::endl;
199 auto op_seq_index = _lowered_graph.op_seqs().emplace(node_index, permute_node_layout);
200 auto &op_seq = _lowered_graph.op_seqs().at(op_seq_index);
201 op_seq.setInputs(node.getInputs());
202 op_seq.setOutputs(node.getOutputs());
203 _lowered_graph.setLowerInfo(op_seq_index, std::make_unique<operation::LowerInfo>(
204 permute_node_backend, permute_node_layout));
207 // Update Use/Def info
209 _graph.operands().at(operand_index).insertUse(node_index);
210 _graph.operands().at(out_operand_index).setDef(node_index);