1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <inference_engine.hpp>
8 #include <cpp_interfaces/impl/mock_async_infer_request_thread_safe_internal.hpp>
10 using namespace ::testing;
12 using namespace InferenceEngine;
13 using namespace InferenceEngine::details;
15 class AsyncInferRequestThreadSafeInternalTests : public ::testing::Test {
17 MockAsyncInferRequestThreadSafeInternal::Ptr testRequest;
20 bool _doesThrowExceptionWithMessage(std::function<void()> func, string refError) {
21 std::string whatMessage;
24 } catch (const InferenceEngineException &iee) {
25 whatMessage = iee.what();
27 return whatMessage.find(refError) != std::string::npos;
30 virtual void SetUp() {
31 testRequest = make_shared<MockAsyncInferRequestThreadSafeInternal>();
37 TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnStartAsync) {
38 testRequest->setRequestBusy();
39 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->StartAsync(); }, REQUEST_BUSY_str));
42 TEST_F(AsyncInferRequestThreadSafeInternalTests, canResetBusyStatusIfStartAsyncTaskFails) {
43 EXPECT_CALL(*testRequest.get(), StartAsync_ThreadUnsafe()).Times(2)
44 .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
47 ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { testRequest->StartAsync(); }, "compare"));
48 ASSERT_NO_THROW(testRequest->StartAsync());
51 TEST_F(AsyncInferRequestThreadSafeInternalTests, deviceBusyAfterStartAsync) {
52 EXPECT_CALL(*testRequest.get(), StartAsync_ThreadUnsafe()).WillOnce(Return());
54 ASSERT_NO_THROW(testRequest->StartAsync());
56 ASSERT_TRUE(testRequest->isRequestBusy());
60 TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnGetUserData) {
61 testRequest->setRequestBusy();
62 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->GetUserData(nullptr); }, REQUEST_BUSY_str));
66 TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnSetUserData) {
67 testRequest->setRequestBusy();
68 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetUserData(nullptr); }, REQUEST_BUSY_str));
72 TEST_F(AsyncInferRequestThreadSafeInternalTests, returnInferNotStartedOnWait) {
73 testRequest->setRequestBusy();
75 EXPECT_CALL(*testRequest.get(), Wait(ms)).WillOnce(Return(INFER_NOT_STARTED));
77 StatusCode actual = testRequest->Wait(ms);
78 ASSERT_EQ(INFER_NOT_STARTED, actual);
82 TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnInfer) {
83 testRequest->setRequestBusy();
84 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->Infer(); }, REQUEST_BUSY_str));
87 TEST_F(AsyncInferRequestThreadSafeInternalTests, canResetBusyStatusIfInferFails) {
88 EXPECT_CALL(*testRequest.get(), Infer_ThreadUnsafe()).Times(2)
89 .WillOnce(Throw(InferenceEngineException(__FILE__, __LINE__) << "compare"))
92 ASSERT_TRUE(_doesThrowExceptionWithMessage([&]() { testRequest->Infer(); }, "compare"));
93 ASSERT_NO_THROW(testRequest->Infer());
96 // GetPerformanceCounts
97 TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnGetPerformanceCounts) {
98 testRequest->setRequestBusy();
99 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
100 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
101 testRequest->GetPerformanceCounts(info);
102 }, REQUEST_BUSY_str));
106 TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnGetBlob) {
107 testRequest->setRequestBusy();
108 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() {
110 testRequest->GetBlob(nullptr, data);
111 }, REQUEST_BUSY_str));
115 TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnSetBlob) {
116 testRequest->setRequestBusy();
117 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetBlob(nullptr, nullptr); }, REQUEST_BUSY_str));
120 // SetCompletionCallback
121 TEST_F(AsyncInferRequestThreadSafeInternalTests, returnRequestBusyOnSetCompletionCallback) {
122 testRequest->setRequestBusy();
123 ASSERT_TRUE(_doesThrowExceptionWithMessage([this]() { testRequest->SetCompletionCallback(nullptr); },