replace new mish impl with softplus
authorZihao Mu <zihaomu@outlook.com>
Thu, 28 Jul 2022 05:19:06 +0000 (13:19 +0800)
committerZihao Mu <zihaomu@outlook.com>
Thu, 28 Jul 2022 05:19:06 +0000 (13:19 +0800)
modules/dnn/src/onnx/onnx_graph_simplifier.cpp

index 5aad1c1..091d2d4 100644 (file)
@@ -531,35 +531,32 @@ public:
     }
 };
 
-class MishSubgraph2 : public Subgraph
+// softplus(x) = log(exp(x) + 1)
+class SoftplusSubgraph: public Subgraph
 {
 public:
-    MishSubgraph2()
+    SoftplusSubgraph()
     {
         int input = addNodeToMatch("");
         int exp = addNodeToMatch("Exp", input);
         int addVal = addNodeToMatch("");
         int add = addNodeToMatch("Add", addVal, exp);
-        int log = addNodeToMatch("Log", add);
-        int tanh = addNodeToMatch("Tanh", log);
-        addNodeToMatch("Mul", input, tanh);
-        setFusedNode("Mish", input);
+        addNodeToMatch("Log", add);
+        setFusedNode("Softplus", input);
     }
 };
 
-class MishSubgraph3 : public Subgraph
+class SoftplusSubgraph2: public Subgraph
 {
 public:
-    MishSubgraph3()
+    SoftplusSubgraph2()
     {
         int input = addNodeToMatch("");
         int exp = addNodeToMatch("Exp", input);
         int addVal = addNodeToMatch("");
         int add = addNodeToMatch("Add", exp, addVal);
-        int log = addNodeToMatch("Log", add);
-        int tanh = addNodeToMatch("Tanh", log);
-        addNodeToMatch("Mul", input, tanh);
-        setFusedNode("Mish", input);
+        addNodeToMatch("Log", add);
+        setFusedNode("Softplus", input);
     }
 };
 
@@ -766,9 +763,9 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
     subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
     subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
     subgraphs.push_back(makePtr<ExpandSubgraph>());
+    subgraphs.push_back(makePtr<SoftplusSubgraph>());
+    subgraphs.push_back(makePtr<SoftplusSubgraph2>());
     subgraphs.push_back(makePtr<MishSubgraph>());
-    subgraphs.push_back(makePtr<MishSubgraph2>());
-    subgraphs.push_back(makePtr<MishSubgraph3>());
     subgraphs.push_back(makePtr<NormalizeSubgraph4>());
     subgraphs.push_back(makePtr<NormalizeSubgraph5>());