Merge remote-tracking branch 'upstream/3.4' into merge-3.4
[platform/upstream/opencv.git] / modules / dnn / src / onnx / onnx_graph_simplifier.cpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 // Copyright (C) 2020, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7
8 #include "../precomp.hpp"
9
10 #include "../graph_simplifier.hpp"
11 #include "onnx_graph_simplifier.hpp"
12
13 #include <queue>
14
15 namespace cv { namespace dnn {
16 CV__DNN_INLINE_NS_BEGIN
17
18 // This wrapper can behave differently for fake input nodes and real graph nodes.
19 class ONNXNodeWrapper : public ImportNodeWrapper
20 {
21 public:
22     ONNXNodeWrapper(opencv_onnx::NodeProto* _node = 0) : node(_node) {}
23
24     virtual int getNumInputs() const CV_OVERRIDE
25     {
26         return node ? node->input_size() : 0;
27     }
28
29     virtual std::string getInputName(int idx) const CV_OVERRIDE
30     {
31         CV_Assert_N(node, idx < node->input_size());
32         return node->input(idx);
33     }
34
35     virtual std::string getType() const CV_OVERRIDE
36     {
37         return node ? node->op_type() : "";
38     }
39
40     virtual void setType(const std::string& type) CV_OVERRIDE
41     {
42         CV_Assert(node);
43         node->set_op_type(type);
44     }
45
46     virtual void setInputNames(const std::vector<std::string>& inputs) CV_OVERRIDE
47     {
48         CV_Assert(node);
49         node->clear_input();
50         for (int i = 0; i < inputs.size(); ++i)
51             node->add_input(inputs[i]);
52     }
53
54     opencv_onnx::NodeProto* node;
55 };
56
57 // ONNX graph's inputs are separate from nodes so we index them before the rest of nodes.
58 class ONNXGraphWrapper : public ImportGraphWrapper
59 {
60 public:
61     ONNXGraphWrapper(opencv_onnx::GraphProto& _net) : net(_net)
62     {
63         numInputs = net.input_size();
64         numInitializers = net.initializer_size();
65     }
66
67     virtual Ptr<ImportNodeWrapper> getNode(int idx) const CV_OVERRIDE
68     {
69         opencv_onnx::NodeProto* node = 0;
70         if (idx >= numInputs + numInitializers)
71             node = net.mutable_node(idx - numInputs - numInitializers);
72         return makePtr<ONNXNodeWrapper>(node);
73     }
74
75     virtual int getNumNodes() const CV_OVERRIDE
76     {
77         return numInputs + numInitializers + net.node_size();
78     }
79
80     virtual int getNumOutputs(int nodeId) const CV_OVERRIDE
81     {
82         if (nodeId < numInputs + numInitializers)
83             return 1;
84         else
85             return net.node(nodeId - numInputs - numInitializers).output_size();
86     }
87
88     virtual std::string getOutputName(int nodeId, int outId) const CV_OVERRIDE
89     {
90         CV_Assert(outId < getNumOutputs(nodeId));
91         if (nodeId < numInputs)
92             return net.input(nodeId).name();
93         else if (nodeId < numInputs + numInitializers)
94             return net.initializer(nodeId - numInputs).name();
95         else
96             return net.node(nodeId - numInputs - numInitializers).output(outId);
97     }
98
99     virtual void removeNode(int idx) CV_OVERRIDE
100     {
101         CV_Assert(idx >= numInputs + numInitializers);
102         net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1);
103     }
104
105 private:
106     int numInputs, numInitializers;
107     opencv_onnx::GraphProto& net;
108 };
109
110 class SoftMaxSubgraph : public Subgraph
111 {
112 public:
113     SoftMaxSubgraph() : axis(1)
114     {
115         int input = addNodeToMatch("");
116         int inpExp = addNodeToMatch("Exp", input);
117         int sum = addNodeToMatch("ReduceSum", inpExp);
118         addNodeToMatch("Div", inpExp, sum);
119         setFusedNode("Softmax", input);
120     }
121
122     virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
123                        std::vector<int>& matchedNodesIds,
124                        std::vector<int>& targetNodesIds) CV_OVERRIDE
125     {
126         if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
127         {
128             Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[1]);
129             opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node;
130
131             for (int i = 0; i < node->attribute_size(); i++)
132             {
133                 opencv_onnx::AttributeProto attr = node->attribute(i);
134                 if (attr.name() != "axes")
135                     continue;
136                 if (attr.ints_size() != 1)
137                     CV_Error(Error::StsNotImplemented, format("Unexpected number of axes: %d", attr.ints_size()));
138                 axis = attr.ints(0);
139                 return true;
140             }
141             CV_Error(Error::StsNotImplemented, "Missed axes attribute");
142         }
143         return false;
144     }
145
146     virtual void finalize(const Ptr<ImportGraphWrapper>&,
147                           const Ptr<ImportNodeWrapper>& fusedNode,
148                           std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE
149     {
150         opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
151         opencv_onnx::AttributeProto* attr = node->add_attribute();
152         attr->set_name("axis");
153         attr->set_i(axis);
154     }
155
156 private:
157     int axis;
158 };
159
160 class NormalizeSubgraphBase : public Subgraph
161 {
162 public:
163     NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {}
164
165     virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
166                        std::vector<int>& matchedNodesIds,
167                        std::vector<int>& targetNodesIds) CV_OVERRIDE
168     {
169         if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
170         {
171             Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]);
172             opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node;
173
174             for (int i = 0; i < node->attribute_size(); i++)
175             {
176                 opencv_onnx::AttributeProto attr = node->attribute(i);
177                 if (attr.name() != "axes")
178                     continue;
179                 if (attr.ints_size() != 1)
180                     CV_Error(Error::StsNotImplemented, format("Unexpected number of axes: %d", attr.ints_size()));
181                 axis = attr.ints(0);
182                 return true;
183             }
184             CV_Error(Error::StsNotImplemented, "Missed axes attribute");
185         }
186         return false;
187     }
188
189     virtual void finalize(const Ptr<ImportGraphWrapper>&,
190                           const Ptr<ImportNodeWrapper>& fusedNode,
191                           std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE
192     {
193         opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
194         opencv_onnx::AttributeProto* axis_attr = node->add_attribute();
195         axis_attr->set_name("axis");
196         axis_attr->set_i(axis);
197
198         opencv_onnx::AttributeProto* end_axis_attr = node->add_attribute();
199         end_axis_attr->set_name("end_axis");
200         end_axis_attr->set_i(axis);
201     }
202
203 protected:
204     int axis, normNodeOrder;
205 };
206
207 class NormalizeSubgraph1 : public NormalizeSubgraphBase
208 {
209 public:
210     NormalizeSubgraph1()
211     {
212         int input = addNodeToMatch("");
213         int norm = addNodeToMatch("ReduceL2", input);
214         addNodeToMatch("Div", input, norm);
215         setFusedNode("Normalize", input);
216     }
217 };
218
219 class NormalizeSubgraph2 : public NormalizeSubgraphBase
220 {
221 public:
222     NormalizeSubgraph2()
223     {
224         int input = addNodeToMatch("");
225         int norm = addNodeToMatch("ReduceL2", input);
226         int clip = addNodeToMatch("Clip", norm);
227         int shape = addNodeToMatch("Shape", input);
228         int expand = addNodeToMatch("Expand", clip, shape);
229         addNodeToMatch("Div", input, expand);
230         setFusedNode("Normalize", input);
231     }
232 };
233
234 class NormalizeSubgraph3 : public NormalizeSubgraphBase
235 {
236 public:
237     NormalizeSubgraph3() : NormalizeSubgraphBase(1)
238     {
239         int input = addNodeToMatch("");
240         int power = addNodeToMatch("Constant");
241         int squared = addNodeToMatch("Pow", input, power);
242         int sum = addNodeToMatch("ReduceSum", squared);
243         int sqrtNode = addNodeToMatch("Sqrt", sum);
244         int eps = addNodeToMatch("Constant");
245         int add = addNodeToMatch("Add", sqrtNode, eps);
246
247         addNodeToMatch("Div", input, add);
248         setFusedNode("Normalize", input);
249     }
250 };
251
252 class GatherCastSubgraph : public Subgraph
253 {
254 public:
255     GatherCastSubgraph()
256     {
257         int input = addNodeToMatch("");
258         int index = addNodeToMatch("Constant");
259         int gather = addNodeToMatch("Gather", input, index);
260         addNodeToMatch("Cast", gather);
261         setFusedNode("Gather", input, index);
262     }
263 };
264
265 class MulCastSubgraph : public Subgraph
266 {
267 public:
268     MulCastSubgraph()
269     {
270         int input = addNodeToMatch("");
271         int scaleNode = addNodeToMatch("Constant");
272         int mul = addNodeToMatch("Mul", input, scaleNode);
273         addNodeToMatch("Cast", mul);
274         setFusedNode("Mul", input, scaleNode);
275     }
276 };
277
278 class ExtractScalesSubgraph : public Subgraph
279 {
280 public:
281     ExtractScalesSubgraph()
282     {
283         input = addNodeToMatch("");
284
285         int indexH = addNodeToMatch("Constant");
286         int shape1 = addNodeToMatch("Shape", input);
287         int gather1 = addNodeToMatch("Gather", shape1, indexH);
288         scaleHNode = addNodeToMatch("Constant");
289         int mul1 = addNodeToMatch("Mul", gather1, scaleHNode);
290         int floor1 = addNodeToMatch("Floor", mul1);
291
292         int indexW = addNodeToMatch("Constant");
293         int shape2 = addNodeToMatch("Shape", input);
294         int gather2 = addNodeToMatch("Gather", shape2, indexW);
295         scaleWNode = addNodeToMatch("Constant");
296         int mul2 = addNodeToMatch("Mul", gather2, scaleWNode);
297         int floor2 = addNodeToMatch("Floor", mul2);
298
299         int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1);
300         int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2);
301         concatId = addNodeToMatch("Concat", unsqueeze1, unsqueeze2);
302     }
303
304     void finalize(const Ptr<ImportGraphWrapper>& net,
305                   const Ptr<ImportNodeWrapper>& fusedNode,
306                   std::vector<Ptr<ImportNodeWrapper> >& inputs) CV_OVERRIDE
307     {
308         opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast<ONNXNodeWrapper>()->node;
309         opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t();
310         Mat scaleW = getMatFromTensor(tensor_proto);
311         CV_Assert(scaleW.total() == 1);
312         scaleW.convertTo(scaleW, CV_32F);
313
314         constant_node = inputs[2].dynamicCast<ONNXNodeWrapper>()->node;
315         tensor_proto = constant_node->attribute(0).t();
316         Mat scaleH = getMatFromTensor(tensor_proto);
317         CV_Assert(scaleH.total() == 1);
318         scaleH.convertTo(scaleH, CV_32F);
319
320         opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
321         opencv_onnx::AttributeProto* attrH = node->add_attribute();
322         attrH->set_name("height_scale");
323         attrH->set_i(scaleH.at<float>(0));
324         opencv_onnx::AttributeProto* attrW = node->add_attribute();
325         attrW->set_name("width_scale");
326         attrW->set_i(scaleW.at<float>(0));
327
328         node->mutable_input()->DeleteSubrange(1, 2);  // Remove two last inputs
329     }
330
331 protected:
332     int input, concatId;
333     int scaleHNode, scaleWNode;
334 };
335
336 class UpsampleSubgraph : public ExtractScalesSubgraph
337 {
338 public:
339     UpsampleSubgraph() : ExtractScalesSubgraph()
340     {
341         int shape = addNodeToMatch("Shape", input);
342         int slice = addNodeToMatch("Slice", shape);
343
344         int castConcat = addNodeToMatch("Cast", concatId);
345         int castSlice = addNodeToMatch("Cast", slice);
346         int divide = addNodeToMatch("Div", castConcat, castSlice);
347
348         int constant = addNodeToMatch("Constant");
349         int concat = addNodeToMatch("Concat", constant, divide);
350
351         addNodeToMatch("Upsample", input, concat);
352         setFusedNode("Upsample", input, scaleWNode, scaleHNode);
353     }
354 };
355
356 class ResizeSubgraph1 : public ExtractScalesSubgraph
357 {
358 public:
359     ResizeSubgraph1() : ExtractScalesSubgraph()
360     {
361         int shape = addNodeToMatch("Shape", input);
362         int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
363
364         int castConcat = addNodeToMatch("Cast", concatId);
365         int concat = addNodeToMatch("Concat", slice, castConcat);
366         int constant = addNodeToMatch("Constant");
367
368         addNodeToMatch("Resize", input, constant, constant, concat);
369         setFusedNode("Upsample", input, scaleWNode, scaleHNode);
370     }
371 };
372
373 class ResizeSubgraph2 : public ExtractScalesSubgraph
374 {
375 public:
376     ResizeSubgraph2() : ExtractScalesSubgraph()
377     {
378         int constantConcat = addNodeToMatch("Constant");
379         int castConcat = addNodeToMatch("Cast", concatId);
380         int concat = addNodeToMatch("Concat", constantConcat, castConcat);
381         int constant = addNodeToMatch("Constant");
382
383         addNodeToMatch("Resize", input, constant, constant, concat);
384         setFusedNode("Upsample", input, scaleWNode, scaleHNode);
385     }
386 };
387
388 class BatchNormalizationSubgraphBase : public Subgraph
389 {
390 public:
391     BatchNormalizationSubgraphBase()
392     {
393         input  = addNodeToMatch("");
394         var    = addNodeToMatch("");
395         mean   = addNodeToMatch("");
396         weight = addNodeToMatch("");
397         bias   = addNodeToMatch("");
398         A      = addNodeToMatch("");
399         shape1 = addNodeToMatch("");
400         shape2 = addNodeToMatch("");
401     }
402 protected:
403     int input, var, mean, weight, bias, A, shape1, shape2;
404 };
405
406 class BatchNormalizationSubgraph1 : public BatchNormalizationSubgraphBase
407 {
408 public:
409     BatchNormalizationSubgraph1()
410     {
411         int reshape1 = addNodeToMatch("Reshape", weight, shape1);
412         int reshape2 = addNodeToMatch("Reshape", bias, shape2);
413         int shape3 = addNodeToMatch("Constant");
414         int reshape3 = addNodeToMatch("Reshape", var, shape3);
415         int shape4 = addNodeToMatch("Constant");
416         int reshape4 = addNodeToMatch("Reshape", mean, shape4);
417         int sqrtNode = addNodeToMatch("Sqrt", reshape3);
418         int divNode = addNodeToMatch("Div", A, sqrtNode);
419         int mul1 = addNodeToMatch("Mul", reshape1, divNode);
420         int mul2 = addNodeToMatch("Mul", reshape4, mul1);
421         int sub = addNodeToMatch("Sub", reshape2, mul2);
422         int mul3 = addNodeToMatch("Mul", input, mul1);
423         addNodeToMatch("Add", mul3, sub);
424         setFusedNode("BatchNormalization", input, weight, bias, mean, var);
425     }
426 };
427
428 class BatchNormalizationSubgraph2 : public BatchNormalizationSubgraphBase
429 {
430 public:
431     BatchNormalizationSubgraph2()
432     {
433         int sqrtNode = addNodeToMatch("Sqrt", var);
434         int divNode = addNodeToMatch("Div", A, sqrtNode);
435         int mul1 = addNodeToMatch("Mul", weight, divNode);
436         int reshape2 = addNodeToMatch("Reshape", mul1, shape2);
437
438         int mulMean = addNodeToMatch("Mul", mean, mul1);
439         int sub = addNodeToMatch("Sub", bias, mulMean);
440         int reshape1 = addNodeToMatch("Reshape", sub, shape1);
441
442         int mulInput = addNodeToMatch("Mul", input, reshape2);
443         addNodeToMatch("Add", mulInput, reshape1);
444         setFusedNode("BatchNormalization", input, weight, bias, mean, var);
445     }
446 };
447
448 void simplifySubgraphs(opencv_onnx::GraphProto& net)
449 {
450     std::vector<Ptr<Subgraph> > subgraphs;
451     subgraphs.push_back(makePtr<GatherCastSubgraph>());
452     subgraphs.push_back(makePtr<MulCastSubgraph>());
453     subgraphs.push_back(makePtr<UpsampleSubgraph>());
454     subgraphs.push_back(makePtr<ResizeSubgraph1>());
455     subgraphs.push_back(makePtr<ResizeSubgraph2>());
456     subgraphs.push_back(makePtr<SoftMaxSubgraph>());
457     subgraphs.push_back(makePtr<NormalizeSubgraph1>());
458     subgraphs.push_back(makePtr<NormalizeSubgraph2>());
459     subgraphs.push_back(makePtr<NormalizeSubgraph3>());
460     subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
461     subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
462
463     simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
464 }
465
466 Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
467 {
468     if (tensor_proto.raw_data().empty() && tensor_proto.float_data().empty() &&
469         tensor_proto.double_data().empty() && tensor_proto.int64_data().empty())
470         return Mat();
471
472     opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type();
473     Mat blob;
474     std::vector<int> sizes;
475     for (int i = 0; i < tensor_proto.dims_size(); i++) {
476             sizes.push_back(tensor_proto.dims(i));
477     }
478     if (sizes.empty())
479         sizes.assign(1, 1);
480     if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) {
481
482         if (!tensor_proto.float_data().empty()) {
483             const ::google::protobuf::RepeatedField<float> field = tensor_proto.float_data();
484             Mat(sizes, CV_32FC1, (void*)field.data()).copyTo(blob);
485         }
486         else {
487             char* val = const_cast<char*>(tensor_proto.raw_data().c_str());
488             Mat(sizes, CV_32FC1, val).copyTo(blob);
489         }
490     }
491     else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
492     {
493         const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data();
494         CV_Assert(!field.empty());
495         Mat(sizes, CV_64FC1, (void*)field.data()).convertTo(blob, CV_32FC1);
496     }
497     else if (datatype == opencv_onnx::TensorProto_DataType_INT64)
498     {
499         blob.create(sizes, CV_32SC1);
500         int32_t* dst = reinterpret_cast<int32_t*>(blob.data);
501
502         if (!tensor_proto.int64_data().empty()) {
503             ::google::protobuf::RepeatedField< ::google::protobuf::int64> src = tensor_proto.int64_data();
504             convertInt64ToInt32(src, dst, blob.total());
505         }
506         else
507         {
508             const char* val = tensor_proto.raw_data().c_str();
509 #if CV_STRONG_ALIGNMENT
510             // Aligned pointer is required: https://github.com/opencv/opencv/issues/16373
511             // this doesn't work: typedef int64_t CV_DECL_ALIGNED(1) unaligned_int64_t;
512             AutoBuffer<int64_t, 16> aligned_val;
513             if (!isAligned<sizeof(int64_t)>(val))
514             {
515                 size_t sz = tensor_proto.raw_data().size();
516                 aligned_val.allocate(divUp(sz, sizeof(int64_t)));
517                 memcpy(aligned_val.data(), val, sz);
518                 val = (const char*)aligned_val.data();
519             }
520 #endif
521             const int64_t* src = reinterpret_cast<const int64_t*>(val);
522             convertInt64ToInt32(src, dst, blob.total());
523         }
524     }
525     else
526         CV_Error(Error::StsUnsupportedFormat, "Unsupported data type: " +
527                         opencv_onnx::TensorProto_DataType_Name(datatype));
528     if (tensor_proto.dims_size() == 0)
529         blob.dims = 1;  // To force 1-dimensional cv::Mat for scalars.
530     return blob;
531 }
532
533 CV__DNN_INLINE_NS_END
534 }}  // namespace cv::dnn