b535cc97ee4715e4040d6687eae89fefff9e5663
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / cpp_interfaces / task_common_tests.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include <thread>
9
10 #include <ie_common.h>
11 #include <details/ie_exception.hpp>
12 #include <cpp_interfaces/ie_task.hpp>
13 #include <cpp_interfaces/ie_task_executor.hpp>
14 #include <cpp_interfaces/ie_task_with_stages.hpp>
15 #include <cpp_interfaces/ie_task_synchronizer.hpp>
16 #include "task_tests_utils.hpp"
17
18
19 using namespace ::testing;
20 using namespace std;
21 using namespace InferenceEngine;
22 using namespace InferenceEngine::details;
23
24 enum TaskFlavor {
25     BASE_TASK,
26     STAGED_TASK,
27     BASE_WITH_CALLBACK,
28     STAGED_WITH_CALLBACK
29 };
30
31 class TaskCommonTests : public ::testing::Test, public testing::WithParamInterface<TaskFlavor> {
32 protected:
33     Task::Ptr _task;
34
35     Task::Ptr createTask(std::function<void()> function = nullptr, bool forceNull = false) {
36         TaskFlavor flavor = GetParam();
37         bool condition = function || forceNull;
38         Task::Ptr baseTask = condition ? make_shared<Task>(function) : make_shared<Task>();
39         Task::Ptr stagedTask = condition ? make_shared<StagedTask>(function, 1) : make_shared<StagedTask>();
40         auto executor = make_shared<TaskExecutor>();
41         switch (flavor) {
42             case BASE_TASK:
43                 return baseTask;
44             case STAGED_TASK:
45                 return stagedTask;
46             default:
47                 throw logic_error("Specified non-existent flavor of task");
48         }
49     }
50 };
51
52 TEST_P(TaskCommonTests, canCreateTask) {
53     ASSERT_NO_THROW(_task = createTask());
54     ASSERT_EQ(_task->getStatus(), Task::TS_INITIAL);
55 }
56
57 TEST_P(TaskCommonTests, canSetBusyStatus) {
58     ASSERT_NO_THROW(_task = createTask());
59     ASSERT_NO_THROW(_task->occupy());
60     ASSERT_EQ(_task->getStatus(), Task::TS_BUSY);
61 }
62
63 TEST_P(TaskCommonTests, firstOccupyReturnTrueSecondFalse) {
64     ASSERT_NO_THROW(_task = createTask());
65     ASSERT_TRUE(_task->occupy());
66     ASSERT_FALSE(_task->occupy());
67 }
68
69 TEST_P(TaskCommonTests, canRunDefaultTask) {
70     _task = createTask();
71     ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
72     ASSERT_EQ(_task->getStatus(), Task::TS_DONE);
73 }
74
75 TEST_P(TaskCommonTests, throwIfFunctionNull) {
76     ASSERT_THROW(_task = createTask(nullptr, true), InferenceEngineException);
77 }
78
79 TEST_P(TaskCommonTests, canWaitWithoutRun) {
80     _task = createTask();
81     ASSERT_NO_THROW(_task->wait(-1));
82     ASSERT_EQ(_task->getStatus(), Task::TS_INITIAL);
83     ASSERT_NO_THROW(_task->wait(1));
84     ASSERT_EQ(_task->getStatus(), Task::TS_INITIAL);
85 }
86
87 TEST_P(TaskCommonTests, canRunTaskFromThread) {
88     _task = createTask();
89
90     MetaThread metaThread([=]() {
91         _task->runNoThrowNoBusyCheck();
92     });
93
94     metaThread.join();
95     ASSERT_EQ(Task::Status::TS_DONE, _task->getStatus());
96 }
97
98
99 TEST_P(TaskCommonTests, canRunTaskFromThreadWithoutWait) {
100     _task = createTask([]() {
101         std::this_thread::sleep_for(std::chrono::milliseconds(500));
102     });
103     std::thread thread([this]() { _task->runNoThrowNoBusyCheck(); });
104     if (thread.joinable()) thread.join();
105 }
106
107 TEST_P(TaskCommonTests, waitReturnNotStartedIfTaskWasNotRun) {
108     _task = createTask();
109     Task::Status status = _task->wait(1);
110     ASSERT_EQ(status, Task::Status::TS_INITIAL);
111 }
112
113 TEST_P(TaskCommonTests, canCatchIEException) {
114     _task = createTask([]() { THROW_IE_EXCEPTION; });
115     ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
116     Task::Status status = _task->getStatus();
117     ASSERT_EQ(status, Task::Status::TS_ERROR);
118     EXPECT_THROW(_task->checkException(), InferenceEngineException);
119 }
120
121 TEST_P(TaskCommonTests, waitReturnErrorIfException) {
122     _task = createTask([]() { THROW_IE_EXCEPTION; });
123     ASSERT_NO_THROW(_task->occupy());
124     ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
125     Task::Status status = _task->wait(-1);
126     ASSERT_EQ(status, Task::Status::TS_ERROR);
127     EXPECT_THROW(_task->checkException(), InferenceEngineException);
128 }
129
130 TEST_P(TaskCommonTests, canCatchStdException) {
131     _task = createTask([]() { throw std::bad_alloc(); });
132     ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
133     Task::Status status = _task->getStatus();
134     ASSERT_EQ(status, Task::Status::TS_ERROR);
135     EXPECT_THROW(_task->checkException(), std::bad_alloc);
136 }
137
138 TEST_P(TaskCommonTests, canCleanExceptionPtr) {
139     bool throwException = true;
140     _task = createTask([&throwException]() { if (throwException) throw std::bad_alloc(); else return; });
141     ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
142     EXPECT_THROW(_task->checkException(), std::bad_alloc);
143     throwException = false;
144     ASSERT_NO_THROW(_task->runNoThrowNoBusyCheck());
145     EXPECT_NO_THROW(_task->checkException());
146 }
147
148 std::string getTestCaseName(testing::TestParamInfo<TaskFlavor> obj) {
149 #define CASE(x) case x: return #x;
150     switch (obj.param) {
151         CASE(BASE_TASK);
152         CASE(STAGED_TASK);
153         CASE(BASE_WITH_CALLBACK);
154         CASE(STAGED_WITH_CALLBACK);
155         default :
156             return "EMPTY";
157 #undef CASE
158     }
159 }
160
161 INSTANTIATE_TEST_CASE_P(Task, TaskCommonTests,
162                         ::testing::ValuesIn(std::vector<TaskFlavor>{BASE_TASK, STAGED_TASK}), getTestCaseName);