1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 // Copyright (C) 2020, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
8 #include "../precomp.hpp"
10 #include "../graph_simplifier.hpp"
11 #include "onnx_graph_simplifier.hpp"
15 namespace cv { namespace dnn {
16 CV__DNN_INLINE_NS_BEGIN
18 // This wrapper can behave differently for fake input nodes and real graph nodes.
19 class ONNXNodeWrapper : public ImportNodeWrapper
22 ONNXNodeWrapper(opencv_onnx::NodeProto* _node = 0) : node(_node) {}
24 virtual int getNumInputs() const CV_OVERRIDE
26 return node ? node->input_size() : 0;
29 virtual std::string getInputName(int idx) const CV_OVERRIDE
31 CV_Assert_N(node, idx < node->input_size());
32 return node->input(idx);
35 virtual std::string getType() const CV_OVERRIDE
37 return node ? node->op_type() : "";
40 virtual void setType(const std::string& type) CV_OVERRIDE
43 node->set_op_type(type);
46 virtual void setInputNames(const std::vector<std::string>& inputs) CV_OVERRIDE
50 for (int i = 0; i < inputs.size(); ++i)
51 node->add_input(inputs[i]);
54 opencv_onnx::NodeProto* node;
57 // ONNX graph's inputs are separate from nodes so we index them before the rest of nodes.
58 class ONNXGraphWrapper : public ImportGraphWrapper
61 ONNXGraphWrapper(opencv_onnx::GraphProto& _net) : net(_net)
63 numInputs = net.input_size();
64 numInitializers = net.initializer_size();
67 virtual Ptr<ImportNodeWrapper> getNode(int idx) const CV_OVERRIDE
69 opencv_onnx::NodeProto* node = 0;
70 if (idx >= numInputs + numInitializers)
71 node = net.mutable_node(idx - numInputs - numInitializers);
72 return makePtr<ONNXNodeWrapper>(node);
75 virtual int getNumNodes() const CV_OVERRIDE
77 return numInputs + numInitializers + net.node_size();
80 virtual int getNumOutputs(int nodeId) const CV_OVERRIDE
82 if (nodeId < numInputs + numInitializers)
85 return net.node(nodeId - numInputs - numInitializers).output_size();
88 virtual std::string getOutputName(int nodeId, int outId) const CV_OVERRIDE
90 CV_Assert(outId < getNumOutputs(nodeId));
91 if (nodeId < numInputs)
92 return net.input(nodeId).name();
93 else if (nodeId < numInputs + numInitializers)
94 return net.initializer(nodeId - numInputs).name();
96 return net.node(nodeId - numInputs - numInitializers).output(outId);
99 virtual void removeNode(int idx) CV_OVERRIDE
101 CV_Assert(idx >= numInputs + numInitializers);
102 net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1);
106 int numInputs, numInitializers;
107 opencv_onnx::GraphProto& net;
110 class SoftMaxSubgraph : public Subgraph
113 SoftMaxSubgraph() : axis(1)
115 int input = addNodeToMatch("");
116 int inpExp = addNodeToMatch("Exp", input);
117 int sum = addNodeToMatch("ReduceSum", inpExp);
118 addNodeToMatch("Div", inpExp, sum);
119 setFusedNode("Softmax", input);
122 virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
123 std::vector<int>& matchedNodesIds,
124 std::vector<int>& targetNodesIds) CV_OVERRIDE
126 if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
128 Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[1]);
129 opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node;
131 for (int i = 0; i < node->attribute_size(); i++)
133 opencv_onnx::AttributeProto attr = node->attribute(i);
134 if (attr.name() != "axes")
136 if (attr.ints_size() != 1)
137 CV_Error(Error::StsNotImplemented, format("Unexpected number of axes: %d", attr.ints_size()));
141 CV_Error(Error::StsNotImplemented, "Missed axes attribute");
146 virtual void finalize(const Ptr<ImportGraphWrapper>&,
147 const Ptr<ImportNodeWrapper>& fusedNode,
148 std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE
150 opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
151 opencv_onnx::AttributeProto* attr = node->add_attribute();
152 attr->set_name("axis");
160 class NormalizeSubgraphBase : public Subgraph
163 NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {}
165 virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
166 std::vector<int>& matchedNodesIds,
167 std::vector<int>& targetNodesIds) CV_OVERRIDE
169 if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
171 Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]);
172 opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node;
174 for (int i = 0; i < node->attribute_size(); i++)
176 opencv_onnx::AttributeProto attr = node->attribute(i);
177 if (attr.name() != "axes")
179 if (attr.ints_size() != 1)
180 CV_Error(Error::StsNotImplemented, format("Unexpected number of axes: %d", attr.ints_size()));
184 CV_Error(Error::StsNotImplemented, "Missed axes attribute");
189 virtual void finalize(const Ptr<ImportGraphWrapper>&,
190 const Ptr<ImportNodeWrapper>& fusedNode,
191 std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE
193 opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
194 opencv_onnx::AttributeProto* axis_attr = node->add_attribute();
195 axis_attr->set_name("axis");
196 axis_attr->set_i(axis);
198 opencv_onnx::AttributeProto* end_axis_attr = node->add_attribute();
199 end_axis_attr->set_name("end_axis");
200 end_axis_attr->set_i(axis);
204 int axis, normNodeOrder;
207 class NormalizeSubgraph1 : public NormalizeSubgraphBase
212 int input = addNodeToMatch("");
213 int norm = addNodeToMatch("ReduceL2", input);
214 addNodeToMatch("Div", input, norm);
215 setFusedNode("Normalize", input);
219 class NormalizeSubgraph2 : public NormalizeSubgraphBase
224 int input = addNodeToMatch("");
225 int norm = addNodeToMatch("ReduceL2", input);
226 int clip = addNodeToMatch("Clip", norm);
227 int shape = addNodeToMatch("Shape", input);
228 int expand = addNodeToMatch("Expand", clip, shape);
229 addNodeToMatch("Div", input, expand);
230 setFusedNode("Normalize", input);
234 class NormalizeSubgraph3 : public NormalizeSubgraphBase
237 NormalizeSubgraph3() : NormalizeSubgraphBase(1)
239 int input = addNodeToMatch("");
240 int power = addNodeToMatch("Constant");
241 int squared = addNodeToMatch("Pow", input, power);
242 int sum = addNodeToMatch("ReduceSum", squared);
243 int sqrtNode = addNodeToMatch("Sqrt", sum);
244 int eps = addNodeToMatch("Constant");
245 int add = addNodeToMatch("Add", sqrtNode, eps);
247 addNodeToMatch("Div", input, add);
248 setFusedNode("Normalize", input);
252 class GatherCastSubgraph : public Subgraph
257 int input = addNodeToMatch("");
258 int index = addNodeToMatch("Constant");
259 int gather = addNodeToMatch("Gather", input, index);
260 addNodeToMatch("Cast", gather);
261 setFusedNode("Gather", input, index);
264 virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
265 std::vector<int>& matchedNodesIds,
266 std::vector<int>& targetNodesIds) CV_OVERRIDE
268 bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
269 size_t matchedNodesNum = matchedNodesIds.size();
270 // Now we check if merging can be made for these Gather and Cast nodes
271 if (!retVal || matchedNodesNum < 2)
274 int nodeToMatch = matchedNodesIds[matchedNodesNum - 1];
275 const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
276 if (node->getType() == "Cast") {
277 int inpNodeId = matchedNodesIds[matchedNodesNum - 2];
278 const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId);
279 if (inpNode->getType() == "Gather") {
280 int numNodes = net->getNumNodes();
281 std::string inpNodeName = node->getInputName(0);
282 for (int i = 0; i < numNodes; ++i) {
283 const Ptr<ImportNodeWrapper> node_to_check = net->getNode(i);
284 int numInp = node_to_check->getNumInputs();
285 for (int inp = 0; inp < numInp; ++inp) {
286 if (i != nodeToMatch && inpNodeName == node_to_check->getInputName(0)) {
287 // Another node has the same input node, so it cannot be merged.
299 class ExpandSubgraph : public Subgraph
304 int input = addNodeToMatch("");
305 int values = addNodeToMatch("");
306 int init = addNodeToMatch("ConstantOfShape", values);
307 int coeff = addNodeToMatch("Constant");
308 int mul = addNodeToMatch("Mul", init, coeff);
309 int shape = addNodeToMatch("Constant");
310 int condition = addNodeToMatch("Equal", shape, mul);
311 int where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant"));
312 addNodeToMatch("Expand", input, where);
313 setFusedNode("Expand", input, shape);
317 class MulCastSubgraph : public Subgraph
322 int input = addNodeToMatch("");
323 int scaleNode = addNodeToMatch("Constant");
324 int mul = addNodeToMatch("Mul", input, scaleNode);
325 addNodeToMatch("Cast", mul);
326 setFusedNode("Mul", input, scaleNode);
330 class ExtractScalesSubgraph : public Subgraph
333 ExtractScalesSubgraph()
335 input = addNodeToMatch("");
337 int indexH = addNodeToMatch("Constant");
338 int shape1 = addNodeToMatch("Shape", input);
339 int gather1 = addNodeToMatch("Gather", shape1, indexH);
340 scaleHNode = addNodeToMatch("Constant");
341 int mul1 = addNodeToMatch("Mul", gather1, scaleHNode);
342 int floor1 = addNodeToMatch("Floor", mul1);
344 int indexW = addNodeToMatch("Constant");
345 int shape2 = addNodeToMatch("Shape", input);
346 int gather2 = addNodeToMatch("Gather", shape2, indexW);
347 scaleWNode = addNodeToMatch("Constant");
348 int mul2 = addNodeToMatch("Mul", gather2, scaleWNode);
349 int floor2 = addNodeToMatch("Floor", mul2);
351 int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1);
352 int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2);
353 concatId = addNodeToMatch("Concat", unsqueeze1, unsqueeze2);
356 void finalize(const Ptr<ImportGraphWrapper>& net,
357 const Ptr<ImportNodeWrapper>& fusedNode,
358 std::vector<Ptr<ImportNodeWrapper> >& inputs) CV_OVERRIDE
360 opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast<ONNXNodeWrapper>()->node;
361 opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t();
362 Mat scaleW = getMatFromTensor(tensor_proto);
363 CV_Assert(scaleW.total() == 1);
364 scaleW.convertTo(scaleW, CV_32F);
366 constant_node = inputs[2].dynamicCast<ONNXNodeWrapper>()->node;
367 tensor_proto = constant_node->attribute(0).t();
368 Mat scaleH = getMatFromTensor(tensor_proto);
369 CV_Assert(scaleH.total() == 1);
370 scaleH.convertTo(scaleH, CV_32F);
372 opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
373 opencv_onnx::AttributeProto* attrH = node->add_attribute();
374 attrH->set_name("height_scale");
375 attrH->set_i(scaleH.at<float>(0));
376 opencv_onnx::AttributeProto* attrW = node->add_attribute();
377 attrW->set_name("width_scale");
378 attrW->set_i(scaleW.at<float>(0));
380 node->mutable_input()->DeleteSubrange(1, 2); // Remove two last inputs
385 int scaleHNode, scaleWNode;
388 class UpsampleSubgraph : public ExtractScalesSubgraph
391 UpsampleSubgraph() : ExtractScalesSubgraph()
393 int shape = addNodeToMatch("Shape", input);
394 int slice = addNodeToMatch("Slice", shape);
396 int castConcat = addNodeToMatch("Cast", concatId);
397 int castSlice = addNodeToMatch("Cast", slice);
398 int divide = addNodeToMatch("Div", castConcat, castSlice);
400 int constant = addNodeToMatch("Constant");
401 int concat = addNodeToMatch("Concat", constant, divide);
403 addNodeToMatch("Upsample", input, concat);
404 setFusedNode("Upsample", input, scaleWNode, scaleHNode);
408 class ResizeSubgraph1 : public ExtractScalesSubgraph
411 ResizeSubgraph1() : ExtractScalesSubgraph()
413 int shape = addNodeToMatch("Shape", input);
414 int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
416 int castConcat = addNodeToMatch("Cast", concatId);
417 int concat = addNodeToMatch("Concat", slice, castConcat);
418 int constant = addNodeToMatch("Constant");
420 addNodeToMatch("Resize", input, constant, constant, concat);
421 setFusedNode("Upsample", input, scaleWNode, scaleHNode);
425 class ResizeSubgraph2 : public ExtractScalesSubgraph
428 ResizeSubgraph2() : ExtractScalesSubgraph()
430 int constantConcat = addNodeToMatch("Constant");
431 int castConcat = addNodeToMatch("Cast", concatId);
432 int concat = addNodeToMatch("Concat", constantConcat, castConcat);
433 int constant = addNodeToMatch("Constant");
435 addNodeToMatch("Resize", input, constant, constant, concat);
436 setFusedNode("Upsample", input, scaleWNode, scaleHNode);
440 class BatchNormalizationSubgraphBase : public Subgraph
443 BatchNormalizationSubgraphBase()
445 input = addNodeToMatch("");
446 var = addNodeToMatch("");
447 mean = addNodeToMatch("");
448 weight = addNodeToMatch("");
449 bias = addNodeToMatch("");
450 A = addNodeToMatch("");
451 shape1 = addNodeToMatch("");
452 shape2 = addNodeToMatch("");
455 int input, var, mean, weight, bias, A, shape1, shape2;
458 class BatchNormalizationSubgraph1 : public BatchNormalizationSubgraphBase
461 BatchNormalizationSubgraph1()
463 int reshape1 = addNodeToMatch("Reshape", weight, shape1);
464 int reshape2 = addNodeToMatch("Reshape", bias, shape2);
465 int shape3 = addNodeToMatch("Constant");
466 int reshape3 = addNodeToMatch("Reshape", var, shape3);
467 int shape4 = addNodeToMatch("Constant");
468 int reshape4 = addNodeToMatch("Reshape", mean, shape4);
469 int sqrtNode = addNodeToMatch("Sqrt", reshape3);
470 int divNode = addNodeToMatch("Div", A, sqrtNode);
471 int mul1 = addNodeToMatch("Mul", reshape1, divNode);
472 int mul2 = addNodeToMatch("Mul", reshape4, mul1);
473 int sub = addNodeToMatch("Sub", reshape2, mul2);
474 int mul3 = addNodeToMatch("Mul", input, mul1);
475 addNodeToMatch("Add", mul3, sub);
476 setFusedNode("BatchNormalization", input, weight, bias, mean, var);
480 class BatchNormalizationSubgraph2 : public BatchNormalizationSubgraphBase
483 BatchNormalizationSubgraph2()
485 int sqrtNode = addNodeToMatch("Sqrt", var);
486 int divNode = addNodeToMatch("Div", A, sqrtNode);
487 int mul1 = addNodeToMatch("Mul", weight, divNode);
488 int reshape2 = addNodeToMatch("Reshape", mul1, shape2);
490 int mulMean = addNodeToMatch("Mul", mean, mul1);
491 int sub = addNodeToMatch("Sub", bias, mulMean);
492 int reshape1 = addNodeToMatch("Reshape", sub, shape1);
494 int mulInput = addNodeToMatch("Mul", input, reshape2);
495 addNodeToMatch("Add", mulInput, reshape1);
496 setFusedNode("BatchNormalization", input, weight, bias, mean, var);
500 void simplifySubgraphs(opencv_onnx::GraphProto& net)
502 std::vector<Ptr<Subgraph> > subgraphs;
503 subgraphs.push_back(makePtr<GatherCastSubgraph>());
504 subgraphs.push_back(makePtr<MulCastSubgraph>());
505 subgraphs.push_back(makePtr<UpsampleSubgraph>());
506 subgraphs.push_back(makePtr<ResizeSubgraph1>());
507 subgraphs.push_back(makePtr<ResizeSubgraph2>());
508 subgraphs.push_back(makePtr<SoftMaxSubgraph>());
509 subgraphs.push_back(makePtr<NormalizeSubgraph1>());
510 subgraphs.push_back(makePtr<NormalizeSubgraph2>());
511 subgraphs.push_back(makePtr<NormalizeSubgraph3>());
512 subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
513 subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
514 subgraphs.push_back(makePtr<ExpandSubgraph>());
516 simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
519 Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
521 if (tensor_proto.raw_data().empty() && tensor_proto.float_data().empty() &&
522 tensor_proto.double_data().empty() && tensor_proto.int64_data().empty())
525 opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type();
527 std::vector<int> sizes;
528 for (int i = 0; i < tensor_proto.dims_size(); i++) {
529 sizes.push_back(tensor_proto.dims(i));
533 if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) {
535 if (!tensor_proto.float_data().empty()) {
536 const ::google::protobuf::RepeatedField<float> field = tensor_proto.float_data();
537 Mat(sizes, CV_32FC1, (void*)field.data()).copyTo(blob);
540 char* val = const_cast<char*>(tensor_proto.raw_data().c_str());
541 Mat(sizes, CV_32FC1, val).copyTo(blob);
544 else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
546 const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data();
547 CV_Assert(!field.empty());
548 Mat(sizes, CV_64FC1, (void*)field.data()).convertTo(blob, CV_32FC1);
550 else if (datatype == opencv_onnx::TensorProto_DataType_INT32)
552 if (!tensor_proto.int32_data().empty())
554 const ::google::protobuf::RepeatedField<int32_t> field = tensor_proto.int32_data();
555 Mat(sizes, CV_32SC1, (void*)field.data()).copyTo(blob);
559 char* val = const_cast<char*>(tensor_proto.raw_data().c_str());
560 Mat(sizes, CV_32SC1, val).copyTo(blob);
563 else if (datatype == opencv_onnx::TensorProto_DataType_INT64)
565 blob.create(sizes, CV_32SC1);
566 int32_t* dst = reinterpret_cast<int32_t*>(blob.data);
568 if (!tensor_proto.int64_data().empty()) {
569 ::google::protobuf::RepeatedField< ::google::protobuf::int64> src = tensor_proto.int64_data();
570 convertInt64ToInt32(src, dst, blob.total());
574 const char* val = tensor_proto.raw_data().c_str();
575 #if CV_STRONG_ALIGNMENT
576 // Aligned pointer is required: https://github.com/opencv/opencv/issues/16373
577 // this doesn't work: typedef int64_t CV_DECL_ALIGNED(1) unaligned_int64_t;
578 AutoBuffer<int64_t, 16> aligned_val;
579 if (!isAligned<sizeof(int64_t)>(val))
581 size_t sz = tensor_proto.raw_data().size();
582 aligned_val.allocate(divUp(sz, sizeof(int64_t)));
583 memcpy(aligned_val.data(), val, sz);
584 val = (const char*)aligned_val.data();
587 const int64_t* src = reinterpret_cast<const int64_t*>(val);
588 convertInt64ToInt32(src, dst, blob.total());
592 CV_Error(Error::StsUnsupportedFormat, "Unsupported data type: " +
593 opencv_onnx::TensorProto_DataType_Name(datatype));
594 if (tensor_proto.dims_size() == 0)
595 blob.dims = 1; // To force 1-dimensional cv::Mat for scalars.
599 CV__DNN_INLINE_NS_END
600 }} // namespace cv::dnn