1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
10 #include <details/ie_exception.hpp>
11 #include <cpp_interfaces/ie_task.hpp>
12 #include <cpp_interfaces/ie_task_executor.hpp>
13 #include <cpp_interfaces/ie_task_with_stages.hpp>
14 #include <cpp_interfaces/ie_task_synchronizer.hpp>
15 #include "task_tests_utils.hpp"
18 using namespace ::testing;
20 using namespace InferenceEngine;
21 using namespace InferenceEngine::details;
30 class TaskCommonTests : public ::testing::Test, public testing::WithParamInterface<TaskFlavor> {
34 Task::Ptr createTask(std::function<void()> function = nullptr, bool forceNull = false) {
35 TaskFlavor flavor = GetParam();
36 bool condition = function || forceNull;
37 Task::Ptr baseTask = condition ? make_shared<Task>(function) : make_shared<Task>();
38 Task::Ptr stagedTask = condition ? make_shared<StagedTask>(function, 1) : make_shared<StagedTask>();
39 auto executor = make_shared<TaskExecutor>();
46 throw logic_error("Specified non-existent flavor of task");
51 TEST_P(TaskCommonTests, canCreateTask) {
52 ASSERT_NO_THROW(_task = createTask());
53 ASSERT_EQ(_task->getStatus(), Task::TS_INITIAL);
56 TEST_P(TaskCommonTests, canSetBusyStatus) {
57 ASSERT_NO_THROW(_task = createTask());
58 ASSERT_NO_THROW(_task->occupy());
59 ASSERT_EQ(_task->getStatus(), Task::TS_BUSY);
62 TEST_P(TaskCommonTests, firstOccupyReturnTrueSecondFalse) {
63 ASSERT_NO_THROW(_task = createTask());
64 ASSERT_TRUE(_task->occupy());
65 ASSERT_FALSE(_task->occupy());
68 TEST_P(TaskCommonTests, canRunDefaultTask) {
70 ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
71 ASSERT_EQ(_task->getStatus(), Task::TS_DONE);
74 TEST_P(TaskCommonTests, throwIfFunctionNull) {
75 ASSERT_THROW(_task = createTask(nullptr, true), InferenceEngineException);
78 TEST_P(TaskCommonTests, canWaitWithoutRun) {
80 ASSERT_NO_THROW(_task->wait(-1));
81 ASSERT_EQ(_task->getStatus(), Task::TS_INITIAL);
82 ASSERT_NO_THROW(_task->wait(1));
83 ASSERT_EQ(_task->getStatus(), Task::TS_INITIAL);
86 TEST_P(TaskCommonTests, canRunTaskFromThread) {
89 MetaThread metaThread([=]() {
90 _task->runNoThrowNoBusyCheck();
94 ASSERT_EQ(Task::Status::TS_DONE, _task->getStatus());
98 TEST_P(TaskCommonTests, canRunTaskFromThreadWithoutWait) {
99 _task = createTask([]() {
100 std::this_thread::sleep_for(std::chrono::milliseconds(500));
102 std::thread thread([this]() { _task->runNoThrowNoBusyCheck(); });
103 if (thread.joinable()) thread.join();
106 TEST_P(TaskCommonTests, waitReturnNotStartedIfTaskWasNotRun) {
107 _task = createTask();
108 Task::Status status = _task->wait(1);
109 ASSERT_EQ(status, Task::Status::TS_INITIAL);
112 TEST_P(TaskCommonTests, canCatchIEException) {
113 _task = createTask([]() { THROW_IE_EXCEPTION; });
114 ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
115 Task::Status status = _task->getStatus();
116 ASSERT_EQ(status, Task::Status::TS_ERROR);
117 EXPECT_THROW(_task->checkException(), InferenceEngineException);
120 TEST_P(TaskCommonTests, waitReturnErrorIfException) {
121 _task = createTask([]() { THROW_IE_EXCEPTION; });
122 ASSERT_NO_THROW(_task->occupy());
123 ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
124 Task::Status status = _task->wait(-1);
125 ASSERT_EQ(status, Task::Status::TS_ERROR);
126 EXPECT_THROW(_task->checkException(), InferenceEngineException);
129 TEST_P(TaskCommonTests, canCatchStdException) {
130 _task = createTask([]() { throw std::bad_alloc(); });
131 ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
132 Task::Status status = _task->getStatus();
133 ASSERT_EQ(status, Task::Status::TS_ERROR);
134 EXPECT_THROW(_task->checkException(), std::bad_alloc);
137 TEST_P(TaskCommonTests, canCleanExceptionPtr) {
138 bool throwException = true;
139 _task = createTask([&throwException]() { if (throwException) throw std::bad_alloc(); else return; });
140 ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
141 EXPECT_THROW(_task->checkException(), std::bad_alloc);
142 throwException = false;
143 ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
144 EXPECT_NO_THROW(_task->checkException());
147 std::string getTestCaseName(testing::TestParamInfo<TaskFlavor> obj) {
148 #define CASE(x) case x: return #x;
152 CASE(BASE_WITH_CALLBACK);
153 CASE(STAGED_WITH_CALLBACK);
160 INSTANTIATE_TEST_CASE_P(Task, TaskCommonTests,
161 ::testing::ValuesIn(std::vector<TaskFlavor>{BASE_TASK, STAGED_TASK}), getTestCaseName);