Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / pass / PermutationEliminationPass.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "PermutationEliminationPass.h"
18 #include "backend/controlflow/Config.h"
19
20 #include "util/logging.h"
21
22 namespace onert
23 {
24 namespace ir
25 {
26 namespace pass
27 {
28
29 void PermutationEliminationPass::callback(const OperationIndex &ind, Operation &node)
30 {
31   _op_ind = ind;
32   node.accept(*this);
33 };
34
35 void PermutationEliminationPass::visit(const operation::Permute &node)
36 {
37   auto in_operand = node.getInputs().at(0);
38   auto out_operand = node.getOutputs().at(0);
39
40   // Check if two tensors are both portable
41   // TODO Make this general, this is just a workaround to check two tensors are portable
42   {
43     auto in_def_factor = _lowered_graph.getLowerInfo(in_operand)->def_factors().getOnlyElement();
44     auto out_def_factor = _lowered_graph.getLowerInfo(out_operand)->def_factors().getOnlyElement();
45
46     auto in_backend_id = in_def_factor.backend()->config()->id();
47     auto out_backend_id = out_def_factor.backend()->config()->id();
48
49     // TODO Fix this workaround that removes only Permute between cpu and controlflow backend.
50     //      This should be general.
51     if (!((in_backend_id == backend::controlflow::Config::ID && out_backend_id == "cpu") ||
52           (in_backend_id == "cpu" && out_backend_id == backend::controlflow::Config::ID)))
53       return;
54   }
55
56   if (_graph.getOutputs().contains(out_operand))
57   {
58     // Exceptional case : When the output operand is a model output
59     // In this case we keep the output and remove the input
60
61     auto &out_operand_obj = _graph.operands().at(out_operand);
62     assert(out_operand_obj.getDef() == _op_ind);
63     out_operand_obj.unsetDef();
64     _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
65       if (!op_seq.getOutputs().contains(in_operand))
66         return;
67
68       // Update OpSequence/Operation edges and Operand edges
69       op_seq.replaceOutputs(in_operand, out_operand);
70       for (auto op : op_seq.operations())
71       {
72         auto &operation_obj = _graph.operations().at(op);
73         if (operation_obj.getOutputs().contains(in_operand))
74         {
75           operation_obj.replaceOutputs(in_operand, out_operand);
76           out_operand_obj.setDef(op);
77         }
78       }
79     });
80
81     // Remove Permute operation, enclosing OpSequence and the operand
82     {
83       _graph.removeOperand(in_operand);
84
85       auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind);
86       // Assumes enclosing OpSequence contatins just this Permute operation
87       assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1);
88       _lowered_graph.op_seqs().remove(op_seq_ind);
89       _graph.operations().remove(_op_ind);
90     }
91
92     _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
93       if (!op_seq.getInputs().contains(in_operand))
94         return;
95
96       op_seq.replaceInputs(in_operand, out_operand);
97       for (auto op : op_seq.operations())
98       {
99         auto &operation_obj = _graph.operations().at(op);
100         if (operation_obj.getInputs().contains(in_operand))
101         {
102           operation_obj.replaceInputs(in_operand, out_operand);
103           out_operand_obj.insertUse(op);
104         }
105       }
106     });
107
108     VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl;
109     VERBOSE(removePermute) << "  - Input (removed) Operand : " << in_operand << std::endl;
110     VERBOSE(removePermute) << "  - Output(kept)    Operand : " << out_operand << std::endl;
111   }
112   else
113   {
114     // Otherwise keep the input and remove the output
115
116     auto &in_operand_obj = _graph.operands().at(in_operand);
117     in_operand_obj.removeUse(_op_ind);
118
119     // Make OpSequences(that use the output) use the input
120     _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
121       if (!op_seq.getInputs().contains(out_operand))
122         return;
123
124       op_seq.replaceInputs(out_operand, in_operand);
125       for (auto op : op_seq.operations())
126       {
127         auto &operation_obj = _graph.operations().at(op);
128         if (operation_obj.getInputs().contains(out_operand))
129         {
130           operation_obj.replaceInputs(out_operand, in_operand);
131           in_operand_obj.insertUse(op);
132         }
133       }
134     });
135
136     // Remove Permute operation, enclosing OpSequence and the operand
137     {
138       _graph.removeOperand(out_operand);
139
140       auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind);
141       // Assumes enclosing OpSequence contatins just this Permute operation
142       assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1);
143       _lowered_graph.op_seqs().remove(op_seq_ind);
144       _graph.operations().remove(_op_ind);
145     }
146
147     VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl;
148     VERBOSE(removePermute) << "  - Input (kept)    Operand : " << in_operand << std::endl;
149     VERBOSE(removePermute) << "  - Output(removed) Operand : " << out_operand << std::endl;
150   }
151 }
152
153 } // namespace pass
154 } // namespace ir
155 } // namespace onert