fix scale layer can not handle 1x1 weight correctly.
authorZihao Mu <zihaomu@outlook.com>
Wed, 13 Jul 2022 03:25:27 +0000 (11:25 +0800)
committerZihao Mu <zihaomu@outlook.com>
Wed, 13 Jul 2022 03:25:27 +0000 (11:25 +0800)
modules/dnn/src/layers/scale_layer.cpp
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 594b0bb..406d27b 100644 (file)
@@ -91,6 +91,16 @@ public:
         if (hasWeights && hasBias)
             CV_CheckEQ(weights.total(), bias.total(), "Incompatible weights/bias blobs");
 
+        if (weights.total() == 1)
+        {
+            // The total() of bias should be same as weights.
+            if (hasBias)
+                inpBlob.convertTo(outBlob, CV_32F, weights.at<float>(0), bias.at<float>(0));
+            else
+                inpBlob.convertTo(outBlob, CV_32F, weights.at<float>(0));
+            return;
+        }
+
         int endAxis;
         for (endAxis = axis + 1; endAxis <= inpBlob.dims; ++endAxis)
         {
index 7a0532f..8221855 100644 (file)
@@ -1818,6 +1818,8 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
 
 void findBroadAxis(const MatShape& broadShape, const MatShape& outShape, size_t& axis, int& broadAxis)
 {
+    // Currently, this function can only complete 1-dimensional expansion of broadShape.
+    // If there are two dimensions in broadShape that need to be expended, it will fail.
     const size_t diff = outShape.size() - broadShape.size();
 
     // find the first non-one element of the broadcasting shape
@@ -1982,25 +1984,30 @@ void ONNXImporter::parseMul(LayerParams& layerParams, const opencv_onnx::NodePro
         const MatShape& outShape = outShapes[node_proto.input(0)];
 
         size_t axis = 0;
-        int broadAxis = -1;
-        findBroadAxis(broadShape, outShape, axis, broadAxis);
-
-        // if there is a one dimension in the middle that should be broadcasted, broadcast it
-        if (broadAxis != -1)
+        if (total(broadShape) != 1)
         {
-            opencv_onnx::NodeProto concat_node_proto = node_proto;
-            const std::string& input1 = concat_node_proto.input(1);
+            // If broadShape is a scalar, we set axis as 0.
+            // Other-wise, we check broadcast is available.
+            int broadAxis = -1;
+            findBroadAxis(broadShape, outShape, axis, broadAxis);
+
+            // if there is a one dimension in the middle that should be broadcasted, broadcast it
+            if (broadAxis != -1)
+            {
+                opencv_onnx::NodeProto concat_node_proto = node_proto;
+                const std::string& input1 = concat_node_proto.input(1);
 
-            expandMid(layerParams.name, concat_node_proto, input1, outShape[broadAxis]);
+                expandMid(layerParams.name, concat_node_proto, input1, outShape[broadAxis]);
 
-            LayerParams concatLP;
-            concatLP.name = layerParams.name + "/concat";
-            concatLP.set("axis", broadAxis);
-            concatLP.type = "Concat";
-            concat_node_proto.set_output(0, concatLP.name);
+                LayerParams concatLP;
+                concatLP.name = layerParams.name + "/concat";
+                concatLP.set("axis", broadAxis);
+                concatLP.type = "Concat";
+                concat_node_proto.set_output(0, concatLP.name);
 
-            addLayer(concatLP, concat_node_proto);
-            node_proto.set_input(1, concatLP.name);
+                addLayer(concatLP, concat_node_proto);
+                node_proto.set_input(1, concatLP.name);
+            }
         }
 
         CV_Assert(axis != outShape.size());
index 56203cb..3d1dd38 100644 (file)
@@ -725,6 +725,8 @@ TEST_P(Test_ONNX_layers, Div)
 
     normAssert(ref, out, "", default_l1,  default_lInf);
     expectNoFallbacksFromIE(net);
+
+    testONNXModels("div_test_1x1",npy, 0, 0, false, true, 2);
 }
 
 TEST_P(Test_ONNX_layers, DynamicReshape)