22cee1fb2c1cad4d87c2df62a9fb26be0695c719
[platform/core/ml/nntrainer.git] / nntrainer / graph / graph_node.h
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4  *
5  * @file   graph_node.h
6  * @date   1 April 2021
7  * @see    https://github.com/nnstreamer/nntrainer
8  * @author Parichay Kapoor <pk.kapoor@samsung.com>
9  * @bug    No known bugs except for NYI items
10  * @brief  This is the graph node interface for c++ API
11  */
12
13 #ifndef __GRAPH_NODE_H__
14 #define __GRAPH_NODE_H__
15
16 #include <iterator>
17 #include <memory>
18 #include <string>
19 #include <vector>
20
21 namespace nntrainer {
22
23 /**
24  * @class   Layer Base class for the graph node
25  * @brief   Base class for all layers
26  */
27 class GraphNode {
28 public:
29   /**
30    * @brief Provides the time/order at which the node will be executed.
31    * @details This time will be finalized once the graph has been calculated.
32    * Each element indicates the orders with which the below operations
33    * for each node are executed:
34    * 1. Forwarding
35    * 2. calcGradient
36    * 3. calcDerivative
37    * 4. ApplyGradient
38    * One constraint is that they must be sorted in ascending order.
39    * This ensures that the operations are executed in the order of their
40    * listing.
41    */
42   typedef std::tuple<unsigned int, unsigned int, unsigned int, unsigned int> ExecutionOrder;
43
44   /**
45    * @brief     Destructor of Layer Class
46    */
47   virtual ~GraphNode() = default;
48
49   /**
50    * @brief     Get the Name of the underlying object
51    *
52    * @return std::string Name of the underlying object
53    * @note name of each node in the graph must be unique
54    */
55   virtual const std::string getName() const noexcept = 0;
56
57   /**
58    * @brief     Set the Name of the underlying object
59    *
60    * @param[in] std::string Name for the underlying object
61    * @note name of each node in the graph must be unique, and caller must ensure
62    * that
63    */
64   virtual void setName(const std::string &name) = 0;
65
66   /**
67    * @brief     Get the Type of the underlying object
68    *
69    * @return const std::string type representation
70    */
71   virtual const std::string getType() const = 0;
72
73   /**
74    * @brief     Get the input connections for this node
75    *
76    * @return list of name of the nodes which form input connections
77    */
78   virtual const std::vector<std::string> getInputConnections() const = 0;
79
80   /**
81    * @brief     Get the output connections for this node
82    *
83    * @return list of name of the nodes which form output connections
84    */
85   virtual const std::vector<std::string> getOutputConnections() const = 0;
86
87   /**
88    * @brief     get the execution order/location of this node
89    *
90    * @retval    the execution order/location of this node
91    * @details   The two values represents the value for forward and backward
92    * respectively
93    */
94   virtual ExecutionOrder getExecutionOrder() const = 0;
95
96   /**
97    * @brief     set the execution order/location of this node
98    *
99    * @param     exec_order the execution order/location of this node
100    * @details   The two values represents the value for forward and backward
101    * respectively
102    */
103   virtual void setExecutionOrder(ExecutionOrder exec_order_) = 0;
104 };
105
106 /**
107  * @brief   Iterator for GraphNode which return const
108  * std::shared_ptr<LayerNodeType> object upon realize
109  *
110  * @note    This does not include the complete list of required functions. Add
111  * them as per need.
112  *
113  * @note    GraphNodeType is to enable for both GraphNode and const GraphNode
114  */
115 template <typename LayerNodeType, typename GraphNodeType>
116 class GraphNodeIterator
117   : public std::iterator<std::random_access_iterator_tag, GraphNodeType> {
118   GraphNodeType *p; /** underlying object of GraphNode */
119
120 public:
121   /**
122    * @brief   iterator_traits types definition
123    *
124    * @note    these are not requried to be explicitly defined now, but maintains
125    *          forward compatibility for c++17 and later
126    *
127    * @note    value_type, pointer and reference are different from standard
128    * iterator
129    */
130   typedef const std::shared_ptr<LayerNodeType> value_type;
131   typedef std::random_access_iterator_tag iterator_category;
132   typedef std::ptrdiff_t difference_type;
133   typedef const std::shared_ptr<LayerNodeType> *pointer;
134   typedef const std::shared_ptr<LayerNodeType> &reference;
135
136   /**
137    * @brief Construct a new Graph Node Iterator object
138    *
139    * @param x underlying object of GraphNode
140    */
141   GraphNodeIterator(GraphNodeType *x) : p(x) {}
142
143   /**
144    * @brief reference operator
145    *
146    * @return value_type
147    * @note this is different from standard iterator
148    */
149   value_type operator*() const {
150     return std::static_pointer_cast<LayerNodeType>(*p);
151   }
152
153   /**
154    * @brief pointer operator
155    *
156    * @return value_type
157    * @note this is different from standard iterator
158    */
159   value_type operator->() const {
160     return std::static_pointer_cast<LayerNodeType>(*p);
161   }
162
163   /**
164    * @brief == comparison operator override
165    *
166    * @param lhs iterator lhs
167    * @param rhs iterator rhs
168    * @retval true if match
169    * @retval false if mismatch
170    */
171   friend bool operator==(GraphNodeIterator const &lhs,
172                          GraphNodeIterator const &rhs) {
173     return lhs.p == rhs.p;
174   }
175
176   /**
177    * @brief != comparison operator override
178    *
179    * @param lhs iterator lhs
180    * @param rhs iterator rhs
181    * @retval true if mismatch
182    * @retval false if match
183    */
184   friend bool operator!=(GraphNodeIterator const &lhs,
185                          GraphNodeIterator const &rhs) {
186     return lhs.p != rhs.p;
187   }
188
189   /**
190    * @brief override for ++ operator
191    *
192    * @return GraphNodeIterator&
193    */
194   GraphNodeIterator &operator++() {
195     p += 1;
196     return *this;
197   }
198
199   /**
200    * @brief override for operator++
201    *
202    * @return GraphNodeIterator
203    */
204   GraphNodeIterator operator++(int) {
205     GraphNodeIterator temp(p);
206     p += 1;
207     return temp;
208   }
209
210   /**
211    * @brief override for -- operator
212    *
213    * @return GraphNodeIterator&
214    */
215   GraphNodeIterator &operator--() {
216     p -= 1;
217     return *this;
218   }
219
220   /**
221    * @brief override for operator--
222    *
223    * @return GraphNodeIterator
224    */
225   GraphNodeIterator operator--(int) {
226     GraphNodeIterator temp(p);
227     p -= 1;
228     return temp;
229   }
230
231   /**
232    * @brief override for subtract operator
233    *
234    * @param offset offset to subtract
235    * @return GraphNodeIterator
236    */
237   GraphNodeIterator operator-(const difference_type offset) const {
238     return GraphNodeIterator(p - offset);
239   }
240
241   /**
242    * @brief override for subtract operator
243    *
244    * @param other iterator to subtract
245    * @return difference_type
246    */
247   difference_type operator-(const GraphNodeIterator &other) const {
248     return p - other.p;
249   }
250
251   /**
252    * @brief override for subtract and return result operator
253    *
254    * @param offset offset to subtract
255    * @return GraphNodeIterator&
256    */
257   GraphNodeIterator &operator-=(const difference_type offset) {
258     p -= offset;
259     return *this;
260   }
261
262   /**
263    * @brief override for add operator
264    *
265    * @param offset offset to add
266    * @return GraphNodeIterator
267    */
268   GraphNodeIterator operator+(const difference_type offset) const {
269     return GraphNodeIterator(p + offset);
270   }
271
272   /**
273    * @brief override for add and return result operator
274    *
275    * @param offset offset to add
276    * @return GraphNodeIterator&
277    */
278   GraphNodeIterator &operator+=(const difference_type offset) {
279     p += offset;
280     return *this;
281   }
282 };
283
284 /**
285  * @brief   Reverse Iterator for GraphNode which return LayerNode object upon
286  * realize
287  *
288  * @note    This just extends GraphNodeIterator and is limited by its
289  * functionality.
290  */
291 template <typename T_iterator>
292 class GraphNodeReverseIterator : public std::reverse_iterator<T_iterator> {
293 public:
294   /**
295    * @brief Construct a new Graph Node Reverse Iterator object
296    *
297    * @param iter Iterator
298    */
299   explicit GraphNodeReverseIterator(T_iterator iter) :
300     std::reverse_iterator<T_iterator>(iter) {}
301
302   /**
303    *  @brief reference operator
304    *
305    * @return T_iterator::value_type
306    * @note this is different from standard iterator
307    */
308   typename T_iterator::value_type operator*() const {
309     auto temp = std::reverse_iterator<T_iterator>::current - 1;
310     return *temp;
311   }
312
313   /**
314    *  @brief pointer operator
315    *
316    * @return T_iterator::value_type
317    * @note this is different from standard iterator
318    */
319   typename T_iterator::value_type operator->() const {
320     auto temp = std::reverse_iterator<T_iterator>::current - 1;
321     return *temp;
322   }
323 };
324
325 /**
326  * @brief     Iterators to traverse the graph
327  */
328 template <class LayerNodeType>
329 using graph_const_iterator =
330   GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>;
331
332 /**
333  * @brief     Iterators to traverse the graph
334  */
335 template <class LayerNodeType>
336 using graph_const_reverse_iterator = GraphNodeReverseIterator<
337   GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>>;
338
339 } // namespace nntrainer
340 #endif // __GRAPH_NODE_H__