From 5e57bdc429c9752543826bfa2f692ec83d94c483 Mon Sep 17 00:00:00 2001 From: Tomasz Socha Date: Mon, 17 Aug 2020 21:14:39 +0200 Subject: [PATCH] [FIX] Fix data layout for reshaped network (#1748) * [FIX] Fix data layout for reshaped network * [PATCH] Don't change compatible layouts * Add UT for reshaped network * FIX no. 1 --- .../inference_engine/cnn_network_ngraph_impl.cpp | 24 ++++++++++++++++++++-- .../inference_engine/ngraph_reshape_tests.cpp | 24 ++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp index 232d61c..9fd5d73 100644 --- a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp +++ b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp @@ -67,6 +67,24 @@ CNNNetwork::CNNNetwork(const std::shared_ptr& graph) { void CNNNetworkNGraphImpl::createDataForResult(const ::ngraph::Output<::ngraph::Node>& output, const std::string& outName, DataPtr& ptr) { + const auto isCompatible = [](size_t size, const Layout& l) -> bool { + switch (size) { + case 0: + return l == Layout::SCALAR; + case 1: + return l == Layout::C; + case 2: + return l == Layout::CN || l == Layout::HW || l == Layout::NC; + case 3: + return l == Layout::CHW; + case 4: + return l == Layout::NCHW || l == Layout::NHWC; + case 5: + return l == Layout::NCDHW || l == Layout::NDHWC; + default: + return false; + } + }; // query shape from ngraph::Parameter output shape and check there are no zeros in it SizeVector dims; if (output.get_partial_shape().is_static()) { @@ -78,10 +96,12 @@ void CNNNetworkNGraphImpl::createDataForResult(const ::ngraph::Output<::ngraph:: } if (ptr) { - ptr->reshape(dims, ptr->getTensorDesc().getLayout()); + const auto origLayout = ptr->getTensorDesc().getLayout(); + const auto layout = isCompatible(dims.size(), origLayout) ? origLayout : TensorDesc::getLayoutByDims(dims); + ptr->reshape(dims, layout); } else { - const auto precision = details::convertPrecision(output.get_element_type()); const auto layout = TensorDesc::getLayoutByDims(dims); + const auto precision = details::convertPrecision(output.get_element_type()); ptr.reset(new Data(outName, {precision, dims, layout})); } } diff --git a/inference-engine/tests/functional/inference_engine/ngraph_reshape_tests.cpp b/inference-engine/tests/functional/inference_engine/ngraph_reshape_tests.cpp index 8600e16..3657c43 100644 --- a/inference-engine/tests/functional/inference_engine/ngraph_reshape_tests.cpp +++ b/inference-engine/tests/functional/inference_engine/ngraph_reshape_tests.cpp @@ -59,6 +59,30 @@ TEST_F(NGraphReshapeTests, getBatchSize) { ASSERT_EQ(1, cnnNetwork.getBatchSize()); } +TEST_F(NGraphReshapeTests, ReshapedDynamicShapeLayout) { + std::shared_ptr ngraph; + { + ngraph::PartialShape shape({-1, 3, 22, 22}); + ngraph::element::Type type(ngraph::element::Type_t::f32); + auto param = std::make_shared(type, shape); + param->set_friendly_name("A"); + auto relu = std::make_shared(param); + + ngraph::ParameterVector params = {param}; + + ngraph = std::make_shared(relu, params); + } + + CNNNetwork cnnNetwork(ngraph); + ASSERT_EQ(Layout::SCALAR, cnnNetwork.getInputsInfo()["A"]->getLayout()); + + ICNNNetwork::InputShapes new_shape; + new_shape["A"] = ngraph::Shape{1, 3, 22, 22}; + cnnNetwork.reshape(new_shape); + + ASSERT_EQ(Layout::NCHW, cnnNetwork.getInputsInfo()["A"]->getLayout()); +} + TEST_F(NGraphReshapeTests, ReshapeBatchReLU) { std::shared_ptr ngraph; { -- 2.7.4