Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / ie_reshaper.hpp
index 4f18507..834abe3 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
 // SPDX-License-Identifier: Apache-2.0
 //
 
@@ -13,7 +13,7 @@
 
 #include <ie_layers.h>
 #include <ie_context.hpp>
-#include "../ie_network.hpp"
+#include <builders/ie_network_builder.hpp>
 #include "details/caseless.hpp"
 #include "ie_reshape_launcher.hpp"
 #include "ie_icnn_network.hpp"
@@ -60,9 +60,12 @@ public:
      * @param network - const reference to the ICNNNetwork for performing shape inference
      */
     explicit Reshaper(ICNNNetwork& network,
-                      const LauncherCreator::Ptr& creator = std::make_shared<LauncherCreator>());
+            const LauncherCreator::Ptr& creator = std::make_shared<LauncherCreator>());
 
-    Reshaper(const Context& context, details::Network::Ptr& network);
+    explicit Reshaper(std::vector<DataPtr> inputs,
+            const LauncherCreator::Ptr& launcherCreator = std::make_shared<LauncherCreator>());
+
+    Reshaper(Builder::Network* network);
 
     virtual ~Reshaper() = default;
 
@@ -78,6 +81,25 @@ public:
      * @param inputShapes - Map of input names (data) to their input shapes.
      */
     StatusCode run(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp = nullptr);
+
+    /**
+     * @brief Perform shape inference for the given input shapes but not apply it.
+     * In case of cusses call apply() method.
+     * @param inputShapes - Map of input names (data) to their input shapes.
+     * @throws exception if shape infer failed without corruption of original shapes
+     */
+    StatusCode runNoApply(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp = nullptr);
+
+    /**
+     * @brief Apply shapes pre calculated by runNoApply() method.
+     */
+    StatusCode apply(ResponseDesc* resp = nullptr);
+
+    /**
+     * @brief Return newly calculated shape for provided data.
+     */
+    SizeVector getResultShapeFor(DataPtr &data, ResponseDesc* resp = nullptr);
+
 private:
     ReshapeLauncher::Ptr getLauncherByLayerName(const std::string& layerName) const;
 
@@ -91,8 +113,7 @@ private:
     std::set<CNNLayerPtr> _inputLayers{};
     InferenceEngine::details::caseless_set<std::string> _allTypes;
 
-    Context ctx;
-    details::Network::Ptr network;
+    Builder::Network* network;
 };
 
 }  // namespace ShapeInfer