Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / cpp_interfaces / async_infer_request_tests.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <gmock/gmock-generated-actions.h>
8 #include <ie_version.hpp>
9 #include <mock_iasync_infer_request.hpp>
10 #include <cpp/ie_infer_request.hpp>
11 #include <inference_engine/cpp_interfaces/exception2status.hpp>
12 #include <cpp_interfaces/impl/mock_async_infer_request_internal.hpp>
13 #include <mock_not_empty_icnn_network.hpp>
14 #include <inference_engine/cpp_interfaces/base/ie_infer_async_request_base.hpp>
15
16 using namespace ::testing;
17 using namespace std;
18 using namespace InferenceEngine;
19 using namespace InferenceEngine::details;
20
21 class InferRequestTests : public ::testing::Test {
22 protected:
23     std::shared_ptr<MockIInferRequest> mock_request;
24     InferRequest::Ptr requestWrapper;
25     ResponseDesc dsc;
26
27     shared_ptr<MockAsyncInferRequestInternal> mockInferRequestInternal;
28     MockNotEmptyICNNNetwork mockNotEmptyNet;
29     std::string _incorrectName;
30     std::string _inputName;
31     std::string _failedToFindInOutError;
32     std::string _inputDataNotAllocatedError;
33
34     virtual void TearDown() {
35         EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockInferRequestInternal.get()));
36         EXPECT_TRUE(Mock::VerifyAndClearExpectations(mock_request.get()));
37         EXPECT_TRUE(Mock::VerifyAndClearExpectations(requestWrapper.get()));
38     }
39
40     virtual void SetUp() {
41         mock_request = make_shared<MockIInferRequest>();
42         requestWrapper = make_shared<InferRequest>(mock_request);
43         _incorrectName = "incorrect_name";
44         _inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
45         _failedToFindInOutError =
46                 NOT_FOUND_str + "Failed to find input or output with name: \'" + _incorrectName + "\'";
47         _inputDataNotAllocatedError = std::string("Input data was not allocated. Input name: \'")
48                                       + _inputName + "\'";
49     }
50
51     InferRequest::Ptr getInferRequestWithMockImplInside() {
52         IInferRequest::Ptr inferRequest;
53         InputsDataMap inputsInfo;
54         mockNotEmptyNet.getInputsInfo(inputsInfo);
55         OutputsDataMap outputsInfo;
56         mockNotEmptyNet.getOutputsInfo(outputsInfo);
57         mockInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(inputsInfo, outputsInfo);
58         inferRequest = shared_from_irelease(
59                 new InferRequestBase<MockAsyncInferRequestInternal>(mockInferRequestInternal));
60         return make_shared<InferRequest>(inferRequest);
61     }
62
63     std::string getExceptionMessage(std::function<void()> function) {
64         std::string actualError;
65         try {
66             function();
67         } catch (const InferenceEngineException &iie) {
68             actualError = iie.what();
69         }
70         return actualError;
71     }
72
73     BlobMap getBlobMapWithIncorrectName() const {
74         Blob::Ptr Blob = make_shared_blob<float>(Precision::FP32, NCHW, {});
75         Blob->allocate();
76         return BlobMap{{_incorrectName, Blob}};
77     }
78
79     BlobMap getBlobMapWithNotAllocatedInput() const {
80         Blob::Ptr Blob = make_shared_blob<float>(Precision::FP32, NCHW, {});
81         return BlobMap{{_inputName, Blob}};
82     }
83 };
84
85 // StartAsync
86 TEST_F(InferRequestTests, canForwardStartAsync) {
87     EXPECT_CALL(*mock_request.get(), StartAsync(_)).WillOnce(Return(OK));
88     ASSERT_NO_THROW(requestWrapper->StartAsync());
89 }
90
91 TEST_F(InferRequestTests, throwsIfStartAsyncReturnNotOK) {
92     EXPECT_CALL(*mock_request.get(), StartAsync(_)).WillOnce(Return(GENERAL_ERROR));
93     ASSERT_THROW(requestWrapper->StartAsync(), InferenceEngineException);
94 }
95
96 // Wait
97 TEST_F(InferRequestTests, canForwardWait) {
98     int64_t ms = 0;
99     EXPECT_CALL(*mock_request.get(), Wait(ms, _)).WillOnce(Return(OK));
100     ASSERT_TRUE(OK == requestWrapper->Wait(ms));
101 }
102
103 TEST_F(InferRequestTests, canForwardStatusFromWait) {
104     EXPECT_CALL(*mock_request.get(), Wait(_, _)).WillOnce(Return(RESULT_NOT_READY));
105     ASSERT_EQ(requestWrapper->Wait(0), RESULT_NOT_READY);
106 }
107
108 // Infer
109 TEST_F(InferRequestTests, canForwardInfer) {
110     EXPECT_CALL(*mock_request.get(), Infer(_)).WillOnce(Return(OK));
111     ASSERT_NO_THROW(requestWrapper->Infer());
112 }
113
114 TEST_F(InferRequestTests, throwsIfInferReturnNotOK) {
115     EXPECT_CALL(*mock_request.get(), Infer(_)).WillOnce(Return(GENERAL_ERROR));
116     ASSERT_THROW(requestWrapper->Infer(), InferenceEngineException);
117 }
118
119 // GetPerformanceCounts
120 TEST_F(InferRequestTests, canForwardGetPerformanceCounts) {
121     std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
122     EXPECT_CALL(*mock_request.get(), GetPerformanceCounts(_, _)).WillOnce(Return(OK));
123     ASSERT_NO_THROW(info = requestWrapper->GetPerformanceCounts());
124 }
125
126 TEST_F(InferRequestTests, throwsIfGetPerformanceCountsReturnNotOK) {
127     std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
128     EXPECT_CALL(*mock_request.get(), GetPerformanceCounts(_, _)).WillOnce(Return(GENERAL_ERROR));
129     ASSERT_THROW(info = requestWrapper->GetPerformanceCounts(), InferenceEngineException);
130 }
131
132 MATCHER_P(blob_in_map_pointer_is_same, ref_blob, "") {
133     auto a = arg.begin()->second.get();
134     return (float *) (arg.begin()->second->buffer()) == (float *) (ref_blob->buffer());
135 }
136
137 // SetInput
138 TEST_F(InferRequestTests, getInputCallsSetBlob) {
139     Blob::Ptr inblob;
140     std::string blobName1 = "blob1";
141     std::string blobName2 = "blob2";
142     BlobMap blobMap{{blobName1, inblob},
143                     {blobName2, inblob}};
144
145     EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName1.c_str()), inblob, _)).WillOnce(Return(OK));
146     EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName2.c_str()), inblob, _)).WillOnce(Return(OK));
147     ASSERT_NO_THROW(requestWrapper->SetInput(blobMap));
148 }
149
150 TEST_F(InferRequestTests, throwsIfSetInputReturnNotOK) {
151     EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
152     BlobMap blobMap{{{}, {}}};
153     ASSERT_THROW(requestWrapper->SetInput(blobMap), InferenceEngineException);
154 }
155
156 // SetOutput
157 TEST_F(InferRequestTests, getOutputCallsSetBlob) {
158     Blob::Ptr inblob;
159     std::string blobName1 = "blob1";
160     std::string blobName2 = "blob2";
161     BlobMap blobMap{{blobName1, inblob},
162                     {blobName2, inblob}};
163
164     EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName1.c_str()), inblob, _)).WillOnce(Return(OK));
165     EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName2.c_str()), inblob, _)).WillOnce(Return(OK));
166     ASSERT_NO_THROW(requestWrapper->SetOutput(blobMap));
167 }
168
169 // GetBlob
170 TEST_F(InferRequestTests, canForwardGetBlob) {
171     Blob::Ptr blob = make_shared_blob<float>(Precision::FP32, NCHW, {});
172     blob->allocate();
173     std::string name = "blob1";
174
175     EXPECT_CALL(*mock_request.get(), GetBlob(StrEq(name.c_str()), _, _)).WillOnce(DoAll(SetArgReferee<1>(blob), Return(OK)));
176     ASSERT_NO_THROW(requestWrapper->GetBlob(name));
177 }
178
179 TEST_F(InferRequestTests, throwsIfGetBlobReturnNotOK) {
180     Blob::Ptr blob;
181     std::string name = "blob1";
182
183     EXPECT_CALL(*mock_request.get(), GetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
184     ASSERT_THROW(blob = requestWrapper->GetBlob(name), InferenceEngineException);
185 }
186
187 // SetBlob
188 TEST_F(InferRequestTests, canForwardSetBlob) {
189     Blob::Ptr blob;
190     std::string name = "blob1";
191
192     EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(name.c_str()), blob, _)).WillOnce(Return(OK));
193     ASSERT_NO_THROW(requestWrapper->SetBlob(name, blob));
194 }
195
196 TEST_F(InferRequestTests, throwsIfSetBlobReturnNotOK) {
197     Blob::Ptr blob;
198     std::string name = "blob1";
199
200     EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
201     ASSERT_THROW(requestWrapper->SetBlob(name, blob), InferenceEngineException);
202 }
203
204 TEST_F(InferRequestTests, throwsIfSetOutputReturnNotOK) {
205     EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
206     BlobMap blobMap{{{}, {}}};
207     ASSERT_THROW(requestWrapper->SetOutput(blobMap), InferenceEngineException);
208 }
209
210 // SetCompletionCallback API
211 void callme(InferenceEngine::IInferRequest::Ptr p, InferenceEngine::StatusCode) {
212     void *data = nullptr;
213     p->GetUserData(&data, nullptr);
214     ASSERT_NE(nullptr, data);
215 }
216
217 TEST_F(InferRequestTests, canForwardCompletionCallback) {
218     void *data = nullptr;
219     EXPECT_CALL(*mock_request.get(), SetCompletionCallback(_)).WillOnce(
220             DoAll(InvokeArgument<0>(static_pointer_cast<IInferRequest>(mock_request), OK), Return(OK)));
221     EXPECT_CALL(*mock_request.get(), GetUserData(_, _)).WillRepeatedly(
222             DoAll(Invoke([&](void **pData, ResponseDesc *resp) {
223                 *pData = data;
224             }), Return(OK)));
225     EXPECT_CALL(*mock_request.get(), SetUserData(_, _)).WillOnce(DoAll(SaveArg<0>(&data), Return(OK)));
226     ASSERT_NO_THROW(requestWrapper->SetCompletionCallback(&callme));
227 }
228
229 TEST_F(InferRequestTests, canForwardAnyCallback) {
230     void *data = nullptr;
231     EXPECT_CALL(*mock_request.get(), SetCompletionCallback(_)).WillOnce(
232             DoAll(InvokeArgument<0>(static_pointer_cast<IInferRequest>(mock_request), OK), Return(OK)));
233     EXPECT_CALL(*mock_request.get(), GetUserData(_, _)).WillRepeatedly(
234             DoAll(Invoke([&](void **pData, ResponseDesc *resp) {
235                 *pData = data;
236             }), Return(OK)));
237     EXPECT_CALL(*mock_request.get(), SetUserData(_, _)).WillOnce(DoAll(SaveArg<0>(&data), Return(OK)));
238
239     ASSERT_NO_THROW(requestWrapper->SetCompletionCallback([&]() {
240         // data used to store callback pointer
241         ASSERT_NE(data, nullptr);
242     }));
243 }
244
245 TEST_F(InferRequestTests, failToSetInputWithInCorrectName) {
246     auto InferRequest = getInferRequestWithMockImplInside();
247     auto blobMap = getBlobMapWithIncorrectName();
248     auto exceptionMessage = getExceptionMessage([&]() { InferRequest->SetInput(blobMap); });
249     ASSERT_EQ(_failedToFindInOutError, exceptionMessage.substr(0, _failedToFindInOutError.size()));
250 }
251
252 TEST_F(InferRequestTests, failToSetOutputWithInCorrectName) {
253     auto InferRequest = getInferRequestWithMockImplInside();
254     auto blobMap = getBlobMapWithIncorrectName();
255     auto exceptionMessage = getExceptionMessage([&]() { InferRequest->SetOutput(blobMap); });
256     ASSERT_EQ(_failedToFindInOutError, exceptionMessage.substr(0, _failedToFindInOutError.size()));
257 }
258
259 TEST_F(InferRequestTests, failToSetInputWithNotAllocatedInput) {
260     auto InferRequest = getInferRequestWithMockImplInside();
261     auto blobMap = getBlobMapWithNotAllocatedInput();
262     auto exceptionMessage = getExceptionMessage([&]() { InferRequest->SetInput(blobMap); });
263     ASSERT_EQ(_inputDataNotAllocatedError, exceptionMessage.substr(0, _inputDataNotAllocatedError.size()));
264 }