Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / shape_infer / reshaper_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-matchers.h>
7
8 #include <shape_infer/mock_ishape_infer_impl.hpp>
9 #include <shape_infer/mock_shape_infer_extension.hpp>
10 #include <mock_icnn_network.hpp>
11 #include <../graph_tools/graph_test_base.hpp>
12 #include <shape_infer/mock_reshaper_launcher.hpp>
13 #include <shape_infer/ie_reshaper.hpp>
14
15 using namespace InferenceEngine;
16 using namespace InferenceEngine::details;
17 using namespace ShapeInfer;
18 using namespace ::testing;
19 using namespace ::GraphTest;
20
21 class ReshaperTest : public GraphTestsBase {
22 protected:
23     class TestLauncherCreator : public LauncherCreator {
24     public:
25         struct Mocks {
26             MockReshapeLauncher::Ptr launcher;
27             MockInputController* iController;
28             MockOutputController* oController;
29             MockIShapeInferImpl::Ptr shapeInferImpl;
30
31             Mocks(const MockReshapeLauncher::Ptr& _launcher, MockInputController* _iController,
32                   MockOutputController* _oController, const MockIShapeInferImpl::Ptr& _shapeInferImpl) :
33                     launcher(_launcher), iController(_iController), oController(_oController),
34                     shapeInferImpl(_shapeInferImpl) {}
35         };
36
37         ReshapeLauncher::Ptr
38         createNotInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions) override {
39             return createLauncher(layer);
40         }
41
42         ReshapeLauncher::Ptr
43         createInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions) override {
44             return createLauncher(layer);
45         }
46
47         std::vector<Mocks> getMocks() {
48             return _mocks;
49         }
50
51     private:
52         ReshapeLauncher::Ptr createLauncher(const CNNLayer* layer) {
53             auto initializer = std::make_shared<MockReshapeLauncher::TestLauncherInitializer>();
54             auto shapeInferImpl = std::make_shared<MockIShapeInferImpl>();
55             auto mockLauncher = std::make_shared<MockReshapeLauncher>(initializer, layer, shapeInferImpl);
56             _mocks.emplace_back(mockLauncher, initializer->getInputController(), initializer->getOutputController(),
57                                 shapeInferImpl);
58             return mockLauncher;
59         }
60
61     private:
62         std::vector<Mocks> _mocks;
63     };
64
65     class TestEmptyLauncherCreator : public LauncherCreator {
66     public:
67         ReshapeLauncher::Ptr
68         createNotInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions) override {
69             return std::make_shared<FakeReshapeLauncher>(layer, std::make_shared<MockIShapeInferImpl>());;
70         }
71
72         ReshapeLauncher::Ptr
73         createInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions) override {
74             return std::make_shared<InputReshapeLauncher>(layer, std::make_shared<MockIShapeInferImpl>());
75         }
76     };
77
78     void prepareInputs(InputsDataMap& inputsMap, int batchSize = 1) override {
79         GraphTestsBase::prepareInputs(inputsMap);
80         for (auto layer = lhsLayers.begin(); layer != lhsLayers.end(); layer++) {
81             if ((*layer)->insData.empty()) {
82                 (*layer)->type = "Input";
83             }
84         }
85     }
86
87     void SetUp() override {
88         GraphTestsBase::SetUp();
89         impl = std::make_shared<MockIShapeInferImpl>();
90         CONNECT(0, 1);
91     };
92
93 public:
94     StatusCode sts = GENERAL_ERROR;
95     ResponseDesc resp;
96     static const std::string TEST_NAME;
97     MockIShapeInferImpl::Ptr impl;
98     ReshaperPtr reshaper;
99 };
100
101 const std::string ReshaperTest::TEST_NAME = "TEST_NAME";
102
103 TEST_F(ReshaperTest, canCreateReshaper) {
104     EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
105         prepareInputs(maps);
106     })));
107     Reshaper reshaper(mockNet);
108 }
109
110 TEST_F(ReshaperTest, throwOnAddNullExtension) {
111     EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
112         prepareInputs(maps);
113     })));
114     Reshaper reshaper(mockNet);
115     MockShapeInferExtension::Ptr extension;
116     ASSERT_THROW(reshaper.AddExtension(extension), InferenceEngineException);
117 }
118
119 TEST_F(ReshaperTest, canAddExtensionWithNotRegistered) {
120     EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
121         prepareInputs(maps);
122     })));
123     Reshaper reshaper(mockNet);
124     auto extension = std::make_shared<MockShapeInferExtension>();
125     EXPECT_CALL(*extension.get(), getShapeInferTypes(_, _, _)).WillOnce(DoAll(
126             WithArg<0>(Invoke([&](char**& type) {
127                 type = new char*[1];
128                 type[0] = new char[TEST_NAME.size() + 1];
129                 std::copy(TEST_NAME.begin(), TEST_NAME.end(), type[0]);
130                 type[0][TEST_NAME.size()] = '\0';
131             })),
132             WithArg<1>(Invoke([&](unsigned int& size) { size = 1; })),
133             Return(OK)));
134     reshaper.AddExtension(extension);
135 }
136
137 TEST_F(ReshaperTest, throwOnExtensionWithAlreadyRegisteredImpl) {
138     EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
139         prepareInputs(maps);
140     })));
141     Reshaper reshaper(mockNet);
142     auto extension = std::make_shared<MockShapeInferExtension>();
143     std::string conv_name = "Convolution";
144     EXPECT_CALL(*extension.get(), getShapeInferTypes(_, _, _)).WillOnce(DoAll(
145             WithArg<0>(Invoke([&](char**& type) {
146                 type = new char*[2];
147                 type[0] = new char[TEST_NAME.size() + 1];
148                 std::copy(TEST_NAME.begin(), TEST_NAME.end(), type[0]);
149                 type[0][TEST_NAME.size()] = '\0';
150                 type[1] = new char[conv_name.size() + 1];
151                 std::copy(conv_name.begin(), conv_name.end(), type[1]);
152                 type[1][conv_name.size()] = '\0';
153             })),
154             WithArg<1>(Invoke([&](unsigned int& size) { size = 2; })),
155             Return(OK)));
156     ASSERT_THROW(reshaper.AddExtension(extension), InferenceEngineException);
157 }
158
159 TEST_F(ReshaperTest, canResetOnReshape) {
160     EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
161         prepareInputs(maps);
162     })));
163     auto testCreator = std::make_shared<TestLauncherCreator>();
164     Reshaper reshaper(mockNet, testCreator);
165     auto mocks = testCreator->getMocks();
166     auto inputMock = mocks[0];
167     EXPECT_CALL(*(inputMock.launcher).get(), setShapeByName(_, _));
168     for (auto it:mocks) {
169         EXPECT_CALL(*(it.launcher).get(), getLayerName()).WillRepeatedly(Return(it.launcher->realGetLayerName()));
170         EXPECT_CALL(*(it.launcher).get(), reset());
171         EXPECT_CALL(*(it.launcher).get(), reshape(_));
172         EXPECT_CALL(*(it.launcher).get(), applyChanges(_));
173     }
174
175     auto extension = std::make_shared<MockShapeInferExtension>();
176     EXPECT_CALL(*extension.get(), getShapeInferTypes(_, _, _)).WillOnce(DoAll(
177             WithArg<0>(Invoke([&](char**& type) {
178                 type = new char*[1];
179                 type[0] = new char[TEST_NAME.size() + 1];
180                 std::copy(TEST_NAME.begin(), TEST_NAME.end(), type[0]);
181                 type[0][TEST_NAME.size()] = '\0';
182             })),
183             WithArg<1>(Invoke([&](unsigned int& size) { size = 1; })),
184             Return(OK)));
185     reshaper.AddExtension(extension);
186
187     reshaper.run({{"0", {2}}});
188 }
189
190 TEST_F(ReshaperTest, canUpdateFakeImpl) {
191     EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
192         prepareInputs(maps);
193     })));
194     auto testCreator = std::make_shared<TestEmptyLauncherCreator>();
195     Reshaper reshaper(mockNet, testCreator);
196     auto newImpl = std::make_shared<MockIShapeInferImpl>();
197
198     const char* registered[] = {""};
199     auto extension = std::make_shared<MockShapeInferExtension>();
200     EXPECT_CALL(*extension.get(), getShapeInferTypes(_, _, _)).WillOnce(DoAll(
201             WithArg<0>(Invoke([&](char**& type) {
202                 type = new char*[1];
203                 type[0] = new char[1];
204                 type[0][0] = '\0';
205             })),
206             WithArg<1>(Invoke([&](unsigned int& size) { size = 1; })),
207             Return(OK)));
208     EXPECT_CALL(*extension.get(), getShapeInferImpl(_, _, _)).WillOnce(DoAll(
209             WithArg<0>(Invoke([&](IShapeInferImpl::Ptr& impl) { impl = newImpl; })),
210             Return(OK)));
211     reshaper.AddExtension(extension);
212
213     EXPECT_CALL(*newImpl.get(), inferShapes(_, _, _, _, _)).
214             WillOnce(DoAll(
215             WithArg<3>(Invoke([&](std::vector<SizeVector>& outShape) { outShape.push_back({1, 2}); })), Return(OK)));
216     reshaper.run({{"0", {1, 2}}});
217 }