1 // SPDX-License-Identifier: Apache-2.0
3 * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
7 * @see https://github.com/nnstreamer/nntrainer
8 * @author Parichay Kapoor <pk.kapoor@samsung.com>
9 * @bug No known bugs except for NYI items
10 * @brief This is the graph node interface for c++ API
13 #ifndef __GRAPH_NODE_H__
14 #define __GRAPH_NODE_H__
24 * @class Layer Base class for the graph node
25 * @brief Base class for all layers
30 * @brief Provides the time/order at which the node will be executed.
31 * @details This time will be finalized once the graph has been calculated.
32 * Each element indicates the orders with which the below operations
33 * for each node are executed:
38 * One constraint is that they must be sorted in ascending order.
39 * This ensures that the operations are executed in the order of their
42 typedef std::tuple<unsigned int, unsigned int, unsigned int, unsigned int>
46 * @brief Destructor of Layer Class
48 virtual ~GraphNode() = default;
51 * @brief Get the Name of the underlying object
53 * @return std::string Name of the underlying object
54 * @note name of each node in the graph must be unique
56 virtual const std::string getName() const noexcept = 0;
59 * @brief Set the Name of the underlying object
61 * @param[in] std::string Name for the underlying object
62 * @note name of each node in the graph must be unique, and caller must ensure
65 virtual void setName(const std::string &name) = 0;
68 * @brief Get the Type of the underlying object
70 * @return const std::string type representation
72 virtual const std::string getType() const = 0;
75 * @brief Get the trainable parameter
77 * @return bool true / false
79 virtual bool getTrainable() const = 0;
82 * @brief Get the input connections for this node
84 * @return list of name of the nodes which form input connections
86 virtual const std::vector<std::string> getInputConnections() const = 0;
89 * @brief Get the output connections for this node
91 * @return list of name of the nodes which form output connections
93 virtual const std::vector<std::string> getOutputConnections() const = 0;
96 * @brief get the execution order/location of this node
98 * @retval the execution order/location of this node
99 * @details The two values represents the value for forward and backward
102 virtual ExecutionOrder getExecutionOrder() const = 0;
105 * @brief set the execution order/location of this node
107 * @param exec_order the execution order/location of this node
108 * @details The two values represents the value for forward and backward
111 virtual void setExecutionOrder(ExecutionOrder exec_order_) = 0;
115 * @brief Iterator for GraphNode which return const
116 * std::shared_ptr<LayerNodeType> object upon realize
118 * @note This does not include the complete list of required functions. Add
121 * @note GraphNodeType is to enable for both GraphNode and const GraphNode
123 template <typename LayerNodeType, typename GraphNodeType>
124 class GraphNodeIterator
125 : public std::iterator<std::random_access_iterator_tag, GraphNodeType> {
126 GraphNodeType *p; /** underlying object of GraphNode */
130 * @brief iterator_traits types definition
132 * @note these are not requried to be explicitly defined now, but maintains
133 * forward compatibility for c++17 and later
135 * @note value_type, pointer and reference are different from standard
138 typedef const std::shared_ptr<LayerNodeType> value_type;
139 typedef std::random_access_iterator_tag iterator_category;
140 typedef std::ptrdiff_t difference_type;
141 typedef const std::shared_ptr<LayerNodeType> *pointer;
142 typedef const std::shared_ptr<LayerNodeType> &reference;
145 * @brief Construct a new Graph Node Iterator object
147 * @param x underlying object of GraphNode
149 GraphNodeIterator(GraphNodeType *x) : p(x) {}
152 * @brief reference operator
155 * @note this is different from standard iterator
157 value_type operator*() const {
158 return std::static_pointer_cast<LayerNodeType>(*p);
162 * @brief pointer operator
165 * @note this is different from standard iterator
167 value_type operator->() const {
168 return std::static_pointer_cast<LayerNodeType>(*p);
172 * @brief == comparison operator override
174 * @param lhs iterator lhs
175 * @param rhs iterator rhs
176 * @retval true if match
177 * @retval false if mismatch
179 friend bool operator==(GraphNodeIterator const &lhs,
180 GraphNodeIterator const &rhs) {
181 return lhs.p == rhs.p;
185 * @brief != comparison operator override
187 * @param lhs iterator lhs
188 * @param rhs iterator rhs
189 * @retval true if mismatch
190 * @retval false if match
192 friend bool operator!=(GraphNodeIterator const &lhs,
193 GraphNodeIterator const &rhs) {
194 return lhs.p != rhs.p;
198 * @brief override for ++ operator
200 * @return GraphNodeIterator&
202 GraphNodeIterator &operator++() {
208 * @brief override for operator++
210 * @return GraphNodeIterator
212 GraphNodeIterator operator++(int) {
213 GraphNodeIterator temp(p);
219 * @brief override for -- operator
221 * @return GraphNodeIterator&
223 GraphNodeIterator &operator--() {
229 * @brief override for operator--
231 * @return GraphNodeIterator
233 GraphNodeIterator operator--(int) {
234 GraphNodeIterator temp(p);
240 * @brief override for subtract operator
242 * @param offset offset to subtract
243 * @return GraphNodeIterator
245 GraphNodeIterator operator-(const difference_type offset) const {
246 return GraphNodeIterator(p - offset);
250 * @brief override for subtract operator
252 * @param other iterator to subtract
253 * @return difference_type
255 difference_type operator-(const GraphNodeIterator &other) const {
260 * @brief override for subtract and return result operator
262 * @param offset offset to subtract
263 * @return GraphNodeIterator&
265 GraphNodeIterator &operator-=(const difference_type offset) {
271 * @brief override for add operator
273 * @param offset offset to add
274 * @return GraphNodeIterator
276 GraphNodeIterator operator+(const difference_type offset) const {
277 return GraphNodeIterator(p + offset);
281 * @brief override for add and return result operator
283 * @param offset offset to add
284 * @return GraphNodeIterator&
286 GraphNodeIterator &operator+=(const difference_type offset) {
293 * @brief Reverse Iterator for GraphNode which return LayerNode object upon
296 * @note This just extends GraphNodeIterator and is limited by its
299 template <typename T_iterator>
300 class GraphNodeReverseIterator : public std::reverse_iterator<T_iterator> {
303 * @brief Construct a new Graph Node Reverse Iterator object
305 * @param iter Iterator
307 explicit GraphNodeReverseIterator(T_iterator iter) :
308 std::reverse_iterator<T_iterator>(iter) {}
311 * @brief reference operator
313 * @return T_iterator::value_type
314 * @note this is different from standard iterator
316 typename T_iterator::value_type operator*() const {
317 auto temp = std::reverse_iterator<T_iterator>::current - 1;
322 * @brief pointer operator
324 * @return T_iterator::value_type
325 * @note this is different from standard iterator
327 typename T_iterator::value_type operator->() const {
328 auto temp = std::reverse_iterator<T_iterator>::current - 1;
334 * @brief Iterators to traverse the graph
336 template <class LayerNodeType>
337 using graph_const_iterator =
338 GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>;
341 * @brief Iterators to traverse the graph
343 template <class LayerNodeType>
344 using graph_const_reverse_iterator = GraphNodeReverseIterator<
345 GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>>;
347 } // namespace nntrainer
348 #endif // __GRAPH_NODE_H__