1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <ie_version.hpp>
8 #include <inference_engine/cnn_network_impl.hpp>
9 #include <cpp_interfaces/base/ie_plugin_base.hpp>
11 #include <mock_icnn_network.hpp>
12 #include <mock_iexecutable_network.hpp>
13 #include <mock_not_empty_icnn_network.hpp>
14 #include <cpp_interfaces/mock_plugin_impl.hpp>
15 #include <cpp_interfaces/impl/mock_inference_plugin_internal.hpp>
16 #include <cpp_interfaces/impl/mock_executable_thread_safe_default.hpp>
17 #include <cpp_interfaces/interface/mock_iinfer_request_internal.hpp>
18 #include <mock_iasync_infer_request.hpp>
20 using namespace ::testing;
22 using namespace InferenceEngine;
23 using namespace InferenceEngine::details;
25 class InferenceEnginePluginInternalTest : public ::testing::Test {
27 shared_ptr<IInferencePlugin> plugin;
28 shared_ptr<MockInferencePluginInternal> mock_plugin_impl;
29 shared_ptr<MockExecutableNetworkInternal> mockExeNetworkInternal;
30 shared_ptr<MockExecutableNetworkThreadSafe> mockExeNetworkTS;
31 shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
32 MockNotEmptyICNNNetwork mockNotEmptyNet;
37 virtual void TearDown() {
38 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mock_plugin_impl.get()));
39 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockExeNetworkInternal.get()));
40 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockExeNetworkTS.get()));
41 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockInferRequestInternal.get()));
44 virtual void SetUp() {
45 mock_plugin_impl.reset(new MockInferencePluginInternal());
46 plugin = details::shared_from_irelease(make_ie_compatible_plugin({1, 6, "test", "version"}, mock_plugin_impl));
47 mockExeNetworkInternal = make_shared<MockExecutableNetworkInternal>();
50 void getInferRequestWithMockImplInside(IInferRequest::Ptr &request) {
51 IExecutableNetwork::Ptr exeNetwork;
52 InputsDataMap inputsInfo;
53 mockNotEmptyNet.getInputsInfo(inputsInfo);
54 OutputsDataMap outputsInfo;
55 mockNotEmptyNet.getOutputsInfo(outputsInfo);
56 mockInferRequestInternal = make_shared<MockInferRequestInternal>(inputsInfo, outputsInfo);
57 mockExeNetworkTS = make_shared<MockExecutableNetworkThreadSafe>();
58 EXPECT_CALL(*mock_plugin_impl.get(), LoadExeNetworkImpl(_, _)).WillOnce(Return(mockExeNetworkTS));
59 EXPECT_CALL(*mockExeNetworkTS.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
60 sts = plugin->LoadNetwork(exeNetwork, mockNotEmptyNet, {}, &dsc);
61 ASSERT_EQ((int) StatusCode::OK, sts) << dsc.msg;
62 ASSERT_NE(exeNetwork, nullptr) << dsc.msg;
63 sts = exeNetwork->CreateInferRequest(request, &dsc);
64 ASSERT_EQ((int) StatusCode::OK, sts) << dsc.msg;
68 MATCHER_P(blob_in_map_pointer_is_same, ref_blob, "") {
69 auto a = arg.begin()->second.get();
70 return (float *) (arg.begin()->second->buffer()) == (float *) (ref_blob->buffer());
73 TEST_F(InferenceEnginePluginInternalTest, canUseNewInferViaOldAPI) {
74 shared_ptr<Blob> inblob(new TBlob<float>(Precision::FP32, NCHW));
75 shared_ptr<Blob> resblob(new TBlob<float>(Precision::FP32, NCHW));
77 inblob->Resize({1}, Layout::C);
78 resblob->Resize({1}, Layout::C);
87 EXPECT_CALL(*mock_plugin_impl.get(), Infer(Matcher<const BlobMap &>(blob_in_map_pointer_is_same(inblob)),
88 Matcher<BlobMap &>(blob_in_map_pointer_is_same(resblob)))).Times(1);
90 EXPECT_NO_THROW(plugin->Infer((Blob &) *inblob.get(), (Blob &) *resblob.get(), nullptr));
93 TEST_F(InferenceEnginePluginInternalTest, loadExeNetworkCallsSetNetworkIO) {
94 IExecutableNetwork::Ptr exeNetwork;
95 map<string, string> config;
96 EXPECT_CALL(*mockExeNetworkInternal.get(), setNetworkInputs(_)).Times(1);
97 EXPECT_CALL(*mockExeNetworkInternal.get(), setNetworkOutputs(_)).Times(1);
98 EXPECT_CALL(*mock_plugin_impl.get(), LoadExeNetworkImpl(Ref(mockNotEmptyNet), Ref(config))).WillOnce(
99 Return(mockExeNetworkInternal));
100 EXPECT_NO_THROW(plugin->LoadNetwork(exeNetwork, mockNotEmptyNet, config, nullptr));
103 TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithInCorrectName) {
104 Blob::Ptr inBlob = make_shared_blob<float>(Precision::FP32, NCHW, {});
106 string inputName = "not_input";
107 std::string refError = NOT_FOUND_str + "Failed to find input or output with name: \'" + inputName + "\'";
108 IInferRequest::Ptr inferRequest;
109 getInferRequestWithMockImplInside(inferRequest);
111 ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
112 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
113 dsc.msg[refError.length()] = '\0';
114 ASSERT_EQ(refError, dsc.msg);
117 TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithNullPtr) {
118 Blob::Ptr inBlob = make_shared_blob<float>(Precision::FP32, NCHW, {});
120 string inputName = "not_input";
121 std::string refError = NOT_FOUND_str + "Failed to set blob with empty name";
122 IInferRequest::Ptr inferRequest;
123 getInferRequestWithMockImplInside(inferRequest);
125 ASSERT_NO_THROW(sts = inferRequest->SetBlob(nullptr, inBlob, &dsc));
126 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
127 dsc.msg[refError.length()] = '\0';
128 ASSERT_EQ(refError, dsc.msg);
131 TEST_F(InferenceEnginePluginInternalTest, failToSetNullPtr) {
132 string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
133 std::string refError = NOT_ALLOCATED_str + "Failed to set empty blob with name: \'" + inputName + "\'";
134 IInferRequest::Ptr inferRequest;
135 getInferRequestWithMockImplInside(inferRequest);
136 Blob::Ptr inBlob = nullptr;
138 ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
139 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
140 dsc.msg[refError.length()] = '\0';
141 ASSERT_EQ(refError, dsc.msg);
144 TEST_F(InferenceEnginePluginInternalTest, failToSetEmptyBlob) {
146 string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
147 std::string refError = NOT_ALLOCATED_str + "Failed to set empty blob with name: \'" + inputName + "\'";
148 IInferRequest::Ptr inferRequest;
149 getInferRequestWithMockImplInside(inferRequest);
151 ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
152 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
153 dsc.msg[refError.length()] = '\0';
154 ASSERT_EQ(refError, dsc.msg);
157 TEST_F(InferenceEnginePluginInternalTest, failToSetNotAllocatedBlob) {
158 string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
159 std::string refError = "Input data was not allocated. Input name: \'" + inputName + "\'";
160 IInferRequest::Ptr inferRequest;
161 getInferRequestWithMockImplInside(inferRequest);
162 Blob::Ptr blob = make_shared_blob<float>(Precision::FP32, NCHW, {});
164 ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), blob, &dsc));
165 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
166 dsc.msg[refError.length()] = '\0';
167 ASSERT_EQ(refError, dsc.msg);
170 class InferenceEnginePluginInternal2Test : public ::testing::Test {
172 shared_ptr<IInferencePlugin> plugin;
173 shared_ptr<MockInferencePluginInternal2> mockPluginImpl;
174 shared_ptr<MockIExecutableNetwork> mockExeNetwork;
175 shared_ptr<MockIInferRequest> mockIInferRequest;
176 MockICNNNetwork mockEmptyNet;
177 MockNotEmptyICNNNetwork mockNotEmptyNet;
182 virtual void TearDown() {}
184 virtual void SetUp() {
185 mockPluginImpl = make_shared<MockInferencePluginInternal2>();
186 plugin = details::shared_from_irelease(make_ie_compatible_plugin({1, 6, "test", "version"}, mockPluginImpl));
187 mockExeNetwork = make_shared<MockIExecutableNetwork>();
190 shared_ptr<MockIInferRequest> getMockIInferRequestPtr() {
191 auto mockRequest = make_shared<MockIInferRequest>();
192 EXPECT_CALL(*mockPluginImpl.get(), LoadNetwork(_, _, _)).WillOnce(SetArgReferee<0>(mockExeNetwork));
193 EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequest(_, _)).WillOnce(DoAll(SetArgReferee<0>(mockRequest),
194 Return(StatusCode::OK)));
195 plugin->LoadNetwork(mockNotEmptyNet, nullptr);
200 TEST_F(InferenceEnginePluginInternal2Test, loadExeNetworkWithEmptyNetworkReturnsError) {
201 string refError = "The network doesn't have inputs/outputs";
202 EXPECT_CALL(mockEmptyNet, getInputsInfo(_)).Times(1);
203 EXPECT_CALL(mockEmptyNet, getOutputsInfo(_)).Times(1);
204 EXPECT_NO_THROW(sts = plugin->LoadNetwork(mockEmptyNet, &dsc));
205 ASSERT_EQ(GENERAL_ERROR, sts);
206 dsc.msg[refError.length()] = '\0';
207 ASSERT_EQ(refError, dsc.msg);
210 TEST_F(InferenceEnginePluginInternal2Test, canForwardGetPerfCount) {
211 mockIInferRequest = getMockIInferRequestPtr();
212 map<string, InferenceEngineProfileInfo> profileInfo;
213 EXPECT_CALL(*mockIInferRequest.get(), GetPerformanceCounts(Ref(profileInfo), _)).WillOnce(Return(StatusCode::OK));
214 ASSERT_EQ(OK, plugin->GetPerformanceCounts(profileInfo, &dsc)) << dsc.msg;
217 TEST_F(InferenceEnginePluginInternal2Test, deprecatedInferCallsSetterAndInfer) {
218 mockIInferRequest = getMockIInferRequestPtr();
220 Blob::Ptr inBlob, resBlob;
221 BlobMap inBlobMap, resBlobMap;
222 inBlobMap[MockNotEmptyICNNNetwork::INPUT_BLOB_NAME] = inBlob;
223 resBlobMap[MockNotEmptyICNNNetwork::OUTPUT_BLOB_NAME] = resBlob;
225 EXPECT_CALL(*mockIInferRequest.get(), SetBlob(StrEq(MockNotEmptyICNNNetwork::INPUT_BLOB_NAME), inBlob, _)).WillOnce(
226 Return(StatusCode::OK));
227 EXPECT_CALL(*mockIInferRequest.get(),
228 SetBlob(StrEq(MockNotEmptyICNNNetwork::OUTPUT_BLOB_NAME), resBlob, _)).WillOnce(Return(StatusCode::OK));
229 EXPECT_CALL(*mockIInferRequest.get(), Infer(_)).WillOnce(Return(StatusCode::OK));
231 ASSERT_EQ(OK, plugin->Infer(inBlobMap, resBlobMap, &dsc)) << dsc.msg;
234 TEST_F(InferenceEnginePluginInternal2Test, deprecatedLoadNetworkCallsCreateInferRequest) {
235 EXPECT_CALL(*mockPluginImpl.get(), LoadNetwork(_, _, _)).WillOnce(SetArgReferee<0>(mockExeNetwork));
236 EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequest(_, _)).WillOnce(Return(StatusCode::OK));
237 ASSERT_EQ(OK, plugin->LoadNetwork(mockNotEmptyNet, &dsc)) << dsc.msg;