1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-matchers.h>
8 #include <inference_engine/shape_infer/ie_reshape_launcher.hpp>
9 #include <inference_engine/blob_factory.hpp>
10 #include <shape_infer/mock_ishape_infer_impl.hpp>
11 #include <shape_infer/mock_reshaper_launcher.hpp>
13 using namespace InferenceEngine;
14 using namespace InferenceEngine::details;
15 using namespace ShapeInfer;
16 using namespace ::testing;
18 class ReshapeLauncherTest : public ::testing::Test {
20 void SetUp() override {
21 notEmptyData = getNotEmptyData();
22 impl = std::make_shared<MockIShapeInferImpl>();
24 std::vector<Blob::CPtr> getBlobs(const std::vector<SizeVector>& shapes) {
25 std::vector<Blob::CPtr> inBlobs;
26 for (auto const& dims : shapes) {
27 TensorDesc desc(Precision::FP32, dims, TensorDesc::getLayoutByDims(dims));
28 auto blob = make_blob_with_precision(desc);
29 inBlobs.push_back(blob);
34 StatusCode sts = GENERAL_ERROR;
36 static const std::string TEST_NAME;
38 MockIShapeInferImpl::Ptr impl;
40 SizeVector outDims{2};
41 std::map<std::string, std::string> changedParams{{TEST_NAME, TEST_NAME}};
43 DataPtr getNotEmptyData() {
44 return std::make_shared<Data>(TEST_NAME, Precision::FP32, Layout::C);
48 const std::string ReshapeLauncherTest::TEST_NAME = "TEST_NAME";
50 TEST_F(ReshapeLauncherTest, failedToCreateWithNullLayer) {
51 const CNNLayer* layer = nullptr;
52 ASSERT_THROW(ReshapeLauncher launcher(layer, impl), InferenceEngineException);
55 TEST_F(ReshapeLauncherTest, failedToCreateWithNullInsData) {
57 layer.outData = {notEmptyData};
58 ASSERT_THROW(ReshapeLauncher launcher(&layer, impl), InferenceEngineException);
61 TEST_F(ReshapeLauncherTest, failedToCreateWithExpiredInsData) {
63 layer.outData = {notEmptyData};
64 DataWeakPtr expired = std::make_shared<Data>(TEST_NAME, Precision::UNSPECIFIED);
65 layer.insData = {expired};
66 ASSERT_THROW(ReshapeLauncher launcher(&layer, impl), InferenceEngineException);
69 TEST_F(ReshapeLauncherTest, failedToCreateWithEmptyOutData) {
71 layer.insData = {notEmptyData};
72 ASSERT_THROW(ReshapeLauncher launcher(&layer, impl), InferenceEngineException);
75 TEST_F(ReshapeLauncherTest, failedToCreateWithNullOutData) {
77 layer.insData = {notEmptyData};
78 layer.outData = {nullptr};
79 ASSERT_THROW(ReshapeLauncher launcher(&layer, impl), InferenceEngineException);
82 TEST_F(ReshapeLauncherTest, failedToCreateWithEmptyImpl) {
84 layer.outData = {notEmptyData};
85 layer.insData = {notEmptyData};
87 ASSERT_THROW(ReshapeLauncher launcher(&layer, impl), InferenceEngineException);
90 TEST_F(ReshapeLauncherTest, canCreateReshapeLauncher) {
92 layer.outData = {notEmptyData};
93 layer.insData = {notEmptyData};
94 ReshapeLauncher launcher(&layer, impl);
97 TEST_F(ReshapeLauncherTest, throwOnReshapeWihtNotEnoughShapes) {
99 layer.outData = {notEmptyData};
100 layer.insData = {notEmptyData, notEmptyData};
101 ReshapeLauncher launcher(&layer, impl);
103 launcher.setShapeByName(inDims, TEST_NAME);
105 launcher.reshape({});
106 FAIL() << "Reshape should be failed!";
110 TEST_F(ReshapeLauncherTest, implIsCalledOnReshape) {
112 layer.insData = {notEmptyData};
113 auto initializer = std::make_shared<MockReshapeLauncher::TestLauncherInitializer>();
114 ReshapeLauncher launcher(&layer, impl, initializer);
115 auto inputController = initializer->getInputController();
116 auto outputController = initializer->getOutputController();
117 std::vector<SizeVector> shapes{inDims};
118 auto blobs = getBlobs(shapes);
119 EXPECT_CALL(*inputController, setShapeByName(inDims, TEST_NAME));
120 EXPECT_CALL(*inputController, getBlobs(true)).WillOnce(Return(blobs));
121 EXPECT_CALL(*outputController, setShapes(_));
122 EXPECT_CALL(*outputController, propagateShapes(_));
123 EXPECT_CALL(*impl.get(), inferShapes(blobs, _, _, _, _)).WillOnce(Return(OK));
124 launcher.setShapeByName(inDims, TEST_NAME);
125 launcher.reshape({});
128 TEST_F(ReshapeLauncherTest, canApplyChanges) {
130 layer.outData = {getNotEmptyData()};
131 layer.insData = {notEmptyData};
132 ReshapeLauncher launcher(&layer, impl);
133 launcher.setShapeByName(inDims, TEST_NAME);
135 EXPECT_CALL(*impl.get(), inferShapes(_, _, _, _, _)).
137 WithArg<3>(Invoke([&](std::vector<SizeVector>& outShape) { outShape.push_back(outDims); })), Return(OK)));
138 launcher.reshape({});
139 launcher.applyChanges(&layer);
141 auto insData = layer.insData;
142 auto outData = layer.outData;
143 ASSERT_EQ(1, insData.size());
144 ASSERT_EQ(1, outData.size());
145 auto ins0Data = insData[0].lock();
146 auto out0Data = outData[0];
147 ASSERT_NE(nullptr, ins0Data);
148 ASSERT_NE(nullptr, out0Data);
149 ASSERT_EQ(inDims, ins0Data->getDims());
150 ASSERT_EQ(outDims, out0Data->getDims());
153 TEST_F(ReshapeLauncherTest, throwOnApplyingWithNotEnoughOutput) {
155 layer.outData = {notEmptyData};
156 layer.insData = {notEmptyData};
157 ReshapeLauncher launcher(&layer, impl);
158 launcher.setShapeByName(inDims, TEST_NAME);
159 EXPECT_CALL(*impl.get(), inferShapes(_, _, _, _, _)).
161 WithArg<3>(Invoke([&](std::vector<SizeVector>& outShape) {
162 outShape.push_back(outDims);
163 outShape.push_back(outDims);
166 ASSERT_THROW(launcher.reshape({}), InferenceEngineException);
167 ASSERT_THROW(launcher.applyChanges(&layer), InferenceEngineException);
170 TEST_F(ReshapeLauncherTest, throwOnApplyingWithNotEnoughShapes) {
172 layer.outData = {notEmptyData, notEmptyData};
173 layer.insData = {notEmptyData};
174 ReshapeLauncher launcher(&layer, impl);
175 launcher.setShapeByName(inDims, TEST_NAME);
176 EXPECT_CALL(*impl.get(), inferShapes(_, _, _, _, _)).
178 WithArg<3>(Invoke([&](std::vector<SizeVector>& outShape) { outShape.push_back(outDims); })),
180 ASSERT_THROW(launcher.reshape({}), InferenceEngineException);
181 ASSERT_THROW(launcher.applyChanges(&layer), InferenceEngineException);
184 TEST_F(ReshapeLauncherTest, canNotApplyForLayerWithAnotherName) {
186 layer1.outData = {notEmptyData};
187 layer1.insData = {notEmptyData};
189 layer2.name = TEST_NAME;
190 ReshapeLauncher launcher(&layer1, impl);
191 { // to not fail because of empty input and output shapes
192 launcher.setShapeByName(inDims, TEST_NAME);
193 EXPECT_CALL(*impl.get(), inferShapes(_, _, _, _, _)).
195 WithArg<3>(Invoke([&](std::vector<SizeVector>& outShape) { outShape.push_back(outDims); })),
197 launcher.reshape({});
199 ASSERT_THROW(launcher.applyChanges(&layer2), InferenceEngineException);
202 TEST_F(ReshapeLauncherTest, DISABLED_canNotApplyForLayerWithAnotherParams) {
204 layer1.outData = {notEmptyData};
205 layer1.insData = {notEmptyData};
207 layer2.params = changedParams;
208 ReshapeLauncher launcher(&layer1, impl);
209 { // to not fail because of empty input and output shapes
210 launcher.setShapeByName(inDims, TEST_NAME);
211 EXPECT_CALL(*impl.get(), inferShapes(_, _, _, _, _)).
213 WithArg<3>(Invoke([&](std::vector<SizeVector>& outShape) { outShape.push_back(outDims); })),
215 launcher.reshape({});
217 ASSERT_THROW(launcher.applyChanges(&layer2), InferenceEngineException);
220 TEST_F(ReshapeLauncherTest, canNotApplyForLayerWithEmptyInShapes) {
222 layer1.outData = {notEmptyData};
223 layer1.insData = {notEmptyData};
225 layer2.params = changedParams;
226 ReshapeLauncher launcher(&layer1, impl);
227 { // to not fail because of inconsistent number of input/outputs
228 layer1.insData.clear();
229 layer1.outData.clear();
231 ASSERT_THROW(launcher.applyChanges(&layer2), InferenceEngineException);
234 TEST_F(ReshapeLauncherTest, canNotApplyForLayerWithEmptyOutShapes) {
236 layer1.outData = {notEmptyData};
237 layer1.insData = {notEmptyData};
239 layer2.params = changedParams;
240 ReshapeLauncher launcher(&layer1, impl);
241 { // to not fail because of inconsistent number of input/outputs
242 launcher.setShapeByName(inDims, TEST_NAME);
243 layer1.outData.clear();
245 ASSERT_THROW(launcher.applyChanges(&layer2), InferenceEngineException);
248 TEST_F(ReshapeLauncherTest, canReset) {
249 auto initializer = std::make_shared<MockReshapeLauncher::TestLauncherInitializer>();
250 MockReshapeLauncher launcher(initializer);
251 auto inputController = initializer->getInputController();
252 auto outputController = initializer->getOutputController();
253 EXPECT_CALL(*inputController, reset()).Times(1);
254 EXPECT_CALL(*outputController, reset()).Times(1);
255 launcher.realReset();