28cf4137d512a80acf37ab608f49e53854488c38
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / Graph.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ir/Graph.h"
18
19 #include "OperationValidator.h"
20 #include "verifier/Verifier.h"
21
22 #include "util/Set.h"
23
24 namespace onert
25 {
26 namespace ir
27 {
28
29 Graph::Graph() = default;
30
31 Graph::Graph(const Graph &) = default;
32
33 Graph::~Graph(void) = default;
34
35 OperandIndex Graph::addOperand(const Shape &shape, const TypeInfo &type)
36 {
37   return _operands.emplace(shape, type);
38 }
39
40 OperandIndex Graph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand)
41 {
42   return _operands.push(std::move(operand), index);
43 }
44
45 bool Graph::checkOperandsForOperation(const Operation &operation)
46 {
47   auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
48   auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
49   for (auto input : inputs)
50     if (!operands().exist(input))
51       return false;
52   for (auto input : outputs)
53     if (!operands().exist(input))
54       return false;
55   return true;
56 }
57
58 void Graph::linkOperandToOperation(OperationIndex index, const Operation &operation)
59 {
60   auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
61   auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
62
63   for (auto input : inputs)
64     operands().at(input).insertUse(index);
65   for (auto output : outputs)
66     operands().at(output).setDef(index);
67 }
68
69 OperationIndex Graph::addOperation(std::unique_ptr<Operation> &&operation)
70 {
71   const Operation &op_ref = *operation;
72   if (!checkOperandsForOperation(op_ref))
73     return OperationIndex{};
74   auto ind = _operations.push(std::move(operation));
75   if (ind.valid())
76     linkOperandToOperation(ind, op_ref);
77   return ind;
78 }
79
80 OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<Operation> &&operation)
81 {
82   const Operation &op_ref = *operation;
83   if (!checkOperandsForOperation(op_ref))
84     return OperationIndex{};
85   auto ind_gen = _operations.push(std::move(operation), index);
86   if (ind_gen.valid())
87   {
88     assert(ind_gen == index);
89     linkOperandToOperation(index, op_ref);
90   }
91   return index;
92 }
93
94 void Graph::setOperandValue(const OperandIndex &ind, std::shared_ptr<Data> data)
95 {
96   assert(_operands.exist(ind));
97   _operands.at(ind).data(std::move(data));
98 }
99
100 void Graph::addInput(const OperandIndex &ind, const std::string &name)
101 {
102   if (!name.empty())
103     _name_to_input.emplace(name, IOIndex{_inputs.size()});
104   _inputs.append(ind);
105 }
106
107 void Graph::addOutput(const OperandIndex &ind, const std::string &name)
108 {
109   if (!name.empty())
110     _name_to_output.emplace(name, IOIndex{_outputs.size()});
111   _outputs.append(ind);
112 }
113
114 IOIndex Graph::getInputIndex(const std::string &name) const
115 {
116   auto itr = _name_to_input.find(name);
117   return (itr == _name_to_input.end()) ? IOIndex{} : itr->second;
118 }
119
120 IOIndex Graph::getOutputIndex(const std::string &name) const
121 {
122   auto itr = _name_to_output.find(name);
123   return (itr == _name_to_output.end()) ? IOIndex{} : itr->second;
124 }
125
126 void Graph::verify(void)
127 {
128   // Call graph verifications for the MODEL phase
129   {
130     // Except for edge consistency, the user might have been given a bad model
131     // so here it throws an execption rather than assertion.
132     if (!verifier::InputOutputChecker().verify(*this))
133       throw std::runtime_error{"One of model input and output operands does not exist."};
134     if (!verifier::DAGChecker().verify(*this))
135       throw std::runtime_error{"The graph is cyclic."};
136     assert(verifier::EdgeChecker().verify(*this));
137   }
138
139   // Check shape independent operation feature
140   // - Operand type
141   // - Shape independent parameter
142   OperationValidator{*this}();
143 }
144
145 void Graph::initializeUseDef()
146 {
147   operations().iterate([&](const OperationIndex &index, const Operation &node) -> void {
148     auto outputs = node.getOutputs();
149     for (auto output : outputs | ir::Remove::UNDEFINED)
150     {
151       operands().at(output).setDef(index);
152     }
153
154     for (auto input : node.getInputs() | ir::Remove::UNDEFINED)
155     {
156       operands().at(input).insertUse(index);
157     }
158   });
159 }
160
161 std::vector<ir::OperationIndex> Graph::topolSortOperations() const
162 {
163   std::vector<ir::OperationIndex> ret;
164   util::Set<ir::OperationIndex> unvisited;
165   operations().iterate(
166     [&](const ir::OperationIndex &index, const ir::Operation &) { unvisited.add(index); });
167
168   std::function<void(const ir::OperationIndex &, const ir::Operation &)> dfs =
169     [&](const ir::OperationIndex &index, const ir::Operation &op) -> void {
170     if (!unvisited.contains(index))
171       return;
172     unvisited.remove(index);
173
174     for (const auto output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
175     {
176       const auto &operand = operands().at(output);
177       for (const auto &use : operand.getUses())
178       {
179         dfs(use, operations().at(use));
180       }
181     }
182     ret.push_back(index);
183   };
184   operations().iterate(dfs);
185
186   assert(unvisited.empty()); // All of the nodes must have been visited
187   // Reversing Postorder DFS result to make it sorted in topoligical order
188   std::reverse(ret.begin(), ret.end());
189   return ret;
190 }
191
192 } // namespace ir
193 } // namespace onert