Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / include / ir / Model.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_IR_MODEL_H__
18 #define __ONERT_IR_MODEL_H__
19
20 #include <memory>
21 #include <unordered_map>
22
23 #include "ir/IGraph.h"
24 #include "ir/Index.h"
25 #include "util/ObjectManager.h"
26
27 namespace onert
28 {
29 namespace backend
30 {
31 namespace custom
32 {
33 class IKernelBuilder;
34 } // namespace custom
35 } // namespace backend
36 } // namespace onert
37
38 namespace onert
39 {
40 namespace ir
41 {
42
43 class Model
44 {
45 public:
46   Model() = default;
47   Model(const Model &obj) = default;
48   Model(Model &&) = default;
49   Model &operator=(const Model &) = default;
50   Model &operator=(Model &&) = default;
51   ~Model() = default;
52
53   /**
54    * @brief Put subgraph in the container with a new Index for that
55    *
56    * @param[in] subg Subgraph to be pushed
57    * @param[in] index Index of subgraph to be pushed
58    * @return Created
59    */
60   void push(SubgraphIndex index, const std::shared_ptr<IGraph> &subg) { _subgraphs[index] = subg; }
61
62   /**
63    * @brief Remove the subgraph that is associated with the given index
64    *
65    * @param[in] index Index of the subgraph to be removed
66    * @return N/A
67    */
68   void remove(const SubgraphIndex &index) { _subgraphs.erase(index); }
69
70   /**
71    * @brief Get the subgraph that is associated with the given index
72    *
73    * @param[in] index Index of the subgraph to be returned
74    * @return IGraph
75    */
76   const std::shared_ptr<IGraph> &at(const SubgraphIndex &index) const
77   {
78     return _subgraphs.at(index);
79   }
80   /**
81    * @brief Get the subgraph that is associated with the given index
82    *
83    * @param[in] index Index of the subgraph to be returned
84    * @return IGraph
85    */
86   std::shared_ptr<IGraph> &at(const SubgraphIndex &index) { return _subgraphs.at(index); }
87
88   /**
89    * @brief Get the subgraph that is associated with the given index
90    *
91    * @param[in] index Index of the subgraph to be returned
92    * @return true if such entry exists otherwise false
93    */
94   bool exist(const SubgraphIndex &index) const
95   {
96     auto it = _subgraphs.find(index);
97     return it != _subgraphs.end();
98   }
99
100   /**
101    * @brief Iterate over the container with given function
102    *
103    * @param[in] fn Function to be run for every container entry
104    * @return N/A
105    */
106   void iterate(const std::function<void(const SubgraphIndex &, const IGraph &)> &fn) const
107   {
108     for (const auto &e : _subgraphs)
109     {
110       fn(e.first, *e.second);
111     }
112   }
113
114   /**
115    * @brief Iterate over the container with given function
116    *
117    * @param[in] fn Function to be run for every container entry
118    * @return N/A
119    */
120   void iterate(const std::function<void(const SubgraphIndex &, IGraph &)> &fn)
121   {
122     for (const auto &e : _subgraphs)
123     {
124       fn(e.first, *e.second);
125     }
126   }
127
128   /**
129    * @brief Get count of Subgraphs
130    *
131    * @return count of Subgraphs
132    */
133   size_t subgraphs_count() const { return _subgraphs.size(); }
134
135   /**
136    * @brief Return the primary subgraph
137    *
138    * @return std::shared_ptr<IGraph> Primary subgraph
139    */
140   std::shared_ptr<IGraph> primary_subgraph() const { return _subgraphs.at(SubgraphIndex{0}); }
141
142   /**
143    * @brief Return whether the model has only typename Graph
144    *
145    * @tparam Graph Type that inherits from IGraph
146    *
147    * @return true if the model has only typename Graph, otherwise false
148    */
149   template <typename Graph, std::enable_if_t<std::is_base_of<IGraph, Graph>::value, bool> = true>
150   bool hasOnly()
151   {
152     for (const auto &e : _subgraphs)
153     {
154       if (std::dynamic_pointer_cast<Graph>(e.second) == nullptr)
155         return false;
156     }
157     return true;
158   }
159
160 private:
161   std::unordered_map<SubgraphIndex, std::shared_ptr<IGraph>> _subgraphs;
162
163   // Custom operations support
164 public:
165   void
166   bindKernelBuilder(const std::shared_ptr<onert::backend::custom::IKernelBuilder> &kernel_builder)
167   {
168     _kernel_builder = kernel_builder;
169   }
170
171   const std::shared_ptr<backend::custom::IKernelBuilder> &getKernelBuilder() const
172   {
173     return _kernel_builder;
174   }
175
176 private:
177   std::shared_ptr<backend::custom::IKernelBuilder> _kernel_builder;
178 };
179
180 } // namespace ir
181 } // namespace onert
182
183 #endif // __ONERT_IR_MODEL_H__