Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / pass / PermutationEliminationPass.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "PermutationEliminationPass.h"
18 #include "backend/controlflow/Config.h"
19
20 #include "util/logging.h"
21
22 namespace onert
23 {
24 namespace compiler
25 {
26 namespace pass
27 {
28
29 void PermutationEliminationPass::callback(const ir::OperationIndex &ind, ir::Operation &node)
30 {
31   _op_ind = ind;
32   node.accept(*this);
33 };
34
35 void PermutationEliminationPass::visit(const ir::operation::Permute &node)
36 {
37   auto in_operand = node.getInputs().at(0);
38   auto out_operand = node.getOutputs().at(0);
39
40   // Check if two tensors are both portable if not, we can't eliminate the node
41   {
42     auto in_def_factor = _lowered_graph.getLowerInfo(in_operand)->def_factors().getOnlyElement();
43     auto out_def_factor = _lowered_graph.getLowerInfo(out_operand)->def_factors().getOnlyElement();
44
45     auto in_config = in_def_factor.backend()->config();
46     auto out_config = out_def_factor.backend()->config();
47
48     // FIXME Supporting dynamic tensor does not exactly mean those are portable.
49     //       It may need to have another config option for checking if each uses `IPortableTensor`.
50     if (!(in_config->supportDynamicTensor() && out_config->supportDynamicTensor()))
51       return;
52   }
53
54   if (_graph.getOutputs().contains(out_operand))
55   {
56     // Exceptional case : When the output operand is a model output
57     // In this case we keep the output and remove the input
58
59     auto &out_operand_obj = _graph.operands().at(out_operand);
60     assert(out_operand_obj.getDef() == _op_ind);
61     out_operand_obj.unsetDef();
62     _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
63       if (!op_seq.getOutputs().contains(in_operand))
64         return;
65
66       // Update OpSequence/ir::Operation edges and ir::Operand edges
67       op_seq.replaceOutputs(in_operand, out_operand);
68       for (auto op : op_seq.operations())
69       {
70         auto &operation_obj = _graph.operations().at(op);
71         if (operation_obj.getOutputs().contains(in_operand))
72         {
73           operation_obj.replaceOutputs(in_operand, out_operand);
74           out_operand_obj.setDef(op);
75         }
76       }
77     });
78
79     // Remove Permute operation, enclosing OpSequence and the operand
80     {
81       _graph.removeOperand(in_operand);
82
83       auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind);
84       // Assumes enclosing OpSequence contatins just this Permute operation
85       assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1);
86       _lowered_graph.op_seqs().remove(op_seq_ind);
87       _graph.operations().remove(_op_ind);
88     }
89
90     _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
91       if (!op_seq.getInputs().contains(in_operand))
92         return;
93
94       op_seq.replaceInputs(in_operand, out_operand);
95       for (auto op : op_seq.operations())
96       {
97         auto &operation_obj = _graph.operations().at(op);
98         if (operation_obj.getInputs().contains(in_operand))
99         {
100           operation_obj.replaceInputs(in_operand, out_operand);
101           out_operand_obj.insertUse(op);
102         }
103       }
104     });
105
106     VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl;
107     VERBOSE(removePermute) << "  - Input (removed) ir::Operand : " << in_operand << std::endl;
108     VERBOSE(removePermute) << "  - Output(kept)    ir::Operand : " << out_operand << std::endl;
109   }
110   else
111   {
112     // Otherwise keep the input and remove the output
113
114     auto &in_operand_obj = _graph.operands().at(in_operand);
115     in_operand_obj.removeUse(_op_ind);
116
117     // Make OpSequences(that use the output) use the input
118     _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
119       if (!op_seq.getInputs().contains(out_operand))
120         return;
121
122       op_seq.replaceInputs(out_operand, in_operand);
123       for (auto op : op_seq.operations())
124       {
125         auto &operation_obj = _graph.operations().at(op);
126         if (operation_obj.getInputs().contains(out_operand))
127         {
128           operation_obj.replaceInputs(out_operand, in_operand);
129           in_operand_obj.insertUse(op);
130         }
131       }
132     });
133
134     // Remove Permute operation, enclosing OpSequence and the operand
135     {
136       _graph.removeOperand(out_operand);
137
138       auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind);
139       // Assumes enclosing OpSequence contatins just this Permute operation
140       assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1);
141       _lowered_graph.op_seqs().remove(op_seq_ind);
142       _graph.operations().remove(_op_ind);
143     }
144
145     VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl;
146     VERBOSE(removePermute) << "  - Input (kept)    ir::Operand : " << in_operand << std::endl;
147     VERBOSE(removePermute) << "  - Output(removed) ir::Operand : " << out_operand << std::endl;
148   }
149 }
150
151 } // namespace pass
152 } // namespace compiler
153 } // namespace onert