add another Mish graph simplifier.
authorZihao Mu <zihaomu@outlook.com>
Thu, 28 Jul 2022 03:21:29 +0000 (11:21 +0800)
committerZihao Mu <zihaomu@outlook.com>
Thu, 28 Jul 2022 03:21:29 +0000 (11:21 +0800)
modules/dnn/src/onnx/onnx_graph_simplifier.cpp
modules/dnn/test/test_onnx_importer.cpp

index c6e54d6..5aad1c1 100644 (file)
@@ -531,6 +531,38 @@ public:
     }
 };
 
+class MishSubgraph2 : public Subgraph
+{
+public:
+    MishSubgraph2()
+    {
+        int input = addNodeToMatch("");
+        int exp = addNodeToMatch("Exp", input);
+        int addVal = addNodeToMatch("");
+        int add = addNodeToMatch("Add", addVal, exp);
+        int log = addNodeToMatch("Log", add);
+        int tanh = addNodeToMatch("Tanh", log);
+        addNodeToMatch("Mul", input, tanh);
+        setFusedNode("Mish", input);
+    }
+};
+
+class MishSubgraph3 : public Subgraph
+{
+public:
+    MishSubgraph3()
+    {
+        int input = addNodeToMatch("");
+        int exp = addNodeToMatch("Exp", input);
+        int addVal = addNodeToMatch("");
+        int add = addNodeToMatch("Add", exp, addVal);
+        int log = addNodeToMatch("Log", add);
+        int tanh = addNodeToMatch("Tanh", log);
+        addNodeToMatch("Mul", input, tanh);
+        setFusedNode("Mish", input);
+    }
+};
+
 class MulCastSubgraph : public Subgraph
 {
 public:
@@ -735,6 +767,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
     subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
     subgraphs.push_back(makePtr<ExpandSubgraph>());
     subgraphs.push_back(makePtr<MishSubgraph>());
+    subgraphs.push_back(makePtr<MishSubgraph2>());
+    subgraphs.push_back(makePtr<MishSubgraph3>());
     subgraphs.push_back(makePtr<NormalizeSubgraph4>());
     subgraphs.push_back(makePtr<NormalizeSubgraph5>());
 
index 578e044..39c635a 100644 (file)
@@ -1325,6 +1325,7 @@ TEST_P(Test_ONNX_layers, ResizeOpset11_Torch1_6)
 TEST_P(Test_ONNX_layers, Mish)
 {
     testONNXModels("mish");
+    testONNXModels("mish_no_softplus");
 }
 
 TEST_P(Test_ONNX_layers, CalculatePads)