Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / util_const_infer_test.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <gtest/gtest.h>
8
9 #include <initializer_list>
10 #include <string>
11 #include <utility>
12 #include <unordered_set>
13 #include <unordered_map>
14
15 #include <ie_util_internal.hpp>
16 #include <tests_common.hpp>
17 #include <graph_transformer.h>
18 #include "ie_utils.hpp"
19 #include "blob_factory.hpp"
20 #include "debug.h"
21 #include "util_test.hpp"
22 #include <details/ie_cnn_network_tools.h>
23
24 namespace IE = InferenceEngine;
25
26 class ConstTransformatorTest : public IE::ConstTransformer {
27 public:
28     explicit ConstTransformatorTest(IE::details::CNNNetworkImpl* network) : IE::ConstTransformer(network) {}
29
30     const std::map<std::string, bool>
31     getConstLayers(const std::vector<InferenceEngine::CNNLayerPtr>& sortedLayers) override {
32         return ConstTransformer::getConstLayers(sortedLayers);
33     }
34
35     const InferenceEngine::BlobMap getConstData(const std::map<std::string, bool>& constLayers,
36                                                     const std::vector<InferenceEngine::CNNLayerPtr>& sortedLayers) override {
37         return ConstTransformer::getConstData(constLayers, sortedLayers);
38     }
39
40     std::vector<std::string>
41     foldConstSubgraphsInternal(const std::map<std::string, bool>& constLayers, const IE::BlobMap& constData,
42                                const std::vector<IE::CNNLayerPtr>& sortedLayers) override {
43         return ConstTransformer::foldConstSubgraphsInternal(constLayers, constData, sortedLayers);
44     }
45
46     void trimShapeInputs(const std::vector<std::string>& constLayers) override {
47         ConstTransformer::trimShapeInputs(constLayers);
48     }
49
50 };
51
52 class RemoveLayerTests : public testing::Test {
53 protected:
54     void SetUp() override;
55
56     //
57     // I1-d1-L1-d4              I4
58     //       / \  \              \
59     //      |  d7  \            d10
60     //      |  |    \            /
61     //  I2-d2-L2-d5-L4-d6-L5-d9-L10
62     //        /           /
63     //       /  ____d8___/
64     //      /  /
65     // I3-d3-L3
66     //
67     IE::details::CNNNetworkImplPtr getNetwork();
68
69     IE::CNNLayerPtr getLayer(const std::string& name);
70
71     IE::DataPtr getData(const std::string& name);
72
73     IE::BlobMap fillConstData(const std::vector<std::string>& constLayers);
74
75     IE::BlobMap initConstLayers(const std::vector<std::string>& constLayers);
76
77     NetBuilder netBuilder;
78     IE::details::CNNNetworkImplPtr net;
79     size_t originalLayersNum;
80     std::unique_ptr<ConstTransformatorTest> testTransformator;
81 };
82
83 class AdvancedShapeInferTests : public RemoveLayerTests {
84 protected:
85     void SetUp() override {};
86 };