Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / pass / ConstantInsertionPass.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ConstantInsertionPass.h"
18
19 #include "backend/Backend.h"
20 #include <ir/Graph.h>
21 #include <util/Utils.h>
22
23 namespace onert
24 {
25 namespace ir
26 {
27 namespace pass
28 {
29
30 void ConstantInsertionPass::callback(const OperationIndex &node_index, Operation &node)
31 {
32   const auto &op_sequence_index = _lowered_graph.op_seqs().getOperation(node_index);
33   const auto op_seq_lower_info = _lowered_graph.getLowerInfo(op_sequence_index);
34   const auto backend = op_seq_lower_info->backend();
35   const auto layout = op_seq_lower_info->layout();
36   const auto factor = operand::PermuteFactor{backend, layout};
37
38   for (const auto input : node.getInputs() | Remove::DUPLICATED | ir::Remove::UNDEFINED)
39   {
40     auto &object = _graph.operands().at(input);
41
42     if (object.isConstant())
43     {
44       const auto key = ReplaceKey{input, factor};
45       if (_replace_operands_map.count(key) == 0)
46       {
47         auto new_object = object;
48         new_object.unsetDef();
49         // TODO Remove const_case
50         const_cast<OperationIndexSet &>(new_object.getUses()).clear();
51         const auto new_index = _graph.operands().emplace(new_object);
52         _replace_operands_map[key] = new_index;
53       }
54
55       const auto replaced_input = _replace_operands_map[key];
56       // Update op_seq
57       if (_lowered_graph.op_seqs().at(op_sequence_index).getInputs().contains(input))
58       {
59         // All inputs of op_seq have the same PermuteFactor because those inputs are inputs of first
60         // operation
61         _lowered_graph.op_seqs().at(op_sequence_index).replaceInputs(input, replaced_input);
62       }
63
64       // Update the same inputs of a node at once because inputs of an operation have the same
65       // PermuteFactor
66       node.replaceInputs(input, replaced_input);
67
68       // Update operand
69       auto &replaced_object = _graph.operands().at(replaced_input);
70       replaced_object.insertUse(node_index);
71
72       // Remove this node from uses of origin operand
73       // Constant operand has no def.
74       assert(!object.getDef().valid());
75       object.removeUse(node_index);
76
77       // Remove origin operand
78       if (object.getUses().size() == 0)
79         _graph.removeOperand(input);
80     }
81   }
82
83   // Now this runtime does not support the node making output as constant
84   for (const auto &output : node.getOutputs())
85   {
86     UNUSED_RELEASE(output);
87     assert(!_graph.operands().at(output).isConstant());
88   }
89 }
90
91 } // namespace pass
92 } // namespace ir
93 } // namespace onert