add new (Log)SoftMax simplification passes
authorSmirnov Egor <s.e.a.98@yandex.ru>
Tue, 30 Nov 2021 12:20:52 +0000 (15:20 +0300)
committerSmirnov Egor <s.e.a.98@yandex.ru>
Tue, 30 Nov 2021 12:20:52 +0000 (15:20 +0300)
modules/dnn/src/onnx/onnx_graph_simplifier.cpp

index 76937e0..e4cf73f 100644 (file)
@@ -107,17 +107,10 @@ private:
     opencv_onnx::GraphProto& net;
 };
 
-class SoftMaxSubgraph : public Subgraph
+class SoftMaxSubgraphBase : public Subgraph
 {
 public:
-    SoftMaxSubgraph() : axis(1)
-    {
-        int input = addNodeToMatch("");
-        int inpExp = addNodeToMatch("Exp", input);
-        int sum = addNodeToMatch("ReduceSum", inpExp);
-        addNodeToMatch("Div", inpExp, sum);
-        setFusedNode("Softmax", input);
-    }
+    SoftMaxSubgraphBase() : axis(1), id(-1) {}
 
     virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
                        std::vector<int>& matchedNodesIds,
@@ -125,7 +118,8 @@ public:
     {
         if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
         {
-            Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[1]);
+            CV_Assert(id >= 0 && id < matchedNodesIds.size());
+            Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]);
             opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node;
 
             for (int i = 0; i < node->attribute_size(); i++)
@@ -153,8 +147,60 @@ public:
         attr->set_i(axis);
     }
 
-private:
+protected:
     int axis;
+    int id;
+};
+
+class SoftMaxSubgraph : public SoftMaxSubgraphBase
+{
+public:
+    SoftMaxSubgraph()
+    {
+        int input = addNodeToMatch("");
+        int inpExp = addNodeToMatch("Exp", input);
+
+        int sum = addNodeToMatch("ReduceSum", inpExp);
+        id = 1;
+
+        addNodeToMatch("Div", inpExp, sum);
+        setFusedNode("Softmax", input);
+    }
+};
+
+class SoftMaxSubgraph2 : public SoftMaxSubgraphBase {
+public:
+    SoftMaxSubgraph2() {
+        int input = addNodeToMatch("");
+
+        int reducemax = addNodeToMatch("ReduceMax", input);
+        id = 0;
+
+        int sub = addNodeToMatch("Sub", input, reducemax);
+        int exp = addNodeToMatch("Exp", sub);
+        int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
+        addNodeToMatch("Div", exp, reducesum);
+        setFusedNode("Softmax", input);
+    }
+};
+
+class LogSoftMaxSubgraph : public SoftMaxSubgraphBase
+{
+public:
+    LogSoftMaxSubgraph()
+    {
+        int input = addNodeToMatch("");
+
+        int reducemax = addNodeToMatch("ReduceMax", input);
+        id = 0;
+
+        int sub_1 = addNodeToMatch("Sub", input, reducemax);
+        int exp = addNodeToMatch("Exp", sub_1);
+        int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
+        int log = addNodeToMatch("Log", reducesum);
+        addNodeToMatch("Sub", sub_1, log);
+        setFusedNode("LogSoftmax", input);
+    }
 };
 
 class NormalizeSubgraphBase : public Subgraph
@@ -574,6 +620,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
     subgraphs.push_back(makePtr<ResizeSubgraph1>());
     subgraphs.push_back(makePtr<ResizeSubgraph2>());
     subgraphs.push_back(makePtr<SoftMaxSubgraph>());
+    subgraphs.push_back(makePtr<SoftMaxSubgraph2>());
+    subgraphs.push_back(makePtr<LogSoftMaxSubgraph>());
     subgraphs.push_back(makePtr<NormalizeSubgraph1>());
     subgraphs.push_back(makePtr<NormalizeSubgraph2>());
     subgraphs.push_back(makePtr<NormalizeSubgraph2_2>());