arm_compute v18.05
[platform/upstream/armcl.git] / src / graph / detail / ExecutionHelpers.cpp
1 /*
2  * Copyright (c) 2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph/detail/ExecutionHelpers.h"
25
26 #include "arm_compute/graph/Graph.h"
27 #include "arm_compute/graph/GraphContext.h"
28 #include "arm_compute/graph/GraphManager.h"
29 #include "arm_compute/graph/Tensor.h"
30 #include "arm_compute/graph/backends/BackendRegistry.h"
31
32 namespace arm_compute
33 {
34 namespace graph
35 {
36 namespace detail
37 {
38 void default_initialize_backends()
39 {
40     for(const auto &backend : backends::BackendRegistry::get().backends())
41     {
42         backend.second->initialize_backend();
43     }
44 }
45
46 void validate_all_nodes(Graph &g)
47 {
48     auto &nodes = g.nodes();
49
50     // Create tasks
51     for(auto &node : nodes)
52     {
53         if(node != nullptr)
54         {
55             Target assigned_target = node->assigned_target();
56             auto   backend         = backends::BackendRegistry::get().find_backend(assigned_target);
57             ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
58             Status status = backend->validate_node(*node);
59             ARM_COMPUTE_ERROR_ON_MSG(!bool(status), status.error_description().c_str());
60         }
61     }
62 }
63
64 void configure_all_tensors(Graph &g)
65 {
66     auto &tensors = g.tensors();
67
68     for(auto &tensor : tensors)
69     {
70         if(tensor)
71         {
72             Target target  = tensor->desc().target;
73             auto   backend = backends::BackendRegistry::get().find_backend(target);
74             ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
75             auto handle = backend->create_tensor(*tensor);
76             ARM_COMPUTE_ERROR_ON_MSG(!backend, "Couldn't create backend handle!");
77             tensor->set_handle(std::move(handle));
78         }
79     }
80 }
81
82 void allocate_all_input_tensors(INode &node)
83 {
84     for(unsigned int i = 0; i < node.num_inputs(); ++i)
85     {
86         Tensor *tensor = node.input(i);
87         if(tensor != nullptr && !tensor->bound_edges().empty())
88         {
89             ARM_COMPUTE_ERROR_ON_MSG(!tensor->handle(), "Tensor handle is not configured!");
90             tensor->handle()->allocate();
91         }
92     }
93 }
94
95 void allocate_all_output_tensors(INode &node)
96 {
97     for(unsigned int i = 0; i < node.num_outputs(); ++i)
98     {
99         Tensor *tensor = node.output(i);
100         if(tensor != nullptr && !tensor->bound_edges().empty())
101         {
102             ARM_COMPUTE_ERROR_ON_MSG(!tensor->handle(), "Tensor handle is not configured!");
103             tensor->handle()->allocate();
104         }
105     }
106 }
107
108 void allocate_const_tensors(Graph &g)
109 {
110     for(auto &node : g.nodes())
111     {
112         if(node != nullptr)
113         {
114             switch(node->type())
115             {
116                 case NodeType::Const:
117                 case NodeType::Input:
118                     allocate_all_output_tensors(*node);
119                     break;
120                 case NodeType::Output:
121                     allocate_all_input_tensors(*node);
122                 default:
123                     break;
124             }
125         }
126     }
127 }
128
129 void allocate_all_tensors(Graph &g)
130 {
131     auto &tensors = g.tensors();
132
133     for(auto &tensor : tensors)
134     {
135         if(tensor && !tensor->bound_edges().empty() && tensor->handle() != nullptr && tensor->handle()->tensor().info()->is_resizable() && tensor->handle()->tensor().is_used())
136         {
137             tensor->handle()->allocate();
138         }
139     }
140 }
141
142 ExecutionWorkload configure_all_nodes(Graph &g, GraphContext &ctx)
143 {
144     ExecutionWorkload workload;
145     workload.graph = &g;
146     workload.ctx   = &ctx;
147
148     auto &nodes = g.nodes();
149
150     // Create tasks
151     for(auto &node : nodes)
152     {
153         if(node != nullptr)
154         {
155             Target assigned_target = node->assigned_target();
156             auto   backend         = backends::BackendRegistry::get().find_backend(assigned_target);
157             ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
158             auto func = backend->configure_node(*node, ctx);
159             if(func != nullptr)
160             {
161                 ExecutionTask task;
162                 task.task = std::move(func);
163                 task.node = node.get();
164                 workload.tasks.push_back(std::move(task));
165             }
166         }
167     }
168
169     // Add inputs and outputs
170     for(auto &node : nodes)
171     {
172         if(node != nullptr && node->type() == NodeType::Input)
173         {
174             workload.inputs.push_back(node->output(0));
175         }
176
177         if(node != nullptr && node->type() == NodeType::Output)
178         {
179             workload.outputs.push_back(node->input(0));
180             continue;
181         }
182     }
183
184     return workload;
185 }
186
187 void release_unused_tensors(Graph &g)
188 {
189     for(auto &tensor : g.tensors())
190     {
191         if(tensor != nullptr && tensor->handle() != nullptr)
192         {
193             tensor->handle()->release_if_unused();
194         }
195     }
196 }
197
198 void call_tensor_accessor(Tensor *tensor)
199 {
200     ARM_COMPUTE_ERROR_ON(!tensor);
201     tensor->call_accessor();
202 }
203
204 void call_all_const_node_accessors(Graph &g)
205 {
206     auto &nodes = g.nodes();
207
208     for(auto &node : nodes)
209     {
210         if(node != nullptr && node->type() == NodeType::Const)
211         {
212             call_tensor_accessor(node->output(0));
213         }
214     }
215 }
216
217 void call_all_input_node_accessors(ExecutionWorkload &workload)
218 {
219     for(auto &input : workload.inputs)
220     {
221         if(input != nullptr)
222         {
223             input->call_accessor();
224         }
225     }
226 }
227
228 void prepare_all_tasks(ExecutionWorkload &workload)
229 {
230     ARM_COMPUTE_ERROR_ON(workload.graph == nullptr);
231     for(auto &task : workload.tasks)
232     {
233         task.prepare();
234         release_unused_tensors(*workload.graph);
235     }
236 }
237
238 void call_all_tasks(ExecutionWorkload &workload)
239 {
240     ARM_COMPUTE_ERROR_ON(workload.ctx == nullptr);
241
242     // Acquire memory for the transition buffers
243     for(auto &mm_ctx : workload.ctx->memory_managers())
244     {
245         if(mm_ctx.second.cross_group != nullptr)
246         {
247             mm_ctx.second.cross_group->acquire();
248         }
249     }
250
251     // Execute tasks
252     for(auto &task : workload.tasks)
253     {
254         task();
255     }
256
257     // Release memory for the transition buffers
258     for(auto &mm_ctx : workload.ctx->memory_managers())
259     {
260         if(mm_ctx.second.cross_group != nullptr)
261         {
262             mm_ctx.second.cross_group->release();
263         }
264     }
265 }
266
267 void call_all_output_node_accessors(ExecutionWorkload &workload)
268 {
269     for(auto &output : workload.outputs)
270     {
271         if(output != nullptr)
272         {
273             output->call_accessor();
274         }
275     }
276 }
277 } // namespace detail
278 } // namespace graph
279 } // namespace arm_compute