Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / graph_transformer.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 /**
6  * @brief a header file with common functions for graph transformation
7  * @file graph_transformer.h
8  */
9
10 #pragma once
11
12 #include <map>
13 #include <vector>
14 #include <string>
15 #include <ie_icnn_network.hpp>
16 #include <details/caseless.hpp>
17 #include "cnn_network_impl.hpp"
18
19 namespace InferenceEngine {
20
21 /**
22  * @brief TBD
23  */
24 class INFERENCE_ENGINE_API_CLASS(ConstTransformer) {
25 public:
26     explicit ConstTransformer(details::CNNNetworkImpl* _network) {
27         if (!_network) THROW_IE_EXCEPTION << "[ERROR]: Failed to init ConstTransformer with null pointer of network";
28         network = _network;
29         cnnNetwork = CNNNetwork(network);
30     }
31
32     /**
33      * @brief calculates const layers, combines const subgraph into a single const layers
34      */
35     void foldConstSubgraphs();
36
37     /**
38       * @brief folds Const Subgraphs and removes second input of Reshape-like layers (Interp, Gather, Resample, ...)
39       */
40     void fullTrim();
41
42 protected:
43     /**
44      * @brief collect all const layers with marking if it defines shape (1 - for shape, 0 - otherwise)
45      */
46     virtual const std::map<std::string, bool> getConstLayers(const std::vector<CNNLayerPtr>& sortedLayers);
47
48     /**
49      * @brief TBD
50      */
51     virtual const BlobMap
52         getConstData(const std::map<std::string, bool>& constLayers, const std::vector<CNNLayerPtr>& sortedLayers);
53
54     /**
55      * @brief TBD
56      */
57     virtual std::vector<std::string>
58     foldConstSubgraphsInternal(const std::map<std::string, bool>& constLayers, const BlobMap& constData,
59                                const std::vector<CNNLayerPtr>& sortedLayers);
60
61     /**
62      * @brief TBD
63      */
64     virtual void trimShapeInputs(const std::vector<std::string>& constLayers);
65
66 private:
67     const details::caseless_set<std::string> shapeTaking = {"Reshape", "Resample", "Interp"};
68     details::CNNNetworkImpl* network;
69     CNNNetwork cnnNetwork;
70 };
71
72 }  // namespace InferenceEngine