1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include <ie_version.hpp>
9 #include <cpp/ie_executable_network.hpp>
10 #include <cpp_interfaces/base/ie_plugin_base.hpp>
12 #include <mock_icnn_network.hpp>
13 #include <mock_iexecutable_network.hpp>
14 #include <cpp_interfaces/interface/mock_imemory_state_internal.hpp>
15 #include <inference_engine/cpp_interfaces/base/ie_executable_network_base.hpp>
16 #include <cpp_interfaces/interface/mock_iexecutable_network_internal.hpp>
17 #include <inference_engine/cpp_interfaces/impl/ie_memory_state_internal.hpp>
19 using namespace ::testing;
21 using namespace InferenceEngine;
22 using namespace InferenceEngine::details;
24 class MemoryStateTests : public ::testing::Test {
26 shared_ptr<MockIExecutableNetworkInternal> mockExeNetworkInternal;
27 shared_ptr<MockIMemoryStateInternal> mockMemoryStateInternal;
29 virtual void SetUp() {
30 mockExeNetworkInternal = make_shared<MockIExecutableNetworkInternal>();
31 mockMemoryStateInternal = make_shared<MockIMemoryStateInternal>();
35 TEST_F(MemoryStateTests, ExecutableNetworkCanConvertOneMemoryStateFromCppToAPI) {
37 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
38 std::vector<IMemoryStateInternal::Ptr> toReturn(1);
39 toReturn[0] = mockMemoryStateInternal;
41 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
43 auto state = net.QueryState();
44 ASSERT_EQ(state.size(), 1);
47 TEST_F(MemoryStateTests, ExecutableNetworkCanConvertZeroMemoryStateFromCppToAPI) {
49 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
50 std::vector<IMemoryStateInternal::Ptr> toReturn;
52 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
54 auto state = net.QueryState();
55 ASSERT_EQ(state.size(), 0);
58 TEST_F(MemoryStateTests, ExecutableNetworkCanConvert2MemoryStatesFromCPPtoAPI) {
60 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
61 std::vector<IMemoryStateInternal::Ptr> toReturn;
62 toReturn.push_back(mockMemoryStateInternal);
63 toReturn.push_back(mockMemoryStateInternal);
65 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
67 auto state = net.QueryState();
68 ASSERT_EQ(state.size(), 2);
71 TEST_F(MemoryStateTests, MemoryStatePropagatesReset) {
73 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
74 std::vector<IMemoryStateInternal::Ptr> toReturn;
75 toReturn.push_back(mockMemoryStateInternal);
77 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
78 EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).Times(1);
80 auto state = net.QueryState();
81 state.front().Reset();
84 TEST_F(MemoryStateTests, MemoryStatePropagatesExceptionsFromReset) {
86 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
87 std::vector<IMemoryStateInternal::Ptr> toReturn;
88 toReturn.push_back(mockMemoryStateInternal);
90 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
91 EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
93 auto state = net.QueryState();
94 EXPECT_ANY_THROW(state.front().Reset());
97 TEST_F(MemoryStateTests, MemoryStatePropagatesGetName) {
99 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
100 std::vector<IMemoryStateInternal::Ptr> toReturn;
101 toReturn.push_back(mockMemoryStateInternal);
103 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
104 EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
106 auto state = net.QueryState();
107 EXPECT_STREQ(state.front().GetName().c_str(), "someName");
110 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithZeroLen) {
112 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
113 std::vector<IMemoryStateInternal::Ptr> toReturn;
114 toReturn.push_back(mockMemoryStateInternal);
116 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
117 EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
119 IMemoryState::Ptr pState;
121 static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
122 char *name = reinterpret_cast<char *>(1);
123 EXPECT_NO_THROW(pState->GetName(name, 0, nullptr));
127 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfOne) {
129 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
130 std::vector<IMemoryStateInternal::Ptr> toReturn;
131 toReturn.push_back(mockMemoryStateInternal);
133 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
134 EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
136 IMemoryState::Ptr pState;
138 static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
140 EXPECT_NO_THROW(pState->GetName(name, 1, nullptr));
141 EXPECT_STREQ(name, "");
144 TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfTwo) {
146 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
147 std::vector<IMemoryStateInternal::Ptr> toReturn;
148 toReturn.push_back(mockMemoryStateInternal);
150 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
151 EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
153 IMemoryState::Ptr pState;
155 static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
157 EXPECT_NO_THROW(pState->GetName(name, 2, nullptr));
158 EXPECT_STREQ(name, "s");
161 TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) {
163 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
164 std::vector<IMemoryStateInternal::Ptr> toReturn;
166 toReturn.push_back(mockMemoryStateInternal);
168 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
169 EXPECT_CALL(*mockMemoryStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
171 float data[] = {123, 124, 125};
172 auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
174 EXPECT_NO_THROW(net.QueryState().front().SetState(stateBlob));
175 ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
176 ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
177 ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
180 TEST_F(MemoryStateTests, MemoryStateCanPropagateGetLastState) {
182 auto net = ExecutableNetwork(make_executable_network(mockExeNetworkInternal));
183 std::vector<IMemoryStateInternal::Ptr> toReturn;
185 float data[] = {123, 124, 125};
186 auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
189 toReturn.push_back(mockMemoryStateInternal);
191 EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
192 EXPECT_CALL(*mockMemoryStateInternal.get(), GetLastState()).WillOnce(Return(stateBlob));
195 auto saver = net.QueryState().front().GetLastState();
196 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
197 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
198 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
201 class MemoryStateInternalMockImpl : public MemoryStateInternal {
203 using MemoryStateInternal::MemoryStateInternal;
204 MOCK_METHOD0(Reset, void ());
207 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveName) {
209 IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
211 ASSERT_STREQ(pState->GetName().c_str(), "name");
215 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) {
217 IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
218 float data[] = {123, 124, 125};
219 auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
221 pState->SetState(stateBlob);
222 auto saver = pState->GetLastState();
224 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 123);
225 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 124);
226 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[2], 125);
230 TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) {
232 IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
233 float data[] = {123, 124, 125};
234 auto stateBlob = make_shared_blob<float>(Precision::FP32, C, {3}, data, sizeof(data) / sizeof(*data));
236 pState->SetState(stateBlob);
241 auto saver = pState->GetLastState();
243 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 121);
244 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 122);
245 ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[2], 123);