Corrected DNN elementwise multiplication
authorAnastasia Murzova <anastasia.murzova@xperience.ai>
Mon, 22 Mar 2021 19:37:49 +0000 (22:37 +0300)
committerAnastasia Murzova <anastasia.murzova@xperience.ai>
Wed, 24 Mar 2021 07:53:11 +0000 (10:53 +0300)
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index 53d62fc9f7f70d84bf8c705127ae8badd65ff3a3..bdab0663a77ca43b29c1955cf9c716b45db8e5c9 100644 (file)
@@ -12,6 +12,7 @@ Implementation of Tensorflow models parser
 #include "../precomp.hpp"
 
 #include <opencv2/core/utils/logger.defines.hpp>
+#include <opencv2/dnn/shape_utils.hpp>
 #undef CV_LOG_STRIP_LEVEL
 #define CV_LOG_STRIP_LEVEL CV_LOG_LEVEL_DEBUG + 1
 #include <opencv2/core/utils/logger.hpp>
@@ -1825,6 +1826,7 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
             {
                 // Check if all the inputs have the same shape.
                 bool equalInpShapes = true;
+                bool isShapeOnes = false;
                 MatShape outShape0;
                 for (int ii = 0; ii < num_inputs && !netInputShapes.empty(); ii++)
                 {
@@ -1845,12 +1847,14 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
                     else if (outShape != outShape0)
                     {
                         equalInpShapes = false;
+                        isShapeOnes = isAllOnes(outShape, 2, outShape.size()) ||
+                                      isAllOnes(outShape0, 2, outShape0.size());
                         break;
                     }
                 }
 
                 int id;
-                if (equalInpShapes || netInputShapes.empty())
+                if (equalInpShapes || netInputShapes.empty() || (!equalInpShapes && isShapeOnes))
                 {
                     layerParams.set("operation", type == "RealDiv" ? "div" : "prod");
                     id = dstNet.addLayer(name, "Eltwise", layerParams);
index 62a559a672643f0f8a9be1c3c973aa45ce1ee363..5e45a5c0f0ad3ba4bfae30e7dcd786e49b8cad25 100644 (file)
@@ -210,6 +210,12 @@ TEST_P(Test_TensorFlow_layers, eltwise_add_vec)
     runTensorFlowNet("eltwise_add_vec");
 }
 
+TEST_P(Test_TensorFlow_layers, eltwise_mul_vec)
+{
+    runTensorFlowNet("eltwise_mul_vec");
+}
+
+
 TEST_P(Test_TensorFlow_layers, channel_broadcast)
 {
     if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)