1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
7 #include <gmock/gmock.h>
9 #include <inference_engine/shape_infer/ie_reshape_launcher.hpp>
10 #include <inference_engine/shape_infer/ie_reshape_io_controllers.hpp>
11 #include <shape_infer/mock_ishape_infer_impl.hpp>
12 #include <shape_infer/mock_input_controller.hpp>
13 #include <shape_infer/mock_output_controller.hpp>
16 using namespace InferenceEngine;
17 using namespace ShapeInfer;
19 class MockReshapeLauncher : public ReshapeLauncher {
21 using Ptr = std::shared_ptr<MockReshapeLauncher>;
22 class TestLauncherInitializer : public DefaultInitializer {
24 void check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) override {}
26 InputController* createInputController(const CNNLayer* layer) override {
28 std::vector<DataPtr> data;
30 for (auto const& insData: layer->insData) {
31 data.push_back(insData.lock());
34 _iController = new MockInputController(data);
39 OutputController* createOutputController(const CNNLayer* layer) override {
41 std::vector<DataPtr> data;
42 if (layer) data = layer->outData;;
43 _oController = new MockOutputController(data);
48 MockInputController* getInputController() {
52 MockOutputController* getOutputController() {
57 MockInputController* _iController;
58 MockOutputController* _oController;
61 MockReshapeLauncher(const DefaultInitializer::Ptr& initializer = std::make_shared<TestLauncherInitializer>(),
62 const CNNLayer* layer = nullptr,
63 const IShapeInferImpl::Ptr& impl = std::make_shared<MockIShapeInferImpl>())
64 : ReshapeLauncher(layer, impl, initializer) {}
66 MOCK_METHOD2(setShapeByName, void(const SizeVector&, const std::string&));
68 MOCK_METHOD1(reshape, void(const std::set<ReshapeLauncher::Ptr>&));
70 MOCK_METHOD1(applyChanges, void(CNNLayer*));
72 MOCK_METHOD0(reset, void());
74 MOCK_QUALIFIED_METHOD0(getLayerName, const, std::string());
76 MOCK_METHOD1(setShapeInferImpl, void(const IShapeInferImpl::Ptr& ));
79 ReshapeLauncher::reset();
83 ReshapeLauncher::reshape({});
86 std::string realGetLayerName() {
87 return ReshapeLauncher::getLayerName();