Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / cpp_interfaces / executable_network_thread_safe_tests.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <cpp_interfaces/impl/mock_executable_thread_safe_default.hpp>
7 #include <cpp_interfaces/impl/mock_infer_request_internal.hpp>
8 #include <cpp_interfaces/base/ie_executable_network_base.hpp>
9
10 using namespace ::testing;
11 using namespace std;
12 using namespace InferenceEngine;
13 using namespace InferenceEngine::details;
14
15 class ExecutableNetworkThreadSafeTests : public ::testing::Test {
16 protected:
17     shared_ptr<MockExecutableNetworkThreadSafe> mockExeNetwork;
18     shared_ptr<IExecutableNetwork> exeNetwork;
19     shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
20     ResponseDesc dsc;
21     StatusCode sts;
22
23     virtual void TearDown() {
24         EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockInferRequestInternal.get()));
25         EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockExeNetwork.get()));
26     }
27
28     virtual void SetUp() {
29         mockExeNetwork = make_shared<MockExecutableNetworkThreadSafe>();
30         exeNetwork = details::shared_from_irelease(
31                 new ExecutableNetworkBase<MockExecutableNetworkThreadSafe>(mockExeNetwork));
32         InputsDataMap networkInputs;
33         OutputsDataMap networkOutputs;
34         mockInferRequestInternal = make_shared<MockInferRequestInternal>(networkInputs, networkOutputs);
35     }
36 };
37
38 TEST_F(ExecutableNetworkThreadSafeTests, createInferRequestCallsThreadSafeImplAndSetNetworkIO) {
39     IInferRequest::Ptr req;
40     EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
41     EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
42     auto threadSafeReq = dynamic_pointer_cast<InferRequestBase<AsyncInferRequestThreadSafeDefault>>(req);
43     ASSERT_NE(threadSafeReq, nullptr);
44 }
45
46 TEST_F(ExecutableNetworkThreadSafeTests, returnErrorIfInferThrowsException) {
47     IInferRequest::Ptr req;
48     EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
49     EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
50     EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::runtime_error("")));
51     EXPECT_NO_THROW(sts = req->Infer(&dsc));
52     ASSERT_EQ(StatusCode::GENERAL_ERROR, sts) << dsc.msg;
53 }
54
55 TEST_F(ExecutableNetworkThreadSafeTests, returnErrorIfStartAsyncThrowsException) {
56     IInferRequest::Ptr req;
57     EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
58     EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
59     EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::runtime_error("")));
60     EXPECT_NO_THROW(sts = req->StartAsync(&dsc));
61     ASSERT_TRUE(StatusCode::OK == sts) << dsc.msg;
62     EXPECT_NO_THROW(sts = req->Wait(IInferRequest::WaitMode::RESULT_READY, &dsc));
63     ASSERT_EQ(StatusCode::GENERAL_ERROR, sts) << dsc.msg;
64 }