2 * Copyright (c) 2018 ARM Limited.
4 * SPDX-License-Identifier: MIT
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 #ifndef __ARM_COMPUTE_GRAPH_INODE_H__
25 #define __ARM_COMPUTE_GRAPH_INODE_H__
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/graph/TensorDescriptor.h"
29 #include "arm_compute/graph/Types.h"
37 // Forward declarations
50 virtual ~INode() = default;
51 /** Prevent instances of this class from being copied (As this class contains pointers) */
52 INode(const INode &) = delete;
53 /** Prevent instances of this class from being copy assigned (As this class contains pointers) */
54 INode &operator=(const INode &) = delete;
55 /** Allow instances of this class to be moved */
56 INode(INode &&) = default;
57 /** Allow instances of this class to be move assigned */
58 INode &operator=(INode &&) = default;
61 * @return Status containing any errors
63 virtual Status validate() const;
64 /** Returns node's type
68 virtual NodeType type() const = 0;
69 /** Accepts a node visitor
71 * @param[in] v Visitor to accept
73 virtual void accept(INodeVisitor &v) = 0;
74 /** Forwards descriptor information to outputs if possible
76 * @return True if descriptor information could be forwarded otherwise false
78 virtual bool forward_descriptors() = 0;
79 /** Calculates output configuration
81 * @param[in] idx Output index to configure
83 * @return Output descriptor configuration
85 virtual TensorDescriptor configure_output(size_t idx) const = 0;
86 /** Returns node's name
90 std::string name() const;
96 /** Returns node's Graph
98 * @return Node's graph
100 const Graph *graph() const;
101 /** Returns node's Graph
103 * @return Node's graph
106 /** Sets the graph that this node is registered to
108 * @param[in] g Back reference to graph
110 void set_graph(Graph *g);
113 * @param[in] id Node id
115 void set_id(NodeID id);
116 /** Sets common node parameters
118 * @param[in] common_params Common node parameters to set
120 void set_common_node_parameters(NodeParams common_params);
121 /** Sets target preference
123 * @note This is not the target that the graph executor might choose, its just an indication
125 * @param[in] target Target preference
127 void set_requested_target(Target target);
128 /** Sets the final execution target
130 * @note GraphManager might change this target
132 * @param[in] target Final execution target
134 void set_assigned_target(Target target);
135 /** Sets the output tensor of at a given index
137 * @note All edges will get updated
139 * @param[in] tid Tensor ID
140 * @param[in] idx Output index
142 void set_output_tensor(TensorID tid, size_t idx);
143 /** Returns inputs of the node
145 * @return Inputs of the node
147 const std::vector<TensorID> &inputs() const;
148 /** Returns outputs of the node
150 * @return Outputs of the node
152 const std::vector<TensorID> &outputs() const;
153 /** Returns input edge set
155 * @return Set of input edges
157 const std::vector<EdgeID> &input_edges() const;
158 /** Returns output edge set
160 * @return Set of output edges
162 const std::set<EdgeID> &output_edges() const;
163 /** Returns the tensor ID of a given input of the node
165 * @note Precondition : idx should be a valid input index
167 * @param[in] idx Index of the node input
169 * @return TensorID of the requested input
171 TensorID input_id(size_t idx) const;
172 /** Returns the tensor ID of a given output of the node
174 * @note Precondition : idx should be a valid output index
176 * @param[in] idx Index of the node output
178 * @return TensorID of the requested output
180 TensorID output_id(size_t idx) const;
181 /** Returns the tensor of a given input of the node
183 * @note Precondition : idx should be a valid input index
185 * @param[in] idx Index of the node input
187 * @return Tensor of the requested input
189 Tensor *input(size_t idx) const;
190 /** Returns the tensor of a given output of the node
192 * @note Precondition : idx should be a valid output index
194 * @param[in] idx Index of the node output
196 * @return Tensor of the requested output
198 Tensor *output(size_t idx) const;
199 /** Returns the edge ID of a given input of the node
201 * @note Precondition : idx should be a valid input index
203 * @param[in] idx Index of the node input
205 * @return EdgeID of the requested input
207 EdgeID input_edge_id(size_t idx) const;
208 /** Returns the edge of a given input of the node
210 * @note Precondition : idx should be a valid input index
212 * @param[in] idx Index of the node input
214 * @return Edge of the requested input
216 Edge *input_edge(size_t idx) const;
217 /** Returns number of inputs of the node
219 * @return Number of inputs
221 size_t num_inputs() const;
222 /** Returns number of outputs of the node
224 * @return Number of outputs
226 size_t num_outputs() const;
227 /** Returns requested target for this node
229 * @return Requested execution target
231 Target requested_target() const;
232 /** Returns assigned target for this node
234 * @return Assigned target of this node
236 Target assigned_target() const;
242 Graph *_graph; /**< Backward reference to graph owning the node */
243 NodeID _id; /**< Node ID */
244 NodeParams _common_params; /**< Node common params */
245 std::vector<TensorID> _outputs; /**< Output of the node */
246 std::vector<EdgeID> _input_edges; /**< Inputs edge set */
247 std::set<EdgeID> _output_edges; /**< Output edge set */
248 Target _assigned_target; /**< Assigned target by the Graph executor */
251 } // namespace arm_compute
252 #endif /* __ARM_COMPUTE_GRAPH_INODE_H__ */