Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / cpp_interfaces / async_infer_request_thread_safe_default_tests.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include <inference_engine.hpp>
9 #include <cpp_interfaces/impl/mock_infer_request_internal.hpp>
10 #include <cpp_interfaces/impl/mock_async_infer_request_default.hpp>
11 #include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp>
12 #include <cpp_interfaces/mock_task_synchronizer.hpp>
13 #include <cpp_interfaces/mock_task_executor.hpp>
14 #include <cpp_interfaces/base/ie_infer_async_request_base.hpp>
15
16 using namespace ::testing;
17 using namespace std;
18 using namespace InferenceEngine;
19 using namespace InferenceEngine::details;
20
21 class TestAsyncInferRequestThreadSafeDefault : public AsyncInferRequestThreadSafeDefault {
22 public:
23     TestAsyncInferRequestThreadSafeDefault(const InferRequestInternal::Ptr &request,
24                                            const ITaskExecutor::Ptr &taskExecutor,
25                                            const TaskSynchronizer::Ptr &taskSynchronizer,
26                                            const ITaskExecutor::Ptr &callbackExecutor)
27             : AsyncInferRequestThreadSafeDefault(request, taskExecutor, taskSynchronizer, callbackExecutor) {}
28
29     void setRequestBusy() {
30         AsyncInferRequestThreadSafeDefault::setIsRequestBusy(true);
31     }
32 };
33
34 class InferRequestThreadSafeDefaultTests : public ::testing::Test {
35 protected:
36     shared_ptr<TestAsyncInferRequestThreadSafeDefault> testRequest;
37     ResponseDesc dsc;
38
39     shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
40     MockTaskSynchronizer::Ptr mockTaskSync;
41     MockTaskExecutor::Ptr mockTaskExecutor;
42
43
44     virtual void TearDown() {
45     }
46
47     virtual void SetUp() {
48         InputsDataMap inputsInfo;
49         OutputsDataMap outputsInfo;
50         mockTaskSync = make_shared<MockTaskSynchronizer>();
51         mockTaskExecutor = make_shared<MockTaskExecutor>();
52         mockInferRequestInternal = make_shared<MockInferRequestInternal>(inputsInfo, outputsInfo);
53         testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, mockTaskExecutor,
54                                                                           mockTaskSync, mockTaskExecutor);
55     }
56
57     bool _doesThrowExceptionWithMessage(std::function<void()> func, string refError) {
58         std::string whatMessage;
59         try {
60             func();
61         } catch (const InferenceEngineException &iee) {
62             whatMessage = iee.what();
63         }
64         return whatMessage.find(refError) != std::string::npos;
65     }
66 };
67
68 // StartAsync
69 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnStartAsync) {
70     testRequest->setRequestBusy();
71     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->StartAsync(); }, REQUEST_BUSY_str));
72 }
73
74 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfStartAsyncTaskFails) {
75     MockAsyncInferRequestDefault mockAsync(mockInferRequestInternal, mockTaskExecutor, mockTaskSync, mockTaskExecutor);
76     EXPECT_CALL(mockAsync, initNextAsyncTask()).Times(2).WillRepeatedly(Return());
77     EXPECT_CALL(mockAsync, startAsyncTask()).Times(2)
78             .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
79             .WillOnce(Return());
80
81     ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { mockAsync.StartAsync(); }, "compare"));
82     ASSERT_NO_THROW(mockAsync.StartAsync());
83 }
84
85 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfInitNextTaskFails) {
86     MockAsyncInferRequestDefault mockAsync(mockInferRequestInternal, mockTaskExecutor, mockTaskSync, mockTaskExecutor);
87     EXPECT_CALL(mockAsync, startAsyncTask()).Times(1).WillOnce(Return());
88     EXPECT_CALL(mockAsync, initNextAsyncTask()).Times(2)
89             .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
90             .WillOnce(Return());
91
92     ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { mockAsync.StartAsync(); }, "compare"));
93     ASSERT_NO_THROW(mockAsync.StartAsync());
94 }
95
96 // GetUserData
97 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetUserData) {
98     testRequest->setRequestBusy();
99     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->GetUserData(nullptr); }, REQUEST_BUSY_str));
100 }
101
102 // SetUserData
103 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetUserData) {
104     testRequest->setRequestBusy();
105     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetUserData(nullptr); }, REQUEST_BUSY_str));
106 }
107
108 // Wait
109 TEST_F(InferRequestThreadSafeDefaultTests, returnInferNotStartedOnWait) {
110     testRequest->setRequestBusy();
111     int64_t ms = 0;
112     StatusCode actual = testRequest->Wait(ms);
113     ASSERT_EQ(INFER_NOT_STARTED, actual);
114 }
115
116 // Infer
117 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnInfer) {
118     testRequest->setRequestBusy();
119     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->Infer(); }, REQUEST_BUSY_str));
120 }
121
122 TEST_F(InferRequestThreadSafeDefaultTests, canResetBusyStatusIfInferFails) {
123     EXPECT_CALL(*mockTaskSync.get(), lock()).Times(2);
124     EXPECT_CALL(*mockTaskSync.get(), unlock()).Times(2);
125     EXPECT_CALL(*mockInferRequestInternal, InferImpl()).Times(2)
126             .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
127             .WillOnce(Return());
128     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->Infer(); }, "compare"));
129     ASSERT_NO_THROW(testRequest->Infer());
130 }
131
132 // GetPerformanceCounts
133 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetPerformanceCounts) {
134     testRequest->setRequestBusy();
135     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
136         std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
137         testRequest->GetPerformanceCounts(info);
138     }, REQUEST_BUSY_str));
139 }
140
141 // GetBlob
142 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnGetBlob) {
143     testRequest->setRequestBusy();
144     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
145         Blob::Ptr data;
146         testRequest->GetBlob(nullptr, data);
147     }, REQUEST_BUSY_str));
148 }
149
150 // SetBlob
151 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetBlob) {
152     testRequest->setRequestBusy();
153     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetBlob(nullptr, nullptr); }, REQUEST_BUSY_str));
154 }
155
156 // SetCompletionCallback
157 TEST_F(InferRequestThreadSafeDefaultTests, returnRequestBusyOnSetCompletionCallback) {
158     testRequest->setRequestBusy();
159     ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetCompletionCallback(nullptr); },
160                                                REQUEST_BUSY_str));
161 }
162
163 TEST_F(InferRequestThreadSafeDefaultTests, callbackTakesOKIfAsyncRequestWasOK) {
164     auto taskExecutor = std::make_shared<TaskExecutor>();
165     testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
166                                                                       mockTaskSync, taskExecutor);
167
168     IInferRequest::Ptr asyncRequest;
169     asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
170             testRequest), [](IInferRequest *p) { p->Release(); });
171     testRequest->SetPointerToPublicInterface(asyncRequest);
172
173     testRequest->SetCompletionCallback({[](InferenceEngine::IInferRequest::Ptr request, StatusCode status) {
174         ASSERT_EQ((int) StatusCode::OK, status);
175     }});
176     EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).Times(1);
177
178     testRequest->StartAsync();
179     testRequest->Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
180 }
181
182 TEST_F(InferRequestThreadSafeDefaultTests, callbackIsCalledIfAsyncRequestFailed) {
183     auto taskExecutor = std::make_shared<TaskExecutor>();
184     testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
185                                                                       mockTaskSync, taskExecutor);
186     IInferRequest::Ptr asyncRequest;
187     asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
188             testRequest), [](IInferRequest *p) { p->Release(); });
189     testRequest->SetPointerToPublicInterface(asyncRequest);
190
191     bool wasCalled = false;
192     InferRequest cppRequest(asyncRequest);
193     std::function<void(InferRequest, StatusCode)> callback =
194             [&](InferRequest request, StatusCode status) {
195                 wasCalled = true;
196                 ASSERT_EQ(StatusCode::GENERAL_ERROR, status);
197             };
198     cppRequest.SetCompletionCallback(callback);
199     EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
200
201     testRequest->StartAsync();
202     EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
203     ASSERT_TRUE(wasCalled);
204 }
205
206 TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailedAndNoCallback) {
207     auto taskExecutor = std::make_shared<TaskExecutor>();
208     testRequest = make_shared<TestAsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor,
209                                                                       mockTaskSync, taskExecutor);
210     IInferRequest::Ptr asyncRequest;
211     asyncRequest.reset(new InferRequestBase<TestAsyncInferRequestThreadSafeDefault>(
212             testRequest), [](IInferRequest *p) { p->Release(); });
213     testRequest->SetPointerToPublicInterface(asyncRequest);
214
215     EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
216
217     testRequest->StartAsync();
218     EXPECT_THROW(testRequest->Wait(IInferRequest::WaitMode::RESULT_READY), std::exception);
219 }