Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / cpp_interfaces / async_infer_request_base_tests.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <cpp_interfaces/impl/mock_inference_plugin_internal.hpp>
8 #include <cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp>
9
10 #include <ie_version.hpp>
11 #include <cpp_interfaces/base/ie_plugin_base.hpp>
12 #include <cpp_interfaces/base/ie_infer_async_request_base.hpp>
13
14 using namespace ::testing;
15 using namespace std;
16 using namespace InferenceEngine;
17 using namespace InferenceEngine::details;
18
19 class InferRequestBaseTests : public ::testing::Test {
20 protected:
21     std::shared_ptr<MockIAsyncInferRequestInternal> mock_impl;
22     shared_ptr<IInferRequest> request;
23     ResponseDesc dsc;
24
25     virtual void TearDown() {
26     }
27
28     virtual void SetUp() {
29         mock_impl.reset(new MockIAsyncInferRequestInternal());
30         request = details::shared_from_irelease(new InferRequestBase<MockIAsyncInferRequestInternal>(mock_impl));
31     }
32 };
33
34 // StartAsync
35 TEST_F(InferRequestBaseTests, canForwardStartAsync) {
36     EXPECT_CALL(*mock_impl.get(), StartAsync()).Times(1);
37     ASSERT_EQ(OK, request->StartAsync(&dsc));
38 }
39
40 TEST_F(InferRequestBaseTests, canReportErrorInStartAsync) {
41     EXPECT_CALL(*mock_impl.get(), StartAsync()).WillOnce(Throw(std::runtime_error("compare")));
42     ASSERT_NE(request->StartAsync(&dsc), OK);
43     ASSERT_STREQ(dsc.msg, "compare");
44 }
45
46 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInStartAsync) {
47     EXPECT_CALL(*mock_impl.get(), StartAsync()).WillOnce(Throw(5));
48     ASSERT_EQ(UNEXPECTED, request->StartAsync(nullptr));
49 }
50
51 // GetUserData
52 TEST_F(InferRequestBaseTests, canForwardGetUserData) {
53     void **data = nullptr;
54     EXPECT_CALL(*mock_impl.get(), GetUserData(data)).Times(1);
55     ASSERT_EQ(OK, request->GetUserData(data, &dsc));
56 }
57
58 TEST_F(InferRequestBaseTests, canReportErrorInGetUserData) {
59     EXPECT_CALL(*mock_impl.get(), GetUserData(_)).WillOnce(Throw(std::runtime_error("compare")));
60     ASSERT_NE(request->GetUserData(nullptr, &dsc), OK);
61     ASSERT_STREQ(dsc.msg, "compare");
62 }
63
64 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInGetUserData) {
65     EXPECT_CALL(*mock_impl.get(), GetUserData(_)).WillOnce(Throw(5));
66     ASSERT_EQ(UNEXPECTED, request->GetUserData(nullptr, nullptr));
67 }
68
69 // SetUserData
70 TEST_F(InferRequestBaseTests, canForwardSetUserData) {
71     void *data = nullptr;
72     EXPECT_CALL(*mock_impl.get(), SetUserData(data)).Times(1);
73     ASSERT_EQ(OK, request->SetUserData(data, &dsc));
74 }
75
76 TEST_F(InferRequestBaseTests, canReportErrorInSetUserData) {
77     EXPECT_CALL(*mock_impl.get(), SetUserData(_)).WillOnce(Throw(std::runtime_error("compare")));
78     ASSERT_NE(request->SetUserData(nullptr, &dsc), OK);
79     ASSERT_STREQ(dsc.msg, "compare");
80 }
81
82 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInSetUserData) {
83     EXPECT_CALL(*mock_impl.get(), SetUserData(_)).WillOnce(Throw(5));
84     ASSERT_EQ(UNEXPECTED, request->SetUserData(nullptr, nullptr));
85 }
86
87 // Wait
88 TEST_F(InferRequestBaseTests, canForwardWait) {
89     int64_t ms = 0;
90     EXPECT_CALL(*mock_impl.get(), Wait(ms)).WillOnce(Return(StatusCode::OK));
91     ASSERT_EQ(OK, request->Wait(ms, &dsc)) << dsc.msg;
92 }
93
94 TEST_F(InferRequestBaseTests, canReportErrorInWait) {
95     EXPECT_CALL(*mock_impl.get(), Wait(_)).WillOnce(Throw(std::runtime_error("compare")));
96     int64_t ms = 0;
97     ASSERT_NE(request->Wait(ms, &dsc), OK);
98     ASSERT_STREQ(dsc.msg, "compare");
99 }
100
101 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInWait) {
102     EXPECT_CALL(*mock_impl.get(), Wait(_)).WillOnce(Throw(5));
103     int64_t ms = 0;
104     ASSERT_EQ(UNEXPECTED, request->Wait(ms, nullptr));
105 }
106
107 // Infer
108 TEST_F(InferRequestBaseTests, canForwardInfer) {
109     EXPECT_CALL(*mock_impl.get(), Infer()).Times(1);
110     ASSERT_EQ(OK, request->Infer(&dsc));
111 }
112
113 TEST_F(InferRequestBaseTests, canReportErrorInInfer) {
114     EXPECT_CALL(*mock_impl.get(), Infer()).WillOnce(Throw(std::runtime_error("compare")));
115     ASSERT_NE(request->Infer(&dsc), OK);
116     ASSERT_STREQ(dsc.msg, "compare");
117 }
118
119 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInInfer) {
120     EXPECT_CALL(*mock_impl.get(), Infer()).WillOnce(Throw(5));
121     ASSERT_EQ(UNEXPECTED, request->Infer(nullptr));
122 }
123
124 // GetPerformanceCounts
125 TEST_F(InferRequestBaseTests, canForwardGetPerformanceCounts) {
126     std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
127     EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(Ref(info))).Times(1);
128     ASSERT_EQ(OK, request->GetPerformanceCounts(info, &dsc));
129 }
130
131 TEST_F(InferRequestBaseTests, canReportErrorInGetPerformanceCounts) {
132     std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
133     EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(_)).WillOnce(Throw(std::runtime_error("compare")));
134     ASSERT_NE(request->GetPerformanceCounts(info, &dsc), OK);
135     ASSERT_STREQ(dsc.msg, "compare");
136 }
137
138 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInGetPerformanceCounts) {
139     std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
140     EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(_)).WillOnce(Throw(5));
141     ASSERT_EQ(UNEXPECTED, request->GetPerformanceCounts(info, nullptr));
142 }
143
144 // GetBlob
145 TEST_F(InferRequestBaseTests, canForwardGetBlob) {
146     Blob::Ptr data;
147     const char *name = "";
148     EXPECT_CALL(*mock_impl.get(), GetBlob(name, Ref(data))).Times(1);
149     ASSERT_EQ(OK, request->GetBlob(name, data, &dsc));
150 }
151
152 TEST_F(InferRequestBaseTests, canReportErrorInGetBlob) {
153     EXPECT_CALL(*mock_impl.get(), GetBlob(_, _)).WillOnce(Throw(std::runtime_error("compare")));
154     Blob::Ptr data;
155     ASSERT_NE(request->GetBlob(nullptr, data, &dsc), OK);
156     ASSERT_STREQ(dsc.msg, "compare");
157 }
158
159 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInGetBlob) {
160     Blob::Ptr data;
161     EXPECT_CALL(*mock_impl.get(), GetBlob(_, _)).WillOnce(Throw(5));
162     ASSERT_EQ(UNEXPECTED, request->GetBlob(nullptr, data, nullptr));
163 }
164
165 // SetBlob
166 TEST_F(InferRequestBaseTests, canForwardSetBlob) {
167     Blob::Ptr data;
168     const char *name = "";
169     EXPECT_CALL(*mock_impl.get(), SetBlob(name, Ref(data))).Times(1);
170     ASSERT_EQ(OK, request->SetBlob(name, data, &dsc));
171 }
172
173 TEST_F(InferRequestBaseTests, canReportErrorInSetBlob) {
174     EXPECT_CALL(*mock_impl.get(), SetBlob(_, _)).WillOnce(Throw(std::runtime_error("compare")));
175     Blob::Ptr data;
176     ASSERT_NE(request->SetBlob(nullptr, data, &dsc), OK);
177     ASSERT_STREQ(dsc.msg, "compare");
178 }
179
180 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInSetBlob) {
181     Blob::Ptr data;
182     EXPECT_CALL(*mock_impl.get(), SetBlob(_, _)).WillOnce(Throw(5));
183     ASSERT_EQ(UNEXPECTED, request->SetBlob(nullptr, data, nullptr));
184 }
185
186 // SetCompletionCallback
187 TEST_F(InferRequestBaseTests, canForwardSetCompletionCallback) {
188     InferenceEngine::IInferRequest::CompletionCallback callback = nullptr;
189     EXPECT_CALL(*mock_impl.get(), SetCompletionCallback(callback)).Times(1);
190     ASSERT_NO_THROW(request->SetCompletionCallback(callback));
191 }
192
193 TEST_F(InferRequestBaseTests, canReportErrorInSetCompletionCallback) {
194     EXPECT_CALL(*mock_impl.get(), SetCompletionCallback(_)).WillOnce(Throw(std::runtime_error("compare")));
195     ASSERT_NO_THROW(request->SetCompletionCallback(nullptr));
196 }