33a431f6a9a4a556e90ee179e8be4443c69b08b7
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / cpp_interfaces / memory_state_tests.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include <ie_version.hpp>
9 #include <cpp/ie_executable_network.hpp>
10 #include <cpp_interfaces/base/ie_plugin_base.hpp>
11
12 #include <mock_icnn_network.hpp>
13 #include <mock_iexecutable_network.hpp>
14 #include <cpp_interfaces/interface/mock_imemory_state_internal.hpp>
15 #include <inference_engine/cpp_interfaces/base/ie_executable_network_base.hpp>
16 #include <cpp_interfaces/interface/mock_iexecutable_network_internal.hpp>
17 #include <inference_engine/cpp_interfaces/impl/ie_memory_state_internal.hpp>
18
19 using namespace ::testing;
20 using namespace std;
21 using namespace InferenceEngine;
22 using namespace InferenceEngine::details;
23
24 class MemoryStateTests : public ::testing::Test {
25  protected:
26     shared_ptr<MockIExecutableNetworkInternal> mockExeNetworkInternal;
27     shared_ptr<MockIMemoryStateInternal> mockMemoryStateInternal;
28
29     virtual void SetUp() {
30         mockExeNetworkInternal = make_shared<MockIExecutableNetworkInternal>();
31         mockMemoryStateInternal = make_shared<MockIMemoryStateInternal>();
32     }
33 };
34
35 TEST_F(MemoryStateTests, ExecutableNetworkCanConvertOneMemoryStateFromCppToAPI) {
36
37     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
38     std::vector<IMemoryStateInternal::Ptr> toReturn(1);
39     toReturn[0] = mockMemoryStateInternal;
40
41     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
42
43     auto state = net.QueryState();
44     ASSERT_EQ(state.size(), 1);
45 }
46
47 TEST_F(MemoryStateTests, ExecutableNetworkCanConvertZeroMemoryStateFromCppToAPI) {
48
49     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
50     std::vector<IMemoryStateInternal::Ptr> toReturn;
51
52     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
53
54     auto state = net.QueryState();
55     ASSERT_EQ(state.size(), 0);
56 }
57
58 TEST_F(MemoryStateTests, ExecutableNetworkCanConvert2MemoryStatesFromCPPtoAPI) {
59
60     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
61     std::vector<IMemoryStateInternal::Ptr> toReturn;
62     toReturn.push_back(mockMemoryStateInternal);
63     toReturn.push_back(mockMemoryStateInternal);
64
65     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
66
67     auto state = net.QueryState();
68     ASSERT_EQ(state.size(), 2);
69 }
70
71 TEST_F(MemoryStateTests, MemoryStatePropagatesReset) {
72
73     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
74     std::vector<IMemoryStateInternal::Ptr> toReturn;
75     toReturn.push_back(mockMemoryStateInternal);
76
77     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
78     EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).Times(1);
79
80     auto state = net.QueryState();
81     state.front().Reset();
82 }
83
84 TEST_F(MemoryStateTests, MemoryStatePropagatesExceptionsFromReset) {
85
86     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
87     std::vector<IMemoryStateInternal::Ptr> toReturn;
88     toReturn.push_back(mockMemoryStateInternal);
89
90     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
91     EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
92
93     auto state = net.QueryState();
94     EXPECT_ANY_THROW(state.front().Reset());
95 }
96
97 TEST_F(MemoryStateTests, MemoryStatePropagatesGetName) {
98
99     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
100     std::vector<IMemoryStateInternal::Ptr> toReturn;
101     toReturn.push_back(mockMemoryStateInternal);
102
103     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
104     EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
105
106     auto state = net.QueryState();
107     EXPECT_STREQ(state.front().GetName().c_str(), "someName");
108 }
109
110 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithZeroLen) {
111
112     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
113     std::vector<IMemoryStateInternal::Ptr> toReturn;
114     toReturn.push_back(mockMemoryStateInternal);
115
116     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
117     EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
118
119     IMemoryState::Ptr pState;
120
121     static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
122     char *name = reinterpret_cast<char *>(1);
123     EXPECT_NO_THROW(pState->GetName(name, 0, nullptr));
124 }
125
126
127 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfOne) {
128
129     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
130     std::vector<IMemoryStateInternal::Ptr> toReturn;
131     toReturn.push_back(mockMemoryStateInternal);
132
133     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
134     EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
135
136     IMemoryState::Ptr pState;
137
138     static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
139     char name[1];
140     EXPECT_NO_THROW(pState->GetName(name, 1, nullptr));
141     EXPECT_STREQ(name, "");
142 }
143
144 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfTwo) {
145
146     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
147     std::vector<IMemoryStateInternal::Ptr> toReturn;
148     toReturn.push_back(mockMemoryStateInternal);
149
150     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
151     EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
152
153     IMemoryState::Ptr pState;
154
155     static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
156     char name[2];
157     EXPECT_NO_THROW(pState->GetName(name, 2, nullptr));
158     EXPECT_STREQ(name, "s");
159 }
160
161 TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) {
162
163     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
164     std::vector<IMemoryStateInternal::Ptr> toReturn;
165     Blob::Ptr saver;
166     toReturn.push_back(mockMemoryStateInternal);
167
168     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
169     EXPECT_CALL(*mockMemoryStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
170
171     float data[] = {123, 124, 125};
172     auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
173
174     EXPECT_NO_THROW(net.QueryState().front().SetState(stateBlob));
175     ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
176     ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
177     ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
178 }
179
180 TEST_F(MemoryStateTests, MemoryStateCanPropagateGetLastState) {
181
182     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
183     std::vector<IMemoryStateInternal::Ptr> toReturn;
184
185     float data[] = {123, 124, 125};
186     auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
187
188
189     toReturn.push_back(mockMemoryStateInternal);
190
191     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
192     EXPECT_CALL(*mockMemoryStateInternal.get(), GetLastState()).WillOnce(Return(stateBlob));
193
194
195     auto saver = net.QueryState().front().GetLastState();
196     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
197     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
198     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
199 }
200
201 class MemoryStateInternalMockImpl : public MemoryStateInternal {
202  public:
203     using MemoryStateInternal::MemoryStateInternal;
204     MOCK_METHOD0(Reset, void ());
205 };
206
207 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveName) {
208
209     IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
210
211     ASSERT_STREQ(pState->GetName().c_str(), "name");
212 }
213
214
215 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) {
216
217     IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
218     float data[] = {123, 124, 125};
219     auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
220
221     pState->SetState(stateBlob);
222     auto saver = pState->GetLastState();
223
224     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 123);
225     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 124);
226     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[2], 125);
227 }
228
229
230 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) {
231
232     IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
233     float data[] = {123, 124, 125};
234     auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
235
236     pState->SetState(stateBlob);
237
238     data[0] = 121;
239     data[1] = 122;
240     data[2] = 123;
241     auto saver = pState->GetLastState();
242
243     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 121);
244     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 122);
245     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[2], 123);
246 }