1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <cpp_interfaces/impl/mock_executable_thread_safe_async_only.hpp>
8 #include <cpp_interfaces/impl/mock_async_infer_request_internal.hpp>
10 #include <ie_version.hpp>
11 #include <cpp_interfaces/base/ie_executable_network_base.hpp>
13 using namespace ::testing;
15 using namespace InferenceEngine;
16 using namespace InferenceEngine::details;
18 class ExecutableNetworkThreadSafeAsyncOnlyTests : public ::testing::Test {
20 shared_ptr<MockExecutableNetworkThreadSafeAsyncOnly> mockExeNetwork;
21 shared_ptr<MockAsyncInferRequestInternal> mockAsyncInferRequestInternal;
22 shared_ptr<IExecutableNetwork> exeNetwork;
26 virtual void TearDown() {
27 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockAsyncInferRequestInternal.get()));
28 EXPECT_TRUE(Mock::VerifyAndClearExpectations(mockExeNetwork.get()));
31 virtual void SetUp() {
32 mockExeNetwork = make_shared<MockExecutableNetworkThreadSafeAsyncOnly>();
33 exeNetwork = details::shared_from_irelease(
34 new ExecutableNetworkBase<MockExecutableNetworkThreadSafeAsyncOnly>(mockExeNetwork));
35 InputsDataMap networkInputs;
36 OutputsDataMap networkOutputs;
37 mockAsyncInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(networkInputs, networkOutputs);
41 TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, createAsyncInferRequestCallsThreadSafeImplAndSetNetworkIO) {
42 IInferRequest::Ptr req;
43 EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
44 Return(mockAsyncInferRequestInternal));
45 EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
46 auto threadSafeReq = dynamic_pointer_cast<InferRequestBase<AsyncInferRequestInternal>>(req);
47 ASSERT_NE(threadSafeReq, nullptr);
50 TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, returnErrorIfInferThrowsException) {
51 IInferRequest::Ptr req;
52 EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
53 Return(mockAsyncInferRequestInternal));
54 EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
55 EXPECT_CALL(*mockAsyncInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::runtime_error("")));
56 EXPECT_NO_THROW(sts = req->Infer(&dsc));
57 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts) << dsc.msg;
60 TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, returnErrorIfStartAsyncThrowsException) {
61 IInferRequest::Ptr req;
62 EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
63 Return(mockAsyncInferRequestInternal));
64 EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
65 EXPECT_CALL(*mockAsyncInferRequestInternal.get(), StartAsyncImpl()).WillOnce(Throw(std::runtime_error("")));
66 EXPECT_NO_THROW(sts = req->StartAsync(&dsc));
67 ASSERT_EQ(StatusCode::GENERAL_ERROR, sts) << dsc.msg;
70 TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, canForwardStartAsyncAndInfer) {
71 IInferRequest::Ptr req;
72 EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
73 Return(mockAsyncInferRequestInternal));
74 EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
75 EXPECT_CALL(*mockAsyncInferRequestInternal.get(), StartAsyncImpl()).Times(1);
76 EXPECT_CALL(*mockAsyncInferRequestInternal.get(), InferImpl()).Times(1);
78 EXPECT_NO_THROW(req->StartAsync(&dsc)) << dsc.msg;
79 EXPECT_NO_THROW(req->Infer(&dsc)) << dsc.msg;
82 TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, canForwardInferAndStartAsync) {
83 IInferRequest::Ptr req;
84 EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
85 Return(mockAsyncInferRequestInternal));
86 EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
87 EXPECT_CALL(*mockAsyncInferRequestInternal.get(), StartAsyncImpl()).Times(1);
88 EXPECT_CALL(*mockAsyncInferRequestInternal.get(), InferImpl()).Times(1);
89 EXPECT_NO_THROW(req->Infer(&dsc)) << dsc.msg;
90 EXPECT_NO_THROW(req->StartAsync(&dsc)) << dsc.msg;