1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <gmock/gmock-generated-actions.h>
8 #include <ie_version.hpp>
9 #include <mock_iasync_infer_request.hpp>
10 #include <cpp/ie_infer_request.hpp>
11 #include <inference_engine/cpp_interfaces/exception2status.hpp>
12 #include <cpp_interfaces/impl/mock_async_infer_request_internal.hpp>
13 #include <mock_not_empty_icnn_network.hpp>
14 #include <inference_engine/cpp_interfaces/base/ie_infer_async_request_base.hpp>
16 using namespace ::testing;
18 using namespace InferenceEngine;
19 using namespace InferenceEngine::details;
21 class InferRequestTests : public ::testing::Test {
23 std::shared_ptr<MockIInferRequest> mock_request;
24 InferRequest::Ptr requestWrapper;
27 shared_ptr<MockAsyncInferRequestInternal> mockInferRequestInternal;
28 MockNotEmptyICNNNetwork mockNotEmptyNet;
29 std::string _incorrectName;
30 std::string _inputName;
31 std::string _failedToFindInOutError;
32 std::string _inputDataNotAllocatedError;
34 virtual void TearDown() {
35 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockInferRequestInternal.get()));
36 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mock_request.get()));
37 EXPECT_TRUE(Mock::VerifyAndClearExpectations(requestWrapper.get()));
40 virtual void SetUp() {
41 mock_request = make_shared<MockIInferRequest>();
42 requestWrapper = make_shared<InferRequest>(mock_request);
43 _incorrectName = "incorrect_name";
44 _inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
45 _failedToFindInOutError =
46 NOT_FOUND_str + "Failed to find input or output with name: \'" + _incorrectName + "\'";
47 _inputDataNotAllocatedError = std::string("Input data was not allocated. Input name: \'")
51 InferRequest::Ptr getInferRequestWithMockImplInside() {
52 IInferRequest::Ptr inferRequest;
53 InputsDataMap inputsInfo;
54 mockNotEmptyNet.getInputsInfo(inputsInfo);
55 OutputsDataMap outputsInfo;
56 mockNotEmptyNet.getOutputsInfo(outputsInfo);
57 mockInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(inputsInfo, outputsInfo);
58 inferRequest = shared_from_irelease(
59 new InferRequestBase<MockAsyncInferRequestInternal>(mockInferRequestInternal));
60 return make_shared<InferRequest>(inferRequest);
63 std::string getExceptionMessage(std::function<void()> function) {
64 std::string actualError;
67 } catch (const InferenceEngineException &iie) {
68 actualError = iie.what();
73 BlobMap getBlobMapWithIncorrectName() const {
74 Blob::Ptr Blob = make_shared_blob<float>(Precision::FP32, NCHW, {});
76 return BlobMap{{_incorrectName, Blob}};
79 BlobMap getBlobMapWithNotAllocatedInput() const {
80 Blob::Ptr Blob = make_shared_blob<float>(Precision::FP32, NCHW, {});
81 return BlobMap{{_inputName, Blob}};
86 TEST_F(InferRequestTests, canForwardStartAsync) {
87 EXPECT_CALL(*mock_request.get(), StartAsync(_)).WillOnce(Return(OK));
88 ASSERT_NO_THROW(requestWrapper->StartAsync());
91 TEST_F(InferRequestTests, throwsIfStartAsyncReturnNotOK) {
92 EXPECT_CALL(*mock_request.get(), StartAsync(_)).WillOnce(Return(GENERAL_ERROR));
93 ASSERT_THROW(requestWrapper->StartAsync(), InferenceEngineException);
97 TEST_F(InferRequestTests, canForwardWait) {
99 EXPECT_CALL(*mock_request.get(), Wait(ms, _)).WillOnce(Return(OK));
100 ASSERT_TRUE(OK == requestWrapper->Wait(ms));
103 TEST_F(InferRequestTests, canForwardStatusFromWait) {
104 EXPECT_CALL(*mock_request.get(), Wait(_, _)).WillOnce(Return(RESULT_NOT_READY));
105 ASSERT_EQ(requestWrapper->Wait(0), RESULT_NOT_READY);
109 TEST_F(InferRequestTests, canForwardInfer) {
110 EXPECT_CALL(*mock_request.get(), Infer(_)).WillOnce(Return(OK));
111 ASSERT_NO_THROW(requestWrapper->Infer());
114 TEST_F(InferRequestTests, throwsIfInferReturnNotOK) {
115 EXPECT_CALL(*mock_request.get(), Infer(_)).WillOnce(Return(GENERAL_ERROR));
116 ASSERT_THROW(requestWrapper->Infer(), InferenceEngineException);
119 // GetPerformanceCounts
120 TEST_F(InferRequestTests, canForwardGetPerformanceCounts) {
121 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
122 EXPECT_CALL(*mock_request.get(), GetPerformanceCounts(_, _)).WillOnce(Return(OK));
123 ASSERT_NO_THROW(info = requestWrapper->GetPerformanceCounts());
126 TEST_F(InferRequestTests, throwsIfGetPerformanceCountsReturnNotOK) {
127 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
128 EXPECT_CALL(*mock_request.get(), GetPerformanceCounts(_, _)).WillOnce(Return(GENERAL_ERROR));
129 ASSERT_THROW(info = requestWrapper->GetPerformanceCounts(), InferenceEngineException);
132 MATCHER_P(blob_in_map_pointer_is_same, ref_blob, "") {
133 auto a = arg.begin()->second.get();
134 return (float *) (arg.begin()->second->buffer()) == (float *) (ref_blob->buffer());
138 TEST_F(InferRequestTests, getInputCallsSetBlob) {
140 std::string blobName1 = "blob1";
141 std::string blobName2 = "blob2";
142 BlobMap blobMap{{blobName1, inblob},
143 {blobName2, inblob}};
145 EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName1.c_str()), inblob, _)).WillOnce(Return(OK));
146 EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName2.c_str()), inblob, _)).WillOnce(Return(OK));
147 ASSERT_NO_THROW(requestWrapper->SetInput(blobMap));
150 TEST_F(InferRequestTests, throwsIfSetInputReturnNotOK) {
151 EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
152 BlobMap blobMap{{{}, {}}};
153 ASSERT_THROW(requestWrapper->SetInput(blobMap), InferenceEngineException);
157 TEST_F(InferRequestTests, getOutputCallsSetBlob) {
159 std::string blobName1 = "blob1";
160 std::string blobName2 = "blob2";
161 BlobMap blobMap{{blobName1, inblob},
162 {blobName2, inblob}};
164 EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName1.c_str()), inblob, _)).WillOnce(Return(OK));
165 EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(blobName2.c_str()), inblob, _)).WillOnce(Return(OK));
166 ASSERT_NO_THROW(requestWrapper->SetOutput(blobMap));
170 TEST_F(InferRequestTests, canForwardGetBlob) {
171 Blob::Ptr blob = make_shared_blob<float>(Precision::FP32, NCHW, {});
173 std::string name = "blob1";
175 EXPECT_CALL(*mock_request.get(), GetBlob(StrEq(name.c_str()), _, _)).WillOnce(DoAll(SetArgReferee<1>(blob), Return(OK)));
176 ASSERT_NO_THROW(requestWrapper->GetBlob(name));
179 TEST_F(InferRequestTests, throwsIfGetBlobReturnNotOK) {
181 std::string name = "blob1";
183 EXPECT_CALL(*mock_request.get(), GetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
184 ASSERT_THROW(blob = requestWrapper->GetBlob(name), InferenceEngineException);
188 TEST_F(InferRequestTests, canForwardSetBlob) {
190 std::string name = "blob1";
192 EXPECT_CALL(*mock_request.get(), SetBlob(StrEq(name.c_str()), blob, _)).WillOnce(Return(OK));
193 ASSERT_NO_THROW(requestWrapper->SetBlob(name, blob));
196 TEST_F(InferRequestTests, throwsIfSetBlobReturnNotOK) {
198 std::string name = "blob1";
200 EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
201 ASSERT_THROW(requestWrapper->SetBlob(name, blob), InferenceEngineException);
204 TEST_F(InferRequestTests, throwsIfSetOutputReturnNotOK) {
205 EXPECT_CALL(*mock_request.get(), SetBlob(_, _, _)).WillOnce(Return(GENERAL_ERROR));
206 BlobMap blobMap{{{}, {}}};
207 ASSERT_THROW(requestWrapper->SetOutput(blobMap), InferenceEngineException);
210 // SetCompletionCallback API
211 void callme(InferenceEngine::IInferRequest::Ptr p, InferenceEngine::StatusCode) {
212 void *data = nullptr;
213 p->GetUserData(&data, nullptr);
214 ASSERT_NE(nullptr, data);
217 TEST_F(InferRequestTests, canForwardCompletionCallback) {
218 void *data = nullptr;
219 EXPECT_CALL(*mock_request.get(), SetCompletionCallback(_)).WillOnce(
220 DoAll(InvokeArgument<0>(static_pointer_cast<IInferRequest>(mock_request), OK), Return(OK)));
221 EXPECT_CALL(*mock_request.get(), GetUserData(_, _)).WillRepeatedly(
222 DoAll(Invoke([&](void **pData, ResponseDesc *resp) {
225 EXPECT_CALL(*mock_request.get(), SetUserData(_, _)).WillOnce(DoAll(SaveArg<0>(&data), Return(OK)));
226 ASSERT_NO_THROW(requestWrapper->SetCompletionCallback(&callme));
229 TEST_F(InferRequestTests, canForwardAnyCallback) {
230 void *data = nullptr;
231 EXPECT_CALL(*mock_request.get(), SetCompletionCallback(_)).WillOnce(
232 DoAll(InvokeArgument<0>(static_pointer_cast<IInferRequest>(mock_request), OK), Return(OK)));
233 EXPECT_CALL(*mock_request.get(), GetUserData(_, _)).WillRepeatedly(
234 DoAll(Invoke([&](void **pData, ResponseDesc *resp) {
237 EXPECT_CALL(*mock_request.get(), SetUserData(_, _)).WillOnce(DoAll(SaveArg<0>(&data), Return(OK)));
239 ASSERT_NO_THROW(requestWrapper->SetCompletionCallback([&]() {
240 // data used to store callback pointer
241 ASSERT_NE(data, nullptr);
245 TEST_F(InferRequestTests, failToSetInputWithInCorrectName) {
246 auto InferRequest = getInferRequestWithMockImplInside();
247 auto blobMap = getBlobMapWithIncorrectName();
248 auto exceptionMessage = getExceptionMessage([&]() { InferRequest->SetInput(blobMap); });
249 ASSERT_EQ(_failedToFindInOutError, exceptionMessage.substr(0, _failedToFindInOutError.size()));
252 TEST_F(InferRequestTests, failToSetOutputWithInCorrectName) {
253 auto InferRequest = getInferRequestWithMockImplInside();
254 auto blobMap = getBlobMapWithIncorrectName();
255 auto exceptionMessage = getExceptionMessage([&]() { InferRequest->SetOutput(blobMap); });
256 ASSERT_EQ(_failedToFindInOutError, exceptionMessage.substr(0, _failedToFindInOutError.size()));
259 TEST_F(InferRequestTests, failToSetInputWithNotAllocatedInput) {
260 auto InferRequest = getInferRequestWithMockImplInside();
261 auto blobMap = getBlobMapWithNotAllocatedInput();
262 auto exceptionMessage = getExceptionMessage([&]() { InferRequest->SetInput(blobMap); });
263 ASSERT_EQ(_inputDataNotAllocatedError, exceptionMessage.substr(0, _inputDataNotAllocatedError.size()));