1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <cpp_interfaces/impl/mock_inference_plugin_internal.hpp>
8 #include <cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp>
10 #include <ie_version.hpp>
11 #include <cpp_interfaces/base/ie_plugin_base.hpp>
12 #include <cpp_interfaces/base/ie_infer_async_request_base.hpp>
14 using namespace ::testing;
16 using namespace InferenceEngine;
17 using namespace InferenceEngine::details;
19 class InferRequestBaseTests : public ::testing::Test {
21 std::shared_ptr<MockIAsyncInferRequestInternal> mock_impl;
22 shared_ptr<IInferRequest> request;
25 virtual void TearDown() {
28 virtual void SetUp() {
29 mock_impl.reset(new MockIAsyncInferRequestInternal());
30 request = details::shared_from_irelease(new InferRequestBase<MockIAsyncInferRequestInternal>(mock_impl));
35 TEST_F(InferRequestBaseTests, canForwardStartAsync) {
36 EXPECT_CALL(*mock_impl.get(), StartAsync()).Times(1);
37 ASSERT_EQ(OK, request->StartAsync(&dsc));
40 TEST_F(InferRequestBaseTests, canReportErrorInStartAsync) {
41 EXPECT_CALL(*mock_impl.get(), StartAsync()).WillOnce(Throw(std::runtime_error("compare")));
42 ASSERT_NE(request->StartAsync(&dsc), OK);
43 ASSERT_STREQ(dsc.msg, "compare");
46 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInStartAsync) {
47 EXPECT_CALL(*mock_impl.get(), StartAsync()).WillOnce(Throw(5));
48 ASSERT_EQ(UNEXPECTED, request->StartAsync(nullptr));
52 TEST_F(InferRequestBaseTests, canForwardGetUserData) {
53 void **data = nullptr;
54 EXPECT_CALL(*mock_impl.get(), GetUserData(data)).Times(1);
55 ASSERT_EQ(OK, request->GetUserData(data, &dsc));
58 TEST_F(InferRequestBaseTests, canReportErrorInGetUserData) {
59 EXPECT_CALL(*mock_impl.get(), GetUserData(_)).WillOnce(Throw(std::runtime_error("compare")));
60 ASSERT_NE(request->GetUserData(nullptr, &dsc), OK);
61 ASSERT_STREQ(dsc.msg, "compare");
64 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInGetUserData) {
65 EXPECT_CALL(*mock_impl.get(), GetUserData(_)).WillOnce(Throw(5));
66 ASSERT_EQ(UNEXPECTED, request->GetUserData(nullptr, nullptr));
70 TEST_F(InferRequestBaseTests, canForwardSetUserData) {
72 EXPECT_CALL(*mock_impl.get(), SetUserData(data)).Times(1);
73 ASSERT_EQ(OK, request->SetUserData(data, &dsc));
76 TEST_F(InferRequestBaseTests, canReportErrorInSetUserData) {
77 EXPECT_CALL(*mock_impl.get(), SetUserData(_)).WillOnce(Throw(std::runtime_error("compare")));
78 ASSERT_NE(request->SetUserData(nullptr, &dsc), OK);
79 ASSERT_STREQ(dsc.msg, "compare");
82 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInSetUserData) {
83 EXPECT_CALL(*mock_impl.get(), SetUserData(_)).WillOnce(Throw(5));
84 ASSERT_EQ(UNEXPECTED, request->SetUserData(nullptr, nullptr));
88 TEST_F(InferRequestBaseTests, canForwardWait) {
90 EXPECT_CALL(*mock_impl.get(), Wait(ms)).WillOnce(Return(StatusCode::OK));
91 ASSERT_EQ(OK, request->Wait(ms, &dsc)) << dsc.msg;
94 TEST_F(InferRequestBaseTests, canReportErrorInWait) {
95 EXPECT_CALL(*mock_impl.get(), Wait(_)).WillOnce(Throw(std::runtime_error("compare")));
97 ASSERT_NE(request->Wait(ms, &dsc), OK);
98 ASSERT_STREQ(dsc.msg, "compare");
101 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInWait) {
102 EXPECT_CALL(*mock_impl.get(), Wait(_)).WillOnce(Throw(5));
104 ASSERT_EQ(UNEXPECTED, request->Wait(ms, nullptr));
108 TEST_F(InferRequestBaseTests, canForwardInfer) {
109 EXPECT_CALL(*mock_impl.get(), Infer()).Times(1);
110 ASSERT_EQ(OK, request->Infer(&dsc));
113 TEST_F(InferRequestBaseTests, canReportErrorInInfer) {
114 EXPECT_CALL(*mock_impl.get(), Infer()).WillOnce(Throw(std::runtime_error("compare")));
115 ASSERT_NE(request->Infer(&dsc), OK);
116 ASSERT_STREQ(dsc.msg, "compare");
119 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInInfer) {
120 EXPECT_CALL(*mock_impl.get(), Infer()).WillOnce(Throw(5));
121 ASSERT_EQ(UNEXPECTED, request->Infer(nullptr));
124 // GetPerformanceCounts
125 TEST_F(InferRequestBaseTests, canForwardGetPerformanceCounts) {
126 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
127 EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(Ref(info))).Times(1);
128 ASSERT_EQ(OK, request->GetPerformanceCounts(info, &dsc));
131 TEST_F(InferRequestBaseTests, canReportErrorInGetPerformanceCounts) {
132 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
133 EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(_)).WillOnce(Throw(std::runtime_error("compare")));
134 ASSERT_NE(request->GetPerformanceCounts(info, &dsc), OK);
135 ASSERT_STREQ(dsc.msg, "compare");
138 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInGetPerformanceCounts) {
139 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> info;
140 EXPECT_CALL(*mock_impl.get(), GetPerformanceCounts(_)).WillOnce(Throw(5));
141 ASSERT_EQ(UNEXPECTED, request->GetPerformanceCounts(info, nullptr));
145 TEST_F(InferRequestBaseTests, canForwardGetBlob) {
147 const char *name = "";
148 EXPECT_CALL(*mock_impl.get(), GetBlob(name, Ref(data))).Times(1);
149 ASSERT_EQ(OK, request->GetBlob(name, data, &dsc));
152 TEST_F(InferRequestBaseTests, canReportErrorInGetBlob) {
153 EXPECT_CALL(*mock_impl.get(), GetBlob(_, _)).WillOnce(Throw(std::runtime_error("compare")));
155 ASSERT_NE(request->GetBlob(nullptr, data, &dsc), OK);
156 ASSERT_STREQ(dsc.msg, "compare");
159 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInGetBlob) {
161 EXPECT_CALL(*mock_impl.get(), GetBlob(_, _)).WillOnce(Throw(5));
162 ASSERT_EQ(UNEXPECTED, request->GetBlob(nullptr, data, nullptr));
166 TEST_F(InferRequestBaseTests, canForwardSetBlob) {
168 const char *name = "";
169 EXPECT_CALL(*mock_impl.get(), SetBlob(name, Ref(data))).Times(1);
170 ASSERT_EQ(OK, request->SetBlob(name, data, &dsc));
173 TEST_F(InferRequestBaseTests, canReportErrorInSetBlob) {
174 EXPECT_CALL(*mock_impl.get(), SetBlob(_, _)).WillOnce(Throw(std::runtime_error("compare")));
176 ASSERT_NE(request->SetBlob(nullptr, data, &dsc), OK);
177 ASSERT_STREQ(dsc.msg, "compare");
180 TEST_F(InferRequestBaseTests, canCatchUnknownErrorInSetBlob) {
182 EXPECT_CALL(*mock_impl.get(), SetBlob(_, _)).WillOnce(Throw(5));
183 ASSERT_EQ(UNEXPECTED, request->SetBlob(nullptr, data, nullptr));
186 // SetCompletionCallback
187 TEST_F(InferRequestBaseTests, canForwardSetCompletionCallback) {
188 InferenceEngine::IInferRequest::CompletionCallback callback = nullptr;
189 EXPECT_CALL(*mock_impl.get(), SetCompletionCallback(callback)).Times(1);
190 ASSERT_NO_THROW(request->SetCompletionCallback(callback));
193 TEST_F(InferRequestBaseTests, canReportErrorInSetCompletionCallback) {
194 EXPECT_CALL(*mock_impl.get(), SetCompletionCallback(_)).WillOnce(Throw(std::runtime_error("compare")));
195 ASSERT_NO_THROW(request->SetCompletionCallback(nullptr));