Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / mocks / shape_infer / mock_reshaper_launcher.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <gmock/gmock.h>
8
9 #include <inference_engine/shape_infer/ie_reshape_launcher.hpp>
10 #include <inference_engine/shape_infer/ie_reshape_io_controllers.hpp>
11 #include <shape_infer/mock_ishape_infer_impl.hpp>
12 #include <shape_infer/mock_input_controller.hpp>
13 #include <shape_infer/mock_output_controller.hpp>
14
15
16 using namespace InferenceEngine;
17 using namespace ShapeInfer;
18
19 class MockReshapeLauncher : public ReshapeLauncher {
20 public:
21     using Ptr = std::shared_ptr<MockReshapeLauncher>;
22     class TestLauncherInitializer : public DefaultInitializer {
23     public:
24         void check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) override {}
25
26         InputController* createInputController(const CNNLayer* layer) override {
27             if (!_iController) {
28                 std::vector<DataPtr> data;
29                 if (layer) {
30                     for (auto const& insData: layer->insData) {
31                         data.push_back(insData.lock());
32                     }
33                 }
34                 _iController = new MockInputController(data);
35             }
36             return _iController;
37         }
38
39         OutputController* createOutputController(const CNNLayer* layer) override {
40             if (!_oController) {
41                 std::vector<DataPtr> data;
42                 if (layer) data = layer->outData;;
43                 _oController = new MockOutputController(data);
44             }
45             return _oController;
46         }
47
48         MockInputController* getInputController() {
49             return _iController;
50         }
51
52         MockOutputController* getOutputController() {
53             return _oController;
54         }
55
56     private:
57         MockInputController* _iController;
58         MockOutputController* _oController;
59     };
60
61     MockReshapeLauncher(const DefaultInitializer::Ptr& initializer = std::make_shared<TestLauncherInitializer>(),
62                         const CNNLayer* layer = nullptr,
63                         const IShapeInferImpl::Ptr& impl = std::make_shared<MockIShapeInferImpl>())
64             : ReshapeLauncher(layer, impl, initializer) {}
65
66     MOCK_METHOD2(setShapeByName, void(const SizeVector&, const std::string&));
67
68     MOCK_METHOD1(reshape, void(const std::set<ReshapeLauncher::Ptr>&));
69
70     MOCK_METHOD1(applyChanges, void(CNNLayer*));
71
72     MOCK_METHOD0(reset, void());
73
74     MOCK_QUALIFIED_METHOD0(getLayerName, const, std::string());
75
76     MOCK_METHOD1(setShapeInferImpl, void(const IShapeInferImpl::Ptr& ));
77
78     void realReset() {
79         ReshapeLauncher::reset();
80     }
81
82     void realReshape() {
83         ReshapeLauncher::reshape({});
84     }
85
86     std::string realGetLayerName() {
87         return ReshapeLauncher::getLayerName();
88     }
89 };
90