1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-matchers.h>
8 #include <shape_infer/mock_ishape_infer_impl.hpp>
9 #include <shape_infer/mock_reshaper_launcher.hpp>
11 using namespace InferenceEngine;
12 using namespace InferenceEngine::details;
13 using namespace ShapeInfer;
14 using namespace ::testing;
16 class InputReshapeLauncherTest : public ::testing::Test {
18 void SetUp() override {
19 notEmptyData = getNotEmptyData();
20 impl = std::make_shared<MockIShapeInferImpl>();
24 static const std::string TEST_NAME;
26 MockIShapeInferImpl::Ptr impl;
27 SizeVector outDims{2};
29 DataPtr getNotEmptyData() {
30 return std::make_shared<Data>(TEST_NAME, Precision::UNSPECIFIED, Layout::C);
33 CNNLayerPtr createLayer(const std::string& name = TEST_NAME, const std::string& type = "Input") {
34 LayerParams params{name, type, Precision::UNSPECIFIED};
35 auto layer = std::make_shared<CNNLayer>(params);
36 if (layer == nullptr) {
37 THROW_IE_EXCEPTION << "InputReshapeLauncherTest::createLayer(). Could not create CNNLayer";
39 layer->outData = {notEmptyData};
40 notEmptyData->setDims(outDims);
45 const std::string InputReshapeLauncherTest::TEST_NAME = "TEST_NAME";
47 TEST_F(InputReshapeLauncherTest, failedToCreateWithNullLayer) {
48 const CNNLayer* layer = nullptr;
49 ASSERT_THROW(InputReshapeLauncher launcher(layer, impl), InferenceEngineException);
52 TEST_F(InputReshapeLauncherTest, failedToCreateWithEmptyOutData) {
54 ASSERT_THROW(InputReshapeLauncher launcher(&layer, impl), InferenceEngineException);
57 TEST_F(InputReshapeLauncherTest, failedToCreateWithNullOutData) {
59 layer.outData = {nullptr};
60 ASSERT_THROW(InputReshapeLauncher launcher(&layer, impl), InferenceEngineException);
63 TEST_F(InputReshapeLauncherTest, failedToCreateWithNotInputType) {
65 layer.outData = {notEmptyData};
66 ASSERT_THROW(InputReshapeLauncher launcher(&layer, impl), InferenceEngineException);
69 TEST_F(InputReshapeLauncherTest, canCreateReshapeLauncher) {
70 ASSERT_NO_THROW(InputReshapeLauncher launcher(createLayer().get(), impl));
73 TEST_F(InputReshapeLauncherTest, canPushShapes) {
74 InputReshapeLauncher launcher(createLayer().get(), impl);
75 ASSERT_NO_THROW(launcher.setShapeByName(outDims, TEST_NAME));
78 TEST_F(InputReshapeLauncherTest, canPropagateWithNotEnoughShapes) {
79 InputReshapeLauncher launcher(createLayer().get(), impl);
83 TEST_F(InputReshapeLauncherTest, throwOnPropagateWithEmptyLaunchers) {
84 auto layer = createLayer();
85 layer->outData[0]->inputTo = {{{}, createLayer(TEST_NAME, TEST_NAME)}};
86 InputReshapeLauncher launcher(layer.get(), impl);
87 launcher.setShapeByName(outDims, TEST_NAME);
89 ASSERT_THROW(launcher.reshape({}), InferenceEngineException);
92 TEST_F(InputReshapeLauncherTest, throwOnPropagateWithoutProperLauncher) {
93 auto layer = createLayer();
94 layer->outData[0]->inputTo = {{{}, createLayer(TEST_NAME + "another", TEST_NAME)}};
95 InputReshapeLauncher inLauncher(layer.get(), impl);
96 inLauncher.setShapeByName(outDims, TEST_NAME);
97 auto launcher = std::make_shared<MockReshapeLauncher>();
98 EXPECT_CALL(*launcher.get(), getLayerName()).WillOnce(Return(TEST_NAME));
99 ASSERT_THROW(inLauncher.reshape({{launcher}}), InferenceEngineException);
102 TEST_F(InputReshapeLauncherTest, canPropagate) {
103 auto layer = createLayer();
104 layer->outData[0]->inputTo = {{{}, createLayer(TEST_NAME, TEST_NAME)}};
105 InputReshapeLauncher inLauncher(layer.get(), impl);
106 auto launcher = std::make_shared<MockReshapeLauncher>();
107 EXPECT_CALL(*launcher.get(), setShapeByName(outDims, TEST_NAME));
108 EXPECT_CALL(*launcher.get(), getLayerName()).WillOnce(Return(TEST_NAME));
109 inLauncher.setShapeByName(outDims, TEST_NAME);
110 inLauncher.reshape({{launcher}});
113 TEST_F(InputReshapeLauncherTest, canReset) {
114 auto layer = createLayer();
115 InputReshapeLauncher launcher(layer.get(), impl);
116 ASSERT_NO_THROW(launcher.reset());
119 TEST_F(InputReshapeLauncherTest, canApplyWithoutSettingShapes) {
120 auto layer = createLayer();
121 layer->outData.push_back(notEmptyData);
122 InputReshapeLauncher launcher(layer.get(), impl);
123 ASSERT_NO_THROW(launcher.applyChanges(layer.get()));
126 TEST_F(InputReshapeLauncherTest, canNotApplyForLayerWithAnotherName) {
127 auto layer1 = createLayer("");
128 auto layer2 = createLayer();
129 InputReshapeLauncher launcher(layer1.get(), impl);
130 launcher.setShapeByName(outDims, TEST_NAME);
131 ASSERT_THROW(launcher.applyChanges(layer2.get()), InferenceEngineException);
134 TEST_F(InputReshapeLauncherTest, canApplyChanges) {
135 auto layer = createLayer();
136 InputReshapeLauncher launcher(layer.get(), impl);
137 launcher.setShapeByName(outDims, TEST_NAME);
138 launcher.applyChanges(layer.get());
140 auto outData = layer->outData;
141 ASSERT_EQ(1, outData.size());
142 auto out0Data = outData[0];
143 ASSERT_NE(nullptr, out0Data);
144 ASSERT_EQ(outDims, out0Data->getDims());
147 TEST_F(InputReshapeLauncherTest, canGetShapesFromLayer) {
149 layer.outData = {notEmptyData};
150 notEmptyData->setDims(outDims);
151 auto initializer = std::make_shared<MockReshapeLauncher::TestLauncherInitializer>();
152 InputReshapeLauncher launcher(&layer, impl, initializer);
153 auto outputController = initializer->getOutputController();
154 EXPECT_CALL(*outputController, getIRShapes()).WillOnce(Return(std::vector<SizeVector>{outDims}));
155 EXPECT_CALL(*outputController, getShapes(false)).WillOnce(Return(std::vector<SizeVector>{SizeVector()}));
156 EXPECT_CALL(*outputController, setShapeByIndex(outDims, 0));
157 EXPECT_CALL(*outputController, propagateShapes(_));
158 launcher.reshape({});