26d890d6f8f849560f0bb1f0e56b3f72fece9558
[platform/upstream/hailort.git] /
1 /**
2  * Copyright (c) 2020-2022 Hailo Technologies Ltd. All rights reserved.
3  * Distributed under the MIT license (https://opensource.org/licenses/MIT)
4  **/
5 /**
6  * @file switch_network_groups_manually_example.cpp
7  * This example demonstrates basic usage of HailoRT streaming api over multiple networks, using vstreams.
8  * It loads several HEF networks with single/multiple inputs and single/multiple outputs into a Hailo VDevice and performs a
9  * short inference on each one.
10  * After inference is finished, the example switches to the next HEF and start inference again.
11  **/
12
13 #include "hailo/hailort.hpp"
14
15 #include <iostream>
16 #include <chrono>
17
18 constexpr bool QUANTIZED = true;
19 constexpr hailo_format_type_t FORMAT_TYPE = HAILO_FORMAT_TYPE_AUTO;
20
21 constexpr size_t INFER_FRAME_COUNT = 100;
22 constexpr size_t RUN_COUNT = 10;
23 constexpr std::chrono::milliseconds WAIT_FOR_ACTIVATION_TIMEOUT_MS(10);
24 constexpr uint32_t DEVICE_COUNT = 1;
25
26 using namespace hailort;
27
28 #include <mutex>
29 #include <condition_variable>
30
31 class SyncObject final {
32 /* Synchronization class used to make sure I/O threads are blocking while their network_group is not activated  */
33 public:
34     explicit SyncObject(size_t count) : m_original_count(count), m_count(count), m_all_arrived(false), m_mutex(), m_cv(), m_is_active(true)
35         {};
36
37     /* In main thread we wait until I/O threads are done (0 == m_count),
38        signaling the I/O threads only after deactivating their network_group and resetting m_count to m_original_count */
39     void wait_all(std::unique_ptr<ActivatedNetworkGroup> &&activated_network_group)
40     {
41         if (!m_is_active.load()) {
42             return;
43         }
44         std::unique_lock<std::mutex> lock(m_mutex);
45         m_cv.wait(lock, [this] { return ((0 == m_count) || !m_is_active); });
46         activated_network_group.reset();
47         m_count = m_original_count;
48         m_all_arrived = true;
49         m_cv.notify_all();
50     }
51
52     /* In I/O threads we wait until signaled by main thread (true == m_all_arrived),
53        resetting m_all_arrived to false to make sure it was setted by 'wait_all' call */
54     void notify_and_wait()
55     {
56         if (!m_is_active.load()) {
57             return;
58         }
59         std::unique_lock<std::mutex> lock(m_mutex);
60         m_all_arrived = false;
61         --m_count;
62         m_cv.notify_all();
63         m_cv.wait(lock, [this] { return ((m_all_arrived) || !m_is_active); });
64     }
65
66     void terminate()
67     {
68         {
69             std::unique_lock<std::mutex> lock(m_mutex);
70             m_is_active.store(false);
71         }
72         m_cv.notify_all();
73     }
74
75 private:
76     const size_t m_original_count;
77     std::atomic_size_t m_count;
78     std::atomic_bool m_all_arrived;
79
80     std::mutex m_mutex;
81     std::condition_variable m_cv;
82     std::atomic_bool m_is_active;
83 };
84
85
86 void write_all(std::shared_ptr<ConfiguredNetworkGroup> network_group, InputVStream &input_vstream,
87     std::shared_ptr<SyncObject> sync_object, std::shared_ptr<std::atomic_bool> should_threads_run, hailo_status &status_out)
88 {
89     std::vector<uint8_t> buff(input_vstream.get_frame_size());
90
91     auto status = HAILO_UNINITIALIZED;
92     while (true) {
93         if (!(*should_threads_run)) {
94             break;
95         }
96         status = network_group->wait_for_activation(WAIT_FOR_ACTIVATION_TIMEOUT_MS);
97         if (HAILO_TIMEOUT == status) {
98             continue;
99         } else if (HAILO_SUCCESS != status) {
100             std::cerr << "Wait for network group activation failed. status = " << status << std::endl;
101             status_out = status;
102             return;
103         }
104
105         for (size_t i = 0; i < INFER_FRAME_COUNT; i++) {
106             status = input_vstream.write(MemoryView(buff.data(), buff.size()));
107             if (HAILO_SUCCESS != status) {
108                 status_out = status;
109                 return;
110             }
111         }
112         sync_object->notify_and_wait();
113     }
114     return;
115 }
116
117 void read_all(std::shared_ptr<ConfiguredNetworkGroup> network_group, OutputVStream &output_vstream,
118     std::shared_ptr<SyncObject> sync_object, std::shared_ptr<std::atomic_bool> should_threads_run, hailo_status &status_out)
119 {
120     std::vector<uint8_t> buff(output_vstream.get_frame_size());
121
122     auto status = HAILO_UNINITIALIZED;
123     while (true) {
124         if (!(*should_threads_run)) {
125             break;
126         }
127         status = network_group->wait_for_activation(WAIT_FOR_ACTIVATION_TIMEOUT_MS);
128         if (HAILO_TIMEOUT == status) {
129             continue;
130         } else if (HAILO_SUCCESS != status) {
131             std::cerr << "Wait for network group activation failed. status = " << status << std::endl;
132             status_out = status;
133             return;
134         }
135
136         for (size_t i = 0; i < INFER_FRAME_COUNT; i++) {
137             status = output_vstream.read(MemoryView(buff.data(), buff.size()));
138             if (HAILO_SUCCESS != status) {
139                 status_out = status;
140                 return;
141             }
142         }
143         sync_object->notify_and_wait();
144     }
145     return;
146 }
147
148 void network_group_thread_main(std::shared_ptr<ConfiguredNetworkGroup> network_group, std::shared_ptr<SyncObject> sync_object,
149     std::shared_ptr<std::atomic_bool> should_threads_run, hailo_status &status_out)
150 {
151     // Create VStreams
152     auto vstreams_exp = VStreamsBuilder::create_vstreams(*network_group, QUANTIZED, FORMAT_TYPE);
153     if (!vstreams_exp) {
154         std::cerr << "Failed to create vstreams, status = " << vstreams_exp.status() << std::endl;
155         status_out = vstreams_exp.status();
156         return;
157     }
158
159     // Create send/recv loops
160     std::vector<std::unique_ptr<std::thread>> recv_ths;
161     std::vector<hailo_status> read_results;
162     read_results.reserve(vstreams_exp->second.size());
163     for (auto &vstream : vstreams_exp->second) {
164         read_results.push_back(HAILO_SUCCESS); // Success oriented
165         recv_ths.emplace_back(std::make_unique<std::thread>(read_all,
166             network_group, std::ref(vstream), sync_object, should_threads_run, std::ref(read_results.back())));
167     }
168
169     std::vector<std::unique_ptr<std::thread>> send_ths;
170     std::vector<hailo_status> write_results;
171     write_results.reserve(vstreams_exp->first.size());
172     for (auto &vstream : vstreams_exp->first) {
173         write_results.push_back(HAILO_SUCCESS); // Success oriented
174         send_ths.emplace_back(std::make_unique<std::thread>(write_all,
175             network_group, std::ref(vstream), std::ref(sync_object), should_threads_run, std::ref(write_results.back())));
176     }
177
178     for (auto &send_th : send_ths) {
179         if (send_th->joinable()) {
180             send_th->join();
181         }
182     }
183     for (auto &recv_th : recv_ths) {
184         if (recv_th->joinable()) {
185             recv_th->join();
186         }
187     }
188
189     for (auto &status : read_results) {
190         if (HAILO_SUCCESS != status) {
191             status_out = status;
192             return;
193         }
194     }
195     for (auto &status : write_results) {
196         if (HAILO_SUCCESS != status) {
197             status_out = status;
198             return;
199         }
200     }
201     status_out = HAILO_SUCCESS;
202     return;
203 }
204
205 int main()
206 {
207     hailo_vdevice_params_t params;
208     auto status = hailo_init_vdevice_params(&params);
209     if (HAILO_SUCCESS != status) {
210         std::cerr << "Failed init vdevice_params, status = " << status << std::endl;
211         return status;
212     }
213
214     params.scheduling_algorithm = HAILO_SCHEDULING_ALGORITHM_NONE;
215     params.device_count = DEVICE_COUNT;
216     auto vdevice_exp = VDevice::create(params);
217     if (!vdevice_exp) {
218         std::cerr << "Failed create vdevice, status = " << vdevice_exp.status() << std::endl;
219         return vdevice_exp.status();
220     }
221     auto vdevice = vdevice_exp.release();
222
223     std::vector<std::string> hef_paths = {"hefs/shortcut_net.hef", "hefs/shortcut_net.hef"};
224     std::vector<std::shared_ptr<ConfiguredNetworkGroup>> configured_network_groups;
225
226     for (const auto &path : hef_paths) {
227         auto hef_exp = Hef::create(path);
228         if (!hef_exp) {
229             std::cerr << "Failed to create hef: " << path  << ", status = " << hef_exp.status() << std::endl;
230             return hef_exp.status();
231         }
232         auto hef = hef_exp.release();
233
234         auto added_network_groups = vdevice->configure(hef);
235         if (!added_network_groups) {
236             std::cerr << "Failed to configure vdevice, status = " << added_network_groups.status() << std::endl;
237             return added_network_groups.status();
238         }
239         configured_network_groups.insert(configured_network_groups.end(), added_network_groups->begin(), added_network_groups->end());
240     }
241
242     auto should_threads_run = std::make_shared<std::atomic_bool>(true);
243
244     std::vector<std::shared_ptr<SyncObject>> sync_objects;
245     sync_objects.reserve(configured_network_groups.size());
246     std::vector<hailo_status> threads_results;
247     threads_results.reserve(configured_network_groups.size());
248     std::vector<std::unique_ptr<std::thread>> network_group_threads;
249     network_group_threads.reserve(configured_network_groups.size());
250
251     for (auto network_group : configured_network_groups) {
252         threads_results.push_back(HAILO_UNINITIALIZED);
253         auto vstream_infos = network_group->get_all_vstream_infos();
254         if (!vstream_infos) {
255             std::cerr << "Failed to get vstream infos, status = " << vstream_infos.status() << std::endl;
256             return vstream_infos.status();
257         }
258         sync_objects.emplace_back((std::make_shared<SyncObject>(vstream_infos->size())));
259         network_group_threads.emplace_back(std::make_unique<std::thread>(network_group_thread_main,
260             network_group, sync_objects.back(), should_threads_run, std::ref(threads_results.back())));
261     }
262
263     for (size_t i = 0; i < RUN_COUNT; i++) {
264         for (size_t network_group_idx = 0; network_group_idx < configured_network_groups.size(); network_group_idx++) {
265             auto activated_network_group_exp = configured_network_groups[network_group_idx]->activate();
266             if (!activated_network_group_exp) {
267                 std::cerr << "Failed to activate network group, status = "  << activated_network_group_exp.status() << std::endl;
268                 return activated_network_group_exp.status();
269             }
270             sync_objects[network_group_idx]->wait_all(activated_network_group_exp.release());
271         }
272     }
273
274     *should_threads_run = false;
275     for (auto &sync_object : sync_objects) {
276         sync_object->terminate();
277     }
278
279     for (auto &th : network_group_threads) {
280         if (th->joinable()) {
281             th->join();
282         }
283     }
284
285     for (auto &thread_status : threads_results) {
286         if (HAILO_SUCCESS != thread_status) {
287             std::cerr << "Inference failed, status = "  << thread_status << std::endl;
288             return thread_status;
289         }
290     }
291
292     std::cout << "Inference finished successfully" << std::endl;
293     return HAILO_SUCCESS;
294 }