Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / cpp_interfaces / executable_network_base_tests.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "cpp_interfaces/impl/mock_inference_plugin_internal.hpp"
8 #include "cpp_interfaces/interface/mock_iexecutable_network_internal.hpp"
9
10 #include <ie_version.hpp>
11 #include "cpp_interfaces/base/ie_plugin_base.hpp"
12
13 using namespace ::testing;
14 using namespace std;
15 using namespace InferenceEngine;
16 using namespace InferenceEngine::details;
17
18 class ExecutableNetworkBaseTests : public ::testing::Test {
19 protected:
20     shared_ptr<MockIExecutableNetworkInternal> mock_impl;
21     shared_ptr<IExecutableNetwork> exeNetwork;
22     ResponseDesc dsc;
23
24     virtual void TearDown() {
25     }
26
27     virtual void SetUp() {
28         mock_impl.reset(new MockIExecutableNetworkInternal());
29         exeNetwork = shared_from_irelease(new ExecutableNetworkBase<MockIExecutableNetworkInternal>(mock_impl));
30     }
31 };
32
33 // CreateInferRequest
34 TEST_F(ExecutableNetworkBaseTests, canForwardCreateInferRequest) {
35     IInferRequest::Ptr req;
36     EXPECT_CALL(*mock_impl.get(), CreateInferRequest(Ref(req))).Times(1);
37     ASSERT_EQ(OK, exeNetwork->CreateInferRequest(req, &dsc));
38 }
39
40 TEST_F(ExecutableNetworkBaseTests, canReportErrorInCreateInferRequest) {
41     EXPECT_CALL(*mock_impl.get(), CreateInferRequest(_)).WillOnce(Throw(std::runtime_error("compare")));
42     IInferRequest::Ptr req;
43     ASSERT_NE(exeNetwork->CreateInferRequest(req, &dsc), OK);
44     ASSERT_STREQ(dsc.msg, "compare");
45 }
46
47 TEST_F(ExecutableNetworkBaseTests, canCatchUnknownErrorInCreateInferRequest) {
48     EXPECT_CALL(*mock_impl.get(), CreateInferRequest(_)).WillOnce(Throw(5));
49     IInferRequest::Ptr req;
50     ASSERT_EQ(UNEXPECTED, exeNetwork->CreateInferRequest(req, nullptr));
51 }
52
53 // Export
54 TEST_F(ExecutableNetworkBaseTests, canForwardExport) {
55     const std::string modelFileName;
56     EXPECT_CALL(*mock_impl.get(), Export(Ref(modelFileName))).Times(1);
57     ASSERT_EQ(OK, exeNetwork->Export(modelFileName, &dsc));
58 }
59
60 TEST_F(ExecutableNetworkBaseTests, canReportErrorInExport) {
61     EXPECT_CALL(*mock_impl.get(), Export(_)).WillOnce(Throw(std::runtime_error("compare")));
62     ASSERT_NE(exeNetwork->Export({}, &dsc), OK);
63     ASSERT_STREQ(dsc.msg, "compare");
64 }
65
66 TEST_F(ExecutableNetworkBaseTests, canCatchUnknownErrorInExport) {
67     EXPECT_CALL(*mock_impl.get(), Export(_)).WillOnce(Throw(5));
68     ASSERT_EQ(UNEXPECTED, exeNetwork->Export({}, nullptr));
69 }
70
71 // GetMappedTopology
72 TEST_F(ExecutableNetworkBaseTests, canForwardGetMappedTopology) {
73     std::map<std::string, std::vector<PrimitiveInfo::Ptr>> deployedTopology;
74     EXPECT_CALL(*mock_impl.get(), GetMappedTopology(Ref(deployedTopology))).Times(1);
75     ASSERT_EQ(OK, exeNetwork->GetMappedTopology(deployedTopology, &dsc));
76 }
77
78 TEST_F(ExecutableNetworkBaseTests, canReportErrorInCreateInferRequestGetMappedTopology) {
79     EXPECT_CALL(*mock_impl.get(), GetMappedTopology(_)).WillOnce(Throw(std::runtime_error("compare")));
80     std::map<std::string, std::vector<PrimitiveInfo::Ptr>> deployedTopology;
81     ASSERT_NE(exeNetwork->GetMappedTopology(deployedTopology, &dsc), OK);
82     ASSERT_STREQ(dsc.msg, "compare");
83 }
84
85 TEST_F(ExecutableNetworkBaseTests, canCatchUnknownErrorInGetMappedTopology) {
86     EXPECT_CALL(*mock_impl.get(), GetMappedTopology(_)).WillOnce(Throw(5));
87     std::map<std::string, std::vector<PrimitiveInfo::Ptr>> deployedTopology;
88     ASSERT_EQ(UNEXPECTED, exeNetwork->GetMappedTopology(deployedTopology, nullptr));
89 }