Merge pull request #16722 from l-bat:reshape_opset_11
authorLiubov Batanina <piccione-mail@yandex.ru>
Wed, 4 Mar 2020 08:27:10 +0000 (11:27 +0300)
committerGitHub <noreply@github.com>
Wed, 4 Mar 2020 08:27:10 +0000 (11:27 +0300)
* Supported Div op for constants

* Added Mul test

modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 6f3ac04..3d7e33a 100644 (file)
@@ -465,31 +465,6 @@ void ONNXImporter::populateNet(Net dstNet)
                 layerParams.blobs.push_back(-1.0f * blob.reshape(1, 1));
             }
         }
-        else if (layer_type == "Div")
-        {
-            if (constBlobs.find(node_proto.input(1)) == constBlobs.end())
-            {
-                layerParams.type = "Eltwise";
-                layerParams.set("operation", "div");
-            }
-            else
-            {
-                Mat blob = getBlob(node_proto, constBlobs, 1);
-                CV_Assert_N(blob.type() == CV_32F, blob.total());
-                if (blob.total() == 1)
-                {
-                    layerParams.set("scale", 1.0f / blob.at<float>(0));
-                    layerParams.type = "Power";
-                }
-                else
-                {
-                    layerParams.type = "Scale";
-                    divide(1.0, blob, blob);
-                    layerParams.blobs.push_back(blob);
-                    layerParams.set("bias_term", false);
-                }
-            }
-        }
         else if (layer_type == "Neg")
         {
             layerParams.type = "Power";
@@ -638,24 +613,58 @@ void ONNXImporter::populateNet(Net dstNet)
             layerParams.set("bias_term", false);
             layerParams.set("num_output", layerParams.blobs[0].size[0]);
         }
-        else if (layer_type == "Mul")
+        else if (layer_type == "Mul" || layer_type == "Div")
         {
             CV_Assert(node_proto.input_size() == 2);
-            if (layer_id.find(node_proto.input(1)) == layer_id.end()) {
-                Mat blob = getBlob(node_proto, constBlobs, 1);
+
+            bool isDiv = layer_type == "Div";
+            int constId = -1;
+            bool haveVariables = false;
+            for (int i = 0; i < 2; ++i)
+            {
+                if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
+                    constId = i;
+                else
+                    haveVariables = true;
+            }
+            if (constId != -1 && haveVariables)
+            {
+                Mat blob = getBlob(node_proto, constBlobs, constId);
                 blob = blob.reshape(1, 1);
                 if (blob.total() == 1) {
-                    layerParams.set("scale", blob.at<float>(0));
+                    float coeff = isDiv ? 1.0 / blob.at<float>(0) : blob.at<float>(0);
+                    layerParams.set("scale", coeff);
                     layerParams.type = "Power";
                 }
                 else {
+                    if (isDiv)
+                        divide(1.0, blob, blob);
                     layerParams.blobs.push_back(blob);
                     layerParams.type = "Scale";
                 }
             }
             else {
                 layerParams.type = "Eltwise";
-                layerParams.set("operation", "prod");
+                layerParams.set("operation", isDiv ? "div" : "prod");
+            }
+
+            if (!haveVariables)
+            {
+                Mat inp0 = getBlob(node_proto, constBlobs, 0);
+                Mat inp1 = getBlob(node_proto, constBlobs, 1);
+                if (inp0.size != inp1.size)
+                    CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes");
+
+                Mat out;
+                if (isDiv)
+                    divide(inp0, inp1, out);
+                else
+                    multiply(inp0, inp1, out);
+
+                out = out.reshape(1, inp0.dims, inp0.size);
+                out.dims = inp0.dims;  // to workaround dims == 1
+                constBlobs.insert(std::make_pair(layerParams.name, out));
+                continue;
             }
         }
         else if (layer_type == "Conv")
index bb7cba1..2838a72 100644 (file)
@@ -382,6 +382,8 @@ TEST_P(Test_ONNX_layers, DynamicReshape)
         if (target == DNN_TARGET_OPENCL)      applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
     }
     testONNXModels("dynamic_reshape");
+    testONNXModels("dynamic_reshape_opset_11");
+    testONNXModels("flatten_by_prod");
 }
 
 TEST_P(Test_ONNX_layers, Reshape)