[Memory Planner] Update the Memory Planner
[platform/core/ml/nntrainer.git] / nntrainer / graph / graph_node.h
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4  *
5  * @file   graph_node.h
6  * @date   1 April 2021
7  * @see    https://github.com/nnstreamer/nntrainer
8  * @author Parichay Kapoor <pk.kapoor@samsung.com>
9  * @bug    No known bugs except for NYI items
10  * @brief  This is the graph node interface for c++ API
11  */
12
13 #ifndef __GRAPH_NODE_H__
14 #define __GRAPH_NODE_H__
15
16 #include <iterator>
17 #include <memory>
18 #include <string>
19 #include <vector>
20
21 namespace nntrainer {
22
23 /**
24  * @class   Layer Base class for the graph node
25  * @brief   Base class for all layers
26  */
27 class GraphNode {
28 public:
29   /**
30    * @brief Provides the time/order at which the node will be executed.
31    * @details This time will be finalized once the graph has been calculated.
32    * Each element indicates the orders with which the below operations
33    * for each node are executed:
34    * 1. Forwarding
35    * 2. calcGradient
36    * 3. calcDerivative
37    * 4. ApplyGradient
38    * One constraint is that they must be sorted in ascending order.
39    * This ensures that the operations are executed in the order of their
40    * listing.
41    */
42   typedef std::tuple<unsigned int, unsigned int, unsigned int, unsigned int>
43     ExecutionOrder;
44
45   /**
46    * @brief     Destructor of Layer Class
47    */
48   virtual ~GraphNode() = default;
49
50   /**
51    * @brief     Get the Name of the underlying object
52    *
53    * @return std::string Name of the underlying object
54    * @note name of each node in the graph must be unique
55    */
56   virtual const std::string getName() const noexcept = 0;
57
58   /**
59    * @brief     Set the Name of the underlying object
60    *
61    * @param[in] std::string Name for the underlying object
62    * @note name of each node in the graph must be unique, and caller must ensure
63    * that
64    */
65   virtual void setName(const std::string &name) = 0;
66
67   /**
68    * @brief     Get the Type of the underlying object
69    *
70    * @return const std::string type representation
71    */
72   virtual const std::string getType() const = 0;
73
74   /**
75    * @brief     Get the trainable parameter
76    *
77    * @return bool true / false
78    */
79   virtual bool getTrainable() const = 0;
80
81   /**
82    * @brief     Get the input connections for this node
83    *
84    * @return list of name of the nodes which form input connections
85    */
86   virtual const std::vector<std::string> getInputConnections() const = 0;
87
88   /**
89    * @brief     Get the output connections for this node
90    *
91    * @return list of name of the nodes which form output connections
92    */
93   virtual const std::vector<std::string> getOutputConnections() const = 0;
94
95   /**
96    * @brief     get the execution order/location of this node
97    *
98    * @retval    the execution order/location of this node
99    * @details   The two values represents the value for forward and backward
100    * respectively
101    */
102   virtual ExecutionOrder getExecutionOrder() const = 0;
103
104   /**
105    * @brief     set the execution order/location of this node
106    *
107    * @param     exec_order the execution order/location of this node
108    * @details   The two values represents the value for forward and backward
109    * respectively
110    */
111   virtual void setExecutionOrder(ExecutionOrder exec_order_) = 0;
112 };
113
114 /**
115  * @brief   Iterator for GraphNode which return const
116  * std::shared_ptr<LayerNodeType> object upon realize
117  *
118  * @note    This does not include the complete list of required functions. Add
119  * them as per need.
120  *
121  * @note    GraphNodeType is to enable for both GraphNode and const GraphNode
122  */
123 template <typename LayerNodeType, typename GraphNodeType>
124 class GraphNodeIterator
125   : public std::iterator<std::random_access_iterator_tag, GraphNodeType> {
126   GraphNodeType *p; /** underlying object of GraphNode */
127
128 public:
129   /**
130    * @brief   iterator_traits types definition
131    *
132    * @note    these are not requried to be explicitly defined now, but maintains
133    *          forward compatibility for c++17 and later
134    *
135    * @note    value_type, pointer and reference are different from standard
136    * iterator
137    */
138   typedef const std::shared_ptr<LayerNodeType> value_type;
139   typedef std::random_access_iterator_tag iterator_category;
140   typedef std::ptrdiff_t difference_type;
141   typedef const std::shared_ptr<LayerNodeType> *pointer;
142   typedef const std::shared_ptr<LayerNodeType> &reference;
143
144   /**
145    * @brief Construct a new Graph Node Iterator object
146    *
147    * @param x underlying object of GraphNode
148    */
149   GraphNodeIterator(GraphNodeType *x) : p(x) {}
150
151   /**
152    * @brief reference operator
153    *
154    * @return value_type
155    * @note this is different from standard iterator
156    */
157   value_type operator*() const {
158     return std::static_pointer_cast<LayerNodeType>(*p);
159   }
160
161   /**
162    * @brief pointer operator
163    *
164    * @return value_type
165    * @note this is different from standard iterator
166    */
167   value_type operator->() const {
168     return std::static_pointer_cast<LayerNodeType>(*p);
169   }
170
171   /**
172    * @brief == comparison operator override
173    *
174    * @param lhs iterator lhs
175    * @param rhs iterator rhs
176    * @retval true if match
177    * @retval false if mismatch
178    */
179   friend bool operator==(GraphNodeIterator const &lhs,
180                          GraphNodeIterator const &rhs) {
181     return lhs.p == rhs.p;
182   }
183
184   /**
185    * @brief != comparison operator override
186    *
187    * @param lhs iterator lhs
188    * @param rhs iterator rhs
189    * @retval true if mismatch
190    * @retval false if match
191    */
192   friend bool operator!=(GraphNodeIterator const &lhs,
193                          GraphNodeIterator const &rhs) {
194     return lhs.p != rhs.p;
195   }
196
197   /**
198    * @brief override for ++ operator
199    *
200    * @return GraphNodeIterator&
201    */
202   GraphNodeIterator &operator++() {
203     p += 1;
204     return *this;
205   }
206
207   /**
208    * @brief override for operator++
209    *
210    * @return GraphNodeIterator
211    */
212   GraphNodeIterator operator++(int) {
213     GraphNodeIterator temp(p);
214     p += 1;
215     return temp;
216   }
217
218   /**
219    * @brief override for -- operator
220    *
221    * @return GraphNodeIterator&
222    */
223   GraphNodeIterator &operator--() {
224     p -= 1;
225     return *this;
226   }
227
228   /**
229    * @brief override for operator--
230    *
231    * @return GraphNodeIterator
232    */
233   GraphNodeIterator operator--(int) {
234     GraphNodeIterator temp(p);
235     p -= 1;
236     return temp;
237   }
238
239   /**
240    * @brief override for subtract operator
241    *
242    * @param offset offset to subtract
243    * @return GraphNodeIterator
244    */
245   GraphNodeIterator operator-(const difference_type offset) const {
246     return GraphNodeIterator(p - offset);
247   }
248
249   /**
250    * @brief override for subtract operator
251    *
252    * @param other iterator to subtract
253    * @return difference_type
254    */
255   difference_type operator-(const GraphNodeIterator &other) const {
256     return p - other.p;
257   }
258
259   /**
260    * @brief override for subtract and return result operator
261    *
262    * @param offset offset to subtract
263    * @return GraphNodeIterator&
264    */
265   GraphNodeIterator &operator-=(const difference_type offset) {
266     p -= offset;
267     return *this;
268   }
269
270   /**
271    * @brief override for add operator
272    *
273    * @param offset offset to add
274    * @return GraphNodeIterator
275    */
276   GraphNodeIterator operator+(const difference_type offset) const {
277     return GraphNodeIterator(p + offset);
278   }
279
280   /**
281    * @brief override for add and return result operator
282    *
283    * @param offset offset to add
284    * @return GraphNodeIterator&
285    */
286   GraphNodeIterator &operator+=(const difference_type offset) {
287     p += offset;
288     return *this;
289   }
290 };
291
292 /**
293  * @brief   Reverse Iterator for GraphNode which return LayerNode object upon
294  * realize
295  *
296  * @note    This just extends GraphNodeIterator and is limited by its
297  * functionality.
298  */
299 template <typename T_iterator>
300 class GraphNodeReverseIterator : public std::reverse_iterator<T_iterator> {
301 public:
302   /**
303    * @brief Construct a new Graph Node Reverse Iterator object
304    *
305    * @param iter Iterator
306    */
307   explicit GraphNodeReverseIterator(T_iterator iter) :
308     std::reverse_iterator<T_iterator>(iter) {}
309
310   /**
311    *  @brief reference operator
312    *
313    * @return T_iterator::value_type
314    * @note this is different from standard iterator
315    */
316   typename T_iterator::value_type operator*() const {
317     auto temp = std::reverse_iterator<T_iterator>::current - 1;
318     return *temp;
319   }
320
321   /**
322    *  @brief pointer operator
323    *
324    * @return T_iterator::value_type
325    * @note this is different from standard iterator
326    */
327   typename T_iterator::value_type operator->() const {
328     auto temp = std::reverse_iterator<T_iterator>::current - 1;
329     return *temp;
330   }
331 };
332
333 /**
334  * @brief     Iterators to traverse the graph
335  */
336 template <class LayerNodeType>
337 using graph_const_iterator =
338   GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>;
339
340 /**
341  * @brief     Iterators to traverse the graph
342  */
343 template <class LayerNodeType>
344 using graph_const_reverse_iterator = GraphNodeReverseIterator<
345   GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>>;
346
347 } // namespace nntrainer
348 #endif // __GRAPH_NODE_H__