e402a8dd9fa29eba6046291e6bbff3d6e117bd1c
[platform/upstream/hailort.git] /
1 /**
2  * Copyright (c) 2020-2022 Hailo Technologies Ltd. All rights reserved.
3  * Distributed under the MIT license (https://opensource.org/licenses/MIT)
4  **/
5 /**
6  * @file raw_async_streams_single_thread_example
7  * This example demonstrates using low level async streams using single thread over c++.
8  **/
9
10 #include "hailo/hailort.hpp"
11
12 #include <thread>
13 #include <iostream>
14 #include <queue>
15 #include <condition_variable>
16
17 #if defined(__unix__)
18 #include <sys/mman.h>
19 #endif
20
21 using namespace hailort;
22
23 using AlignedBuffer = std::shared_ptr<uint8_t>;
24 static AlignedBuffer page_aligned_alloc(size_t size)
25 {
26 #if defined(__unix__)
27     auto addr = mmap(NULL, size, PROT_WRITE | PROT_READ, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
28     if (MAP_FAILED == addr) throw std::bad_alloc();
29     return AlignedBuffer(reinterpret_cast<uint8_t*>(addr), [size](void *addr) { munmap(addr, size); });
30 #elif defined(_MSC_VER)
31     auto addr = VirtualAlloc(NULL, size, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE);
32     if (!addr) throw std::bad_alloc();
33     return AlignedBuffer(reinterpret_cast<uint8_t*>(addr), [](void *addr){ VirtualFree(addr, 0, MEM_RELEASE); });
34 #else
35 #pragma error("Aligned alloc not supported")
36 #endif
37 }
38
39 static hailo_status infer(ConfiguredNetworkGroup &network_group)
40 {
41     // Assume one input and output
42     auto &output = network_group.get_output_streams()[0].get();
43     auto &input = network_group.get_input_streams()[0].get();
44
45     auto input_queue_size = input.get_async_max_queue_size();
46     auto output_queue_size = output.get_async_max_queue_size();
47     if (!input_queue_size || !output_queue_size) {
48         std::cerr << "Failed getting async queue size" << std::endl;
49         return HAILO_INTERNAL_FAILURE;
50     }
51
52     // Allocate buffers. The buffers sent to the async API must be page aligned.
53     // Note - the buffers can be freed only after all callbacks are called. The user can either wait for all
54     // callbacks, or as done in this example, call ConfiguredNetworkGroup::shutdown that will make sure all callbacks
55     // are called.
56     std::vector<AlignedBuffer> buffer_guards;
57
58     OutputStream::TransferDoneCallback read_done = [&output, &read_done](const OutputStream::CompletionInfo &completion_info) {
59         hailo_status status = HAILO_UNINITIALIZED;
60         switch (completion_info.status) {
61         case HAILO_SUCCESS:
62             // Real applications can forward the buffer to post-process/display. Here we just re-launch new async read.
63             status = output.read_async(completion_info.buffer_addr, completion_info.buffer_size, read_done);
64             if ((HAILO_SUCCESS != status) && (HAILO_STREAM_ABORTED_BY_USER != status)) {
65                 std::cerr << "Failed read async with status=" << status << std::endl;
66             }
67             break;
68         case HAILO_STREAM_ABORTED_BY_USER:
69             // Transfer was canceled, finish gracefully.
70             break;
71         default:
72             std::cerr << "Got an unexpected status on callback. status=" << completion_info.status << std::endl;
73         }
74     };
75
76     InputStream::TransferDoneCallback write_done = [&input, &write_done](const InputStream::CompletionInfo &completion_info) {
77         hailo_status status = HAILO_UNINITIALIZED;
78         switch (completion_info.status) {
79         case HAILO_SUCCESS:
80             // Real applications may free the buffer and replace it with new buffer ready to be sent. Here we just
81             // re-launch new async write.
82             status = input.write_async(completion_info.buffer_addr, completion_info.buffer_size, write_done);
83             if ((HAILO_SUCCESS != status) && (HAILO_STREAM_ABORTED_BY_USER != status)) {
84                 std::cerr << "Failed read async with status=" << status << std::endl;
85             }
86             break;
87         case HAILO_STREAM_ABORTED_BY_USER:
88             // Transfer was canceled, finish gracefully.
89             break;
90         default:
91             std::cerr << "Got an unexpected status on callback. status=" << completion_info.status << std::endl;
92         }
93     };
94
95     // We launch "*output_queue_size" async read operation. On each async callback, we launch a new async read operation.
96     for (size_t i = 0; i < *output_queue_size; i++) {
97         // Buffers read from async operation must be page aligned.
98         auto buffer = page_aligned_alloc(output.get_frame_size());
99         auto status = output.read_async(buffer.get(), output.get_frame_size(), read_done);
100         if (HAILO_SUCCESS != status) {
101             std::cerr << "read_async failed with status=" << status << std::endl;
102             return status;
103         }
104
105         buffer_guards.emplace_back(buffer);
106     }
107
108     // We launch "*input_queue_size" async write operation. On each async callback, we launch a new async write operation.
109     for (size_t i = 0; i < *input_queue_size; i++) {
110         // Buffers written to async operation must be page aligned.
111         auto buffer = page_aligned_alloc(input.get_frame_size());
112         auto status = input.write_async(buffer.get(), input.get_frame_size(), write_done);
113         if (HAILO_SUCCESS != status) {
114             std::cerr << "write_async failed with status=" << status << std::endl;
115             return status;
116         }
117
118         buffer_guards.emplace_back(buffer);
119     }
120
121     std::this_thread::sleep_for(std::chrono::seconds(5));
122
123     // Calling shutdown on a network group will ensure that all async operations are done. All pending
124     // operations will be canceled and their callbacks will be called with status=HAILO_STREAM_ABORTED_BY_USER.
125     // Only after the shutdown is called, we can safely free the buffers and any variable captured inside the async
126     // callback lambda.
127     network_group.shutdown();
128
129     return HAILO_SUCCESS;
130 }
131
132
133 static Expected<std::shared_ptr<ConfiguredNetworkGroup>> configure_network_group(Device &device, const std::string &hef_path)
134 {
135     auto hef = Hef::create(hef_path);
136     if (!hef) {
137         return make_unexpected(hef.status());
138     }
139
140     auto configure_params = device.create_configure_params(hef.value());
141     if (!configure_params) {
142         return make_unexpected(configure_params.status());
143     }
144
145     // change stream_params to operate in async mode
146     for (auto &ng_name_params_pair : *configure_params) {
147         for (auto &stream_params_name_pair : ng_name_params_pair.second.stream_params_by_name) {
148             stream_params_name_pair.second.flags = HAILO_STREAM_FLAGS_ASYNC;
149         }
150     }
151
152     auto network_groups = device.configure(hef.value(), configure_params.value());
153     if (!network_groups) {
154         return make_unexpected(network_groups.status());
155     }
156
157     if (1 != network_groups->size()) {
158         std::cerr << "Invalid amount of network groups" << std::endl;
159         return make_unexpected(HAILO_INTERNAL_FAILURE);
160     }
161
162     return std::move(network_groups->at(0));
163 }
164
165 int main()
166 {
167     auto device = Device::create();
168     if (!device) {
169         std::cerr << "Failed to create device " << device.status() << std::endl;
170         return EXIT_FAILURE;
171     }
172
173     static const auto HEF_FILE = "hefs/shortcut_net.hef";
174     auto network_group = configure_network_group(*device.value(), HEF_FILE);
175     if (!network_group) {
176         std::cerr << "Failed to configure network group" << HEF_FILE << std::endl;
177         return EXIT_FAILURE;
178     }
179
180     auto activated_network_group = network_group.value()->activate();
181     if (!activated_network_group) {
182         std::cerr << "Failed to activate network group "  << activated_network_group.status() << std::endl;
183         return EXIT_FAILURE;
184     }
185
186     // Now start the inference
187     auto status = infer(*network_group.value());
188     if (HAILO_SUCCESS != status) {
189         std::cerr << "Inference failed with " << status << std::endl;
190         return EXIT_FAILURE;
191     }
192
193     std::cout << "Inference finished successfully" << std::endl;
194     return EXIT_SUCCESS;
195 }