Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / graph_transformer.h
index 9d8014d..d984535 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
 // SPDX-License-Identifier: Apache-2.0
 //
 
@@ -9,16 +9,64 @@
 
 #pragma once
 
+#include <map>
+#include <vector>
+#include <string>
 #include <ie_icnn_network.hpp>
+#include <details/caseless.hpp>
+#include "cnn_network_impl.hpp"
 
 namespace InferenceEngine {
 
 /**
- * @brief Replaces layer with newLayer in network
- * @param network  - graph containing the layer
- * @param layer    - layer which need to replace
- * @param newLayer - new layer instead of layer; it must have same name like a layer for replace
+ * @brief TBD
  */
-void replaceLayerWithNewLayer(ICNNNetwork &network, const CNNLayerPtr &layer, const CNNLayerPtr &newLayer);
+class INFERENCE_ENGINE_API_CLASS(ConstTransformer) {
+public:
+    explicit ConstTransformer(details::CNNNetworkImpl* _network) {
+        if (!_network) THROW_IE_EXCEPTION << "[ERROR]: Failed to init ConstTransformer with null pointer of network";
+        network = _network;
+        cnnNetwork = CNNNetwork(network);
+    }
+
+    /**
+     * @brief calculates const layers, combines const subgraph into a single const layers
+     */
+    void foldConstSubgraphs();
+
+    /**
+      * @brief folds Const Subgraphs and removes second input of Reshape-like layers (Interp, Gather, Resample, ...)
+      */
+    void fullTrim();
+
+protected:
+    /**
+     * @brief collect all const layers with marking if it defines shape (1 - for shape, 0 - otherwise)
+     */
+    virtual const std::map<std::string, bool> getConstLayers(const std::vector<CNNLayerPtr>& sortedLayers);
+
+    /**
+     * @brief TBD
+     */
+    virtual const BlobMap
+        getConstData(const std::map<std::string, bool>& constLayers, const std::vector<CNNLayerPtr>& sortedLayers);
+
+    /**
+     * @brief TBD
+     */
+    virtual std::vector<std::string>
+    foldConstSubgraphsInternal(const std::map<std::string, bool>& constLayers, const BlobMap& constData,
+                               const std::vector<CNNLayerPtr>& sortedLayers);
+
+    /**
+     * @brief TBD
+     */
+    virtual void trimShapeInputs(const std::vector<std::string>& constLayers);
+
+private:
+    const details::caseless_set<std::string> shapeTaking = {"Reshape", "Resample", "Interp"};
+    details::CNNNetworkImpl* network;
+    CNNNetwork cnnNetwork;
+};
 
 }  // namespace InferenceEngine