[Flatbuffer] Add flatbuffer_opnode
[platform/core/ml/nntrainer.git] / nntrainer / compiler / flatbuffer_opnode.h
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2023 DongHak Park <donghak.park@samsung.com>
4  *
5  * @file   flatbuffer_opnode.h
6  * @date   10 February 2023
7  * @brief  NNTrainer flatbuffer opnode
8  * @see    https://github.com/nnstreamer/nntrainer
9  * @author Donghak Park <donghak.park@samsung.com>
10  * @bug    No known bugs except for NYI items
11  */
12
13 #ifndef __FLATBUFFER_OPNODE_H__
14 #define __FLATBUFFER_OPNODE_H__
15
16 #include <functional>
17 #include <utility>
18 #include <vector>
19
20 #include <nntrainer_schema_generated.h>
21 #include <tensor.h>
22
23 namespace nntrainer {
24
25 class LayerNode;
26 class RunLayerContext;
27
28 /**
29  * @brief FlatBufferOpNode class
30  *
31  */
32 class FlatBufferOpNode {
33 public:
34   using Variables = std::vector<const Tensor *>;
35
36   /**
37    * @brief Construct a new Flat Buffer Op Node object
38    *
39    */
40   FlatBufferOpNode();
41
42   /**
43    * @brief Set the Layer Node object
44    *
45    * @param layer layer node
46    */
47   void setLayerNode(const LayerNode &layer);
48
49   /**
50    * @brief Set the Op Type object
51    *
52    * @param op_type_
53    */
54   void setOpType(nntr::BuiltinOperator op_type_) { op_type = op_type_; }
55
56   /**
57    * @brief Set the Builtin Options object
58    *
59    * @param builtin_option_type_ builtin option type
60    * @param builtin_ops_ flatbuffer offset of builtin ops
61    */
62   void setBuiltinOptions(nntr::BuiltinOptions builtin_option_type_,
63                          const flatbuffers::Offset<void> &builtin_ops_);
64
65   /**
66    * @brief Get the Inputs object
67    *
68    * @return Variables& inputs
69    */
70   Variables &getInputs() { return inputs; }
71
72   /**
73    * @brief Get the Inputs object
74    *
75    * @return const Variables& inputs
76    */
77   const Variables &getInputs() const { return inputs; }
78
79   /**
80    * @brief Get the Weights object
81    *
82    * @return Variables& weights
83    */
84   Variables &getWeights() { return weights; }
85
86   /**
87    * @brief Get the Weights object
88    *
89    * @return const Variables& weights
90    */
91   const Variables &getWeights() const { return weights; }
92
93   /**
94    * @brief Get the Outputs object
95    *
96    * @return Variables& outputs
97    */
98   Variables &getOutputs() { return outputs; }
99
100   /**
101    * @brief Get the Outputs object
102    *
103    * @return const Variables& outputs
104    */
105   const Variables &getOutputs() const { return outputs; }
106
107   /**
108    * @brief check if the node is model input
109    *
110    * @return true if op node is model input
111    * @return false  if op node is not model input
112    */
113   bool isInputNode() const { return is_input; }
114
115   /**
116    * @brief check if the node is model output
117    *
118    * @return true if op node is model output
119    * @return false if op node is not model output
120    */
121   bool isOutputNode() const { return is_output; }
122
123   /**
124    * @brief check if the node is virtual node
125    *
126    * @return true if this op node is virtual node
127    * @return false if this op node is not virtual node
128    */
129   bool isVirtualNode() const { return is_virtual; }
130
131   /**
132    * @brief Get the Op Type object
133    *
134    * @return const nntr::BuiltinOperator
135    */
136   const nntr::BuiltinOperator getOpType() const { return op_type; }
137
138   /**
139    * @brief Get the Option Type object
140    *
141    * @return const nntr::BuiltinOptions
142    */
143   const nntr::BuiltinOptions getOptionType() const {
144     return builtin_option_type;
145   }
146
147   /**
148    * @brief Get the Builtin Ops object
149    * @param f Flatbuffer builder
150    *
151    * @return flatbuffers::Offset<void>
152    */
153   flatbuffers::Offset<void> getBuiltinOps() const;
154
155   /**
156    * @brief Get the Input Nodes object
157    *
158    * @return const std::vector<FlatBufferOpNode *> &input_nodes
159    */
160   const std::vector<FlatBufferOpNode *> &getInputNodes() const {
161     return input_nodes;
162   }
163
164   /**
165    * @brief Set arity
166    *
167    * @param value value to set
168    */
169   void arity(size_t value) { input_nodes.resize(value); }
170
171   /**
172    * @brief Get arity
173    *
174    * @return const unsigned input_nodes size
175    */
176   const unsigned arity() const { return input_nodes.size(); }
177
178   /**
179    * @brief Set the Arg object
180    *
181    * @param index argument index to set
182    * @param node the node to be argument
183    */
184   void setArg(size_t index, FlatBufferOpNode *node) {
185     input_nodes.at(index) = node;
186   }
187
188   /**
189    * @brief Get the Arg object
190    *
191    * @param index argument index to get
192    * @return FlatBufferOpNode *input_nodes.at(index)
193    */
194   FlatBufferOpNode *arg(size_t index) const { return input_nodes.at(index); }
195
196 private:
197   Variables inputs;                            /**< input variables */
198   Variables outputs;                           /**< output variables */
199   Variables weights;                           /**< weight variables */
200   std::vector<FlatBufferOpNode *> input_nodes; /**< input nodes */
201
202   bool is_input;   /**< true if given input is model input */
203   bool is_output;  /**< true if given output is model output */
204   bool is_virtual; /**< true if given node is virtual */
205
206   nntr::BuiltinOperator op_type;            /**< op type */
207   nntr::BuiltinOptions builtin_option_type; /**< builtin option type */
208   flatbuffers::Offset<void> builtin_ops;    /**< builtin ops */
209 };
210
211 } // namespace nntrainer
212
213 #endif // __FLATBUFFER_OPNODE_H__