Merge pull request #18296 from sl-sergei:fix_16783
authorSergei Slashchinin <62052793+sl-sergei@users.noreply.github.com>
Tue, 17 Nov 2020 09:52:08 +0000 (12:52 +0300)
committerGitHub <noreply@github.com>
Tue, 17 Nov 2020 09:52:08 +0000 (09:52 +0000)
Fix loading issue for Faster RCNN model from #16783

* Add a reproducer with multi-output Gather

* Fix an issue with ONNX graph simplifier

* fix build

* Move checks to correct class

* Minor changes for better code appearence

modules/dnn/src/onnx/onnx_graph_simplifier.cpp
modules/dnn/test/test_onnx_importer.cpp

index e8b237c..30c0b26 100644 (file)
@@ -260,6 +260,40 @@ public:
         addNodeToMatch("Cast", gather);
         setFusedNode("Gather", input, index);
     }
+
+    virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
+                       std::vector<int>& matchedNodesIds,
+                       std::vector<int>& targetNodesIds) CV_OVERRIDE
+    {
+        bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
+        size_t matchedNodesNum = matchedNodesIds.size();
+        // Now we check if merging can be made for these Gather and Cast nodes
+        if (!retVal || matchedNodesNum < 2)
+            return retVal;
+        else {
+            int nodeToMatch = matchedNodesIds[matchedNodesNum - 1];
+            const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
+            if (node->getType() == "Cast") {
+                int inpNodeId = matchedNodesIds[matchedNodesNum - 2];
+                const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId);
+                if (inpNode->getType() == "Gather") {
+                    int numNodes = net->getNumNodes();
+                    std::string inpNodeName = node->getInputName(0);
+                    for (int i = 0; i < numNodes; ++i) {
+                        const Ptr<ImportNodeWrapper> node_to_check = net->getNode(i);
+                        int numInp = node_to_check->getNumInputs();
+                        for (int inp = 0; inp < numInp; ++inp) {
+                            if (i != nodeToMatch && inpNodeName == node_to_check->getInputName(0)) {
+                                // Another node has the same input node, so it cannot be merged.
+                                return false;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+        return retVal;
+    }
 };
 
 class ExpandSubgraph : public Subgraph
index 5c6de55..14d2d28 100644 (file)
@@ -705,6 +705,11 @@ TEST_P(Test_ONNX_layers, Conv1d_variable_weight_bias)
     normAssert(ref, out, "", default_l1, default_lInf);
 }
 
+TEST_P(Test_ONNX_layers, GatherMultiOutput)
+{
+    testONNXModels("gather_multi_output");
+}
+
 INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
 
 class Test_ONNX_nets : public Test_ONNX_layers