#include <vector>
#include <string>
#include <memory>
+#include <future>
#include "ie_extension.h"
#include <condition_variable>
#include "functional_test_utils/layer_test_utils.hpp"
const int NUM_ITER = 10;
struct TestUserData {
- int numIter = NUM_ITER;
- bool startAsyncOK = true;
- std::atomic<int> numIsCalled{0};
- std::mutex mutex_block_emulation;
- std::condition_variable cv_block_emulation;
- bool isBlocked = true;
+ std::atomic<int> numIter = {0};
+ std::promise<InferenceEngine::StatusCode> promise;
};
TestUserData data;
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
req.SetCompletionCallback<std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>>(
[&](InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode status) {
- // HSD_1805940120: Wait on starting callback return HDDL_ERROR_INVAL_TASK_HANDLE
- if (targetDevice != CommonTestUtils::DEVICE_HDDL) {
- ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), status);
- }
- if (--data.numIter) {
- InferenceEngine::StatusCode sts = request->StartAsync(nullptr);
- if (sts != InferenceEngine::StatusCode::OK) {
- data.startAsyncOK = false;
+ if (status != InferenceEngine::StatusCode::OK) {
+ data.promise.set_value(status);
+ } else {
+ if (data.numIter.fetch_add(1) != NUM_ITER) {
+ InferenceEngine::StatusCode sts = request->StartAsync(nullptr);
+ if (sts != InferenceEngine::StatusCode::OK) {
+ data.promise.set_value(sts);
+ }
+ } else {
+ data.promise.set_value(InferenceEngine::StatusCode::OK);
}
}
- data.numIsCalled++;
- if (!data.numIter) {
- data.isBlocked = false;
- data.cv_block_emulation.notify_all();
- }
});
-
+ auto future = data.promise.get_future();
req.StartAsync();
- InferenceEngine::ResponseDesc responseWait;
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
- // intentionally block until notification from callback
- std::unique_lock<std::mutex> lock(data.mutex_block_emulation);
- data.cv_block_emulation.wait(lock, [&]() { return !data.isBlocked; });
-
- ASSERT_EQ((int) InferenceEngine::StatusCode::OK, waitStatus) << responseWait.msg;
- ASSERT_EQ(NUM_ITER, data.numIsCalled);
- ASSERT_TRUE(data.startAsyncOK);
+ ASSERT_EQ((int) InferenceEngine::StatusCode::OK, waitStatus);
+ future.wait();
+ auto callbackStatus = future.get();
+ ASSERT_EQ((int) InferenceEngine::StatusCode::OK, callbackStatus);
+ auto dataNumIter = data.numIter - 1;
+ ASSERT_EQ(NUM_ITER, dataNumIter);
}
TEST_P(CallbackTests, inferDoesNotCallCompletionCallback) {