1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include <inference_engine.hpp>
9 #include <cpp_interfaces/impl/mock_infer_request_internal.hpp>
10 #include <cpp_interfaces/impl/mock_async_infer_request_default.hpp>
11 #include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp>
12 #include <cpp_interfaces/mock_task_synchronizer.hpp>
13 #include <cpp_interfaces/mock_task_executor.hpp>
14 #include <cpp_interfaces/base/ie_infer_async_request_base.hpp>
16 using namespace ::testing;
18 using namespace InferenceEngine;
19 using namespace InferenceEngine::details;
21 class TestAsyncInferRequestThreadSafeDefault : public AsyncInferRequestThreadSafeDefault {
23 TestAsyncInferRequestThreadSafeDefault(const InferRequestInternal::Ptr &request,
24 const ITaskExecutor::Ptr &taskExecutor,
25 const TaskSynchronizer::Ptr &taskSynchronizer,
26 const ITaskExecutor::Ptr &callbackExecutor)
27 : AsyncInferRequestThreadSafeDefault(request, taskExecutor, taskSynchronizer, callbackExecutor) {}
29 void setRequestBusy() {
30 AsyncInferRequestThreadSafeDefault::setIsRequestBusy(true);
34 class InferRequestThreadSafeDefaultTests : public ::testing::Test {
36 shared_ptr<TestAsyncInferRequestThreadSafeDefault> testRequest;
39 shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
40 MockTaskSynchronizer::Ptr mockTaskSync;
41 MockTaskExecutor::Ptr mockTaskExecutor;
44 virtual void TearDown() {
47 virtual void SetUp() {
48 InputsDataMap inputsInfo;
49 OutputsDataMap outputsInfo;
50 mockTaskSync = make_shared<MockTaskSynchronizer>();
51 mockTaskExecutor = make_shared<MockTaskExecutor>();
52 mockInferRequestInternal = make_shared<MockInferRequestInternal>(inputsInfo, outputsInfo);
53 testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, mockTaskExecutor,
54 mockTaskSync, mockTaskExecutor);
57 bool _doesThrowExceptionWithMessage(std::function<void()> func, string refError) {
58 std::string whatMessage;
61 } catch (const InferenceEngineException &iee) {
62 whatMessage = iee.what();
64 return whatMessage.find(refError) != std::string::npos;
69 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnStartAsync) {
70 testRequest->setRequestBusy();
71 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->StartAsync(); }, REQUEST_BUSY_str));
74 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfStartAsyncTaskFails) {
75 MockAsyncInferRequestDefault mockAsync(mockInferRequestInternal, mockTaskExecutor, mockTaskSync, mockTaskExecutor);
76 EXPECT_CALL(mockAsync, initNextAsyncTask()).Times(2).WillRepeatedly(Return());
77 EXPECT_CALL(mockAsync, startAsyncTask()).Times(2)
78 .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
81 ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { mockAsync.StartAsync(); }, "compare"));
82 ASSERT_NO_THROW(mockAsync.StartAsync());
85 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfInitNextTaskFails) {
86 MockAsyncInferRequestDefault mockAsync(mockInferRequestInternal, mockTaskExecutor, mockTaskSync, mockTaskExecutor);
87 EXPECT_CALL(mockAsync, startAsyncTask()).Times(1).WillOnce(Return());
88 EXPECT_CALL(mockAsync, initNextAsyncTask()).Times(2)
89 .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
92 ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { mockAsync.StartAsync(); }, "compare"));
93 ASSERT_NO_THROW(mockAsync.StartAsync());
97 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetUserData) {
98 testRequest->setRequestBusy();
99 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->GetUserData(nullptr); }, REQUEST_BUSY_str));
103 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetUserData) {
104 testRequest->setRequestBusy();
105 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetUserData(nullptr); }, REQUEST_BUSY_str));
109 TEST_F(InferRequestThreadSafeDefaultTests, returnInferNotStartedOnWait) {
110 testRequest->setRequestBusy();
112 StatusCode actual = testRequest->Wait(ms);
113 ASSERT_EQ(INFER_NOT_STARTED, actual);
117 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnInfer) {
118 testRequest->setRequestBusy();
119 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->Infer(); }, REQUEST_BUSY_str));
122 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfInferFails) {
123 EXPECT_CALL(*mockTaskSync.get(), lock()).Times(2);
124 EXPECT_CALL(*mockTaskSync.get(), unlock()).Times(2);
125 EXPECT_CALL(*mockInferRequestInternal, InferImpl()).Times(2)
126 .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
128 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->Infer(); }, "compare"));
129 ASSERT_NO_THROW(testRequest->Infer());
132 // GetPerformanceCounts
133 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetPerformanceCounts) {
134 testRequest->setRequestBusy();
135 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
136 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
137 testRequest->GetPerformanceCounts(info);
138 }, REQUEST_BUSY_str));
142 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetBlob) {
143 testRequest->setRequestBusy();
144 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
146 testRequest->GetBlob(nullptr, data);
147 }, REQUEST_BUSY_str));
151 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetBlob) {
152 testRequest->setRequestBusy();
153 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetBlob(nullptr, nullptr); }, REQUEST_BUSY_str));
156 // SetCompletionCallback
157 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetCompletionCallback) {
158 testRequest->setRequestBusy();
159 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetCompletionCallback(nullptr); },
163 TEST_F(InferRequestThreadSafeDefaultTests, callbackTakesOKIfAsyncRequestWasOK) {
164 auto taskExecutor = std::make_shared<TaskExecutor>();
165 testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
166 mockTaskSync, taskExecutor);
168 IInferRequest::Ptr asyncRequest;
169 asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
170 testRequest), [](IInferRequest *p) { p->Release(); });
171 testRequest->SetPointerToPublicInterface(asyncRequest);
173 testRequest->SetCompletionCallback({[](InferenceEngine::IInferRequest::Ptr request, StatusCode status) {
174 ASSERT_EQ((int) StatusCode::OK, status);
176 EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).Times(1);
178 testRequest->StartAsync();
179 testRequest->Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
182 TEST_F(InferRequestThreadSafeDefaultTests, callbackIsCalledIfAsyncRequestFailed) {
183 auto taskExecutor = std::make_shared<TaskExecutor>();
184 testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
185 mockTaskSync, taskExecutor);
186 IInferRequest::Ptr asyncRequest;
187 asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
188 testRequest), [](IInferRequest *p) { p->Release(); });
189 testRequest->SetPointerToPublicInterface(asyncRequest);
191 bool wasCalled = false;
192 InferRequest cppRequest(asyncRequest);
193 std::function<void(InferRequest, StatusCode)> callback =
194 [&](InferRequest request, StatusCode status) {
196 ASSERT_EQ(StatusCode::GENERAL_ERROR, status);
198 cppRequest.SetCompletionCallback(callback);
199 EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
201 testRequest->StartAsync();
202 EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
203 ASSERT_TRUE(wasCalled);
206 TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailedAndNoCallback) {
207 auto taskExecutor = std::make_shared<TaskExecutor>();
208 testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
209 mockTaskSync, taskExecutor);
210 IInferRequest::Ptr asyncRequest;
211 asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
212 testRequest), [](IInferRequest *p) { p->Release(); });
213 testRequest->SetPointerToPublicInterface(asyncRequest);
215 EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
217 testRequest->StartAsync();
218 EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);