Merge remote-tracking branch 'upstream/3.4' into merge-3.4
[platform/upstream/opencv.git] / modules / dnn / src / onnx / onnx_importer.cpp
index 3d7e33a..714d4bf 100644 (file)
@@ -30,7 +30,7 @@
 
 namespace cv {
 namespace dnn {
-CV__DNN_EXPERIMENTAL_NS_BEGIN
+CV__DNN_INLINE_NS_BEGIN
 
 
 class ONNXImporter
@@ -427,24 +427,57 @@ void ONNXImporter::populateNet(Net dstNet)
             }
             layerParams.type = "Slice";
         }
-        else if (layer_type == "Add" || layer_type == "Sum")
+        else if (layer_type == "Add" || layer_type == "Sum" || layer_type == "Sub")
         {
+            bool isSub = layer_type == "Sub";
+            CV_CheckEQ(node_proto.input_size(), 2, "");
             if (layer_id.find(node_proto.input(1)) == layer_id.end())
             {
                 Mat blob = getBlob(node_proto, constBlobs, 1);
                 blob = blob.reshape(1, 1);
                 if (blob.total() == 1) {
                     layerParams.type = "Power";
-                    layerParams.set("shift", blob.at<float>(0));
+                    layerParams.set("shift", (isSub ? -1 : 1) * blob.at<float>(0));
                 }
                 else {
                     layerParams.type = "Scale";
                     layerParams.set("bias_term", true);
-                    layerParams.blobs.push_back(blob);
+                    layerParams.blobs.push_back((isSub ? -1 : 1) * blob);
                 }
             }
-            else {
+            else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)])
+            {
                 layerParams.type = "Eltwise";
+                if (isSub)
+                {
+                    static float subCoeffs[] = {1.f, -1.f};
+                    layerParams.set("coeff", DictValue::arrayReal<float*>(subCoeffs, 2));
+                }
+            }
+            else
+            {
+                if (isSub)
+                {
+                    LayerParams powerParams;
+                    powerParams.name = layerParams.name + "/neg";
+                    powerParams.type = "Power";
+                    powerParams.set("scale", -1);
+
+                    //Create Power layer
+                    int id = dstNet.addLayer(powerParams.name, powerParams.type, powerParams);
+                    //Connect to input
+                    layerId = layer_id.find(node_proto.input(1));
+                    CV_Assert(layerId != layer_id.end());
+                    dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
+                    //Add shape
+                    layer_id.insert(std::make_pair(powerParams.name, LayerInfo(id, 0)));
+                    outShapes[powerParams.name] = outShapes[node_proto.input(1)];
+
+                    //Replace input to Power
+                    node_proto.set_input(1, powerParams.name);
+                }
+                layerParams.type = "Scale";
+                layerParams.set("bias_term", true);
             }
         }
         else if (layer_type == "Max")
@@ -452,19 +485,6 @@ void ONNXImporter::populateNet(Net dstNet)
             layerParams.type = "Eltwise";
             layerParams.set("operation", "max");
         }
-        else if (layer_type == "Sub")
-        {
-            Mat blob = getBlob(node_proto, constBlobs, 1);
-            if (blob.total() == 1) {
-                layerParams.type = "Power";
-                layerParams.set("shift", -blob.at<float>(0));
-            }
-            else {
-                layerParams.type = "Scale";
-                layerParams.set("has_bias", true);
-                layerParams.blobs.push_back(-1.0f * blob.reshape(1, 1));
-            }
-        }
         else if (layer_type == "Neg")
         {
             layerParams.type = "Power";
@@ -643,10 +663,35 @@ void ONNXImporter::populateNet(Net dstNet)
                     layerParams.type = "Scale";
                 }
             }
-            else {
+            else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)])
+            {
                 layerParams.type = "Eltwise";
                 layerParams.set("operation", isDiv ? "div" : "prod");
             }
+            else
+            {
+                if (isDiv)
+                {
+                    LayerParams powerParams;
+                    powerParams.name = layerParams.name + "/inv";
+                    powerParams.type = "Power";
+                    powerParams.set("power", -1);
+
+                    //Create Power layer
+                    int id = dstNet.addLayer(powerParams.name, powerParams.type, powerParams);
+                    //Connect to input
+                    layerId = layer_id.find(node_proto.input(1));
+                    CV_Assert(layerId != layer_id.end());
+                    dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
+                    //Add shape
+                    layer_id.insert(std::make_pair(powerParams.name, LayerInfo(id, 0)));
+                    outShapes[powerParams.name] = outShapes[node_proto.input(1)];
+
+                    //Replace input to Power
+                    node_proto.set_input(1, powerParams.name);
+                }
+                layerParams.type = "Scale";
+            }
 
             if (!haveVariables)
             {
@@ -996,7 +1041,7 @@ Mat readTensorFromONNX(const String& path)
     return mat;
 }
 
-CV__DNN_EXPERIMENTAL_NS_END
+CV__DNN_INLINE_NS_END
 }} // namespace
 
 #endif