arm_compute v18.05
[platform/upstream/armcl.git] / arm_compute / graph / INode.h
1 /*
2  * Copyright (c) 2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef __ARM_COMPUTE_GRAPH_INODE_H__
25 #define __ARM_COMPUTE_GRAPH_INODE_H__
26
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/graph/TensorDescriptor.h"
29 #include "arm_compute/graph/Types.h"
30
31 #include <set>
32
33 namespace arm_compute
34 {
35 namespace graph
36 {
37 // Forward declarations
38 class Graph;
39 class Edge;
40 class INodeVisitor;
41 class Tensor;
42
43 /** Node interface */
44 class INode
45 {
46 public:
47     /** Constructor */
48     INode();
49     /** Destructor **/
50     virtual ~INode() = default;
51     /** Prevent instances of this class from being copied (As this class contains pointers) */
52     INode(const INode &) = delete;
53     /** Prevent instances of this class from being copy assigned (As this class contains pointers) */
54     INode &operator=(const INode &) = delete;
55     /** Allow instances of this class to be moved */
56     INode(INode &&) = default;
57     /** Allow instances of this class to be move assigned */
58     INode &operator=(INode &&) = default;
59     /** Validate node
60      *
61      * @return Status containing any errors
62      */
63     virtual Status validate() const;
64     /** Returns node's type
65      *
66      * @return Node's type
67      */
68     virtual NodeType type() const = 0;
69     /** Accepts a node visitor
70      *
71      * @param[in] v Visitor to accept
72      */
73     virtual void accept(INodeVisitor &v) = 0;
74     /** Forwards descriptor information to outputs if possible
75      *
76      * @return True if descriptor information could be forwarded otherwise false
77      */
78     virtual bool forward_descriptors() = 0;
79     /** Calculates output configuration
80      *
81      * @param[in] idx Output index to configure
82      *
83      * @return Output descriptor configuration
84      */
85     virtual TensorDescriptor configure_output(size_t idx) const = 0;
86     /** Returns node's name
87      *
88      * @return Node name
89      */
90     std::string name() const;
91     /** Returns node's ID
92      *
93      * @return Node's ID
94      */
95     NodeID id() const;
96     /** Returns node's Graph
97      *
98      * @return Node's graph
99      */
100     const Graph *graph() const;
101     /** Returns node's Graph
102      *
103      * @return Node's graph
104      */
105     Graph *graph();
106     /** Sets the graph that this node is registered to
107      *
108      * @param[in] g Back reference to graph
109      */
110     void set_graph(Graph *g);
111     /** Sets the node id
112      *
113      * @param[in] id Node id
114      */
115     void set_id(NodeID id);
116     /** Sets common node parameters
117      *
118      * @param[in] common_params Common node parameters to set
119      */
120     void set_common_node_parameters(NodeParams common_params);
121     /** Sets target preference
122      *
123      * @note This is not the target that the graph executor might choose, its just an indication
124      *
125      * @param[in] target Target preference
126      */
127     void set_requested_target(Target target);
128     /** Sets the final execution target
129      *
130      * @note GraphManager might change this target
131      *
132      * @param[in] target Final execution target
133      */
134     void set_assigned_target(Target target);
135     /** Sets the output tensor of at a given index
136      *
137      * @note All edges will get updated
138      *
139      * @param[in] tid Tensor ID
140      * @param[in] idx Output index
141      */
142     void set_output_tensor(TensorID tid, size_t idx);
143     /** Returns inputs of the node
144      *
145      * @return Inputs of the node
146      */
147     const std::vector<TensorID> &inputs() const;
148     /** Returns outputs of the node
149      *
150      * @return Outputs of the node
151      */
152     const std::vector<TensorID> &outputs() const;
153     /** Returns input edge set
154      *
155      * @return Set of input edges
156      */
157     const std::vector<EdgeID> &input_edges() const;
158     /** Returns output edge set
159      *
160      * @return Set of output edges
161      */
162     const std::set<EdgeID> &output_edges() const;
163     /** Returns the tensor ID of a given input of the node
164      *
165      * @note Precondition : idx should be a valid input index
166      *
167      * @param[in] idx Index of the node input
168      *
169      * @return TensorID of the requested input
170      */
171     TensorID input_id(size_t idx) const;
172     /** Returns the tensor ID of a given output of the node
173      *
174      * @note Precondition : idx should be a valid output index
175      *
176      * @param[in] idx Index of the node output
177      *
178      * @return TensorID of the requested output
179      */
180     TensorID output_id(size_t idx) const;
181     /** Returns the tensor of a given input of the node
182      *
183      * @note Precondition : idx should be a valid input index
184      *
185      * @param[in] idx Index of the node input
186      *
187      * @return Tensor of the requested input
188      */
189     Tensor *input(size_t idx) const;
190     /** Returns the tensor of a given output of the node
191      *
192      * @note Precondition : idx should be a valid output index
193      *
194      * @param[in] idx Index of the node output
195      *
196      * @return Tensor of the requested output
197      */
198     Tensor *output(size_t idx) const;
199     /** Returns the edge ID of a given input of the node
200      *
201      * @note Precondition : idx should be a valid input index
202      *
203      * @param[in] idx Index of the node input
204      *
205      * @return EdgeID of the requested input
206      */
207     EdgeID input_edge_id(size_t idx) const;
208     /** Returns the edge of a given input of the node
209      *
210      * @note Precondition : idx should be a valid input index
211      *
212      * @param[in] idx Index of the node input
213      *
214      * @return Edge of the requested input
215      */
216     Edge *input_edge(size_t idx) const;
217     /** Returns number of inputs of the node
218      *
219      * @return Number of inputs
220      */
221     size_t num_inputs() const;
222     /** Returns number of outputs of the node
223      *
224      * @return Number of outputs
225      */
226     size_t num_outputs() const;
227     /** Returns requested target for this node
228      *
229      * @return Requested execution target
230      */
231     Target requested_target() const;
232     /** Returns assigned target for this node
233      *
234      * @return Assigned target of this node
235      */
236     Target assigned_target() const;
237
238 protected:
239     friend class Graph;
240
241 protected:
242     Graph                *_graph;           /**< Backward reference to graph owning the node */
243     NodeID                _id;              /**< Node ID */
244     NodeParams            _common_params;   /**< Node common params */
245     std::vector<TensorID> _outputs;         /**< Output of the node */
246     std::vector<EdgeID>   _input_edges;     /**< Inputs edge set */
247     std::set<EdgeID>      _output_edges;    /**< Output edge set */
248     Target                _assigned_target; /**< Assigned target by the Graph executor */
249 };
250 } // namespace graph
251 } // namespace arm_compute
252 #endif /* __ARM_COMPUTE_GRAPH_INODE_H__ */