1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <ie_version.hpp>
8 #include <cpp/ie_executable_network.hpp>
9 #include <cpp_interfaces/base/ie_plugin_base.hpp>
11 #include <mock_icnn_network.hpp>
12 #include <mock_iexecutable_network.hpp>
13 #include <cpp_interfaces/interface/mock_imemory_state_internal.hpp>
14 #include <inference_engine/cpp_interfaces/base/ie_executable_network_base.hpp>
15 #include <cpp_interfaces/interface/mock_iexecutable_network_internal.hpp>
16 #include <inference_engine/cpp_interfaces/impl/ie_memory_state_internal.hpp>
18 using namespace ::testing;
20 using namespace InferenceEngine;
21 using namespace InferenceEngine::details;
23 class MemoryStateTests : public ::testing::Test {
25 shared_ptr<MockIExecutableNetworkInternal> mockExeNetworkInternal;
26 shared_ptr<MockIMemoryStateInternal> mockMemoryStateInternal;
28 virtual void SetUp() {
29 mockExeNetworkInternal = make_shared<MockIExecutableNetworkInternal>();
30 mockMemoryStateInternal = make_shared<MockIMemoryStateInternal>();
34 TEST_F(MemoryStateTests, ExecutableNetworkCanConvertOneMemoryStateFromCppToAPI) {
36 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
37 std::vector<IMemoryStateInternal::Ptr> toReturn(1);
38 toReturn[0] = mockMemoryStateInternal;
40 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
42 auto state = net.QueryState();
43 ASSERT_EQ(state.size(), 1);
46 TEST_F(MemoryStateTests, ExecutableNetworkCanConvertZeroMemoryStateFromCppToAPI) {
48 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
49 std::vector<IMemoryStateInternal::Ptr> toReturn;
51 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
53 auto state = net.QueryState();
54 ASSERT_EQ(state.size(), 0);
57 TEST_F(MemoryStateTests, ExecutableNetworkCanConvert2MemoryStatesFromCPPtoAPI) {
59 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
60 std::vector<IMemoryStateInternal::Ptr> toReturn;
61 toReturn.push_back(mockMemoryStateInternal);
62 toReturn.push_back(mockMemoryStateInternal);
64 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
66 auto state = net.QueryState();
67 ASSERT_EQ(state.size(), 2);
70 TEST_F(MemoryStateTests, MemoryStatePropagatesReset) {
72 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
73 std::vector<IMemoryStateInternal::Ptr> toReturn;
74 toReturn.push_back(mockMemoryStateInternal);
76 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
77 EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).Times(1);
79 auto state = net.QueryState();
80 state.front().Reset();
83 TEST_F(MemoryStateTests, MemoryStatePropagatesExceptionsFromReset) {
85 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
86 std::vector<IMemoryStateInternal::Ptr> toReturn;
87 toReturn.push_back(mockMemoryStateInternal);
89 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
90 EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
92 auto state = net.QueryState();
93 EXPECT_ANY_THROW(state.front().Reset());
96 TEST_F(MemoryStateTests, MemoryStatePropagatesGetName) {
98 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
99 std::vector<IMemoryStateInternal::Ptr> toReturn;
100 toReturn.push_back(mockMemoryStateInternal);
102 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
103 EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
105 auto state = net.QueryState();
106 EXPECT_STREQ(state.front().GetName().c_str(), "someName");
109 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithZeroLen) {
111 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
112 std::vector<IMemoryStateInternal::Ptr> toReturn;
113 toReturn.push_back(mockMemoryStateInternal);
115 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
116 EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
118 IMemoryState::Ptr pState;
120 static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
121 char *name = reinterpret_cast<char *>(1);
122 EXPECT_NO_THROW(pState->GetName(name, 0, nullptr));
126 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfOne) {
128 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
129 std::vector<IMemoryStateInternal::Ptr> toReturn;
130 toReturn.push_back(mockMemoryStateInternal);
132 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
133 EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
135 IMemoryState::Ptr pState;
137 static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
139 EXPECT_NO_THROW(pState->GetName(name, 1, nullptr));
140 EXPECT_STREQ(name, "");
143 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfTwo) {
145 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
146 std::vector<IMemoryStateInternal::Ptr> toReturn;
147 toReturn.push_back(mockMemoryStateInternal);
149 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
150 EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
152 IMemoryState::Ptr pState;
154 static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
156 EXPECT_NO_THROW(pState->GetName(name, 2, nullptr));
157 EXPECT_STREQ(name, "s");
160 TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) {
162 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
163 std::vector<IMemoryStateInternal::Ptr> toReturn;
165 toReturn.push_back(mockMemoryStateInternal);
167 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
168 EXPECT_CALL(*mockMemoryStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
170 float data[] = {123, 124, 125};
171 auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
173 EXPECT_NO_THROW(net.QueryState().front().SetState(stateBlob));
174 ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
175 ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
176 ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
179 TEST_F(MemoryStateTests, MemoryStateCanPropagateGetLastState) {
181 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
182 std::vector<IMemoryStateInternal::Ptr> toReturn;
184 float data[] = {123, 124, 125};
185 auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
188 toReturn.push_back(mockMemoryStateInternal);
190 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
191 EXPECT_CALL(*mockMemoryStateInternal.get(), GetLastState()).WillOnce(Return(stateBlob));
194 auto saver = net.QueryState().front().GetLastState();
195 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
196 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
197 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
200 class MemoryStateInternalMockImpl : public MemoryStateInternal {
202 using MemoryStateInternal::MemoryStateInternal;
203 MOCK_METHOD0(Reset, void ());
206 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveName) {
208 IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
210 ASSERT_STREQ(pState->GetName().c_str(), "name");
214 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) {
216 IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
217 float data[] = {123, 124, 125};
218 auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
220 pState->SetState(stateBlob);
221 auto saver = pState->GetLastState();
223 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 123);
224 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 124);
225 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[2], 125);
229 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) {
231 IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
232 float data[] = {123, 124, 125};
233 auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
235 pState->SetState(stateBlob);
240 auto saver = pState->GetLastState();
242 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 121);
243 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 122);
244 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[2], 123);