Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / include / ir / NNPkg.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_IR_NNPKG_H__
18 #define __ONERT_IR_NNPKG_H__
19
20 #include <memory>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "ir/Index.h"
25 #include "ir/Model.h"
26
27 namespace onert
28 {
29 namespace ir
30 {
31
32 using IODesc = std::tuple<ModelIndex, SubgraphIndex, IOIndex>;
33
34 struct ModelEdge
35 {
36   IODesc from;
37   IODesc to;
38 };
39
40 struct ModelEdgeEqual
41 {
42   bool operator()(const onert::ir::ModelEdge &lhs, const onert::ir::ModelEdge &rhs) const
43   {
44     return lhs.from == rhs.from && lhs.to == rhs.to;
45   }
46 };
47
48 struct ModelEdgeHash
49 {
50   size_t operator()(const ::onert::ir::ModelEdge &edge) const noexcept
51   {
52     unsigned long long h1 = (std::get<0>(edge.from).value() << 24) |
53                             (std::get<1>(edge.from).value() << 16) | std::get<2>(edge.from).value();
54     unsigned long long h2 = (std::get<0>(edge.to).value() << 24) |
55                             (std::get<1>(edge.to).value() << 16) | std::get<2>(edge.to).value();
56     return h1 + h2;
57   }
58 };
59
60 inline std::ostream &operator<<(std::ostream &o, const IODesc &od)
61 {
62   o << std::get<0>(od).value() << ":" << std::get<1>(od).value() << ":" << std::get<2>(od).value();
63   return o;
64 }
65
66 using ModelEdgeSet = std::unordered_set<ir::ModelEdge, ir::ModelEdgeHash, ir::ModelEdgeEqual>;
67
68 /**
69  * @brief Struct to gather model I/O information in multimodel NN package
70  *        Model I/O will have role one of below
71  *        - Package input/output
72  *        - Edge's start/finish point between model
73  */
74 struct ModelEdges
75 {
76   std::vector<ir::IODesc> pkg_inputs;
77   std::vector<ir::IODesc> pkg_outputs;
78   ModelEdgeSet edges;
79 };
80
81 class NNPkg
82 {
83 public:
84   NNPkg() = default;
85   NNPkg(const NNPkg &obj) = default;
86   NNPkg(NNPkg &&) = default;
87   NNPkg &operator=(const NNPkg &) = default;
88   NNPkg &operator=(NNPkg &&) = default;
89   ~NNPkg() = default;
90
91   NNPkg(std::shared_ptr<Model> model) { _models[ModelIndex{0}] = model; }
92   std::shared_ptr<Model> primary_model() const { return _models.at(onert::ir::ModelIndex{0}); }
93
94   /**
95    * @brief Put model at index
96    *
97    * @param[in] model Model to be pushed
98    * @param[in] index Index where Model is to be pushed
99    */
100   void push(ModelIndex index, const std::shared_ptr<Model> &model) { _models[index] = model; }
101
102   /**
103    * @brief Get the count of model
104    *
105    * @return the count of models
106    */
107   size_t model_count() const { return _models.size(); }
108
109   /**
110    * @brief Get model at index
111    *
112    * @param[in] index Index of the model to be returned
113    * @return Model at index
114    */
115   const std::shared_ptr<Model> &model(const ModelIndex &index) const { return _models.at(index); }
116   /**
117    * @brief Get model at index
118    *
119    * @param[in] index Index of the model to be returned
120    * @return Model at index
121    */
122   std::shared_ptr<Model> &model(const ModelIndex &index) { return _models.at(index); }
123
124   /**
125    * @brief Get pkg_input at index
126    *
127    * @param[in] index Index of pkg_input to be returned
128    * @return IODesc at index
129    */
130   const IODesc &input(uint32_t index) const { return _edges.pkg_inputs[index]; }
131   /**
132    * @brief Get pkg_input at index
133    *
134    * @param[in] index Index of pkg_input to be returned
135    * @return IODesc at index
136    */
137   IODesc &input(uint32_t index) { return _edges.pkg_inputs[index]; }
138   /**
139    * @brief Add input at the end
140    *
141    * @param[in] input Input IODesc to be pushed
142    */
143   void addInput(const IODesc &input) { _edges.pkg_inputs.push_back(input); }
144
145   /**
146    * @brief Get pkg_output at index
147    *
148    * @param[in] index Index of pkg_output to be returned
149    * @return IODesc at index
150    */
151   const IODesc &output(uint32_t index) const { return _edges.pkg_outputs[index]; }
152   /**
153    * @brief Get pkg_output at index
154    *
155    * @param[in] index Index of pkg_output to be returned
156    * @return IODesc at index
157    */
158   IODesc &output(uint32_t index) { return _edges.pkg_outputs[index]; }
159   /**
160    * @brief Add output at the end
161    *
162    * @param[in] output Output IODesc to be pushed
163    */
164   void addOutput(const IODesc &output) { _edges.pkg_outputs.push_back(output); }
165
166   /**
167    * @brief Add edge between models at the end
168    *
169    * @param[in] from from IODesc
170    * @param[in] to   to IODesc
171    */
172   void addEdge(const IODesc &from, const IODesc &to)
173   {
174     std::cout << from << " -> " << to << std::endl;
175     _edges.edges.insert(ModelEdge{from, to});
176   }
177   /**
178    * @brief   Get model edge set
179    * @return  Edge set reference
180    */
181   const ModelEdges &model_edges() { return _edges; }
182
183   /**
184    * @brief Verify NNPkg
185    *
186    */
187   void verify(void)
188   {
189     // Verify edges information
190     //
191     // Only duplicates of nnpkg output and Edge `from` are possible.
192     // | Whether duplicates are possible   | Edge `to` | Edge `from` |
193     // | nnpkg input  (input of subgraph)  | X (*1)    | X (*2)      |
194     // | nnpkg output (output of subgraph) | X (*2)    | O           |
195     // *1. The subjects who determine values of each buffer are different.
196     //    - nnpkg input : user input
197     //    - Edge `to`   : output of another subgraph
198     // *2. `IOIndex` of inputs and outputs of subgraph is distinct.
199     //
200     for (const auto &edge : _edges.edges)
201     {
202       if (std::find(_edges.pkg_inputs.begin(), _edges.pkg_inputs.end(), edge.to) !=
203           _edges.pkg_inputs.end())
204       {
205         throw std::runtime_error{
206           "Invalid edge information. NNPkg inputs and Edge `to` cannot be duplicated"};
207       }
208     }
209   }
210
211   // TODO Find better way to handle single model NNPackage and multi model NNPackage on inputSize(),
212   //      outputSize(), inputInfo(), outputInfo()
213
214   /**
215    * @brief   Get model input size
216    */
217   uint32_t inputSize() const
218   {
219     return _models.size() == 1 ? primary_model()->primary_subgraph()->getInputs().size()
220                                : _edges.pkg_inputs.size();
221   }
222
223   /**
224    * @brief   Get model output size
225    */
226   uint32_t outputSize() const
227   {
228     return _models.size() == 1 ? primary_model()->primary_subgraph()->getOutputs().size()
229                                : _edges.pkg_outputs.size();
230   }
231
232   /**
233    * @brief   Get model input info
234    */
235   const OperandInfo &inputInfo(uint32_t index) const
236   {
237     if (_models.size() == 1)
238     {
239       auto const graph = primary_model()->primary_subgraph();
240       auto const operand_index = graph->getInputs().at(index);
241       return graph->operands().at(operand_index).info();
242     }
243
244     auto const &desc = input(index);
245     auto const graph = model(std::get<ModelIndex>(desc))->primary_subgraph();
246     auto const operand_index = graph->getInputs().at(std::get<IOIndex>(desc).value());
247     return graph->operands().at(operand_index).info();
248   }
249
250   /**
251    * @brief   Get model output info
252    */
253   const OperandInfo &outputInfo(uint32_t index) const
254   {
255     if (_models.size() == 1)
256     {
257       auto const graph = primary_model()->primary_subgraph();
258       auto const operand_index = graph->getOutputs().at(index);
259       return graph->operands().at(operand_index).info();
260     }
261
262     auto const &desc = output(index);
263     auto const graph = model(std::get<ModelIndex>(desc))->primary_subgraph();
264     auto const operand_index = graph->getOutputs().at(std::get<IOIndex>(desc).value());
265     return graph->operands().at(operand_index).info();
266   }
267
268   void changeInputShape(uint32_t index, const ir::Shape &new_shape)
269   {
270     if (_models.size() == 1)
271     {
272       auto graph = primary_model()->primary_subgraph();
273       auto const operand_index = graph->getInputs().at(index);
274       graph->changeShape(operand_index, new_shape);
275       return;
276     }
277
278     auto const &desc = input(index);
279     auto graph = model(std::get<ModelIndex>(desc))->primary_subgraph();
280     auto const operand_index = graph->getInputs().at(std::get<IOIndex>(desc).value());
281     graph->changeShape(operand_index, new_shape);
282   }
283
284   /**
285    * @brief Replace model
286    *
287    * @param[in] model Model to be replaced
288    *
289    * TODO:  Support multiple models
290    */
291   void replaceModel(std::shared_ptr<Model> model) { _models[ModelIndex{0}] = model; }
292
293   // TODO: Add iterate() or getter for edges
294
295 private:
296   std::unordered_map<ModelIndex, std::shared_ptr<Model>> _models;
297   ModelEdges _edges;
298 };
299
300 } // namespace ir
301 } // namespace onert
302
303 namespace std
304 {
305
306 template <> struct hash<onert::ir::IODesc>
307 {
308   size_t operator()(const ::onert::ir::IODesc &iodesc) const noexcept
309   {
310     return (std::get<0>(iodesc).value() << 24) | (std::get<1>(iodesc).value() << 16) |
311            std::get<2>(iodesc).value();
312   }
313 };
314
315 } // namespace std
316
317 #endif // __ONERT_IR_NNPKG_H__