Merge remote-tracking branch 'upstream/3.4' into merge-3.4
[platform/upstream/opencv.git] / modules / dnn / src / onnx / onnx_importer.cpp
index 859b595..6c106e2 100644 (file)
@@ -1162,6 +1162,53 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
                     layerParams.type = "Scale";
                 }
             }
+            else if (!haveVariables)
+            {
+                Mat inp0 = getBlob(node_proto, 0);
+                Mat inp1 = getBlob(node_proto, 1);
+
+                if (inp0.size != inp1.size && (inp0.total() != 1 || inp1.total() != 1))
+                    CV_Error_(Error::StsNotImplemented, ("Different shapes case is not supported with constant inputs: %s", layer_type.c_str()));
+
+                if (inp0.total() == 1 && inp1.total() == 1 && inp0.dims != inp1.dims)
+                {
+                    if (inp0.dims < inp1.dims)
+                    {
+                        inp0 = inp0.reshape(1, inp1.dims, inp1.size);
+                        inp0.dims = inp1.dims;
+                    }
+                    else
+                    {
+                        inp1 = inp1.reshape(1, inp0.dims, inp0.size);
+                        inp1.dims = inp0.dims;
+                    }
+                }
+
+                Mat out;
+                if (inp0.total() != inp1.total())
+                {
+                    if (inp0.total() == 1)
+                    {
+                        float coeff = isDiv ? 1.0 / inp0.at<float>(0) : inp0.at<float>(0);
+                        multiply(inp1, coeff, out);
+                    }
+                    else
+                    {
+                        float coeff = isDiv ? 1.0 / inp1.at<float>(0) : inp1.at<float>(0);
+                        multiply(inp0, coeff, out);
+                    }
+
+                }
+                else
+                {
+                    out = isDiv ? inp0 / inp1 : inp0.mul(inp1);
+                }
+
+                if (inp0.dims == 1 && inp1.dims == 1)
+                    out.dims = 1;  // to workaround dims == 1
+                addConstant(layerParams.name, out);
+                return;
+            }
             else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)])
             {
                 layerParams.type = "Eltwise";
@@ -1201,20 +1248,6 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
                 }
                 layerParams.type = "Scale";
             }
-
-            if (!haveVariables)
-            {
-                Mat inp0 = getBlob(node_proto, 0);
-                Mat inp1 = getBlob(node_proto, 1);
-                if (inp0.size != inp1.size && inp1.total() != 1)
-                    CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes");
-
-                Mat out = isDiv ? inp0 / inp1 : inp0.mul(inp1);
-                out = out.reshape(1, inp0.dims, inp0.size);
-                out.dims = inp0.dims;  // to workaround dims == 1
-                addConstant(layerParams.name, out);
-                return;
-            }
         }
         else if (layer_type == "Conv")
         {
@@ -1733,9 +1766,26 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
             if (!hasVariableInps)
             {
                 std::vector<Mat> inputs(node_proto.input_size()), concatenated;
+                // Due constant folding we can get inputs with different number of dimensions
+                // Insert the missing dimension to inputs
+                MatShape inputShape;
                 for (size_t i = 0; i < inputs.size(); ++i)
                 {
                     inputs[i] = getBlob(node_proto, i);
+                    if (inputs[i].size.dims() > inputShape.size())
+                    {
+                        inputShape = shape(inputs[i]);
+                    }
+                }
+
+                // Concat-1 has default value for axis is 1: https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Concat-1
+                int axis = layerParams.get<int>("axis", 1);
+                for (size_t i = 0; i < inputs.size(); ++i)
+                {
+                    MatShape targetShape = inputShape;
+                    targetShape[axis] = shape(inputs[i])[axis];
+                    CV_CheckEQ(total(targetShape), total(shape(inputs[i])), "");
+                    inputs[i] = inputs[i].reshape(0, targetShape);
                 }
                 runLayer(layerParams, inputs, concatenated);