2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "PermutationEliminationPass.h"
18 #include "backend/controlflow/Config.h"
20 #include "util/logging.h"
29 void PermutationEliminationPass::callback(const OperationIndex &ind, Operation &node)
35 void PermutationEliminationPass::visit(const operation::Permute &node)
37 auto in_operand = node.getInputs().at(0);
38 auto out_operand = node.getOutputs().at(0);
40 // Check if two tensors are both portable
41 // TODO Make this general, this is just a workaround to check two tensors are portable
43 auto in_def_factor = _lowered_graph.getLowerInfo(in_operand)->def_factors().getOnlyElement();
44 auto out_def_factor = _lowered_graph.getLowerInfo(out_operand)->def_factors().getOnlyElement();
46 auto in_backend_id = in_def_factor.backend()->config()->id();
47 auto out_backend_id = out_def_factor.backend()->config()->id();
49 // TODO Fix this workaround that removes only Permute between cpu and controlflow backend.
50 // This should be general.
51 if (!((in_backend_id == backend::controlflow::Config::ID && out_backend_id == "cpu") ||
52 (in_backend_id == "cpu" && out_backend_id == backend::controlflow::Config::ID)))
56 if (_graph.getOutputs().contains(out_operand))
58 // Exceptional case : When the output operand is a model output
59 // In this case we keep the output and remove the input
61 auto &out_operand_obj = _graph.operands().at(out_operand);
62 assert(out_operand_obj.getDef() == _op_ind);
63 out_operand_obj.unsetDef();
64 _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
65 if (!op_seq.getOutputs().contains(in_operand))
68 // Update OpSequence/Operation edges and Operand edges
69 op_seq.replaceOutputs(in_operand, out_operand);
70 for (auto op : op_seq.operations())
72 auto &operation_obj = _graph.operations().at(op);
73 if (operation_obj.getOutputs().contains(in_operand))
75 operation_obj.replaceOutputs(in_operand, out_operand);
76 out_operand_obj.setDef(op);
81 // Remove Permute operation, enclosing OpSequence and the operand
83 _graph.removeOperand(in_operand);
85 auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind);
86 // Assumes enclosing OpSequence contatins just this Permute operation
87 assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1);
88 _lowered_graph.op_seqs().remove(op_seq_ind);
89 _graph.operations().remove(_op_ind);
92 _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
93 if (!op_seq.getInputs().contains(in_operand))
96 op_seq.replaceInputs(in_operand, out_operand);
97 for (auto op : op_seq.operations())
99 auto &operation_obj = _graph.operations().at(op);
100 if (operation_obj.getInputs().contains(in_operand))
102 operation_obj.replaceInputs(in_operand, out_operand);
103 out_operand_obj.insertUse(op);
108 VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl;
109 VERBOSE(removePermute) << " - Input (removed) Operand : " << in_operand << std::endl;
110 VERBOSE(removePermute) << " - Output(kept) Operand : " << out_operand << std::endl;
114 // Otherwise keep the input and remove the output
116 auto &in_operand_obj = _graph.operands().at(in_operand);
117 in_operand_obj.removeUse(_op_ind);
119 // Make OpSequences(that use the output) use the input
120 _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
121 if (!op_seq.getInputs().contains(out_operand))
124 op_seq.replaceInputs(out_operand, in_operand);
125 for (auto op : op_seq.operations())
127 auto &operation_obj = _graph.operations().at(op);
128 if (operation_obj.getInputs().contains(out_operand))
130 operation_obj.replaceInputs(out_operand, in_operand);
131 in_operand_obj.insertUse(op);
136 // Remove Permute operation, enclosing OpSequence and the operand
138 _graph.removeOperand(out_operand);
140 auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind);
141 // Assumes enclosing OpSequence contatins just this Permute operation
142 assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1);
143 _lowered_graph.op_seqs().remove(op_seq_ind);
144 _graph.operations().remove(_op_ind);
147 VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl;
148 VERBOSE(removePermute) << " - Input (kept) Operand : " << in_operand << std::endl;
149 VERBOSE(removePermute) << " - Output(removed) Operand : " << out_operand << std::endl;