1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-matchers.h>
8 #include <shape_infer/mock_reshaper_launcher.hpp>
9 #include <shape_infer/mock_ishape_infer_impl.hpp>
10 #include <shape_infer/ie_reshape_io_controllers.hpp>
12 using namespace InferenceEngine;
13 using namespace InferenceEngine::details;
14 using namespace ShapeInfer;
15 using namespace ::testing;
17 class OutputControllerTest : public ::testing::Test {
20 static const std::string TEST_NAME;
21 DataPtr notEmptyData = std::make_shared<Data>(TEST_NAME, Precision::UNSPECIFIED, Layout::C);
24 CNNLayerPtr createLayer(const std::string& name) {
27 return std::make_shared<CNNLayer>(params);
31 const std::string OutputControllerTest::TEST_NAME = "TEST_NAME";
33 TEST_F(OutputControllerTest, failedToCreateWithEmptyOutData) {
34 std::vector<DataPtr> inData;
35 EXPECT_THROW(OutputController({}, TEST_NAME), InferenceEngineException);
38 TEST_F(OutputControllerTest, failedToCreateWithNullOutData) {
39 EXPECT_THROW(OutputController({nullptr}, TEST_NAME), InferenceEngineException);
42 TEST_F(OutputControllerTest, canCreateOutputController) {
43 ASSERT_NO_THROW(OutputController({notEmptyData}, TEST_NAME));
46 TEST_F(OutputControllerTest, canGetChanges) {
47 OutputController controller({notEmptyData}, TEST_NAME);
48 std::vector<SizeVector> shapes;
49 ASSERT_NO_THROW(shapes = controller.getShapes(false));
50 ASSERT_EQ(1, shapes.size());
53 TEST_F(OutputControllerTest, canSetShapes) {
54 OutputController controller({notEmptyData}, TEST_NAME);
55 auto shapes = {inDims, inDims};
56 ASSERT_NO_THROW(controller.setShapes(shapes));
57 ASSERT_EQ(shapes.size(), controller.getShapes(false).size());
60 TEST_F(OutputControllerTest, noThrowOnGetWithExcessShapes) {
61 OutputController controller({notEmptyData}, TEST_NAME);
62 ASSERT_NO_THROW(controller.setShapes({inDims, inDims}));
63 ASSERT_FALSE(controller.getShapes(false).empty());
66 TEST_F(OutputControllerTest, throwOnPropagateWithNotEnoughShapes) {
67 OutputController controller({notEmptyData, notEmptyData}, TEST_NAME);
68 controller.setShapes({inDims});
69 ASSERT_THROW(controller.propagateShapes({}), InferenceEngineException);
72 TEST_F(OutputControllerTest, throwOnPropagateWithExcessShapes) {
73 OutputController controller({notEmptyData}, TEST_NAME);
74 controller.setShapes({inDims, inDims});
75 ASSERT_THROW(controller.propagateShapes({}), InferenceEngineException);
78 TEST_F(OutputControllerTest, throwOnPropagateWithEmptyLaunchers) {
79 OutputController controller({notEmptyData}, TEST_NAME);
80 notEmptyData->inputTo = {{{}, createLayer(TEST_NAME)}};
81 controller.setShapes({inDims});
82 ASSERT_THROW(controller.propagateShapes({}), InferenceEngineException);
85 TEST_F(OutputControllerTest, throwOnPropagateWithoutProperLauncher) {
86 OutputController controller({notEmptyData}, TEST_NAME);
87 notEmptyData->inputTo = {{{}, createLayer(TEST_NAME + "another")}};
88 controller.setShapes({inDims});
89 auto launcher = std::make_shared<MockReshapeLauncher>();
90 EXPECT_CALL(*launcher.get(), getLayerName()).WillOnce(Return(TEST_NAME));
91 ASSERT_THROW(controller.propagateShapes({launcher}), InferenceEngineException);
94 TEST_F(OutputControllerTest, canPropagateShapes) {
95 OutputController controller({notEmptyData}, TEST_NAME);
96 notEmptyData->inputTo = {{{}, createLayer(TEST_NAME)}};
97 controller.setShapes({inDims});
98 auto launcher = std::make_shared<MockReshapeLauncher>();
99 EXPECT_CALL(*launcher.get(), setShapeByName(inDims, TEST_NAME));
100 EXPECT_CALL(*launcher.get(), getLayerName()).WillOnce(Return(TEST_NAME));
101 controller.propagateShapes({launcher});
104 TEST_F(OutputControllerTest, throwOnApplyWithNotEnoughShapes) {
105 OutputController controller({notEmptyData, notEmptyData}, TEST_NAME);
106 controller.setShapes({inDims});
107 ASSERT_THROW(controller.applyChanges(), InferenceEngineException);
110 TEST_F(OutputControllerTest, throwOnApplyWithExcessShapes) {
111 OutputController controller({notEmptyData}, TEST_NAME);
112 auto shapes = {inDims, inDims};
113 controller.setShapes(shapes);
114 ASSERT_THROW(controller.applyChanges(), InferenceEngineException);
117 TEST_F(OutputControllerTest, canApplyChanges) {
118 OutputController controller({notEmptyData}, TEST_NAME);
119 controller.setShapes({inDims});
120 ASSERT_NO_THROW(controller.applyChanges());
123 TEST_F(OutputControllerTest, canResetShapes) {
124 OutputController controller({notEmptyData}, TEST_NAME);
125 controller.setShapes({inDims});
126 ASSERT_NO_THROW(controller.reset());
127 ASSERT_TRUE(controller.getShapes(false).begin()->empty());