1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <ie_version.hpp>
8 #include "cpp_interfaces/mock_plugin_impl.hpp"
9 #include "cpp_interfaces/base/ie_plugin_base.hpp"
11 using namespace ::testing;
13 using namespace InferenceEngine;
14 using namespace InferenceEngine::details;
16 class PluginBaseTests: public ::testing::Test {
18 std::shared_ptr<MockPluginImpl> mock_impl;
19 shared_ptr<IInferencePlugin> plugin;
21 virtual void TearDown() {
23 virtual void SetUp() {
24 mock_impl.reset(new MockPluginImpl());
25 plugin = details::shared_from_irelease(make_ie_compatible_plugin({1,6,"test", "version"}, mock_impl));
29 TEST_F(PluginBaseTests, canReportVersion) {
31 plugin->GetVersion(V);
33 EXPECT_STREQ(V->buildNumber, "test");
34 EXPECT_STREQ(V->description, "version");
35 EXPECT_EQ(V->apiVersion.major, 1);
36 EXPECT_EQ(V->apiVersion.minor, 6);
40 TEST_F(PluginBaseTests, canForwardLoadNetwork) {
42 EXPECT_CALL(*mock_impl.get(), LoadNetwork(_)).Times(1);
44 ICNNNetwork * network = nullptr;
45 ASSERT_EQ(OK, plugin->LoadNetwork(*network, &dsc));
49 TEST_F(PluginBaseTests, canReportErrorInLoadNetwork) {
51 EXPECT_CALL(*mock_impl.get(), LoadNetwork(_)).WillOnce(Throw(std::runtime_error("compare")));
53 ICNNNetwork * network = nullptr;
54 ASSERT_NE(plugin->LoadNetwork(*network, &dsc), OK);
56 ASSERT_STREQ(dsc.msg, "compare");
59 TEST_F(PluginBaseTests, canCatchUnknownErrorInLoadNetwork) {
61 EXPECT_CALL(*mock_impl.get(), LoadNetwork(_)).WillOnce(Throw(5));
62 ICNNNetwork * network = nullptr;
63 ASSERT_EQ(UNEXPECTED, plugin->LoadNetwork(*network, nullptr));
66 TEST_F(PluginBaseTests, canForwardLoadExeNetwork) {
68 EXPECT_CALL(*mock_impl.get(), LoadExeNetwork(_,_,_)).Times(1);
70 ICNNNetwork * network = nullptr;
71 IExecutableNetwork::Ptr exeNetwork = nullptr;
72 ASSERT_EQ(OK, plugin->LoadNetwork(exeNetwork, *network, {}, &dsc));
76 TEST_F(PluginBaseTests, canReportErrorInLoadExeNetwork) {
78 EXPECT_CALL(*mock_impl.get(), LoadExeNetwork(_,_,_)).WillOnce(Throw(std::runtime_error("compare")));
80 ICNNNetwork * network = nullptr;
81 IExecutableNetwork::Ptr exeNetwork = nullptr;
82 ASSERT_NE(plugin->LoadNetwork(exeNetwork, *network, {}, &dsc), OK);
84 ASSERT_STREQ(dsc.msg, "compare");
87 TEST_F(PluginBaseTests, canCatchUnknownErrorInLoadExeNetwork) {
89 EXPECT_CALL(*mock_impl.get(), LoadExeNetwork(_,_,_)).WillOnce(Throw(5));
90 ICNNNetwork * network = nullptr;
91 IExecutableNetwork::Ptr exeNetwork = nullptr;
92 ASSERT_EQ(UNEXPECTED, plugin->LoadNetwork(exeNetwork, *network, {}, nullptr));
95 TEST_F(PluginBaseTests, canForwarInfer) {
97 TBlob<float> input(Precision::FP32, NCHW);
98 TBlob<float> result(Precision::FP32, NCHW);
101 EXPECT_CALL(*mock_impl.get(), Infer(Ref(input), Ref(result))).Times(1);
103 ASSERT_EQ(OK, plugin->Infer(input, result, &dsc));
106 TEST_F(PluginBaseTests, canReportErrorInInfer) {
108 EXPECT_CALL(*mock_impl.get(), Infer(_,_)).WillOnce(Throw(std::runtime_error("error")));
110 Blob * input = nullptr;
111 ASSERT_NE(plugin->Infer(*input, *input, &dsc), OK);
113 ASSERT_STREQ(dsc.msg, "error");
116 TEST_F(PluginBaseTests, canCatchUnknownErrorInInfer) {
117 EXPECT_CALL(*mock_impl.get(), Infer(_,_)).WillOnce(Throw(5));
118 Blob * input = nullptr;
119 ASSERT_EQ(UNEXPECTED, plugin->Infer(*input, *input, nullptr));
122 TEST_F(PluginBaseTests, canForwarBlobMapInfer) {
126 EXPECT_CALL(*mock_impl.get(), InferBlobMap(Ref(input), Ref(result))).Times(1);
128 ASSERT_EQ(OK, plugin->Infer(input, result, &dsc));
131 TEST_F(PluginBaseTests, canReportErrorInBlobMapInfer) {
133 EXPECT_CALL(*mock_impl.get(), InferBlobMap(_,_)).WillOnce(Throw(std::runtime_error("error")));
135 BlobMap * input = nullptr;
136 ASSERT_NE(plugin->Infer(*input, *input, &dsc), OK);
138 ASSERT_STREQ(dsc.msg, "error");
141 TEST_F(PluginBaseTests, canCatchUnknownErrorInBlobMapInfer) {
142 EXPECT_CALL(*mock_impl.get(), InferBlobMap(_,_)).WillOnce(Throw(5));
143 BlobMap * input = nullptr;
144 ASSERT_EQ(UNEXPECTED, plugin->Infer(*input, *input, nullptr));
147 TEST_F(PluginBaseTests, canForwarGetPerformanceCounts) {
149 std::map <std::string, InferenceEngineProfileInfo> profileInfo;
151 EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(Ref(profileInfo))).Times(1);
153 ASSERT_EQ(OK, plugin->GetPerformanceCounts(profileInfo, &dsc));
157 TEST_F(PluginBaseTests, canReportErrorInGetPerformanceCounts) {
159 std::map <std::string, InferenceEngineProfileInfo> profileInfo;
161 EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(_)).WillOnce(Throw(std::runtime_error("error")));
163 ASSERT_NE(OK, plugin->GetPerformanceCounts(profileInfo, &dsc));
165 ASSERT_STREQ(dsc.msg, "error");
168 TEST_F(PluginBaseTests, canCatchUnknownErrorInGetPerformanceCounts) {
169 EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(_)).WillOnce(Throw(5));
170 std::map <std::string, InferenceEngineProfileInfo> profileInfo;
171 ASSERT_EQ(UNEXPECTED, plugin->GetPerformanceCounts(profileInfo, nullptr));
174 TEST_F(PluginBaseTests, canForwarSetConfig) {
176 const std::map <std::string, std::string> config;
177 EXPECT_CALL(*mock_impl.get(), SetConfig(Ref(config))).Times(1);
178 ASSERT_EQ(OK, plugin->SetConfig(config, &dsc));
181 TEST_F(PluginBaseTests, canReportErrorInSetConfig) {
182 const std::map <std::string, std::string> config;
183 EXPECT_CALL(*mock_impl.get(), SetConfig(_)).WillOnce(Throw(std::runtime_error("error")));
185 ASSERT_NE(OK, plugin->SetConfig(config, &dsc));
186 ASSERT_STREQ(dsc.msg, "error");
189 TEST_F(PluginBaseTests, canCatchUnknownErrorInSetConfig) {
190 EXPECT_CALL(*mock_impl.get(), SetConfig(_)).WillOnce(Throw(5));
191 const std::map <std::string, std::string> config;
192 ASSERT_EQ(UNEXPECTED, plugin->SetConfig(config, nullptr));