Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / cpp_interfaces / async_infer_request_thread_safe_default_tests.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <inference_engine.hpp>
8 #include <cpp_interfaces/impl/mock_infer_request_internal.hpp>
9 #include <cpp_interfaces/impl/mock_async_infer_request_default.hpp>
10 #include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp>
11 #include <cpp_interfaces/mock_task_synchronizer.hpp>
12 #include <cpp_interfaces/mock_task_executor.hpp>
13 #include <cpp_interfaces/base/ie_infer_async_request_base.hpp>
14
15 using namespace ::testing;
16 using namespace std;
17 using namespace InferenceEngine;
18 using namespace InferenceEngine::details;
19
20 class TestAsyncInferRequestThreadSafeDefault : public AsyncInferRequestThreadSafeDefault {
21 public:
22     TestAsyncInferRequestThreadSafeDefault(const InferRequestInternal::Ptr &request,
23                                            const ITaskExecutor::Ptr &taskExecutor,
24                                            const TaskSynchronizer::Ptr &taskSynchronizer,
25                                            const ITaskExecutor::Ptr &callbackExecutor)
26             : AsyncInferRequestThreadSafeDefault(request, taskExecutor, taskSynchronizer, callbackExecutor) {}
27
28     void setRequestBusy() {
29         AsyncInferRequestThreadSafeDefault::setIsRequestBusy(true);
30     }
31 };
32
33 class InferRequestThreadSafeDefaultTests : public ::testing::Test {
34 protected:
35     shared_ptr<TestAsyncInferRequestThreadSafeDefault> testRequest;
36     ResponseDesc dsc;
37
38     shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
39     MockTaskSynchronizer::Ptr mockTaskSync;
40     MockTaskExecutor::Ptr mockTaskExecutor;
41
42
43     virtual void TearDown() {
44     }
45
46     virtual void SetUp() {
47         InputsDataMap inputsInfo;
48         OutputsDataMap outputsInfo;
49         mockTaskSync = make_shared<MockTaskSynchronizer>();
50         mockTaskExecutor = make_shared<MockTaskExecutor>();
51         mockInferRequestInternal = make_shared<MockInferRequestInternal>(inputsInfo, outputsInfo);
52         testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, mockTaskExecutor,
53                                                                           mockTaskSync, mockTaskExecutor);
54     }
55
56     bool _doesThrowExceptionWithMessage(std::function<void()> func, string refError) {
57         std::string whatMessage;
58         try {
59             func();
60         } catch (const InferenceEngineException &iee) {
61             whatMessage = iee.what();
62         }
63         return whatMessage.find(refError) != std::string::npos;
64     }
65 };
66
67 // StartAsync
68 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnStartAsync) {
69     testRequest->setRequestBusy();
70     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->StartAsync(); }, REQUEST_BUSY_str));
71 }
72
73 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfStartAsyncTaskFails) {
74     MockAsyncInferRequestDefault mockAsync(mockInferRequestInternal, mockTaskExecutor, mockTaskSync, mockTaskExecutor);
75     EXPECT_CALL(mockAsync, initNextAsyncTask()).Times(2).WillRepeatedly(Return());
76     EXPECT_CALL(mockAsync, startAsyncTask()).Times(2)
77             .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
78             .WillOnce(Return());
79
80     ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { mockAsync.StartAsync(); }, "compare"));
81     ASSERT_NO_THROW(mockAsync.StartAsync());
82 }
83
84 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfInitNextTaskFails) {
85     MockAsyncInferRequestDefault mockAsync(mockInferRequestInternal, mockTaskExecutor, mockTaskSync, mockTaskExecutor);
86     EXPECT_CALL(mockAsync, startAsyncTask()).Times(1).WillOnce(Return());
87     EXPECT_CALL(mockAsync, initNextAsyncTask()).Times(2)
88             .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
89             .WillOnce(Return());
90
91     ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { mockAsync.StartAsync(); }, "compare"));
92     ASSERT_NO_THROW(mockAsync.StartAsync());
93 }
94
95 // GetUserData
96 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetUserData) {
97     testRequest->setRequestBusy();
98     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->GetUserData(nullptr); }, REQUEST_BUSY_str));
99 }
100
101 // SetUserData
102 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetUserData) {
103     testRequest->setRequestBusy();
104     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetUserData(nullptr); }, REQUEST_BUSY_str));
105 }
106
107 // Wait
108 TEST_F(InferRequestThreadSafeDefaultTests, returnInferNotStartedOnWait) {
109     testRequest->setRequestBusy();
110     int64_t ms = 0;
111     StatusCode actual = testRequest->Wait(ms);
112     ASSERT_EQ(INFER_NOT_STARTED, actual);
113 }
114
115 // Infer
116 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnInfer) {
117     testRequest->setRequestBusy();
118     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->Infer(); }, REQUEST_BUSY_str));
119 }
120
121 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfInferFails) {
122     EXPECT_CALL(*mockTaskSync.get(), lock()).Times(2);
123     EXPECT_CALL(*mockTaskSync.get(), unlock()).Times(2);
124     EXPECT_CALL(*mockInferRequestInternal, InferImpl()).Times(2)
125             .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
126             .WillOnce(Return());
127     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->Infer(); }, "compare"));
128     ASSERT_NO_THROW(testRequest->Infer());
129 }
130
131 // GetPerformanceCounts
132 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetPerformanceCounts) {
133     testRequest->setRequestBusy();
134     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
135         std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
136         testRequest->GetPerformanceCounts(info);
137     }, REQUEST_BUSY_str));
138 }
139
140 // GetBlob
141 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetBlob) {
142     testRequest->setRequestBusy();
143     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
144         Blob::Ptr data;
145         testRequest->GetBlob(nullptr, data);
146     }, REQUEST_BUSY_str));
147 }
148
149 // SetBlob
150 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetBlob) {
151     testRequest->setRequestBusy();
152     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetBlob(nullptr, nullptr); }, REQUEST_BUSY_str));
153 }
154
155 // SetCompletionCallback
156 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetCompletionCallback) {
157     testRequest->setRequestBusy();
158     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetCompletionCallback(nullptr); },
159                                                REQUEST_BUSY_str));
160 }
161
162 TEST_F(InferRequestThreadSafeDefaultTests, callbackTakesOKIfAsyncRequestWasOK) {
163     auto taskExecutor = std::make_shared<TaskExecutor>();
164     testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
165                                                                       mockTaskSync, taskExecutor);
166
167     IInferRequest::Ptr asyncRequest;
168     asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
169             testRequest), [](IInferRequest *p) { p->Release(); });
170     testRequest->SetPointerToPublicInterface(asyncRequest);
171
172     testRequest->SetCompletionCallback([](InferenceEngine::IInferRequest::Ptr request, StatusCode status) {
173         ASSERT_EQ((int) StatusCode::OK, status);
174     });
175     EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).Times(1);
176
177     testRequest->StartAsync();
178     testRequest->Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
179 }
180
181 TEST_F(InferRequestThreadSafeDefaultTests, callbackIsCalledIfAsyncRequestFailed) {
182     auto taskExecutor = std::make_shared<TaskExecutor>();
183     testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
184                                                                       mockTaskSync, taskExecutor);
185     IInferRequest::Ptr asyncRequest;
186     asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
187             testRequest), [](IInferRequest *p) { p->Release(); });
188     testRequest->SetPointerToPublicInterface(asyncRequest);
189
190     bool wasCalled = false;
191     InferRequest cppRequest(asyncRequest);
192     std::function<void(InferRequest, StatusCode)> callback =
193             [&](InferRequest request, StatusCode status) {
194                 wasCalled = true;
195                 ASSERT_EQ(StatusCode::GENERAL_ERROR, status);
196             };
197     cppRequest.SetCompletionCallback(callback);
198     EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
199
200     testRequest->StartAsync();
201     EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
202     ASSERT_TRUE(wasCalled);
203 }
204
205 TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailedAndNoCallback) {
206     auto taskExecutor = std::make_shared<TaskExecutor>();
207     testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
208                                                                       mockTaskSync, taskExecutor);
209     IInferRequest::Ptr asyncRequest;
210     asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
211             testRequest), [](IInferRequest *p) { p->Release(); });
212     testRequest->SetPointerToPublicInterface(asyncRequest);
213
214     EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
215
216     testRequest->StartAsync();
217     EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
218 }