1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "tests_common.hpp"
7 #include <ie_plugin_ptr.hpp>
8 #include "details/ie_so_loader.h"
9 #include "inference_engine.hpp"
10 #include "mock_inference_engine.hpp"
11 #include "mock_error_listener.hpp"
12 #include "../tests/mock_engine/mock_plugin.hpp"
16 using namespace InferenceEngine;
17 using namespace ::testing;
18 using namespace InferenceEngine::details;
20 class PluginTest: public TestsCommon {
22 unique_ptr<SharedObjectLoader> sharedObjectLoader;
23 std::function<IInferencePlugin*(IInferencePlugin*)> createPluginEngineProxy;
24 std::function<void(IInferencePlugin*)> injectPluginEngineProxy ;
25 InferenceEnginePluginPtr getPtr() ;
26 virtual void SetUp() {
28 std::string libraryName = get_mock_engine_name();
29 sharedObjectLoader.reset(new SharedObjectLoader(libraryName.c_str()));
30 createPluginEngineProxy = make_std_function<IInferencePlugin*(IInferencePlugin*)>("CreatePluginEngineProxy");
31 injectPluginEngineProxy = make_std_function<void(IInferencePlugin*)>("InjectProxyEngine");
35 std::function<T> make_std_function(const std::string& functionName) {
36 std::function <T> ptr (reinterpret_cast<T*>(sharedObjectLoader->get_symbol(functionName.c_str())));
40 MockInferenceEngine engine;
44 TEST_F(PluginTest, canCreatePlugin) {
45 auto ptr = make_std_function<IInferencePlugin*(IInferencePlugin*)>("CreatePluginEngineProxy");
47 unique_ptr<IInferencePlugin, std::function<void (IInferencePlugin*)>> smart_ptr(ptr(nullptr), [](IInferencePlugin *p) {
51 //expect that no error handler has been called
52 smart_ptr->SetLogCallback(error);
53 EXPECT_CALL(error, onError(_)).Times(0);
56 TEST_F(PluginTest, canCreatePluginUsingSmartPtr) {
57 ASSERT_NO_THROW(InferenceEnginePluginPtr ptr(get_mock_engine_name()));
60 TEST_F(PluginTest, shouldThrowExceptionIfPluginNotExist) {
61 EXPECT_THROW(InferenceEnginePluginPtr("unknown_plugin"), InferenceEngineException);
64 ACTION_TEMPLATE(CallListenerWithErrorMessage,
65 HAS_1_TEMPLATE_PARAMS(int, k),
66 AND_1_VALUE_PARAMS(pointer))
68 InferenceEngine::IErrorListener & data = ::std::get<k>(args);
69 data.onError(pointer);
72 TEST_F(PluginTest, canCallErrorHandlerIfNecessary) {
74 unique_ptr<IInferencePlugin, std::function<void(IInferencePlugin*)>> smart_ptr(createPluginEngineProxy(&engine), [](IInferencePlugin *p) {
78 const char * err = "my error forward";
80 EXPECT_CALL(error, onError(err)).Times(1);
81 EXPECT_CALL(engine, SetLogCallback(_)).WillOnce(CallListenerWithErrorMessage<0>(err));
82 EXPECT_CALL(engine, Release()).Times(1);
84 smart_ptr->SetLogCallback(error);
87 InferenceEnginePluginPtr PluginTest::getPtr() {
88 InferenceEnginePluginPtr smart_ptr(get_mock_engine_name());
92 TEST_F(PluginTest, canForwardPluginEnginePtr) {
94 injectPluginEngineProxy(&engine);
95 InferenceEnginePluginPtr ptr3 = getPtr();
97 EXPECT_CALL(engine, Infer(_, A<Blob&>(), _)).WillOnce(Return(OK));
98 EXPECT_CALL(engine, Release()).Times(1);
100 TBlob <float> b1(Precision::FP32, NCHW);
101 TBlob <float> b2(Precision::FP32, NCHW);
102 ptr3->Infer(b1, b2, nullptr);
106 TEST_F(PluginTest, canSetConfiguration) {
107 InferenceEnginePluginPtr ptr = getPtr();
108 // TODO: dynamic->reinterpret because of calng/gcc cannot
109 // dynamically cast this MOCK object
110 ASSERT_TRUE(reinterpret_cast<MockPlugin*>(*ptr)->config.empty());
113 std::map<std::string, std::string> config = { { "key", "value" } };
114 ASSERT_EQ(ptr->SetConfig(config, &resp), OK);
117 ASSERT_STREQ(reinterpret_cast<MockPlugin*>(*ptr)->config["key"].c_str(), "value");