Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / cpp_interfaces / memory_state_tests.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include <ie_version.hpp>
8 #include <cpp/ie_executable_network.hpp>
9 #include <cpp_interfaces/base/ie_plugin_base.hpp>
10
11 #include <mock_icnn_network.hpp>
12 #include <mock_iexecutable_network.hpp>
13 #include <cpp_interfaces/interface/mock_imemory_state_internal.hpp>
14 #include <inference_engine/cpp_interfaces/base/ie_executable_network_base.hpp>
15 #include <cpp_interfaces/interface/mock_iexecutable_network_internal.hpp>
16 #include <inference_engine/cpp_interfaces/impl/ie_memory_state_internal.hpp>
17
18 using namespace ::testing;
19 using namespace std;
20 using namespace InferenceEngine;
21 using namespace InferenceEngine::details;
22
23 class MemoryStateTests : public ::testing::Test {
24  protected:
25     shared_ptr<MockIExecutableNetworkInternal> mockExeNetworkInternal;
26     shared_ptr<MockIMemoryStateInternal> mockMemoryStateInternal;
27
28     virtual void SetUp() {
29         mockExeNetworkInternal = make_shared<MockIExecutableNetworkInternal>();
30         mockMemoryStateInternal = make_shared<MockIMemoryStateInternal>();
31     }
32 };
33
34 TEST_F(MemoryStateTests, ExecutableNetworkCanConvertOneMemoryStateFromCppToAPI) {
35
36     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
37     std::vector<IMemoryStateInternal::Ptr> toReturn(1);
38     toReturn[0] = mockMemoryStateInternal;
39
40     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
41
42     auto state = net.QueryState();
43     ASSERT_EQ(state.size(), 1);
44 }
45
46 TEST_F(MemoryStateTests, ExecutableNetworkCanConvertZeroMemoryStateFromCppToAPI) {
47
48     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
49     std::vector<IMemoryStateInternal::Ptr> toReturn;
50
51     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
52
53     auto state = net.QueryState();
54     ASSERT_EQ(state.size(), 0);
55 }
56
57 TEST_F(MemoryStateTests, ExecutableNetworkCanConvert2MemoryStatesFromCPPtoAPI) {
58
59     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
60     std::vector<IMemoryStateInternal::Ptr> toReturn;
61     toReturn.push_back(mockMemoryStateInternal);
62     toReturn.push_back(mockMemoryStateInternal);
63
64     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
65
66     auto state = net.QueryState();
67     ASSERT_EQ(state.size(), 2);
68 }
69
70 TEST_F(MemoryStateTests, MemoryStatePropagatesReset) {
71
72     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
73     std::vector<IMemoryStateInternal::Ptr> toReturn;
74     toReturn.push_back(mockMemoryStateInternal);
75
76     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
77     EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).Times(1);
78
79     auto state = net.QueryState();
80     state.front().Reset();
81 }
82
83 TEST_F(MemoryStateTests, MemoryStatePropagatesExceptionsFromReset) {
84
85     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
86     std::vector<IMemoryStateInternal::Ptr> toReturn;
87     toReturn.push_back(mockMemoryStateInternal);
88
89     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
90     EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
91
92     auto state = net.QueryState();
93     EXPECT_ANY_THROW(state.front().Reset());
94 }
95
96 TEST_F(MemoryStateTests, MemoryStatePropagatesGetName) {
97
98     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
99     std::vector<IMemoryStateInternal::Ptr> toReturn;
100     toReturn.push_back(mockMemoryStateInternal);
101
102     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
103     EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
104
105     auto state = net.QueryState();
106     EXPECT_STREQ(state.front().GetName().c_str(), "someName");
107 }
108
109 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithZeroLen) {
110
111     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
112     std::vector<IMemoryStateInternal::Ptr> toReturn;
113     toReturn.push_back(mockMemoryStateInternal);
114
115     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
116     EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
117
118     IMemoryState::Ptr pState;
119
120     static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
121     char *name = reinterpret_cast<char *>(1);
122     EXPECT_NO_THROW(pState->GetName(name, 0, nullptr));
123 }
124
125
126 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfOne) {
127
128     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
129     std::vector<IMemoryStateInternal::Ptr> toReturn;
130     toReturn.push_back(mockMemoryStateInternal);
131
132     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
133     EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
134
135     IMemoryState::Ptr pState;
136
137     static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
138     char name[1];
139     EXPECT_NO_THROW(pState->GetName(name, 1, nullptr));
140     EXPECT_STREQ(name, "");
141 }
142
143 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfTwo) {
144
145     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
146     std::vector<IMemoryStateInternal::Ptr> toReturn;
147     toReturn.push_back(mockMemoryStateInternal);
148
149     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
150     EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
151
152     IMemoryState::Ptr pState;
153
154     static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
155     char name[2];
156     EXPECT_NO_THROW(pState->GetName(name, 2, nullptr));
157     EXPECT_STREQ(name, "s");
158 }
159
160 TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) {
161
162     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
163     std::vector<IMemoryStateInternal::Ptr> toReturn;
164     Blob::Ptr saver;
165     toReturn.push_back(mockMemoryStateInternal);
166
167     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
168     EXPECT_CALL(*mockMemoryStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
169
170     float data[] = {123, 124, 125};
171     auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
172
173     EXPECT_NO_THROW(net.QueryState().front().SetState(stateBlob));
174     ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
175     ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
176     ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
177 }
178
179 TEST_F(MemoryStateTests, MemoryStateCanPropagateGetLastState) {
180
181     auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
182     std::vector<IMemoryStateInternal::Ptr> toReturn;
183
184     float data[] = {123, 124, 125};
185     auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
186
187
188     toReturn.push_back(mockMemoryStateInternal);
189
190     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
191     EXPECT_CALL(*mockMemoryStateInternal.get(), GetLastState()).WillOnce(Return(stateBlob));
192
193
194     auto saver = net.QueryState().front().GetLastState();
195     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
196     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
197     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
198 }
199
200 class MemoryStateInternalMockImpl : public MemoryStateInternal {
201  public:
202     using MemoryStateInternal::MemoryStateInternal;
203     MOCK_METHOD0(Reset, void ());
204 };
205
206 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveName) {
207
208     IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
209
210     ASSERT_STREQ(pState->GetName().c_str(), "name");
211 }
212
213
214 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) {
215
216     IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
217     float data[] = {123, 124, 125};
218     auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
219
220     pState->SetState(stateBlob);
221     auto saver = pState->GetLastState();
222
223     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 123);
224     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 124);
225     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[2], 125);
226 }
227
228
229 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) {
230
231     IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
232     float data[] = {123, 124, 125};
233     auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
234
235     pState->SetState(stateBlob);
236
237     data[0] = 121;
238     data[1] = 122;
239     data[2] = 123;
240     auto saver = pState->GetLastState();
241
242     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 121);
243     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 122);
244     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[2], 123);
245 }