cbc99fadad176bf05462c632c74bd36f1212b13c
[platform/upstream/hailort.git] /
1 /**
2  * Copyright (c) 2020-2022 Hailo Technologies Ltd. All rights reserved.
3  * Distributed under the MIT license (https://opensource.org/licenses/MIT)
4  **/
5 /**
6  * @file raw_async_streams_multi_thread_example
7  * This example demonstrates using low level async streams over c++
8  **/
9
10 #include "hailo/hailort.hpp"
11
12 #include <thread>
13 #include <iostream>
14
15 #if defined(__unix__)
16 #include <sys/mman.h>
17 #endif
18
19 constexpr auto TIMEOUT = std::chrono::milliseconds(1000);
20
21 using namespace hailort;
22
23 using AlignedBuffer = std::shared_ptr<uint8_t>;
24 static AlignedBuffer page_aligned_alloc(size_t size)
25 {
26 #if defined(__unix__)
27     auto addr = mmap(NULL, size, PROT_WRITE | PROT_READ, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
28     if (MAP_FAILED == addr) throw std::bad_alloc();
29     return AlignedBuffer(reinterpret_cast<uint8_t*>(addr), [size](void *addr) { munmap(addr, size); });
30 #elif defined(_MSC_VER)
31     auto addr = VirtualAlloc(NULL, size, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE);
32     if (!addr) throw std::bad_alloc();
33     return AlignedBuffer(reinterpret_cast<uint8_t*>(addr), [](void *addr){ VirtualFree(addr, 0, MEM_RELEASE); });
34 #else
35 #pragma error("Aligned alloc not supported")
36 #endif
37 }
38
39 Expected<std::shared_ptr<ConfiguredNetworkGroup>> configure_network_group(Device &device, const std::string &hef_path)
40 {
41     auto hef = Hef::create(hef_path);
42     if (!hef) {
43         return make_unexpected(hef.status());
44     }
45
46     auto configure_params = device.create_configure_params(hef.value());
47     if (!configure_params) {
48         return make_unexpected(configure_params.status());
49     }
50
51     // change stream_params here
52     for (auto &ng_name_params_pair : *configure_params) {
53         for (auto &stream_params_name_pair : ng_name_params_pair.second.stream_params_by_name) {
54             stream_params_name_pair.second.flags = HAILO_STREAM_FLAGS_ASYNC;
55         }
56     }
57
58     auto network_groups = device.configure(hef.value(), configure_params.value());
59     if (!network_groups) {
60         return make_unexpected(network_groups.status());
61     }
62
63     if (1 != network_groups->size()) {
64         std::cerr << "Invalid amount of network groups" << std::endl;
65         return make_unexpected(HAILO_INTERNAL_FAILURE);
66     }
67
68     return std::move(network_groups->at(0));
69 }
70
71 static void output_async_callback(const OutputStream::CompletionInfo &completion_info)
72 {
73     // Real applications can free the buffer or forward it to post-process/display.
74     if ((HAILO_SUCCESS != completion_info.status) && (HAILO_STREAM_ABORTED_BY_USER != completion_info.status)) {
75         // We will get HAILO_STREAM_ABORTED_BY_USER when activated_network_group is destructed.
76         std::cerr << "Got an unexpected status on callback. status=" << completion_info.status << std::endl;
77     }
78 }
79
80 static void input_async_callback(const InputStream::CompletionInfo &completion_info)
81 {
82     // Real applications can free the buffer or reuse it for next transfer.
83     if ((HAILO_SUCCESS != completion_info.status) && (HAILO_STREAM_ABORTED_BY_USER  != completion_info.status)) {
84         // We will get HAILO_STREAM_ABORTED_BY_USER  when activated_network_group is destructed.
85         std::cerr << "Got an unexpected status on callback. status=" << completion_info.status << std::endl;
86     }
87 }
88
89 static hailo_status infer(ConfiguredNetworkGroup &network_group)
90 {
91     // Assume one input and output
92     auto &output = network_group.get_output_streams()[0].get();
93     auto &input = network_group.get_input_streams()[0].get();
94
95     // Allocate buffers. The buffers sent to the async API must be page aligned.
96     // For simplicity, in this example, we pass one buffer for each stream (It may be problematic in output since the
97     // buffer will be overridden on each read).
98     // Note - the buffers can be freed only after all callbacks are called. The user can either wait for all
99     // callbacks, or as done in this example, call ConfiguredNetworkGroup::shutdown that will make sure all callbacks
100     // are called.
101     auto output_buffer = page_aligned_alloc(output.get_frame_size());
102     auto input_buffer = page_aligned_alloc(input.get_frame_size());
103
104     std::atomic<hailo_status> output_status(HAILO_UNINITIALIZED);
105     std::thread output_thread([&]() {
106         while (true) {
107             output_status = output.wait_for_async_ready(output.get_frame_size(), TIMEOUT);
108             if (HAILO_SUCCESS != output_status) { return; }
109
110             output_status = output.read_async(output_buffer.get(), output.get_frame_size(), output_async_callback);
111             if (HAILO_SUCCESS != output_status) { return; }
112         }
113     });
114
115     std::atomic<hailo_status> input_status(HAILO_UNINITIALIZED);
116     std::thread input_thread([&]() {
117         while (true) {
118             input_status = input.wait_for_async_ready(input.get_frame_size(), TIMEOUT);
119             if (HAILO_SUCCESS != input_status) { return; }
120
121             input_status = input.write_async(input_buffer.get(), input.get_frame_size(), input_async_callback);
122             if (HAILO_SUCCESS != input_status) { return; }
123         }
124     });
125
126     // After all async operations are launched, the inference is running.
127     std::this_thread::sleep_for(std::chrono::seconds(5));
128
129     // Calling shutdown on a network group will ensure that all async operations are done. All pending
130     // operations will be canceled and their callbacks will be called with status=HAILO_STREAM_ABORTED_BY_USER.
131     // Only after the shutdown is called, we can safely free the buffers and any variable captured inside the async
132     // callback lambda.
133     network_group.shutdown();
134
135     // Thread should be stopped with HAILO_STREAM_ABORTED_BY_USER status.
136     output_thread.join();
137     input_thread.join();
138
139     if ((HAILO_STREAM_ABORTED_BY_USER != output_status) || (HAILO_STREAM_ABORTED_BY_USER != input_status)) {
140         std::cerr << "Got unexpected statues from thread: " << output_status << ", " << input_status << std::endl;
141         return HAILO_INTERNAL_FAILURE;
142     }
143
144     return HAILO_SUCCESS;
145 }
146
147 int main()
148 {
149     auto device = Device::create();
150     if (!device) {
151         std::cerr << "Failed create device " << device.status() << std::endl;
152         return EXIT_FAILURE;
153     }
154
155     static const auto HEF_FILE = "hefs/shortcut_net.hef";
156     auto network_group = configure_network_group(*device.value(), HEF_FILE);
157     if (!network_group) {
158         std::cerr << "Failed to configure network group " << HEF_FILE << std::endl;
159         return EXIT_FAILURE;
160     }
161
162     auto activated_network_group = network_group.value()->activate();
163     if (!activated_network_group) {
164         std::cerr << "Failed to activate network group "  << activated_network_group.status() << std::endl;
165         return EXIT_FAILURE;
166     }
167
168     auto status = infer(*network_group.value());
169     if (HAILO_SUCCESS != status) {
170         return EXIT_FAILURE;
171     }
172
173     std::cout << "Inference finished successfully" << std::endl;
174     return EXIT_SUCCESS;
175 }