1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <cpp_interfaces/impl/mock_executable_thread_safe_default.hpp>
7 #include <cpp_interfaces/impl/mock_infer_request_internal.hpp>
8 #include <cpp_interfaces/base/ie_executable_network_base.hpp>
10 using namespace ::testing;
12 using namespace InferenceEngine;
13 using namespace InferenceEngine::details;
15 class ExecutableNetworkThreadSafeTests : public ::testing::Test {
17 shared_ptr<MockExecutableNetworkThreadSafe> mockExeNetwork;
18 shared_ptr<IExecutableNetwork> exeNetwork;
19 shared_ptr<MockInferRequestInternal> mockInferRequestInternal;
23 virtual void TearDown() {
24 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockInferRequestInternal.get()));
25 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockExeNetwork.get()));
28 virtual void SetUp() {
29 mockExeNetwork = make_shared<MockExecutableNetworkThreadSafe>();
30 exeNetwork = details::shared_from_irelease(
31 new ExecutableNetworkBase<MockExecutableNetworkThreadSafe>(mockExeNetwork));
32 InputsDataMap networkInputs;
33 OutputsDataMap networkOutputs;
34 mockInferRequestInternal = make_shared<MockInferRequestInternal>(networkInputs, networkOutputs);
38 TEST_F(ExecutableNetworkThreadSafeTests, createInferRequestCallsThreadSafeImplAndSetNetworkIO) {
39 IInferRequest::Ptr req;
40 EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
41 EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
42 auto threadSafeReq = dynamic_pointer_cast<InferRequestBase<AsyncInferRequestThreadSafeDefault>>(req);
43 ASSERT_NE(threadSafeReq, nullptr);
46 TEST_F(ExecutableNetworkThreadSafeTests, returnErrorIfInferThrowsException) {
47 IInferRequest::Ptr req;
48 EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
49 EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
50 EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::runtime_error("")));
51 EXPECT_NO_THROW(sts = req->Infer(&dsc));
52 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts) << dsc.msg;
55 TEST_F(ExecutableNetworkThreadSafeTests, returnErrorIfStartAsyncThrowsException) {
56 IInferRequest::Ptr req;
57 EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
58 EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
59 EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::runtime_error("")));
60 EXPECT_NO_THROW(sts = req->StartAsync(&dsc));
61 ASSERT_TRUE(StatusCode::OK == sts) << dsc.msg;
62 EXPECT_NO_THROW(sts = req->Wait(IInferRequest::WaitMode::RESULT_READY, &dsc));
63 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts) << dsc.msg;