1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-matchers.h>
8 #include <shape_infer/mock_ishape_infer_impl.hpp>
9 #include <shape_infer/mock_shape_infer_extension.hpp>
10 #include <mock_icnn_network.hpp>
11 #include <../graph_tools/graph_test_base.hpp>
12 #include <shape_infer/mock_reshaper_launcher.hpp>
13 #include <shape_infer/ie_reshaper.hpp>
15 using namespace InferenceEngine;
16 using namespace InferenceEngine::details;
17 using namespace ShapeInfer;
18 using namespace ::testing;
19 using namespace ::GraphTest;
21 class ReshaperTest : public GraphTestsBase {
23 class TestLauncherCreator : public LauncherCreator {
26 MockReshapeLauncher::Ptr launcher;
27 MockInputController* iController;
28 MockOutputController* oController;
29 MockIShapeInferImpl::Ptr shapeInferImpl;
31 Mocks(const MockReshapeLauncher::Ptr& _launcher, MockInputController* _iController,
32 MockOutputController* _oController, const MockIShapeInferImpl::Ptr& _shapeInferImpl) :
33 launcher(_launcher), iController(_iController), oController(_oController),
34 shapeInferImpl(_shapeInferImpl) {}
38 createNotInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions) override {
39 return createLauncher(layer);
43 createInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions) override {
44 return createLauncher(layer);
47 std::vector<Mocks> getMocks() {
52 ReshapeLauncher::Ptr createLauncher(const CNNLayer* layer) {
53 auto initializer = std::make_shared<MockReshapeLauncher::TestLauncherInitializer>();
54 auto shapeInferImpl = std::make_shared<MockIShapeInferImpl>();
55 auto mockLauncher = std::make_shared<MockReshapeLauncher>(initializer, layer, shapeInferImpl);
56 _mocks.emplace_back(mockLauncher, initializer->getInputController(), initializer->getOutputController(),
62 std::vector<Mocks> _mocks;
65 class TestEmptyLauncherCreator : public LauncherCreator {
68 createNotInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions) override {
69 return std::make_shared<FakeReshapeLauncher>(layer, std::make_shared<MockIShapeInferImpl>());;
73 createInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions) override {
74 return std::make_shared<InputReshapeLauncher>(layer, std::make_shared<MockIShapeInferImpl>());
78 void prepareInputs(InputsDataMap& inputsMap, int batchSize = 1) override {
79 GraphTestsBase::prepareInputs(inputsMap);
80 for (auto layer = lhsLayers.begin(); layer != lhsLayers.end(); layer++) {
81 if ((*layer)->insData.empty()) {
82 (*layer)->type = "Input";
87 void SetUp() override {
88 GraphTestsBase::SetUp();
89 impl = std::make_shared<MockIShapeInferImpl>();
94 StatusCode sts = GENERAL_ERROR;
96 static const std::string TEST_NAME;
97 MockIShapeInferImpl::Ptr impl;
101 const std::string ReshaperTest::TEST_NAME = "TEST_NAME";
103 TEST_F(ReshaperTest, canCreateReshaper) {
104 EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
107 Reshaper reshaper(mockNet);
110 TEST_F(ReshaperTest, throwOnAddNullExtension) {
111 EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
114 Reshaper reshaper(mockNet);
115 MockShapeInferExtension::Ptr extension;
116 ASSERT_THROW(reshaper.AddExtension(extension), InferenceEngineException);
119 TEST_F(ReshaperTest, canAddExtensionWithNotRegistered) {
120 EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
123 Reshaper reshaper(mockNet);
124 auto extension = std::make_shared<MockShapeInferExtension>();
125 EXPECT_CALL(*extension.get(), getShapeInferTypes(_, _, _)).WillOnce(DoAll(
126 WithArg<0>(Invoke([&](char**& type) {
128 type[0] = new char[TEST_NAME.size() + 1];
129 std::copy(TEST_NAME.begin(), TEST_NAME.end(), type[0]);
130 type[0][TEST_NAME.size()] = '\0';
132 WithArg<1>(Invoke([&](unsigned int& size) { size = 1; })),
134 reshaper.AddExtension(extension);
137 TEST_F(ReshaperTest, throwOnExtensionWithAlreadyRegisteredImpl) {
138 EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
141 Reshaper reshaper(mockNet);
142 auto extension = std::make_shared<MockShapeInferExtension>();
143 std::string conv_name = "Convolution";
144 EXPECT_CALL(*extension.get(), getShapeInferTypes(_, _, _)).WillOnce(DoAll(
145 WithArg<0>(Invoke([&](char**& type) {
147 type[0] = new char[TEST_NAME.size() + 1];
148 std::copy(TEST_NAME.begin(), TEST_NAME.end(), type[0]);
149 type[0][TEST_NAME.size()] = '\0';
150 type[1] = new char[conv_name.size() + 1];
151 std::copy(conv_name.begin(), conv_name.end(), type[1]);
152 type[1][conv_name.size()] = '\0';
154 WithArg<1>(Invoke([&](unsigned int& size) { size = 2; })),
156 ASSERT_THROW(reshaper.AddExtension(extension), InferenceEngineException);
159 TEST_F(ReshaperTest, canResetOnReshape) {
160 EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
163 auto testCreator = std::make_shared<TestLauncherCreator>();
164 Reshaper reshaper(mockNet, testCreator);
165 auto mocks = testCreator->getMocks();
166 auto inputMock = mocks[0];
167 EXPECT_CALL(*(inputMock.launcher).get(), setShapeByName(_, _));
168 for (auto it:mocks) {
169 EXPECT_CALL(*(it.launcher).get(), getLayerName()).WillRepeatedly(Return(it.launcher->realGetLayerName()));
170 EXPECT_CALL(*(it.launcher).get(), reset());
171 EXPECT_CALL(*(it.launcher).get(), reshape(_));
172 EXPECT_CALL(*(it.launcher).get(), applyChanges(_));
175 auto extension = std::make_shared<MockShapeInferExtension>();
176 EXPECT_CALL(*extension.get(), getShapeInferTypes(_, _, _)).WillOnce(DoAll(
177 WithArg<0>(Invoke([&](char**& type) {
179 type[0] = new char[TEST_NAME.size() + 1];
180 std::copy(TEST_NAME.begin(), TEST_NAME.end(), type[0]);
181 type[0][TEST_NAME.size()] = '\0';
183 WithArg<1>(Invoke([&](unsigned int& size) { size = 1; })),
185 reshaper.AddExtension(extension);
187 reshaper.run({{"0", {2}}});
190 TEST_F(ReshaperTest, canUpdateFakeImpl) {
191 EXPECT_CALL(mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
194 auto testCreator = std::make_shared<TestEmptyLauncherCreator>();
195 Reshaper reshaper(mockNet, testCreator);
196 auto newImpl = std::make_shared<MockIShapeInferImpl>();
198 const char* registered[] = {""};
199 auto extension = std::make_shared<MockShapeInferExtension>();
200 EXPECT_CALL(*extension.get(), getShapeInferTypes(_, _, _)).WillOnce(DoAll(
201 WithArg<0>(Invoke([&](char**& type) {
203 type[0] = new char[1];
206 WithArg<1>(Invoke([&](unsigned int& size) { size = 1; })),
208 EXPECT_CALL(*extension.get(), getShapeInferImpl(_, _, _)).WillOnce(DoAll(
209 WithArg<0>(Invoke([&](IShapeInferImpl::Ptr& impl) { impl = newImpl; })),
211 reshaper.AddExtension(extension);
213 EXPECT_CALL(*newImpl.get(), inferShapes(_, _, _, _, _)).
215 WithArg<3>(Invoke([&](std::vector<SizeVector>& outShape) { outShape.push_back({1, 2}); })), Return(OK)));
216 reshaper.run({{"0", {1, 2}}});