701085c8b5ac773f4f3be925df2c10e8dc3fde3d
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / cpp_interfaces / iinference_plugin_internal_tests.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include <ie_version.hpp>
9 #include <inference_engine/cnn_network_impl.hpp>
10 #include <cpp_interfaces/base/ie_plugin_base.hpp>
11
12 #include <mock_icnn_network.hpp>
13 #include <mock_iexecutable_network.hpp>
14 #include <mock_not_empty_icnn_network.hpp>
15 #include <cpp_interfaces/mock_plugin_impl.hpp>
16 #include <cpp_interfaces/impl/mock_inference_plugin_internal.hpp>
17 #include <cpp_interfaces/impl/mock_executable_thread_safe_default.hpp>
18 #include <cpp_interfaces/interface/mock_iinfer_request_internal.hpp>
19 #include <mock_iasync_infer_request.hpp>
20
21 using namespace ::testing;
22 using namespace std;
23 using namespace InferenceEngine;
24 using namespace InferenceEngine::details;
25
26 class InferenceEnginePluginInternalTest : public ::testing::Test {
27 protected:
28     shared_ptr<IInferencePlugin> plugin;
29     shared_ptr<MockInferencePluginInternal> mock_plugin_impl;
30     shared_ptr<MockExecutableNetworkInternal> mockExeNetworkInternal;
31     shared_ptr<MockExecutableNetworkThreadSafe> mockExeNetworkTS;
32     shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
33     MockNotEmptyICNNNetwork mockNotEmptyNet;
34
35     ResponseDesc dsc;
36     StatusCode sts;
37
38     virtual void TearDown() {
39         EXPECT_TRUE(Mock::VerifyAndClearExpectations(mock_plugin_impl.get()));
40         EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockExeNetworkInternal.get()));
41         EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockExeNetworkTS.get()));
42         EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockInferRequestInternal.get()));
43     }
44
45     virtual void SetUp() {
46         mock_plugin_impl.reset(new MockInferencePluginInternal());
47         plugin = details::shared_from_irelease(make_ie_compatible_plugin({1, 2, "test", "version"}, mock_plugin_impl));
48         mockExeNetworkInternal = make_shared<MockExecutableNetworkInternal>();
49     }
50
51     void getInferRequestWithMockImplInside(IInferRequest::Ptr &request) {
52         IExecutableNetwork::Ptr exeNetwork;
53         InputsDataMap inputsInfo;
54         mockNotEmptyNet.getInputsInfo(inputsInfo);
55         OutputsDataMap outputsInfo;
56         mockNotEmptyNet.getOutputsInfo(outputsInfo);
57         mockInferRequestInternal = make_shared<MockInferRequestInternal>(inputsInfo, outputsInfo);
58         mockExeNetworkTS = make_shared<MockExecutableNetworkThreadSafe>();
59         EXPECT_CALL(*mock_plugin_impl.get(), LoadExeNetworkImpl(_, _)).WillOnce(Return(mockExeNetworkTS));
60         EXPECT_CALL(*mockExeNetworkTS.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
61         sts = plugin->LoadNetwork(exeNetwork, mockNotEmptyNet, {}, &dsc);
62         ASSERT_EQ((int) StatusCode::OK, sts) << dsc.msg;
63         ASSERT_NE(exeNetwork, nullptr) << dsc.msg;
64         sts = exeNetwork->CreateInferRequest(request, &dsc);
65         ASSERT_EQ((int) StatusCode::OK, sts) << dsc.msg;
66     }
67 };
68
69 MATCHER_P(blob_in_map_pointer_is_same, ref_blob, "") {
70     auto a = arg.begin()->second.get();
71     return (float *) (arg.begin()->second->buffer()) == (float *) (ref_blob->buffer());
72 }
73
74 TEST_F(InferenceEnginePluginInternalTest, canUseNewInferViaOldAPI) {
75     shared_ptr<Blob> inblob(new TBlob<float>(Precision::FP32, NCHW));
76     shared_ptr<Blob> resblob(new TBlob<float>(Precision::FP32, NCHW));
77
78     inblob->Resize({1}, Layout::C);
79     resblob->Resize({1}, Layout::C);
80
81     inblob->allocate();
82     resblob->allocate();
83
84     BlobMap blbi;
85     blbi[""] = inblob;
86     BlobMap blbo;
87     blbo[""] = resblob;
88     EXPECT_CALL(*mock_plugin_impl.get(), Infer(Matcher<const BlobMap &>(blob_in_map_pointer_is_same(inblob)),
89                                                Matcher<BlobMap &>(blob_in_map_pointer_is_same(resblob)))).Times(1);
90
91     EXPECT_NO_THROW(plugin->Infer((Blob &) *inblob.get(), (Blob &) *resblob.get(), nullptr));
92 }
93
94 TEST_F(InferenceEnginePluginInternalTest, loadExeNetworkCallsSetNetworkIO) {
95     IExecutableNetwork::Ptr exeNetwork;
96     map<string, string> config;
97     EXPECT_CALL(*mockExeNetworkInternal.get(), setNetworkInputs(_)).Times(1);
98     EXPECT_CALL(*mockExeNetworkInternal.get(), setNetworkOutputs(_)).Times(1);
99     EXPECT_CALL(*mock_plugin_impl.get(), LoadExeNetworkImpl(Ref(mockNotEmptyNet), Ref(config))).WillOnce(
100             Return(mockExeNetworkInternal));
101     EXPECT_NO_THROW(plugin->LoadNetwork(exeNetwork, mockNotEmptyNet, config, nullptr));
102 }
103
104 TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithInCorrectName) {
105     Blob::Ptr inBlob = make_shared_blob<float>(Precision::FP32, NCHW, {});
106     inBlob->allocate();
107     string inputName = "not_input";
108     std::string refError = NOT_FOUND_str + "Failed to find input or output with name: \'" + inputName + "\'";
109     IInferRequest::Ptr inferRequest;
110     getInferRequestWithMockImplInside(inferRequest);
111
112     ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
113     ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
114     dsc.msg[refError.length()] = '\0';
115     ASSERT_EQ(refError, dsc.msg);
116 }
117
118 TEST_F(InferenceEnginePluginInternalTest, failToSetBlobWithNullPtr) {
119     Blob::Ptr inBlob = make_shared_blob<float>(Precision::FP32, NCHW, {});
120     inBlob->allocate();
121     string inputName = "not_input";
122     std::string refError = NOT_FOUND_str + "Failed to set blob with empty name";
123     IInferRequest::Ptr inferRequest;
124     getInferRequestWithMockImplInside(inferRequest);
125
126     ASSERT_NO_THROW(sts = inferRequest->SetBlob(nullptr, inBlob, &dsc));
127     ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
128     dsc.msg[refError.length()] = '\0';
129     ASSERT_EQ(refError, dsc.msg);
130 }
131
132 TEST_F(InferenceEnginePluginInternalTest, failToSetNullPtr) {
133     string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
134     std::string refError = NOT_ALLOCATED_str + "Failed to set empty blob with name: \'" + inputName + "\'";
135     IInferRequest::Ptr inferRequest;
136     getInferRequestWithMockImplInside(inferRequest);
137     Blob::Ptr inBlob = nullptr;
138
139     ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
140     ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
141     dsc.msg[refError.length()] = '\0';
142     ASSERT_EQ(refError, dsc.msg);
143 }
144
145 TEST_F(InferenceEnginePluginInternalTest, failToSetEmptyBlob) {
146     Blob::Ptr inBlob;
147     string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
148     std::string refError = NOT_ALLOCATED_str + "Failed to set empty blob with name: \'" + inputName + "\'";
149     IInferRequest::Ptr inferRequest;
150     getInferRequestWithMockImplInside(inferRequest);
151
152     ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), inBlob, &dsc));
153     ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
154     dsc.msg[refError.length()] = '\0';
155     ASSERT_EQ(refError, dsc.msg);
156 }
157
158 TEST_F(InferenceEnginePluginInternalTest, failToSetNotAllocatedBlob) {
159     string inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
160     std::string refError = "Input data was not allocated. Input name: \'" + inputName + "\'";
161     IInferRequest::Ptr inferRequest;
162     getInferRequestWithMockImplInside(inferRequest);
163     Blob::Ptr blob = make_shared_blob<float>(Precision::FP32, NCHW, {});
164
165     ASSERT_NO_THROW(sts = inferRequest->SetBlob(inputName.c_str(), blob, &dsc));
166     ASSERT_EQ(StatusCode::GENERAL_ERROR, sts);
167     dsc.msg[refError.length()] = '\0';
168     ASSERT_EQ(refError, dsc.msg);
169 }
170
171 class InferenceEnginePluginInternal2Test : public ::testing::Test {
172 protected:
173     shared_ptr<IInferencePlugin> plugin;
174     shared_ptr<MockInferencePluginInternal2> mockPluginImpl;
175     shared_ptr<MockIExecutableNetwork> mockExeNetwork;
176     shared_ptr<MockIInferRequest> mockIInferRequest;
177     MockICNNNetwork mockEmptyNet;
178     MockNotEmptyICNNNetwork mockNotEmptyNet;
179
180     ResponseDesc dsc;
181     StatusCode sts;
182
183     virtual void TearDown() {}
184
185     virtual void SetUp() {
186         mockPluginImpl = make_shared<MockInferencePluginInternal2>();
187         plugin = details::shared_from_irelease(make_ie_compatible_plugin({1, 2, "test", "version"}, mockPluginImpl));
188         mockExeNetwork = make_shared<MockIExecutableNetwork>();
189     }
190
191     shared_ptr<MockIInferRequest> getMockIInferRequestPtr() {
192         auto mockRequest = make_shared<MockIInferRequest>();
193         EXPECT_CALL(*mockPluginImpl.get(), LoadNetwork(_, _, _)).WillOnce(SetArgReferee<0>(mockExeNetwork));
194         EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequest(_, _)).WillOnce(DoAll(SetArgReferee<0>(mockRequest),
195                                                                                     Return(StatusCode::OK)));
196         plugin->LoadNetwork(mockNotEmptyNet, nullptr);
197         return mockRequest;
198     }
199 };
200
201 TEST_F(InferenceEnginePluginInternal2Test, loadExeNetworkWithEmptyNetworkReturnsError) {
202     string refError = "The network doesn't have inputs/outputs";
203     EXPECT_CALL(mockEmptyNet, getInputsInfo(_)).Times(1);
204     EXPECT_CALL(mockEmptyNet, getOutputsInfo(_)).Times(1);
205     EXPECT_NO_THROW(sts = plugin->LoadNetwork(mockEmptyNet, &dsc));
206     ASSERT_EQ(GENERAL_ERROR, sts);
207     dsc.msg[refError.length()] = '\0';
208     ASSERT_EQ(refError, dsc.msg);
209 }
210
211 TEST_F(InferenceEnginePluginInternal2Test, canForwardGetPerfCount) {
212     mockIInferRequest = getMockIInferRequestPtr();
213     map<string, InferenceEngineProfileInfo> profileInfo;
214     EXPECT_CALL(*mockIInferRequest.get(), GetPerformanceCounts(Ref(profileInfo), _)).WillOnce(Return(StatusCode::OK));
215     ASSERT_EQ(OK, plugin->GetPerformanceCounts(profileInfo, &dsc)) << dsc.msg;
216 }
217
218 TEST_F(InferenceEnginePluginInternal2Test, deprecatedInferCallsSetterAndInfer) {
219     mockIInferRequest = getMockIInferRequestPtr();
220
221     Blob::Ptr inBlob, resBlob;
222     BlobMap inBlobMap, resBlobMap;
223     inBlobMap[MockNotEmptyICNNNetwork::INPUT_BLOB_NAME] = inBlob;
224     resBlobMap[MockNotEmptyICNNNetwork::OUTPUT_BLOB_NAME] = resBlob;
225
226     EXPECT_CALL(*mockIInferRequest.get(), SetBlob(StrEq(MockNotEmptyICNNNetwork::INPUT_BLOB_NAME), inBlob, _)).WillOnce(
227             Return(StatusCode::OK));
228     EXPECT_CALL(*mockIInferRequest.get(),
229                 SetBlob(StrEq(MockNotEmptyICNNNetwork::OUTPUT_BLOB_NAME), resBlob, _)).WillOnce(Return(StatusCode::OK));
230     EXPECT_CALL(*mockIInferRequest.get(), Infer(_)).WillOnce(Return(StatusCode::OK));
231
232     ASSERT_EQ(OK, plugin->Infer(inBlobMap, resBlobMap, &dsc)) << dsc.msg;
233 }
234
235 TEST_F(InferenceEnginePluginInternal2Test, deprecatedLoadNetworkCallsCreateInferRequest) {
236     EXPECT_CALL(*mockPluginImpl.get(), LoadNetwork(_, _, _)).WillOnce(SetArgReferee<0>(mockExeNetwork));
237     EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequest(_, _)).WillOnce(Return(StatusCode::OK));
238     ASSERT_EQ(OK, plugin->LoadNetwork(mockNotEmptyNet, &dsc)) << dsc.msg;
239 }