Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / pass / PermutationEliminationPass.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "PermutationEliminationPass.h"
18
19 #include "backend/Backend.h"
20 #include "util/logging.h"
21
22 namespace onert
23 {
24 namespace compiler
25 {
26 namespace pass
27 {
28
29 void PermutationEliminationPass::callback(const ir::OperationIndex &ind, ir::IOperation &node)
30 {
31   _op_ind = ind;
32   node.accept(*this);
33 };
34
35 void PermutationEliminationPass::visit(const ir::operation::Permute &node)
36 {
37   auto in_operand = node.getInputs().at(0);
38   auto out_operand = node.getOutputs().at(0);
39
40   // Check if two tensors are both portable if not, we can't eliminate the node
41   {
42     auto &operand_li_map = _lowered_graph.lower_info().operand;
43     auto in_def_factor = operand_li_map.getRawPtr(in_operand)->def_factors().getOnlyElement();
44     auto out_def_factor = operand_li_map.getRawPtr(out_operand)->def_factors().getOnlyElement();
45
46     auto in_config = in_def_factor.backend()->config();
47     auto out_config = out_def_factor.backend()->config();
48
49     // FIXME Supporting dynamic tensor does not exactly mean those are portable.
50     //       It may need to have another config option for checking if each uses `IPortableTensor`.
51     if (!(in_config->supportDynamicTensor() && out_config->supportDynamicTensor()))
52       return;
53   }
54
55   if (_graph.getOutputs().contains(out_operand))
56   {
57     // If the input is a const, we cannot remove it since we cannot put the constant data in the
58     // output buffer during prepare phase.
59     auto permute_input = node.getInputs().at(0);
60     if (_graph.operands().at(permute_input).isConstant())
61       return;
62     // If the input is a model input, we cannot remove it since our API lets users to set different
63     // buffers for inputs and outputs even though one tensor is both at the same time.
64     auto permute_output = node.getOutputs().at(0);
65     if (_graph.getInputs().contains(permute_input) && _graph.getOutputs().contains(permute_output))
66       return;
67     // Likewise, if copying between outputs to outputs, keep it.
68     if (_graph.getOutputs().contains(permute_input) && _graph.getOutputs().contains(permute_output))
69       return;
70
71     // Exceptional case : When the output operand is a model output
72     // In this case we keep the output and remove the input
73
74     auto &out_operand_obj = _graph.operands().at(out_operand);
75     assert(out_operand_obj.getDef() == _op_ind);
76     out_operand_obj.unsetDef();
77     _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) {
78       if (!op.getOutputs().contains(in_operand))
79         return;
80       // Update Operation and Operand edges
81       op.replaceOutputs(in_operand, out_operand);
82       out_operand_obj.setDef(op_ind);
83     });
84
85     // Remove Permute operation and the operand
86     {
87       _graph.removeOperand(in_operand);
88       _graph.operations().remove(_op_ind);
89     }
90
91     _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) {
92       if (!op.getInputs().contains(in_operand))
93         return;
94       op.replaceInputs(in_operand, out_operand);
95       out_operand_obj.insertUse(op_ind);
96     });
97
98     VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl;
99     VERBOSE(removePermute) << "  - Input (removed) Operand : " << in_operand << std::endl;
100     VERBOSE(removePermute) << "  - Output(kept)    Operand : " << out_operand << std::endl;
101   }
102   else
103   {
104     // Otherwise keep the input and remove the output
105
106     auto &in_operand_obj = _graph.operands().at(in_operand);
107     in_operand_obj.removeUse(_op_ind);
108
109     // Make operations(that use the output) use the input
110     _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) {
111       if (!op.getInputs().contains(out_operand))
112         return;
113       op.replaceInputs(out_operand, in_operand);
114       in_operand_obj.insertUse(op_ind);
115     });
116
117     // Remove the Permute operation and out_operand
118     {
119       _graph.removeOperand(out_operand);
120       _graph.operations().remove(_op_ind);
121     }
122
123     VERBOSE(removePermute) << "Permute Op removed : " << _op_ind << std::endl;
124     VERBOSE(removePermute) << "  - Input (kept)    Operand : " << in_operand << std::endl;
125     VERBOSE(removePermute) << "  - Output(removed) Operand : " << out_operand << std::endl;
126   }
127 }
128
129 } // namespace pass
130 } // namespace compiler
131 } // namespace onert