1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include <ie_version.hpp>
9 #include <inference_engine/cnn_network_impl.hpp>
10 #include <cpp_interfaces/base/ie_plugin_base.hpp>
12 #include <mock_icnn_network.hpp>
13 #include <mock_iexecutable_network.hpp>
14 #include <mock_not_empty_icnn_network.hpp>
15 #include <cpp_interfaces/mock_plugin_impl.hpp>
16 #include <cpp_interfaces/impl/mock_inference_plugin_internal.hpp>
17 #include <cpp_interfaces/impl/mock_executable_thread_safe_default.hpp>
18 #include <cpp_interfaces/interface/mock_iinfer_request_internal.hpp>
19 #include <mock_iasync_infer_request.hpp>
21 using namespace ::testing;
23 using namespace InferenceEngine;
24 using namespace InferenceEngine::details;
26 class InferenceEnginePluginInternalTest : public ::testing::Test {
28 shared_ptr<IInferencePlugin> plugin;
29 shared_ptr<MockInferencePluginInternal> mock_plugin_impl;
30 shared_ptr<MockExecutableNetworkInternal> mockExeNetworkInternal;
31 shared_ptr<MockExecutableNetworkThreadSafe> mockExeNetworkTS;
32 shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
33 MockNotEmptyICNNNetwork mockNotEmptyNet;
38 virtual void TearDown() {
39 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mock_plugin_impl.get()));
40 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockExeNetworkInternal.get()));
41 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockExeNetworkTS.get()));
42 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockInferRequestInternal.get()));
45 virtual void SetUp() {
46 mock_plugin_impl.reset(new MockInferencePluginInternal());
47 plugin = details::shared_from_irelease(make_ie_compatible_plugin({1, 2, "test", "version"}, mock_plugin_impl));
48 mockExeNetworkInternal = make_shared<MockExecutableNetworkInternal>();
51 void getInferRequestWithMockImplInside(IInferRequest::Ptr &request) {
52 IExecutableNetwork::Ptr exeNetwork;
53 InputsDataMap inputsInfo;
54 mockNotEmptyNet.getInputsInfo(inputsInfo);
55 OutputsDataMap outputsInfo;
56 mockNotEmptyNet.getOutputsInfo(outputsInfo);
57 mockInferRequestInternal = make_shared<MockInferRequestInternal>(inputsInfo, outputsInfo);
58 mockExeNetworkTS = make_shared<MockExecutableNetworkThreadSafe>();
59 EXPECT_CALL(*mock_plugin_impl.get(), LoadExeNetworkImpl(_, _)).WillOnce(Return(mockExeNetworkTS));
60 EXPECT_CALL(*mockExeNetworkTS.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
61 sts = plugin->LoadNetwork(exeNetwork, mockNotEmptyNet, {}, &dsc);
62 ASSERT_EQ((int) StatusCode::OK, sts) << dsc.msg;
63 ASSERT_NE(exeNetwork, nullptr) << dsc.msg;
64 sts = exeNetwork->CreateInferRequest(request, &dsc);
65 ASSERT_EQ((int) StatusCode::OK, sts) << dsc.msg;
69 MATCHER_P(blob_in_map_pointer_is_same, ref_blob, "") {
70 auto a = arg.begin()->second.get();
71 return (float *) (arg.begin()->second->buffer()) == (float *) (ref_blob->buffer());
74 TEST_F(InferenceEnginePluginInternalTest, canUseNewInferViaOldAPI) {
75 shared_ptr<Blob> inblob(new TBlob<float>(Precision::FP32, NCHW));
76 shared_ptr<Blob> resblob(new TBlob<float>(Precision::FP32, NCHW));
78 inblob->Resize({1}, Layout::C);
79 resblob->Resize({1}, Layout::C);
88 EXPECT_CALL(*mock_plugin_impl.get(), Infer(Matcher<const BlobMap &>(blob_in_map_pointer_is_same(inblob)),
89 Matcher<BlobMap &>(blob_in_map_pointer_is_same(resblob)))).Times(1);
91 EXPECT_NO_THROW(plugin->Infer((Blob &) *inblob.get(), (Blob &) *resblob.get(), nullptr));
94 TEST_F(InferenceEnginePluginInternalTest, loadExeNetworkCallsSetNetworkIO) {
95 IExecutableNetwork::Ptr exeNetwork;
96 map<string, string> config;
97 EXPECT_CALL(*mockExeNetworkInternal.get(), setNetworkInputs(_)).Times(1);
98 EXPECT_CALL(*mockExeNetworkInternal.get(), setNetworkOutputs(_)).Times(1);
99 EXPECT_CALL(*mock_plugin_impl.get(), LoadExeNetworkImpl(Ref(mockNotEmptyNet), Ref(config))).WillOnce(
100 Return(mockExeNetworkInternal));
101 EXPECT_NO_THROW(plugin->LoadNetwork(exeNetwork, mockNotEmptyNet, config, nullptr));
104 TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithInCorrectName) {
105 Blob::Ptr inBlob = make_shared_blob<float>(Precision::FP32, NCHW, {});
107 string inputName = "not_input";
108 std::string refError = NOT_FOUND_str + "Failed to find input or output with name: \'" + inputName + "\'";
109 IInferRequest::Ptr inferRequest;
110 getInferRequestWithMockImplInside(inferRequest);
112 ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
113 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
114 dsc.msg[refError.length()] = '\0';
115 ASSERT_EQ(refError, dsc.msg);
118 TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithNullPtr) {
119 Blob::Ptr inBlob = make_shared_blob<float>(Precision::FP32, NCHW, {});
121 string inputName = "not_input";
122 std::string refError = NOT_FOUND_str + "Failed to set blob with empty name";
123 IInferRequest::Ptr inferRequest;
124 getInferRequestWithMockImplInside(inferRequest);
126 ASSERT_NO_THROW(sts = inferRequest->SetBlob(nullptr, inBlob, &dsc));
127 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
128 dsc.msg[refError.length()] = '\0';
129 ASSERT_EQ(refError, dsc.msg);
132 TEST_F(InferenceEnginePluginInternalTest, failToSetNullPtr) {
133 string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
134 std::string refError = NOT_ALLOCATED_str + "Failed to set empty blob with name: \'" + inputName + "\'";
135 IInferRequest::Ptr inferRequest;
136 getInferRequestWithMockImplInside(inferRequest);
137 Blob::Ptr inBlob = nullptr;
139 ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
140 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
141 dsc.msg[refError.length()] = '\0';
142 ASSERT_EQ(refError, dsc.msg);
145 TEST_F(InferenceEnginePluginInternalTest, failToSetEmptyBlob) {
147 string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
148 std::string refError = NOT_ALLOCATED_str + "Failed to set empty blob with name: \'" + inputName + "\'";
149 IInferRequest::Ptr inferRequest;
150 getInferRequestWithMockImplInside(inferRequest);
152 ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
153 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
154 dsc.msg[refError.length()] = '\0';
155 ASSERT_EQ(refError, dsc.msg);
158 TEST_F(InferenceEnginePluginInternalTest, failToSetNotAllocatedBlob) {
159 string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
160 std::string refError = "Input data was not allocated. Input name: \'" + inputName + "\'";
161 IInferRequest::Ptr inferRequest;
162 getInferRequestWithMockImplInside(inferRequest);
163 Blob::Ptr blob = make_shared_blob<float>(Precision::FP32, NCHW, {});
165 ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), blob, &dsc));
166 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
167 dsc.msg[refError.length()] = '\0';
168 ASSERT_EQ(refError, dsc.msg);
171 class InferenceEnginePluginInternal2Test : public ::testing::Test {
173 shared_ptr<IInferencePlugin> plugin;
174 shared_ptr<MockInferencePluginInternal2> mockPluginImpl;
175 shared_ptr<MockIExecutableNetwork> mockExeNetwork;
176 shared_ptr<MockIInferRequest> mockIInferRequest;
177 MockICNNNetwork mockEmptyNet;
178 MockNotEmptyICNNNetwork mockNotEmptyNet;
183 virtual void TearDown() {}
185 virtual void SetUp() {
186 mockPluginImpl = make_shared<MockInferencePluginInternal2>();
187 plugin = details::shared_from_irelease(make_ie_compatible_plugin({1, 2, "test", "version"}, mockPluginImpl));
188 mockExeNetwork = make_shared<MockIExecutableNetwork>();
191 shared_ptr<MockIInferRequest> getMockIInferRequestPtr() {
192 auto mockRequest = make_shared<MockIInferRequest>();
193 EXPECT_CALL(*mockPluginImpl.get(), LoadNetwork(_, _, _)).WillOnce(SetArgReferee<0>(mockExeNetwork));
194 EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequest(_, _)).WillOnce(DoAll(SetArgReferee<0>(mockRequest),
195 Return(StatusCode::OK)));
196 plugin->LoadNetwork(mockNotEmptyNet, nullptr);
201 TEST_F(InferenceEnginePluginInternal2Test, loadExeNetworkWithEmptyNetworkReturnsError) {
202 string refError = "The network doesn't have inputs/outputs";
203 EXPECT_CALL(mockEmptyNet, getInputsInfo(_)).Times(1);
204 EXPECT_CALL(mockEmptyNet, getOutputsInfo(_)).Times(1);
205 EXPECT_NO_THROW(sts = plugin->LoadNetwork(mockEmptyNet, &dsc));
206 ASSERT_EQ(GENERAL_ERROR, sts);
207 dsc.msg[refError.length()] = '\0';
208 ASSERT_EQ(refError, dsc.msg);
211 TEST_F(InferenceEnginePluginInternal2Test, canForwardGetPerfCount) {
212 mockIInferRequest = getMockIInferRequestPtr();
213 map<string, InferenceEngineProfileInfo> profileInfo;
214 EXPECT_CALL(*mockIInferRequest.get(), GetPerformanceCounts(Ref(profileInfo), _)).WillOnce(Return(StatusCode::OK));
215 ASSERT_EQ(OK, plugin->GetPerformanceCounts(profileInfo, &dsc)) << dsc.msg;
218 TEST_F(InferenceEnginePluginInternal2Test, deprecatedInferCallsSetterAndInfer) {
219 mockIInferRequest = getMockIInferRequestPtr();
221 Blob::Ptr inBlob, resBlob;
222 BlobMap inBlobMap, resBlobMap;
223 inBlobMap[MockNotEmptyICNNNetwork::INPUT_BLOB_NAME] = inBlob;
224 resBlobMap[MockNotEmptyICNNNetwork::OUTPUT_BLOB_NAME] = resBlob;
226 EXPECT_CALL(*mockIInferRequest.get(), SetBlob(StrEq(MockNotEmptyICNNNetwork::INPUT_BLOB_NAME), inBlob, _)).WillOnce(
227 Return(StatusCode::OK));
228 EXPECT_CALL(*mockIInferRequest.get(),
229 SetBlob(StrEq(MockNotEmptyICNNNetwork::OUTPUT_BLOB_NAME), resBlob, _)).WillOnce(Return(StatusCode::OK));
230 EXPECT_CALL(*mockIInferRequest.get(), Infer(_)).WillOnce(Return(StatusCode::OK));
232 ASSERT_EQ(OK, plugin->Infer(inBlobMap, resBlobMap, &dsc)) << dsc.msg;
235 TEST_F(InferenceEnginePluginInternal2Test, deprecatedLoadNetworkCallsCreateInferRequest) {
236 EXPECT_CALL(*mockPluginImpl.get(), LoadNetwork(_, _, _)).WillOnce(SetArgReferee<0>(mockExeNetwork));
237 EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequest(_, _)).WillOnce(Return(StatusCode::OK));
238 ASSERT_EQ(OK, plugin->LoadNetwork(mockNotEmptyNet, &dsc)) << dsc.msg;