Fixed set/get batch size logic for scalar inputs (#1837)
authorIlya Churaev <ilya.churaev@intel.com>
Wed, 19 Aug 2020 03:37:04 +0000 (06:37 +0300)
committerGitHub <noreply@github.com>
Wed, 19 Aug 2020 03:37:04 +0000 (06:37 +0300)
inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp
inference-engine/src/legacy_api/src/cnn_network_impl.cpp
inference-engine/tests/functional/inference_engine/cnn_network/cnn_ngraph_impl_tests.cpp

index 9fd5d73..f4fbad4 100644 (file)
@@ -275,15 +275,15 @@ size_t CNNNetworkNGraphImpl::getBatchSize() const noexcept {
         return cnnNetwork->getBatchSize();
     }
     auto params = _ngraph_function->get_parameters();
-    if (params.empty() || !params[0]->get_partial_shape().is_static()) return 0;
-
-    auto shape = _ngraph_function->get_parameters()[0]->get_shape();
-
-    // WA: for speech recognition layouts (copy-past from CNNNetwork)
-    if (shape.size() == 3 || shape.size() == 1) {
-        return 1;
+    for (const auto& param : params) {
+        if (param->get_partial_shape().is_dynamic())
+            continue;
+        auto shape = param->get_shape();
+        // WA: for speech recognition and scalar layouts (copy-past from CNNNetwork)
+        if (!shape.empty() && shape.size() != 3 && shape.size() != 1)
+            return shape[0];
     }
-    return shape[0];
+    return 1;
 }
 
 std::shared_ptr<ngraph::Function> CNNNetworkNGraphImpl::cloneFunction(bool constFolding) const {
index 7358cde..a8d0a6f 100644 (file)
@@ -435,9 +435,9 @@ StatusCode CNNNetworkImpl::setBatchSize(size_t size, ResponseDesc* responseDesc)
 
         SizeVector dims = _inputData.cbegin()->second->getTensorDesc().getDims();
 
-        // 3D input layout doesn't have batch notation
-        if (dims.size() == 3 || dims.size() == 1) {
-            return DescriptionBuffer(PARAMETER_MISMATCH, responseDesc) << "Cannot set batch for 1D/3D input";
+        // 3D/1D/0D input layouts don't have batch notation
+        if (dims.size() == 3 || dims.size() == 1 || dims.empty()) {
+            return DescriptionBuffer(PARAMETER_MISMATCH, responseDesc) << "Cannot set batch for 0D/1D/3D input";
         }
 
         const std::map<CNNLayer*, bool> layersMap = getConstLayersMap(*this);
index e9aea78..00cfa6d 100644 (file)
@@ -160,6 +160,45 @@ TEST(CNNNGraphImplTests, TestSetBatch) {
     ASSERT_EQ(nullptr, cnnNet.getFunction());
 }
 
+TEST(CNNNGraphImplTests, TestGetBatchScalar) {
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        ngraph::Shape shape({});
+        ngraph::element::Type type(ngraph::element::Type_t::f32);
+        auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
+        auto relu = std::make_shared<ngraph::op::Relu>(param);
+        auto result = std::make_shared<ngraph::op::Result>(relu);
+
+        ngraph::ParameterVector params = {param};
+        ngraph::ResultVector results = {result};
+
+        ngraph = std::make_shared<ngraph::Function>(results, params);
+    }
+
+    InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
+    ASSERT_EQ(1, cnnNet.getBatchSize());
+}
+
+TEST(CNNNGraphImplTests, TestSetBatchScalar) {
+    std::shared_ptr<ngraph::Function> ngraph;
+    {
+        ngraph::Shape shape({});
+        ngraph::element::Type type(ngraph::element::Type_t::f32);
+        auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
+        auto relu = std::make_shared<ngraph::op::Relu>(param);
+        auto result = std::make_shared<ngraph::op::Result>(relu);
+
+        ngraph::ParameterVector params = {param};
+        ngraph::ResultVector results = {result};
+
+        ngraph = std::make_shared<ngraph::Function>(results, params);
+    }
+
+    InferenceEngine::details::CNNNetworkNGraphImpl cnnNet(ngraph);
+    ASSERT_EQ(1, cnnNet.getBatchSize());
+    ASSERT_EQ(PARAMETER_MISMATCH, cnnNet.setBatchSize(2, nullptr));  // triggers conversion
+}
+
 TEST(CNNNGraphImplTests, TestSaveAffinity) {
     const std::string testAffinity = "testAffinity";
     std::shared_ptr<ngraph::Function> ngraph;