1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <inference_engine.hpp>
8 #include <cpp_interfaces/impl/mock_infer_request_internal.hpp>
9 #include <cpp_interfaces/impl/mock_async_infer_request_default.hpp>
10 #include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp>
11 #include <cpp_interfaces/mock_task_synchronizer.hpp>
12 #include <cpp_interfaces/mock_task_executor.hpp>
13 #include <cpp_interfaces/base/ie_infer_async_request_base.hpp>
15 using namespace ::testing;
17 using namespace InferenceEngine;
18 using namespace InferenceEngine::details;
20 class TestAsyncInferRequestThreadSafeDefault : public AsyncInferRequestThreadSafeDefault {
22 TestAsyncInferRequestThreadSafeDefault(const InferRequestInternal::Ptr &request,
23 const ITaskExecutor::Ptr &taskExecutor,
24 const TaskSynchronizer::Ptr &taskSynchronizer,
25 const ITaskExecutor::Ptr &callbackExecutor)
26 : AsyncInferRequestThreadSafeDefault(request, taskExecutor, taskSynchronizer, callbackExecutor) {}
28 void setRequestBusy() {
29 AsyncInferRequestThreadSafeDefault::setIsRequestBusy(true);
33 class InferRequestThreadSafeDefaultTests : public ::testing::Test {
35 shared_ptr<TestAsyncInferRequestThreadSafeDefault> testRequest;
38 shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
39 MockTaskSynchronizer::Ptr mockTaskSync;
40 MockTaskExecutor::Ptr mockTaskExecutor;
43 virtual void TearDown() {
46 virtual void SetUp() {
47 InputsDataMap inputsInfo;
48 OutputsDataMap outputsInfo;
49 mockTaskSync = make_shared<MockTaskSynchronizer>();
50 mockTaskExecutor = make_shared<MockTaskExecutor>();
51 mockInferRequestInternal = make_shared<MockInferRequestInternal>(inputsInfo, outputsInfo);
52 testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, mockTaskExecutor,
53 mockTaskSync, mockTaskExecutor);
56 bool _doesThrowExceptionWithMessage(std::function<void()> func, string refError) {
57 std::string whatMessage;
60 } catch (const InferenceEngineException &iee) {
61 whatMessage = iee.what();
63 return whatMessage.find(refError) != std::string::npos;
68 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnStartAsync) {
69 testRequest->setRequestBusy();
70 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->StartAsync(); }, REQUEST_BUSY_str));
73 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfStartAsyncTaskFails) {
74 MockAsyncInferRequestDefault mockAsync(mockInferRequestInternal, mockTaskExecutor, mockTaskSync, mockTaskExecutor);
75 EXPECT_CALL(mockAsync, initNextAsyncTask()).Times(2).WillRepeatedly(Return());
76 EXPECT_CALL(mockAsync, startAsyncTask()).Times(2)
77 .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
80 ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { mockAsync.StartAsync(); }, "compare"));
81 ASSERT_NO_THROW(mockAsync.StartAsync());
84 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfInitNextTaskFails) {
85 MockAsyncInferRequestDefault mockAsync(mockInferRequestInternal, mockTaskExecutor, mockTaskSync, mockTaskExecutor);
86 EXPECT_CALL(mockAsync, startAsyncTask()).Times(1).WillOnce(Return());
87 EXPECT_CALL(mockAsync, initNextAsyncTask()).Times(2)
88 .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
91 ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { mockAsync.StartAsync(); }, "compare"));
92 ASSERT_NO_THROW(mockAsync.StartAsync());
96 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetUserData) {
97 testRequest->setRequestBusy();
98 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->GetUserData(nullptr); }, REQUEST_BUSY_str));
102 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetUserData) {
103 testRequest->setRequestBusy();
104 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetUserData(nullptr); }, REQUEST_BUSY_str));
108 TEST_F(InferRequestThreadSafeDefaultTests, returnInferNotStartedOnWait) {
109 testRequest->setRequestBusy();
111 StatusCode actual = testRequest->Wait(ms);
112 ASSERT_EQ(INFER_NOT_STARTED, actual);
116 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnInfer) {
117 testRequest->setRequestBusy();
118 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->Infer(); }, REQUEST_BUSY_str));
121 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfInferFails) {
122 EXPECT_CALL(*mockTaskSync.get(), lock()).Times(2);
123 EXPECT_CALL(*mockTaskSync.get(), unlock()).Times(2);
124 EXPECT_CALL(*mockInferRequestInternal, InferImpl()).Times(2)
125 .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
127 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->Infer(); }, "compare"));
128 ASSERT_NO_THROW(testRequest->Infer());
131 // GetPerformanceCounts
132 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetPerformanceCounts) {
133 testRequest->setRequestBusy();
134 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
135 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
136 testRequest->GetPerformanceCounts(info);
137 }, REQUEST_BUSY_str));
141 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetBlob) {
142 testRequest->setRequestBusy();
143 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
145 testRequest->GetBlob(nullptr, data);
146 }, REQUEST_BUSY_str));
150 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetBlob) {
151 testRequest->setRequestBusy();
152 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetBlob(nullptr, nullptr); }, REQUEST_BUSY_str));
155 // SetCompletionCallback
156 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetCompletionCallback) {
157 testRequest->setRequestBusy();
158 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetCompletionCallback(nullptr); },
162 TEST_F(InferRequestThreadSafeDefaultTests, callbackTakesOKIfAsyncRequestWasOK) {
163 auto taskExecutor = std::make_shared<TaskExecutor>();
164 testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
165 mockTaskSync, taskExecutor);
167 IInferRequest::Ptr asyncRequest;
168 asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
169 testRequest), [](IInferRequest *p) { p->Release(); });
170 testRequest->SetPointerToPublicInterface(asyncRequest);
172 testRequest->SetCompletionCallback([](InferenceEngine::IInferRequest::Ptr request, StatusCode status) {
173 ASSERT_EQ((int) StatusCode::OK, status);
175 EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).Times(1);
177 testRequest->StartAsync();
178 testRequest->Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
181 TEST_F(InferRequestThreadSafeDefaultTests, callbackIsCalledIfAsyncRequestFailed) {
182 auto taskExecutor = std::make_shared<TaskExecutor>();
183 testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
184 mockTaskSync, taskExecutor);
185 IInferRequest::Ptr asyncRequest;
186 asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
187 testRequest), [](IInferRequest *p) { p->Release(); });
188 testRequest->SetPointerToPublicInterface(asyncRequest);
190 bool wasCalled = false;
191 InferRequest cppRequest(asyncRequest);
192 std::function<void(InferRequest, StatusCode)> callback =
193 [&](InferRequest request, StatusCode status) {
195 ASSERT_EQ(StatusCode::GENERAL_ERROR, status);
197 cppRequest.SetCompletionCallback(callback);
198 EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
200 testRequest->StartAsync();
201 EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
202 ASSERT_TRUE(wasCalled);
205 TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailedAndNoCallback) {
206 auto taskExecutor = std::make_shared<TaskExecutor>();
207 testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
208 mockTaskSync, taskExecutor);
209 IInferRequest::Ptr asyncRequest;
210 asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
211 testRequest), [](IInferRequest *p) { p->Release(); });
212 testRequest->SetPointerToPublicInterface(asyncRequest);
214 EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
216 testRequest->StartAsync();
217 EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);