Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / Graph.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ir/Graph.h"
18
19 #include "OperationValidator.h"
20 #include "verifier/Verifier.h"
21
22 #include "util/Set.h"
23
24 namespace onert
25 {
26 namespace ir
27 {
28
29 Graph::Graph() = default;
30
31 Graph::Graph(const Graph &) = default;
32
33 Graph::~Graph(void) = default;
34
35 OperandIndex Graph::addOperand(const Shape &shape, const TypeInfo &type)
36 {
37   return _operands.emplace(shape, type);
38 }
39
40 OperandIndex Graph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand)
41 {
42   return _operands.push(std::move(operand), index);
43 }
44
45 bool Graph::checkOperandsForOperation(const IOperation &operation)
46 {
47   auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
48   auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
49   for (auto &&input : inputs)
50     if (!operands().exist(input))
51       return false;
52   for (auto &&input : outputs)
53     if (!operands().exist(input))
54       return false;
55   return true;
56 }
57
58 void Graph::linkOperandToOperation(OperationIndex index, const IOperation &operation)
59 {
60   auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
61   auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
62
63   for (auto &&input : inputs)
64     operands().at(input).insertUse(index);
65   for (auto &&output : outputs)
66     operands().at(output).setDef(index);
67 }
68
69 OperationIndex Graph::addOperation(std::unique_ptr<IOperation> &&operation)
70 {
71   const IOperation &op_ref = *operation;
72   if (!checkOperandsForOperation(op_ref))
73     return OperationIndex{};
74   auto ind = _operations.push(std::move(operation));
75   if (ind.valid())
76     linkOperandToOperation(ind, op_ref);
77   return ind;
78 }
79
80 OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<IOperation> &&operation)
81 {
82   const IOperation &op_ref = *operation;
83   if (!checkOperandsForOperation(op_ref))
84     return OperationIndex{};
85   auto ind_gen = _operations.push(std::move(operation), index);
86   if (ind_gen.valid())
87   {
88     assert(ind_gen == index);
89     linkOperandToOperation(index, op_ref);
90   }
91   return index;
92 }
93
94 OperationIndex Graph::replaceOperation(OperationIndex index,
95                                        std::unique_ptr<IOperation> &&operation)
96 {
97   const IOperation &op_ref = *operation;
98   if (!checkOperandsForOperation(op_ref) || !_operations.exist(index))
99     return OperationIndex{};
100
101   // Check the new operation has the same inputs/outputs as the existing operation
102   const auto &old_op = _operations.at(index);
103   if (!(old_op.getInputs() == op_ref.getInputs() && old_op.getOutputs() == op_ref.getOutputs()))
104   {
105     return OperationIndex{};
106   }
107
108   return _operations.set(index, std::move(operation));
109 }
110
111 void Graph::setOperandValue(const OperandIndex &ind, std::shared_ptr<Data> data)
112 {
113   assert(_operands.exist(ind));
114   _operands.at(ind).data(std::move(data));
115 }
116
117 void Graph::changeShape(const OperandIndex &ind, const ir::Shape &new_shape)
118 {
119   assert(_operands.exist(ind));
120   _operands.at(ind).info().shape(new_shape);
121 }
122
123 void Graph::addInput(const OperandIndex &ind, const std::string &name)
124 {
125   if (!name.empty())
126     _name_to_input.emplace(name, IOIndex{_inputs.size()});
127   _inputs.append(ind);
128 }
129
130 void Graph::addOutput(const OperandIndex &ind, const std::string &name)
131 {
132   if (!name.empty())
133     _name_to_output.emplace(name, IOIndex{_outputs.size()});
134   _outputs.append(ind);
135 }
136
137 IOIndex Graph::getInputIndex(const std::string &name) const
138 {
139   auto itr = _name_to_input.find(name);
140   return (itr == _name_to_input.end()) ? IOIndex{} : itr->second;
141 }
142
143 IOIndex Graph::getOutputIndex(const std::string &name) const
144 {
145   auto itr = _name_to_output.find(name);
146   return (itr == _name_to_output.end()) ? IOIndex{} : itr->second;
147 }
148
149 void Graph::verify(void) const
150 {
151   // Call graph verifications for the MODEL phase
152   {
153     // Except for edge consistency, the user might have been given a bad model
154     // so here it throws an execption rather than assertion.
155     if (!verifier::InputOutputChecker().verify(*this))
156       throw std::runtime_error{"One of model input and output operands does not exist."};
157     if (!verifier::DAGChecker().verify(*this))
158       throw std::runtime_error{"The graph is cyclic."};
159     assert(verifier::EdgeChecker().verify(*this));
160   }
161
162   // Check shape independent operation feature
163   // - Operand type
164   // - Shape independent parameter
165   OperationValidator{*this}();
166 }
167
168 void Graph::initializeUseDef()
169 {
170   operations().iterate([&](const OperationIndex &index, const IOperation &node) -> void {
171     auto outputs = node.getOutputs();
172     for (auto &&output : outputs | ir::Remove::UNDEFINED)
173     {
174       operands().at(output).setDef(index);
175     }
176
177     for (auto &&input : node.getInputs() | ir::Remove::UNDEFINED)
178     {
179       operands().at(input).insertUse(index);
180     }
181   });
182 }
183
184 std::vector<ir::OperationIndex> Graph::topolSortOperations() const
185 {
186   std::vector<ir::OperationIndex> ret;
187   util::Set<ir::OperationIndex> unvisited;
188   operations().iterate(
189     [&](const ir::OperationIndex &index, const ir::IOperation &) { unvisited.add(index); });
190
191   std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs =
192     [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void {
193     if (!unvisited.contains(index))
194       return;
195     unvisited.remove(index);
196
197     for (const auto &output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
198     {
199       const auto &operand = operands().at(output);
200       for (const auto &use : operand.getUses())
201       {
202         dfs(use, operations().at(use));
203       }
204     }
205     ret.push_back(index);
206   };
207   operations().iterate(dfs);
208
209   assert(unvisited.empty()); // All of the nodes must have been visited
210   // Reversing Postorder DFS result to make it sorted in topoligical order
211   std::reverse(ret.begin(), ret.end());
212   return ret;
213 }
214
215 } // namespace ir
216 } // namespace onert