Avoid redundant clone and reshape (#1376)
authorIlya Churaev <ilya.churaev@intel.com>
Wed, 29 Jul 2020 16:30:59 +0000 (19:30 +0300)
committerGitHub <noreply@github.com>
Wed, 29 Jul 2020 16:30:59 +0000 (19:30 +0300)
* Avoid redundant clone and reshape

* Removed some constructors

* Fixed output precision

inference-engine/include/cpp/ie_cnn_network.h
inference-engine/src/hetero_plugin/hetero_executable_network.cpp
inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp
inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp
inference-engine/src/legacy_api/src/ie_util_internal.cpp
inference-engine/tests/functional/inference_engine/cnn_network_test.cpp
inference-engine/tests/functional/inference_engine/ngraph_reshape_tests.cpp

index f71dd32..e1b2c71 100644 (file)
@@ -51,9 +51,11 @@ public:
 
     /**
      * @brief A constructor from ngraph::Function object
+     * This constructor wraps existing ngraph::Function
+     * If you want to avoid modification of original Function, please create a copy
      * @param network Pointer to the ngraph::Function object
      */
-    explicit CNNNetwork(const std::shared_ptr<const ngraph::Function>& network);
+    explicit CNNNetwork(const std::shared_ptr<ngraph::Function>& network);
 
     /**
      * @brief A destructor
index e918612..36e027b 100644 (file)
@@ -143,7 +143,7 @@ void dumpGraph(InferenceEngine::ICNNNetwork &network,
 
 
 void dumpGraph(InferenceEngine::ICNNNetwork&                                network,
-               const std::vector<std::shared_ptr<const ngraph::Function>>&  subFunctions,
+               const std::vector<std::shared_ptr<ngraph::Function>>&  subFunctions,
                std::ostream&                                                stream) {
     static const std::array<const char *, 9> colors{{"#FFC405",
                                                      "#20F608",
@@ -665,13 +665,13 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
     InputsDataMap externalInputsData;
     network.getInputsInfo(externalInputsData);
     networks.resize(orderedSubgraphs.size());
-    std::vector<std::shared_ptr<const ngraph::Function>> subFunctions(orderedSubgraphs.size());
+    std::vector<std::shared_ptr<ngraph::Function>> subFunctions(orderedSubgraphs.size());
     std::vector<bool> isInputSubnetwork(orderedSubgraphs.size());
     int id = 0;
     for (auto&& subgraph : orderedSubgraphs) {
         networks[id]._device = subgraph._affinity;
         subFunctions[id] =
-            std::make_shared<const ngraph::Function>(subgraph._results, subgraph._parameters,
+            std::make_shared<ngraph::Function>(subgraph._results, subgraph._parameters,
                                                      _name + '_' + std::to_string(id));
         networks[id]._clonedNetwork = CNNNetwork{subFunctions[id]};
         // update of pre-processing info
index 8a5edef..678a303 100644 (file)
@@ -71,14 +71,13 @@ static std::shared_ptr<ngraph::Function> copyFunction(const std::shared_ptr<cons
     return specialized_function;
 }
 
-// WA: for cnnNetwork ngraph constructor
-CNNNetwork::CNNNetwork(const std::shared_ptr<const ngraph::Function>& graph) {
+CNNNetwork::CNNNetwork(const std::shared_ptr<ngraph::Function>& graph) {
     if (graph == nullptr) {
         THROW_IE_EXCEPTION << "CNNNetwork was not initialized: 'graph' object is empty";
     }
 
-    // Copy nGraph function
-    network = std::make_shared<CNNNetworkNGraphImpl>(copyFunction(graph, false, {}));
+    // Create CNNNetworkNGraphImpl
+    network = std::make_shared<CNNNetworkNGraphImpl>(graph);
     actual = network.get();
     if (actual == nullptr) {
         THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
@@ -146,6 +145,36 @@ CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const std::shared_ptr<Function>& nGra
     }
 }
 
+CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const ICNNNetwork& network) {
+    if (network.getFunction() == nullptr) {
+        THROW_IE_EXCEPTION << "Cannot create CNNNetwork with nGraph from legacy network format!";
+    }
+
+    _ngraph_function = copyFunction(network.getFunction(), false, {});
+    InputsDataMap inputs;
+    OutputsDataMap outputs;
+    network.getInputsInfo(inputs);
+    network.getOutputsInfo(outputs);
+
+    for (const auto& outputInfo : outputs) {
+        const auto& name = outputInfo.second->getName();
+        DataPtr output = std::make_shared<Data>(name, outputInfo.second->getTensorDesc());
+        _outputData[name] = output;
+        _data[name] = output;
+    }
+    for (const auto& inputInfo : inputs) {
+        InputInfo::Ptr info = std::make_shared<InputInfo>();
+        const auto& name = inputInfo.second->getInputData()->getName();
+        DataPtr input = std::make_shared<Data>(name, inputInfo.second->getInputData()->getTensorDesc());
+        _data[name] = input;
+        info->setInputData(input);
+        info->getPreProcess() = inputInfo.second->getPreProcess();
+        info->setPrecision(inputInfo.second->getPrecision());
+        info->setLayout(inputInfo.second->getLayout());
+        _inputData[name] = info;
+    }
+}
+
 void CNNNetworkNGraphImpl::setInputInfo(InputInfo::Ptr data) {
     if (cnnNetwork) cnnNetwork->setInputInfo(data);
     _inputData[data->name()] = data;
index 7b042c7..e776093 100644 (file)
@@ -43,6 +43,7 @@ namespace details {
 class INFERENCE_ENGINE_API_CLASS(CNNNetworkNGraphImpl): public ICNNNetwork {
 public:
     CNNNetworkNGraphImpl(const std::shared_ptr<::ngraph::Function>& nGraph);
+    CNNNetworkNGraphImpl(const ICNNNetwork& nGraph);
     ~CNNNetworkNGraphImpl() override = default;
 
     void getOutputsInfo(std::map<std::string, DataPtr>& out) const noexcept override;
index c8b4767..9fa73ea 100644 (file)
@@ -24,6 +24,7 @@
 #include "graph_tools.hpp"
 #include "net_pass.h"
 #include "precision_utils.h"
+#include "cnn_network_ngraph_impl.hpp"
 
 using std::string;
 
@@ -148,30 +149,8 @@ CNNLayerPtr clonelayer(const CNNLayer& source) {
 }
 
 std::shared_ptr<ICNNNetwork> cloneNetwork(const ICNNNetwork& network) {
-    if (auto func = network.getFunction()) {
-        CNNNetwork net(func);
-
-        InputsDataMap originInputs;
-        OutputsDataMap originOutputs;
-        network.getInputsInfo(originInputs);
-        network.getOutputsInfo(originOutputs);
-        InputsDataMap clonedInputs = net.getInputsInfo();
-        OutputsDataMap clonedOutputs = net.getOutputsInfo();
-
-        for (const auto& outputInfo : originOutputs) {
-            if (clonedOutputs.find(outputInfo.first) == clonedOutputs.end())
-                THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all outputs";
-            clonedOutputs[outputInfo.first]->setPrecision(outputInfo.second->getPrecision());
-            clonedOutputs[outputInfo.first]->setLayout(outputInfo.second->getLayout());
-        }
-        for (const auto& inputInfo : originInputs) {
-            if (clonedInputs.find(inputInfo.first) == clonedInputs.end())
-                THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all inputs";
-            clonedInputs[inputInfo.first]->setPrecision(inputInfo.second->getPrecision());
-            clonedInputs[inputInfo.first]->setLayout(inputInfo.second->getLayout());
-            clonedInputs[inputInfo.first]->getPreProcess() = inputInfo.second->getPreProcess();
-        }
-        return net;
+    if (network.getFunction()) {
+        return std::make_shared<details::CNNNetworkNGraphImpl>(network);
     }
 
     return cloneNet(network);
index c506949..c07a599 100644 (file)
@@ -15,7 +15,7 @@ TEST_F(CNNNetworkTests, throwsOnInitWithNull) {
 }
 
 TEST_F(CNNNetworkTests, throwsOnInitWithNullNgraph) {
-    std::shared_ptr<const ngraph::Function> nlptr = nullptr;
+    std::shared_ptr<ngraph::Function> nlptr = nullptr;
     ASSERT_THROW(CNNNetwork network(nlptr), InferenceEngine::details::InferenceEngineException);
 }
 
index cd7fb7a..f02993e 100644 (file)
@@ -21,6 +21,7 @@
 #include <ngraph/op/relu.hpp>
 #include <ngraph/op/result.hpp>
 #include <ngraph/opsets/opset.hpp>
+#include <ngraph/graph_util.hpp>
 
 #include <ie_util_internal.hpp>
 #include <ie_core.hpp>
@@ -121,6 +122,39 @@ TEST_F(NGraphReshapeTests, ReshapeSpatialReLU) {
 }
 
 TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) {
+    std::shared_ptr<const ngraph::Function> ngraph;
+    {
+        ngraph::PartialShape shape({1, 3, 22, 22});
+        ngraph::element::Type type(ngraph::element::Type_t::f32);
+        auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
+        param->set_friendly_name("data");
+        auto relu = std::make_shared<ngraph::op::Relu>(param);
+        auto result = std::make_shared<ngraph::op::Result>(relu);
+
+        ngraph::ParameterVector params = {param};
+        ngraph::ResultVector results = {result};
+
+        ngraph = std::make_shared<const ngraph::Function>(results, params);
+    }
+
+    ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
+    ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
+
+    CNNNetwork cnnNetwork(ngraph::clone_function(*ngraph));
+    std::map<std::string, std::vector<size_t>> shapes;
+    shapes["data"] = {1, 3, 25, 25};
+
+    ASSERT_NO_THROW(cnnNetwork.reshape(shapes));
+
+    auto changedFunction = cnnNetwork.getFunction();
+    ASSERT_NE(nullptr, changedFunction);
+    ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
+    ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
+    ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
+    ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
+}
+
+TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLUWithoutCloneFunction) {
     std::shared_ptr<ngraph::Function> ngraph;
     {
         ngraph::PartialShape shape({1, 3, 22, 22});
@@ -149,8 +183,8 @@ TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) {
     ASSERT_NE(nullptr, changedFunction);
     ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
     ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
-    ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
-    ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
+    ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
+    ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
 }
 
 class CustomTestOp: public ngraph::op::Op {