1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "cpp_interfaces/impl/mock_inference_plugin_internal.hpp"
8 #include "cpp_interfaces/interface/mock_iexecutable_network_internal.hpp"
10 #include <ie_version.hpp>
11 #include "cpp_interfaces/base/ie_plugin_base.hpp"
13 using namespace ::testing;
15 using namespace InferenceEngine;
16 using namespace InferenceEngine::details;
18 class ExecutableNetworkBaseTests : public ::testing::Test {
20 shared_ptr<MockIExecutableNetworkInternal> mock_impl;
21 shared_ptr<IExecutableNetwork> exeNetwork;
24 virtual void TearDown() {
27 virtual void SetUp() {
28 mock_impl.reset(new MockIExecutableNetworkInternal());
29 exeNetwork = shared_from_irelease(new ExecutableNetworkBase<MockIExecutableNetworkInternal>(mock_impl));
34 TEST_F(ExecutableNetworkBaseTests, canForwardCreateInferRequest) {
35 IInferRequest::Ptr req;
36 EXPECT_CALL(*mock_impl.get(), CreateInferRequest(Ref(req))).Times(1);
37 ASSERT_EQ(OK, exeNetwork->CreateInferRequest(req, &dsc));
40 TEST_F(ExecutableNetworkBaseTests, canReportErrorInCreateInferRequest) {
41 EXPECT_CALL(*mock_impl.get(), CreateInferRequest(_)).WillOnce(Throw(std::runtime_error("compare")));
42 IInferRequest::Ptr req;
43 ASSERT_NE(exeNetwork->CreateInferRequest(req, &dsc), OK);
44 ASSERT_STREQ(dsc.msg, "compare");
47 TEST_F(ExecutableNetworkBaseTests, canCatchUnknownErrorInCreateInferRequest) {
48 EXPECT_CALL(*mock_impl.get(), CreateInferRequest(_)).WillOnce(Throw(5));
49 IInferRequest::Ptr req;
50 ASSERT_EQ(UNEXPECTED, exeNetwork->CreateInferRequest(req, nullptr));
54 TEST_F(ExecutableNetworkBaseTests, canForwardExport) {
55 const std::string modelFileName;
56 EXPECT_CALL(*mock_impl.get(), Export(Ref(modelFileName))).Times(1);
57 ASSERT_EQ(OK, exeNetwork->Export(modelFileName, &dsc));
60 TEST_F(ExecutableNetworkBaseTests, canReportErrorInExport) {
61 EXPECT_CALL(*mock_impl.get(), Export(_)).WillOnce(Throw(std::runtime_error("compare")));
62 ASSERT_NE(exeNetwork->Export({}, &dsc), OK);
63 ASSERT_STREQ(dsc.msg, "compare");
66 TEST_F(ExecutableNetworkBaseTests, canCatchUnknownErrorInExport) {
67 EXPECT_CALL(*mock_impl.get(), Export(_)).WillOnce(Throw(5));
68 ASSERT_EQ(UNEXPECTED, exeNetwork->Export({}, nullptr));
72 TEST_F(ExecutableNetworkBaseTests, canForwardGetMappedTopology) {
73 std::map<std::string, std::vector<PrimitiveInfo::Ptr>> deployedTopology;
74 EXPECT_CALL(*mock_impl.get(), GetMappedTopology(Ref(deployedTopology))).Times(1);
75 ASSERT_EQ(OK, exeNetwork->GetMappedTopology(deployedTopology, &dsc));
78 TEST_F(ExecutableNetworkBaseTests, canReportErrorInCreateInferRequestGetMappedTopology) {
79 EXPECT_CALL(*mock_impl.get(), GetMappedTopology(_)).WillOnce(Throw(std::runtime_error("compare")));
80 std::map<std::string, std::vector<PrimitiveInfo::Ptr>> deployedTopology;
81 ASSERT_NE(exeNetwork->GetMappedTopology(deployedTopology, &dsc), OK);
82 ASSERT_STREQ(dsc.msg, "compare");
85 TEST_F(ExecutableNetworkBaseTests, canCatchUnknownErrorInGetMappedTopology) {
86 EXPECT_CALL(*mock_impl.get(), GetMappedTopology(_)).WillOnce(Throw(5));
87 std::map<std::string, std::vector<PrimitiveInfo::Ptr>> deployedTopology;
88 ASSERT_EQ(UNEXPECTED, exeNetwork->GetMappedTopology(deployedTopology, nullptr));