1 // SPDX-License-Identifier: Apache-2.0
3 * Copyright (C) 2023 DongHak Park <donghak.park@samsung.com>
5 * @file flatbuffer_opnode.h
6 * @date 10 February 2023
7 * @brief NNTrainer flatbuffer opnode
8 * @see https://github.com/nnstreamer/nntrainer
9 * @author Donghak Park <donghak.park@samsung.com>
10 * @bug No known bugs except for NYI items
13 #ifndef __FLATBUFFER_OPNODE_H__
14 #define __FLATBUFFER_OPNODE_H__
20 #include <nntrainer_schema_generated.h>
26 class RunLayerContext;
29 * @brief FlatBufferOpNode class
32 class FlatBufferOpNode {
34 using Variables = std::vector<const Tensor *>;
37 * @brief Construct a new Flat Buffer Op Node object
43 * @brief Set the Layer Node object
45 * @param layer layer node
47 void setLayerNode(const LayerNode &layer);
50 * @brief Set the Op Type object
54 void setOpType(nntr::BuiltinOperator op_type_) { op_type = op_type_; }
57 * @brief Set the Builtin Options object
59 * @param builtin_option_type_ builtin option type
60 * @param builtin_ops_ flatbuffer offset of builtin ops
62 void setBuiltinOptions(nntr::BuiltinOptions builtin_option_type_,
63 const flatbuffers::Offset<void> &builtin_ops_);
66 * @brief Get the Inputs object
68 * @return Variables& inputs
70 Variables &getInputs() { return inputs; }
73 * @brief Get the Inputs object
75 * @return const Variables& inputs
77 const Variables &getInputs() const { return inputs; }
80 * @brief Get the Weights object
82 * @return Variables& weights
84 Variables &getWeights() { return weights; }
87 * @brief Get the Weights object
89 * @return const Variables& weights
91 const Variables &getWeights() const { return weights; }
94 * @brief Get the Outputs object
96 * @return Variables& outputs
98 Variables &getOutputs() { return outputs; }
101 * @brief Get the Outputs object
103 * @return const Variables& outputs
105 const Variables &getOutputs() const { return outputs; }
108 * @brief check if the node is model input
110 * @return true if op node is model input
111 * @return false if op node is not model input
113 bool isInputNode() const { return is_input; }
116 * @brief check if the node is model output
118 * @return true if op node is model output
119 * @return false if op node is not model output
121 bool isOutputNode() const { return is_output; }
124 * @brief check if the node is virtual node
126 * @return true if this op node is virtual node
127 * @return false if this op node is not virtual node
129 bool isVirtualNode() const { return is_virtual; }
132 * @brief Get the Op Type object
134 * @return const nntr::BuiltinOperator
136 const nntr::BuiltinOperator getOpType() const { return op_type; }
139 * @brief Get the Option Type object
141 * @return const nntr::BuiltinOptions
143 const nntr::BuiltinOptions getOptionType() const {
144 return builtin_option_type;
148 * @brief Get the Builtin Ops object
149 * @param f Flatbuffer builder
151 * @return flatbuffers::Offset<void>
153 flatbuffers::Offset<void> getBuiltinOps() const;
156 * @brief Get the Input Nodes object
158 * @return const std::vector<FlatBufferOpNode *> &input_nodes
160 const std::vector<FlatBufferOpNode *> &getInputNodes() const {
167 * @param value value to set
169 void arity(size_t value) { input_nodes.resize(value); }
174 * @return const unsigned input_nodes size
176 const unsigned arity() const { return input_nodes.size(); }
179 * @brief Set the Arg object
181 * @param index argument index to set
182 * @param node the node to be argument
184 void setArg(size_t index, FlatBufferOpNode *node) {
185 input_nodes.at(index) = node;
189 * @brief Get the Arg object
191 * @param index argument index to get
192 * @return FlatBufferOpNode *input_nodes.at(index)
194 FlatBufferOpNode *arg(size_t index) const { return input_nodes.at(index); }
197 Variables inputs; /**< input variables */
198 Variables outputs; /**< output variables */
199 Variables weights; /**< weight variables */
200 std::vector<FlatBufferOpNode *> input_nodes; /**< input nodes */
202 bool is_input; /**< true if given input is model input */
203 bool is_output; /**< true if given output is model output */
204 bool is_virtual; /**< true if given node is virtual */
206 nntr::BuiltinOperator op_type; /**< op type */
207 nntr::BuiltinOptions builtin_option_type; /**< builtin option type */
208 flatbuffers::Offset<void> builtin_ops; /**< builtin ops */
211 } // namespace nntrainer
213 #endif // __FLATBUFFER_OPNODE_H__