1 // SPDX-License-Identifier: Apache-2.0
3 * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
7 * @see https://github.com/nnstreamer/nntrainer
8 * @author Parichay Kapoor <pk.kapoor@samsung.com>
9 * @bug No known bugs except for NYI items
10 * @brief This is the graph node interface for c++ API
13 #ifndef __GRAPH_NODE_H__
14 #define __GRAPH_NODE_H__
24 * @class Layer Base class for the graph node
25 * @brief Base class for all layers
30 * @brief Provides the time/order at which the node will be executed.
31 * @details This time will be finalized once the graph has been calculated.
32 * Each element indicates the orders with which the below operations
33 * for each node are executed:
38 * One constraint is that they must be sorted in ascending order.
39 * This ensures that the operations are executed in the order of their
42 typedef std::tuple<unsigned int, unsigned int, unsigned int, unsigned int> ExecutionOrder;
45 * @brief Destructor of Layer Class
47 virtual ~GraphNode() = default;
50 * @brief Get the Name of the underlying object
52 * @return std::string Name of the underlying object
53 * @note name of each node in the graph must be unique
55 virtual const std::string getName() const noexcept = 0;
58 * @brief Set the Name of the underlying object
60 * @param[in] std::string Name for the underlying object
61 * @note name of each node in the graph must be unique, and caller must ensure
64 virtual void setName(const std::string &name) = 0;
67 * @brief Get the Type of the underlying object
69 * @return const std::string type representation
71 virtual const std::string getType() const = 0;
74 * @brief Get the input connections for this node
76 * @return list of name of the nodes which form input connections
78 virtual const std::vector<std::string> getInputConnections() const = 0;
81 * @brief Get the output connections for this node
83 * @return list of name of the nodes which form output connections
85 virtual const std::vector<std::string> getOutputConnections() const = 0;
88 * @brief get the execution order/location of this node
90 * @retval the execution order/location of this node
91 * @details The two values represents the value for forward and backward
94 virtual ExecutionOrder getExecutionOrder() const = 0;
97 * @brief set the execution order/location of this node
99 * @param exec_order the execution order/location of this node
100 * @details The two values represents the value for forward and backward
103 virtual void setExecutionOrder(ExecutionOrder exec_order_) = 0;
107 * @brief Iterator for GraphNode which return const
108 * std::shared_ptr<LayerNodeType> object upon realize
110 * @note This does not include the complete list of required functions. Add
113 * @note GraphNodeType is to enable for both GraphNode and const GraphNode
115 template <typename LayerNodeType, typename GraphNodeType>
116 class GraphNodeIterator
117 : public std::iterator<std::random_access_iterator_tag, GraphNodeType> {
118 GraphNodeType *p; /** underlying object of GraphNode */
122 * @brief iterator_traits types definition
124 * @note these are not requried to be explicitly defined now, but maintains
125 * forward compatibility for c++17 and later
127 * @note value_type, pointer and reference are different from standard
130 typedef const std::shared_ptr<LayerNodeType> value_type;
131 typedef std::random_access_iterator_tag iterator_category;
132 typedef std::ptrdiff_t difference_type;
133 typedef const std::shared_ptr<LayerNodeType> *pointer;
134 typedef const std::shared_ptr<LayerNodeType> &reference;
137 * @brief Construct a new Graph Node Iterator object
139 * @param x underlying object of GraphNode
141 GraphNodeIterator(GraphNodeType *x) : p(x) {}
144 * @brief reference operator
147 * @note this is different from standard iterator
149 value_type operator*() const {
150 return std::static_pointer_cast<LayerNodeType>(*p);
154 * @brief pointer operator
157 * @note this is different from standard iterator
159 value_type operator->() const {
160 return std::static_pointer_cast<LayerNodeType>(*p);
164 * @brief == comparison operator override
166 * @param lhs iterator lhs
167 * @param rhs iterator rhs
168 * @retval true if match
169 * @retval false if mismatch
171 friend bool operator==(GraphNodeIterator const &lhs,
172 GraphNodeIterator const &rhs) {
173 return lhs.p == rhs.p;
177 * @brief != comparison operator override
179 * @param lhs iterator lhs
180 * @param rhs iterator rhs
181 * @retval true if mismatch
182 * @retval false if match
184 friend bool operator!=(GraphNodeIterator const &lhs,
185 GraphNodeIterator const &rhs) {
186 return lhs.p != rhs.p;
190 * @brief override for ++ operator
192 * @return GraphNodeIterator&
194 GraphNodeIterator &operator++() {
200 * @brief override for operator++
202 * @return GraphNodeIterator
204 GraphNodeIterator operator++(int) {
205 GraphNodeIterator temp(p);
211 * @brief override for -- operator
213 * @return GraphNodeIterator&
215 GraphNodeIterator &operator--() {
221 * @brief override for operator--
223 * @return GraphNodeIterator
225 GraphNodeIterator operator--(int) {
226 GraphNodeIterator temp(p);
232 * @brief override for subtract operator
234 * @param offset offset to subtract
235 * @return GraphNodeIterator
237 GraphNodeIterator operator-(const difference_type offset) const {
238 return GraphNodeIterator(p - offset);
242 * @brief override for subtract operator
244 * @param other iterator to subtract
245 * @return difference_type
247 difference_type operator-(const GraphNodeIterator &other) const {
252 * @brief override for subtract and return result operator
254 * @param offset offset to subtract
255 * @return GraphNodeIterator&
257 GraphNodeIterator &operator-=(const difference_type offset) {
263 * @brief override for add operator
265 * @param offset offset to add
266 * @return GraphNodeIterator
268 GraphNodeIterator operator+(const difference_type offset) const {
269 return GraphNodeIterator(p + offset);
273 * @brief override for add and return result operator
275 * @param offset offset to add
276 * @return GraphNodeIterator&
278 GraphNodeIterator &operator+=(const difference_type offset) {
285 * @brief Reverse Iterator for GraphNode which return LayerNode object upon
288 * @note This just extends GraphNodeIterator and is limited by its
291 template <typename T_iterator>
292 class GraphNodeReverseIterator : public std::reverse_iterator<T_iterator> {
295 * @brief Construct a new Graph Node Reverse Iterator object
297 * @param iter Iterator
299 explicit GraphNodeReverseIterator(T_iterator iter) :
300 std::reverse_iterator<T_iterator>(iter) {}
303 * @brief reference operator
305 * @return T_iterator::value_type
306 * @note this is different from standard iterator
308 typename T_iterator::value_type operator*() const {
309 auto temp = std::reverse_iterator<T_iterator>::current - 1;
314 * @brief pointer operator
316 * @return T_iterator::value_type
317 * @note this is different from standard iterator
319 typename T_iterator::value_type operator->() const {
320 auto temp = std::reverse_iterator<T_iterator>::current - 1;
326 * @brief Iterators to traverse the graph
328 template <class LayerNodeType>
329 using graph_const_iterator =
330 GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>;
333 * @brief Iterators to traverse the graph
335 template <class LayerNodeType>
336 using graph_const_reverse_iterator = GraphNodeReverseIterator<
337 GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>>;
339 } // namespace nntrainer
340 #endif // __GRAPH_NODE_H__