Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / cpp_interfaces / plugin_base_tests.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <ie_version.hpp>
8 #include "cpp_interfaces/mock_plugin_impl.hpp"
9 #include "cpp_interfaces/base/ie_plugin_base.hpp"
10
11 using namespace ::testing;
12 using namespace std;
13 using namespace InferenceEngine;
14 using namespace InferenceEngine::details;
15
16 class PluginBaseTests: public ::testing::Test {
17  protected:
18     std::shared_ptr<MockPluginImpl> mock_impl;
19     shared_ptr<IInferencePlugin> plugin;
20     ResponseDesc dsc;
21     virtual void TearDown() {
22     }
23     virtual void SetUp() {
24         mock_impl.reset(new MockPluginImpl());
25         plugin = details::shared_from_irelease(make_ie_compatible_plugin({1,6,"test", "version"}, mock_impl));
26     }
27 };
28
29 TEST_F(PluginBaseTests, canReportVersion) {
30     const Version *V;
31     plugin->GetVersion(V);
32
33     EXPECT_STREQ(V->buildNumber, "test");
34     EXPECT_STREQ(V->description, "version");
35     EXPECT_EQ(V->apiVersion.major, 1);
36     EXPECT_EQ(V->apiVersion.minor, 6);
37
38 }
39
40 TEST_F(PluginBaseTests, canForwardLoadNetwork) {
41
42     EXPECT_CALL(*mock_impl.get(), LoadNetwork(_)).Times(1);
43
44     ICNNNetwork * network = nullptr;
45     ASSERT_EQ(OK, plugin->LoadNetwork(*network, &dsc));
46 }
47
48
49 TEST_F(PluginBaseTests, canReportErrorInLoadNetwork) {
50
51     EXPECT_CALL(*mock_impl.get(), LoadNetwork(_)).WillOnce(Throw(std::runtime_error("compare")));
52
53     ICNNNetwork * network = nullptr;
54     ASSERT_NE(plugin->LoadNetwork(*network, &dsc), OK);
55
56     ASSERT_STREQ(dsc.msg, "compare");
57 }
58
59 TEST_F(PluginBaseTests, canCatchUnknownErrorInLoadNetwork) {
60
61     EXPECT_CALL(*mock_impl.get(), LoadNetwork(_)).WillOnce(Throw(5));
62     ICNNNetwork * network = nullptr;
63     ASSERT_EQ(UNEXPECTED, plugin->LoadNetwork(*network, nullptr));
64 }
65
66 TEST_F(PluginBaseTests, canForwardLoadExeNetwork) {
67
68     EXPECT_CALL(*mock_impl.get(), LoadExeNetwork(_,_,_)).Times(1);
69
70     ICNNNetwork * network = nullptr;
71     IExecutableNetwork::Ptr exeNetwork = nullptr;
72     ASSERT_EQ(OK, plugin->LoadNetwork(exeNetwork, *network, {}, &dsc));
73 }
74
75
76 TEST_F(PluginBaseTests, canReportErrorInLoadExeNetwork) {
77
78     EXPECT_CALL(*mock_impl.get(), LoadExeNetwork(_,_,_)).WillOnce(Throw(std::runtime_error("compare")));
79
80     ICNNNetwork * network = nullptr;
81     IExecutableNetwork::Ptr exeNetwork = nullptr;
82     ASSERT_NE(plugin->LoadNetwork(exeNetwork, *network, {}, &dsc), OK);
83
84     ASSERT_STREQ(dsc.msg, "compare");
85 }
86
87 TEST_F(PluginBaseTests, canCatchUnknownErrorInLoadExeNetwork) {
88
89     EXPECT_CALL(*mock_impl.get(), LoadExeNetwork(_,_,_)).WillOnce(Throw(5));
90     ICNNNetwork * network = nullptr;
91     IExecutableNetwork::Ptr exeNetwork = nullptr;
92     ASSERT_EQ(UNEXPECTED, plugin->LoadNetwork(exeNetwork, *network, {}, nullptr));
93 }
94
95 TEST_F(PluginBaseTests, canForwarInfer) {
96
97     TBlob<float>  input(Precision::FP32, NCHW);
98     TBlob<float>  result(Precision::FP32, NCHW);
99
100
101     EXPECT_CALL(*mock_impl.get(), Infer(Ref(input), Ref(result))).Times(1);
102
103     ASSERT_EQ(OK, plugin->Infer(input, result, &dsc));
104 }
105
106 TEST_F(PluginBaseTests, canReportErrorInInfer) {
107
108     EXPECT_CALL(*mock_impl.get(), Infer(_,_)).WillOnce(Throw(std::runtime_error("error")));
109
110     Blob * input = nullptr;
111     ASSERT_NE(plugin->Infer(*input, *input, &dsc), OK);
112
113     ASSERT_STREQ(dsc.msg, "error");
114 }
115
116 TEST_F(PluginBaseTests, canCatchUnknownErrorInInfer) {
117     EXPECT_CALL(*mock_impl.get(), Infer(_,_)).WillOnce(Throw(5));
118     Blob * input = nullptr;
119     ASSERT_EQ(UNEXPECTED, plugin->Infer(*input, *input, nullptr));
120 }
121
122 TEST_F(PluginBaseTests, canForwarBlobMapInfer) {
123     BlobMap  input;
124     BlobMap  result;
125
126     EXPECT_CALL(*mock_impl.get(), InferBlobMap(Ref(input), Ref(result))).Times(1);
127
128     ASSERT_EQ(OK, plugin->Infer(input, result, &dsc));
129 }
130
131 TEST_F(PluginBaseTests, canReportErrorInBlobMapInfer) {
132
133     EXPECT_CALL(*mock_impl.get(), InferBlobMap(_,_)).WillOnce(Throw(std::runtime_error("error")));
134
135     BlobMap * input = nullptr;
136     ASSERT_NE(plugin->Infer(*input, *input, &dsc), OK);
137
138     ASSERT_STREQ(dsc.msg, "error");
139 }
140
141 TEST_F(PluginBaseTests, canCatchUnknownErrorInBlobMapInfer) {
142     EXPECT_CALL(*mock_impl.get(), InferBlobMap(_,_)).WillOnce(Throw(5));
143     BlobMap * input = nullptr;
144     ASSERT_EQ(UNEXPECTED, plugin->Infer(*input, *input, nullptr));
145 }
146
147 TEST_F(PluginBaseTests, canForwarGetPerformanceCounts) {
148
149     std::map <std::string, InferenceEngineProfileInfo> profileInfo;
150
151     EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(Ref(profileInfo))).Times(1);
152
153     ASSERT_EQ(OK, plugin->GetPerformanceCounts(profileInfo, &dsc));
154 }
155
156
157 TEST_F(PluginBaseTests, canReportErrorInGetPerformanceCounts) {
158
159     std::map <std::string, InferenceEngineProfileInfo> profileInfo;
160
161     EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(_)).WillOnce(Throw(std::runtime_error("error")));
162
163     ASSERT_NE(OK, plugin->GetPerformanceCounts(profileInfo, &dsc));
164
165     ASSERT_STREQ(dsc.msg, "error");
166 }
167
168 TEST_F(PluginBaseTests, canCatchUnknownErrorInGetPerformanceCounts) {
169     EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(_)).WillOnce(Throw(5));
170     std::map <std::string, InferenceEngineProfileInfo> profileInfo;
171     ASSERT_EQ(UNEXPECTED, plugin->GetPerformanceCounts(profileInfo, nullptr));
172 }
173
174 TEST_F(PluginBaseTests, canForwarSetConfig) {
175
176     const std::map <std::string, std::string> config;
177     EXPECT_CALL(*mock_impl.get(), SetConfig(Ref(config))).Times(1);
178     ASSERT_EQ(OK, plugin->SetConfig(config, &dsc));
179 }
180
181 TEST_F(PluginBaseTests, canReportErrorInSetConfig) {
182     const std::map <std::string, std::string> config;
183     EXPECT_CALL(*mock_impl.get(), SetConfig(_)).WillOnce(Throw(std::runtime_error("error")));
184
185     ASSERT_NE(OK, plugin->SetConfig(config, &dsc));
186     ASSERT_STREQ(dsc.msg, "error");
187 }
188
189 TEST_F(PluginBaseTests, canCatchUnknownErrorInSetConfig) {
190     EXPECT_CALL(*mock_impl.get(), SetConfig(_)).WillOnce(Throw(5));
191     const std::map <std::string, std::string> config;
192     ASSERT_EQ(UNEXPECTED, plugin->SetConfig(config, nullptr));
193 }