a23fce30f58df40b7fcfa49d1acfacda057bf880
[platform/upstream/opencv.git] / modules / dnn / src / graph_simplifier.cpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 // Copyright (C) 2020, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7
8 #include "precomp.hpp"
9
10 #include "graph_simplifier.hpp"
11
12 #include <queue>
13
14 namespace cv { namespace dnn {
15
16 Subgraph::~Subgraph() {}
17
18 int Subgraph::addNodeToMatch(const std::string& op, int input_0, int input_1,
19                              int input_2, int input_3)
20 {
21     int nodeInputs[] = {input_0, input_1, input_2, input_3};
22     int numInputs = 0;
23     for (int i = 0; i < 4; ++i)
24     {
25         numInputs += (int)(nodeInputs[i] != -1);
26     }
27     return addNodeToMatch(op, std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs));
28 }
29
30 int Subgraph::addNodeToMatch(const std::string& op, const std::vector<int>& inputs_)
31 {
32     for (int i = 0; i < inputs_.size(); ++i)
33     {
34         CV_Assert(inputs_[i] < (int)nodes.size());
35     }
36     nodes.push_back(op);
37     inputs.push_back(inputs_);
38     return nodes.size() - 1;
39 }
40
41 void Subgraph::setFusedNode(const std::string& op, int input_0, int input_1,
42                             int input_2, int input_3, int input_4, int input_5)
43 {
44     int nodeInputs[] = {input_0, input_1, input_2, input_3, input_4, input_5};
45     int numInputs = 0;
46     for (int i = 0; i < 6; ++i)
47     {
48         CV_Assert(nodeInputs[i] < (int)nodes.size());
49         numInputs += (int)(nodeInputs[i] != -1);
50     }
51     setFusedNode(op, std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs));
52 }
53
54 void Subgraph::setFusedNode(const std::string& op, const std::vector<int>& inputs_)
55 {
56     fusedNodeInputs = inputs_;
57     fusedNodeOp = op;
58 }
59
60 int Subgraph::getInputNodeId(const Ptr<ImportGraphWrapper>& net,
61                              const Ptr<ImportNodeWrapper>& node,
62                              int inpId)
63 {
64     CV_Assert(inpId < node->getNumInputs());
65     std::string name = node->getInputName(inpId);
66     const int numNodes = net->getNumNodes();
67     for (int i = 0; i < numNodes; ++i)
68     {
69         const int numOutputs = net->getNumOutputs(i);
70         for (int j = 0; j < numOutputs; j++)
71         {
72             if (net->getOutputName(i, j) == name)
73                 return i;
74         }
75     }
76     CV_Error(Error::StsParseError, "Input node with name " + name + " not found");
77 }
78
79 bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
80                      std::vector<int>& matchedNodesIds,
81                      std::vector<int>& targetNodesIds)
82 {
83     matchedNodesIds.clear();
84     targetNodesIds.clear();
85
86     std::queue<int> nodesToMatch;
87     std::queue<int> targetNodes;
88     nodesToMatch.push(nodeId);
89     targetNodes.push(nodes.size() - 1);
90     while (!nodesToMatch.empty())
91     {
92         int nodeToMatch = nodesToMatch.front();
93         int targetNodeId = targetNodes.front();
94         nodesToMatch.pop();
95         targetNodes.pop();
96
97         if (std::find(matchedNodesIds.begin(), matchedNodesIds.end(), nodeToMatch) !=
98             matchedNodesIds.end())
99             continue;
100
101         const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
102         if (node->getType() != nodes[targetNodeId])
103             return false;
104
105         std::vector<int>& inputNodes = inputs[targetNodeId];
106         if (inputNodes.size() != node->getNumInputs())
107             return false;
108
109         for (int j = 0; j < inputNodes.size(); ++j)
110         {
111             if (nodes[inputNodes[j]].empty())  // Unknown input node type.
112                 continue;
113             nodeId = getInputNodeId(net, node, j);
114             const Ptr<ImportNodeWrapper> inpNode = net->getNode(nodeId);
115             if (inpNode->getType() != "Const" && inpNode->getType() != "Constant")
116             {
117                 nodesToMatch.push(nodeId);
118                 targetNodes.push(inputNodes[j]);
119             }
120             else if (nodes[inputNodes[j]] != "Const" && nodes[inputNodes[j]] != "Constant")
121                 return false;
122         }
123         matchedNodesIds.push_back(nodeToMatch);
124         targetNodesIds.push_back(targetNodeId);
125     }
126
127     const int n = matchedNodesIds.size();
128     std::vector<std::pair<int, int> > elements(n);
129     for (int i = 0; i < n; ++i)
130         elements[i] = std::make_pair(matchedNodesIds[i], targetNodesIds[i]);
131     std::sort(elements.begin(), elements.end());
132     for (int i = 0; i < n; ++i)
133     {
134         matchedNodesIds[i] = elements[i].first;
135         targetNodesIds[i] = elements[i].second;
136     }
137     return true;
138 }
139
140 void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds,
141                        const std::vector<int>& targetNodesIds)
142 {
143     // Extract names of input nodes.
144     std::vector<std::string> inputsNames(fusedNodeInputs.size());
145     for (int i = 0; i < fusedNodeInputs.size(); ++i)
146     {
147         std::string inpName;
148         // Find input node name looking at inputs of fused nodes.
149         for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j)
150         {
151             Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds[j]);
152             std::vector<int>& inpIndices = inputs[targetNodesIds[j]];
153
154             CV_Assert(node->getNumInputs() == inpIndices.size());
155             for (int k = 0; k < inpIndices.size(); ++k)
156             {
157                 if (inpIndices[k] == fusedNodeInputs[i])
158                 {
159                     inpName = node->getInputName(k);
160                     break;
161                 }
162             }
163         }
164         CV_Assert(!inpName.empty());
165         inputsNames[i] = inpName;
166     }
167
168     // Remove matched nodes except the last one. Indices in ascending order are expected.
169     Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds.back());
170     for (int i = matchedNodesIds.size() - 2; i >= 0; --i)
171         net->removeNode(matchedNodesIds[i]);
172
173     // Modify the last node to be a fused one.
174     node->setType(fusedNodeOp);
175     node->setInputNames(inputsNames);
176
177     std::vector<Ptr<ImportNodeWrapper> > inputNodes(inputsNames.size());
178     for (int i = 0; i < inputsNames.size(); ++i)
179     {
180         inputNodes[i] = net->getNode(getInputNodeId(net, node, i));
181     }
182     finalize(net, node, inputNodes);
183 }
184
185 void Subgraph::finalize(const Ptr<ImportGraphWrapper>& net,
186                         const Ptr<ImportNodeWrapper>& fusedNode,
187                         std::vector<Ptr<ImportNodeWrapper> >& inputs) {}
188
189 void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
190                        const std::vector<Ptr<Subgraph> >& patterns)
191 {
192     int numNodes = net->getNumNodes();
193     std::vector<int> matchedNodesIds, targetNodesIds;
194     for (int j = 0; j < patterns.size(); ++j)
195     {
196         for (int i = 0; i < numNodes; ++i)
197         {
198             if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds))
199             {
200                 patterns[j]->replace(net, matchedNodesIds, targetNodesIds);
201                 numNodes -= matchedNodesIds.size() - 1;  // #matchedNodes removed and one added.
202             }
203         }
204     }
205 }
206
207 }}  // namespace cv::dnn