Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / inference_engine_plugin_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "tests_common.hpp"
6
7 #include <ie_plugin_ptr.hpp>
8 #include "details/ie_so_loader.h"
9 #include "inference_engine.hpp"
10 #include "mock_inference_engine.hpp"
11 #include "mock_error_listener.hpp"
12 #include "../tests/mock_engine/mock_plugin.hpp"
13
14
15 using namespace std;
16 using namespace InferenceEngine;
17 using namespace ::testing;
18 using namespace InferenceEngine::details;
19
20 class PluginTest: public TestsCommon {
21 protected:
22     unique_ptr<SharedObjectLoader> sharedObjectLoader;
23     std::function<IInferencePlugin*(IInferencePlugin*)> createPluginEngineProxy;
24     std::function<void(IInferencePlugin*)> injectPluginEngineProxy ;
25     InferenceEnginePluginPtr getPtr() ;
26     virtual void SetUp() {
27
28         std::string libraryName = get_mock_engine_name();
29         sharedObjectLoader.reset(new SharedObjectLoader(libraryName.c_str()));
30         createPluginEngineProxy = make_std_function<IInferencePlugin*(IInferencePlugin*)>("CreatePluginEngineProxy");
31         injectPluginEngineProxy = make_std_function<void(IInferencePlugin*)>("InjectProxyEngine");
32
33     }
34     template <class T>
35     std::function<T> make_std_function(const std::string& functionName) {
36         std::function <T> ptr (reinterpret_cast<T*>(sharedObjectLoader->get_symbol(functionName.c_str())));
37         return ptr;
38     }
39
40     MockInferenceEngine engine;
41     Listener error;
42 };
43
44 TEST_F(PluginTest, canCreatePlugin) {
45     auto ptr = make_std_function<IInferencePlugin*(IInferencePlugin*)>("CreatePluginEngineProxy");
46
47     unique_ptr<IInferencePlugin, std::function<void (IInferencePlugin*)>> smart_ptr(ptr(nullptr), [](IInferencePlugin *p) {
48         p->Release();
49     });
50
51     //expect that no error handler has been called
52     smart_ptr->SetLogCallback(error);
53     EXPECT_CALL(error, onError(_)).Times(0);
54 }
55
56 TEST_F(PluginTest, canCreatePluginUsingSmartPtr) {
57     ASSERT_NO_THROW(InferenceEnginePluginPtr ptr(get_mock_engine_name()));
58 }
59
60 TEST_F(PluginTest, shouldThrowExceptionIfPluginNotExist) {
61     EXPECT_THROW(InferenceEnginePluginPtr("unknown_plugin"), InferenceEngineException);
62 }
63
64 ACTION_TEMPLATE(CallListenerWithErrorMessage,
65                 HAS_1_TEMPLATE_PARAMS(int, k),
66                 AND_1_VALUE_PARAMS(pointer))
67 {
68     InferenceEngine::IErrorListener & data = ::std::get<k>(args);
69     data.onError(pointer);
70 }
71
72 TEST_F(PluginTest, canCallErrorHandlerIfNecessary) {
73
74     unique_ptr<IInferencePlugin, std::function<void(IInferencePlugin*)>> smart_ptr(createPluginEngineProxy(&engine), [](IInferencePlugin *p) {
75         p->Release();
76     });
77
78     const char * err = "my error forward";
79
80     EXPECT_CALL(error, onError(err)).Times(1);
81     EXPECT_CALL(engine, SetLogCallback(_)).WillOnce(CallListenerWithErrorMessage<0>(err));
82     EXPECT_CALL(engine, Release()).Times(1);
83
84     smart_ptr->SetLogCallback(error);
85 }
86
87 InferenceEnginePluginPtr PluginTest::getPtr() {
88     InferenceEnginePluginPtr smart_ptr(get_mock_engine_name());
89     return smart_ptr;
90 };
91
92 TEST_F(PluginTest, canForwardPluginEnginePtr) {
93
94     injectPluginEngineProxy(&engine);
95     InferenceEnginePluginPtr ptr3 = getPtr();
96
97     EXPECT_CALL(engine, Infer(_, A<Blob&>(), _)).WillOnce(Return(OK));
98     EXPECT_CALL(engine, Release()).Times(1);
99
100     TBlob <float> b1(Precision::FP32, NCHW);
101     TBlob <float> b2(Precision::FP32, NCHW);
102     ptr3->Infer(b1, b2, nullptr);
103 }
104
105
106 TEST_F(PluginTest, canSetConfiguration) {
107     InferenceEnginePluginPtr ptr = getPtr();
108     // TODO: dynamic->reinterpret because of calng/gcc cannot
109     // dynamically cast this MOCK object
110     ASSERT_TRUE(reinterpret_cast<MockPlugin*>(*ptr)->config.empty());
111
112     ResponseDesc resp;
113     std::map<std::string, std::string> config = { { "key", "value" } };
114     ASSERT_EQ(ptr->SetConfig(config, &resp), OK);
115     config.clear();
116
117     ASSERT_STREQ(reinterpret_cast<MockPlugin*>(*ptr)->config["key"].c_str(), "value");
118 }