Fix addOutput (#1888)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Fri, 21 Aug 2020 02:51:42 +0000 (05:51 +0300)
committerGitHub <noreply@github.com>
Fri, 21 Aug 2020 02:51:42 +0000 (05:51 +0300)
inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp
inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp

index 8137e1b..2efd5cf 100644 (file)
@@ -260,6 +260,8 @@ StatusCode CNNNetworkNGraphImpl::addOutput(const std::string& layerName, size_t
 void CNNNetworkNGraphImpl::addOutput(const ::ngraph::Output<::ngraph::Node> & output) {
     auto dataName = ngraph::op::util::create_ie_output_name(output);
     DataPtr data;
+    if (_data.count(dataName))
+        data = _data[dataName];
     createDataForResult(output, dataName, data);
     _data[dataName] = data;
     _outputData[dataName] = data;
index 02e138b..11c68af 100644 (file)
@@ -753,4 +753,37 @@ TEST(CNNNGraphImplTests, addOutput) {
     ASSERT_EQ(outputs["reshape"]->getLayout(), InferenceEngine::Layout::NCHW);
 }
 
+TEST(CNNNGraphImplTests, addOutputForParameter) {
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        ngraph::PartialShape shape({1, 3, 22, 22});
+        ngraph::element::Type type(ngraph::element::Type_t::f32);
+        auto param = std::make_shared<ngraph::opset3::Parameter>(type, shape);
+        param->set_friendly_name("param");
+        auto relu = std::make_shared<ngraph::opset3::Relu>(param);
+        auto result = std::make_shared<ngraph::op::Result>(relu);
+
+        ngraph::ParameterVector params = {param};
+        ngraph::ResultVector results = {result};
+
+        ngraph = std::make_shared<ngraph::Function>(results, params);
+    }
+
+    CNNNetwork cnnNetwork(ngraph);
+    cnnNetwork.addOutput("param");
+    {
+        auto output_info = cnnNetwork.getOutputsInfo();
+        ASSERT_EQ(output_info.count("param"), 1);
+        ASSERT_EQ(output_info["param"]->getTensorDesc().getDims(), SizeVector({1, 3, 22, 22}));
+    }
+
+    cnnNetwork.reshape({{"param", SizeVector({1, 3, 32, 64})}});
+    cnnNetwork.addOutput("param");
+    {
+        auto output_info = cnnNetwork.getOutputsInfo();
+        ASSERT_EQ(output_info.count("param"), 1);
+        ASSERT_EQ(output_info["param"]->getTensorDesc().getDims(), SizeVector({1, 3, 32, 64}));
+    }
+}
+
 IE_SUPPRESS_DEPRECATED_END