arm_compute v18.05
[platform/upstream/armcl.git] / arm_compute / graph / Graph.h
1 /*
2  * Copyright (c) 2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef __ARM_COMPUTE_GRAPH_GRAPH_H__
25 #define __ARM_COMPUTE_GRAPH_GRAPH_H__
26
27 #include "arm_compute/graph/Edge.h"
28 #include "arm_compute/graph/INode.h"
29 #include "arm_compute/graph/Tensor.h"
30 #include "arm_compute/graph/Types.h"
31
32 #include "support/Mutex.h"
33 #include "support/ToolchainSupport.h"
34
35 #include <map>
36 #include <memory>
37 #include <string>
38 #include <thread>
39 #include <utility>
40 #include <vector>
41
42 namespace arm_compute
43 {
44 namespace graph
45 {
46 /** Graph class
47  *
48  * Represents a multiple source - multiple sink directed graph
49  */
50 class Graph final
51 {
52 public:
53     Graph() = default;
54     /** Constructor
55      *
56      * @param[in] id   Graph identification number. Can be used to differentiate between graphs. Default value 0
57      * @param[in] name Graph name. Default value empty string
58      */
59     Graph(GraphID id, std::string name);
60     /** Prevent instances of this class from being copied (As this class contains pointers) */
61     Graph(const Graph &) = delete;
62     /** Prevent instances of this class from being copy assigned (As this class contains pointers) */
63     Graph &operator=(const Graph &) = delete;
64     /** Allow instances of this class to be moved */
65     Graph(Graph &&) = default;
66     /** Allow instances of this class to be move assigned */
67     Graph &operator=(Graph &&) = default;
68     /** Adds a node to the graph
69      *
70      * @note Models a single output node
71      *
72      * @tparam NT Node operation
73      * @tparam Ts Arguments to operation
74      *
75      * @param args Node arguments
76      *
77      * @return ID of the node
78      */
79     template <typename NT, typename... Ts>
80     NodeID add_node(Ts &&... args);
81     /** Remove the node with the given ID
82      *
83      * @param[in] nid ID of the node to remove
84      *
85      * @return True if the removal took place else false
86      */
87     bool remove_node(NodeID nid);
88     /** Adds a connection between two nodes
89      *
90      * @param[in] source     ID of the source node
91      * @param[in] source_idx Output index of the source node
92      * @param[in] sink       ID of the sink node
93      * @param[in] sink_idx   Input index of the sink node
94      *
95      * @return ID of this connection
96      */
97     EdgeID add_connection(NodeID source, size_t source_idx, NodeID sink, size_t sink_idx);
98     /** Removes an edge (connection)
99      *
100      * @param[in] eid Connection to remove
101      *
102      * @return True if the removal took place else false
103      */
104     bool remove_connection(EdgeID eid);
105     /** Returns graph name
106      *
107      * @return Graph name
108      */
109     std::string name() const;
110     /** Returns graph id
111      *
112      * @return Graph id
113      */
114     GraphID id() const;
115     /** Returns graph input nodes
116      *
117      * @return vector containing the graph inputs
118      */
119     const std::vector<NodeID> &inputs();
120     /** Returns nodes of graph
121      *
122      * @warning Nodes can be nullptr if they have been removed during the mutation steps of the graph
123      *
124      * @return Nodes of graph
125      */
126     std::vector<std::unique_ptr<INode>> &nodes();
127     /** Returns nodes of graph
128      *
129      * @warning Nodes can be nullptr if they have been removed during the mutation steps of the graph
130      *
131      * @return Nodes of graph
132      */
133     const std::vector<std::unique_ptr<INode>> &nodes() const;
134     /** Returns edges of graph
135      *
136      * @warning Edges can be nullptr if they have been removed during the mutation steps of the graph
137      *
138      * @return Edges of graph
139      */
140     const std::vector<std::unique_ptr<Edge>> &edges() const;
141     /** Returns tensors of graph
142      *
143      * @warning Tensor can be nullptr if they have been removed during the mutation steps of the graph
144      *
145      * @return Tensors of graph
146      */
147     std::vector<std::unique_ptr<Tensor>> &tensors();
148     /** Returns tensors of graph
149      *
150      * @warning Tensor can be nullptr if they have been removed during the mutation steps of the graph
151      *
152      * @return Tensors of graph
153      */
154     const std::vector<std::unique_ptr<Tensor>> &tensors() const;
155     /** Get node object given its id
156      *
157      * @warning Can be nullptr if node was removed during the mutation steps of the graph
158      *
159      * @param[in] id Node ID
160      *
161      * @return The actual node object
162      */
163     const INode *node(NodeID id) const;
164     /** Get node object given its id
165      *
166      * @warning Can be nullptr if node was removed during the mutation steps of the graph
167      *
168      * @param[in] id Node ID
169      *
170      * @return The actual node object
171      */
172     INode *node(NodeID id);
173     /** Get edge object given its id
174      *
175      * @warning Can be nullptr if node was removed during the mutation steps of the graph
176      *
177      * @param[in] id Edge ID
178      *
179      * @return The actual edge object
180      */
181     const Edge *edge(EdgeID id) const;
182     /** Get edge object given its id
183      *
184      * @warning Can be nullptr if node was removed during the mutation steps of the graph
185      *
186      * @param[in] id Edge ID
187      *
188      * @return The actual edge object
189      */
190     Edge *edge(EdgeID id);
191     /** Get tensor object given its id
192      *
193      * @warning Can be nullptr if tensor was removed during the mutation steps of the graph
194      *
195      * @param[in] id Tensor ID
196      *
197      * @return The actual tensor object
198      */
199     const Tensor *tensor(TensorID id) const;
200     /** Get tensor object given its id
201      *
202      * @warning Can be nullptr if tensor was removed during the mutation steps of the graph
203      *
204      * @param[in] id Tensor ID
205      *
206      * @return The actual tensor object
207      */
208     Tensor *tensor(TensorID id);
209
210 private:
211     /** Creates a tensor object
212      *
213      * @param[in] desc Tensor descriptor
214      *
215      * @return Tensor ID
216      */
217     TensorID create_tensor(TensorDescriptor desc = TensorDescriptor());
218
219 private:
220     GraphID                              _id      = GraphID(0); /**< Graph id */
221     std::string                          _name    = {};         /**< Graph name */
222     std::vector<std::unique_ptr<INode>>  _nodes   = {};         /**< Graph nodes */
223     std::vector<std::unique_ptr<Edge>>   _edges   = {};         /**< Graph edges */
224     std::vector<std::unique_ptr<Tensor>> _tensors = {};         /**< Graph tensors */
225     std::map<NodeType, std::vector<NodeID>> _tagged_nodes = {}; /**< Graph nodes map with the node type as key */
226     arm_compute::Mutex _mtx = {};                               /**< Mutex used for graph construction */
227 };
228
229 template <typename NT, typename... Ts>
230 inline NodeID Graph::add_node(Ts &&... args)
231 {
232     std::lock_guard<arm_compute::Mutex> lock(_mtx);
233
234     // Create node
235     NodeID nid  = _nodes.size();
236     auto   node = support::cpp14::make_unique<NT>(std::forward<Ts>(args)...);
237     node->set_graph(this);
238     node->set_id(nid);
239
240     // Keep track of input nodes
241     if(node->type() == NodeType::Input)
242     {
243         _tagged_nodes[NodeType::Input].push_back(nid);
244     }
245
246     // Associate a new tensor with each output
247     for(auto &output : node->_outputs)
248     {
249         output = create_tensor();
250     }
251
252     // Propagate node shape if possible
253     node->forward_descriptors();
254
255     // Add node to the graph nodes
256     _nodes.push_back(std::move(node));
257
258     return nid;
259 }
260 } // namespace graph
261 } // namespace arm_compute
262 #endif /* __ARM_COMPUTE_GRAPH_GRAPH_H__ */