1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 // Copyright (C) 2020, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
8 #include "../precomp.hpp"
10 #include "../graph_simplifier.hpp"
11 #include "onnx_graph_simplifier.hpp"
15 namespace cv { namespace dnn {
16 CV__DNN_INLINE_NS_BEGIN
18 // This wrapper can behave differently for fake input nodes and real graph nodes.
19 class ONNXNodeWrapper : public ImportNodeWrapper
22 ONNXNodeWrapper(opencv_onnx::NodeProto* _node = 0) : node(_node) {}
24 virtual int getNumInputs() const CV_OVERRIDE
26 return node ? node->input_size() : 0;
29 virtual std::string getInputName(int idx) const CV_OVERRIDE
31 CV_Assert_N(node, idx < node->input_size());
32 return node->input(idx);
35 virtual std::string getType() const CV_OVERRIDE
37 return node ? node->op_type() : "";
40 virtual void setType(const std::string& type) CV_OVERRIDE
43 node->set_op_type(type);
46 virtual void setInputNames(const std::vector<std::string>& inputs) CV_OVERRIDE
50 for (int i = 0; i < inputs.size(); ++i)
51 node->add_input(inputs[i]);
54 opencv_onnx::NodeProto* node;
57 // ONNX graph's inputs are separate from nodes so we index them before the rest of nodes.
58 class ONNXGraphWrapper : public ImportGraphWrapper
61 ONNXGraphWrapper(opencv_onnx::GraphProto& _net) : net(_net)
63 numInputs = net.input_size();
64 numInitializers = net.initializer_size();
67 virtual Ptr<ImportNodeWrapper> getNode(int idx) const CV_OVERRIDE
69 opencv_onnx::NodeProto* node = 0;
70 if (idx >= numInputs + numInitializers)
71 node = net.mutable_node(idx - numInputs - numInitializers);
72 return makePtr<ONNXNodeWrapper>(node);
75 virtual int getNumNodes() const CV_OVERRIDE
77 return numInputs + numInitializers + net.node_size();
80 virtual int getNumOutputs(int nodeId) const CV_OVERRIDE
82 if (nodeId < numInputs + numInitializers)
85 return net.node(nodeId - numInputs - numInitializers).output_size();
88 virtual std::string getOutputName(int nodeId, int outId) const CV_OVERRIDE
90 CV_Assert(outId < getNumOutputs(nodeId));
91 if (nodeId < numInputs)
92 return net.input(nodeId).name();
93 else if (nodeId < numInputs + numInitializers)
94 return net.initializer(nodeId - numInputs).name();
96 return net.node(nodeId - numInputs - numInitializers).output(outId);
99 virtual void removeNode(int idx) CV_OVERRIDE
101 CV_Assert(idx >= numInputs + numInitializers);
102 net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1);
106 int numInputs, numInitializers;
107 opencv_onnx::GraphProto& net;
110 class SoftMaxSubgraph : public Subgraph
113 SoftMaxSubgraph() : axis(1)
115 int input = addNodeToMatch("");
116 int inpExp = addNodeToMatch("Exp", input);
117 int sum = addNodeToMatch("ReduceSum", inpExp);
118 addNodeToMatch("Div", inpExp, sum);
119 setFusedNode("Softmax", input);
122 virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
123 std::vector<int>& matchedNodesIds,
124 std::vector<int>& targetNodesIds) CV_OVERRIDE
126 if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
128 Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[1]);
129 opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node;
131 for (int i = 0; i < node->attribute_size(); i++)
133 opencv_onnx::AttributeProto attr = node->attribute(i);
134 if (attr.name() != "axes")
136 if (attr.ints_size() != 1)
137 CV_Error(Error::StsNotImplemented, format("Unexpected number of axes: %d", attr.ints_size()));
141 CV_Error(Error::StsNotImplemented, "Missed axes attribute");
146 virtual void finalize(const Ptr<ImportGraphWrapper>&,
147 const Ptr<ImportNodeWrapper>& fusedNode,
148 std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE
150 opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
151 opencv_onnx::AttributeProto* attr = node->add_attribute();
152 attr->set_name("axis");
160 class NormalizeSubgraphBase : public Subgraph
163 NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {}
165 virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
166 std::vector<int>& matchedNodesIds,
167 std::vector<int>& targetNodesIds) CV_OVERRIDE
169 if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
171 Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]);
172 opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node;
174 for (int i = 0; i < node->attribute_size(); i++)
176 opencv_onnx::AttributeProto attr = node->attribute(i);
177 if (attr.name() != "axes")
179 if (attr.ints_size() != 1)
180 CV_Error(Error::StsNotImplemented, format("Unexpected number of axes: %d", attr.ints_size()));
184 CV_Error(Error::StsNotImplemented, "Missed axes attribute");
189 virtual void finalize(const Ptr<ImportGraphWrapper>&,
190 const Ptr<ImportNodeWrapper>& fusedNode,
191 std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE
193 opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
194 opencv_onnx::AttributeProto* axis_attr = node->add_attribute();
195 axis_attr->set_name("axis");
196 axis_attr->set_i(axis);
198 opencv_onnx::AttributeProto* end_axis_attr = node->add_attribute();
199 end_axis_attr->set_name("end_axis");
200 end_axis_attr->set_i(axis);
204 int axis, normNodeOrder;
207 class NormalizeSubgraph1 : public NormalizeSubgraphBase
212 int input = addNodeToMatch("");
213 int norm = addNodeToMatch("ReduceL2", input);
214 addNodeToMatch("Div", input, norm);
215 setFusedNode("Normalize", input);
219 class NormalizeSubgraph2 : public NormalizeSubgraphBase
224 int input = addNodeToMatch("");
225 int norm = addNodeToMatch("ReduceL2", input);
226 int clip = addNodeToMatch("Clip", norm);
227 int shape = addNodeToMatch("Shape", input);
228 int expand = addNodeToMatch("Expand", clip, shape);
229 addNodeToMatch("Div", input, expand);
230 setFusedNode("Normalize", input);
234 class NormalizeSubgraph3 : public NormalizeSubgraphBase
237 NormalizeSubgraph3() : NormalizeSubgraphBase(1)
239 int input = addNodeToMatch("");
240 int power = addNodeToMatch("Constant");
241 int squared = addNodeToMatch("Pow", input, power);
242 int sum = addNodeToMatch("ReduceSum", squared);
243 int sqrtNode = addNodeToMatch("Sqrt", sum);
244 int eps = addNodeToMatch("Constant");
245 int add = addNodeToMatch("Add", sqrtNode, eps);
247 addNodeToMatch("Div", input, add);
248 setFusedNode("Normalize", input);
252 class GatherCastSubgraph : public Subgraph
257 int input = addNodeToMatch("");
258 int index = addNodeToMatch("Constant");
259 int gather = addNodeToMatch("Gather", input, index);
260 addNodeToMatch("Cast", gather);
261 setFusedNode("Gather", input, index);
265 class MulCastSubgraph : public Subgraph
270 int input = addNodeToMatch("");
271 int scaleNode = addNodeToMatch("Constant");
272 int mul = addNodeToMatch("Mul", input, scaleNode);
273 addNodeToMatch("Cast", mul);
274 setFusedNode("Mul", input, scaleNode);
278 class ExtractScalesSubgraph : public Subgraph
281 ExtractScalesSubgraph()
283 input = addNodeToMatch("");
285 int indexH = addNodeToMatch("Constant");
286 int shape1 = addNodeToMatch("Shape", input);
287 int gather1 = addNodeToMatch("Gather", shape1, indexH);
288 scaleHNode = addNodeToMatch("Constant");
289 int mul1 = addNodeToMatch("Mul", gather1, scaleHNode);
290 int floor1 = addNodeToMatch("Floor", mul1);
292 int indexW = addNodeToMatch("Constant");
293 int shape2 = addNodeToMatch("Shape", input);
294 int gather2 = addNodeToMatch("Gather", shape2, indexW);
295 scaleWNode = addNodeToMatch("Constant");
296 int mul2 = addNodeToMatch("Mul", gather2, scaleWNode);
297 int floor2 = addNodeToMatch("Floor", mul2);
299 int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1);
300 int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2);
301 concatId = addNodeToMatch("Concat", unsqueeze1, unsqueeze2);
304 void finalize(const Ptr<ImportGraphWrapper>& net,
305 const Ptr<ImportNodeWrapper>& fusedNode,
306 std::vector<Ptr<ImportNodeWrapper> >& inputs) CV_OVERRIDE
308 opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast<ONNXNodeWrapper>()->node;
309 opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t();
310 Mat scaleW = getMatFromTensor(tensor_proto);
311 CV_Assert(scaleW.total() == 1);
312 scaleW.convertTo(scaleW, CV_32F);
314 constant_node = inputs[2].dynamicCast<ONNXNodeWrapper>()->node;
315 tensor_proto = constant_node->attribute(0).t();
316 Mat scaleH = getMatFromTensor(tensor_proto);
317 CV_Assert(scaleH.total() == 1);
318 scaleH.convertTo(scaleH, CV_32F);
320 opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
321 opencv_onnx::AttributeProto* attrH = node->add_attribute();
322 attrH->set_name("height_scale");
323 attrH->set_i(scaleH.at<float>(0));
324 opencv_onnx::AttributeProto* attrW = node->add_attribute();
325 attrW->set_name("width_scale");
326 attrW->set_i(scaleW.at<float>(0));
328 node->mutable_input()->DeleteSubrange(1, 2); // Remove two last inputs
333 int scaleHNode, scaleWNode;
336 class UpsampleSubgraph : public ExtractScalesSubgraph
339 UpsampleSubgraph() : ExtractScalesSubgraph()
341 int shape = addNodeToMatch("Shape", input);
342 int slice = addNodeToMatch("Slice", shape);
344 int castConcat = addNodeToMatch("Cast", concatId);
345 int castSlice = addNodeToMatch("Cast", slice);
346 int divide = addNodeToMatch("Div", castConcat, castSlice);
348 int constant = addNodeToMatch("Constant");
349 int concat = addNodeToMatch("Concat", constant, divide);
351 addNodeToMatch("Upsample", input, concat);
352 setFusedNode("Upsample", input, scaleWNode, scaleHNode);
356 class ResizeSubgraph1 : public ExtractScalesSubgraph
359 ResizeSubgraph1() : ExtractScalesSubgraph()
361 int shape = addNodeToMatch("Shape", input);
362 int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
364 int castConcat = addNodeToMatch("Cast", concatId);
365 int concat = addNodeToMatch("Concat", slice, castConcat);
366 int constant = addNodeToMatch("Constant");
368 addNodeToMatch("Resize", input, constant, constant, concat);
369 setFusedNode("Upsample", input, scaleWNode, scaleHNode);
373 class ResizeSubgraph2 : public ExtractScalesSubgraph
376 ResizeSubgraph2() : ExtractScalesSubgraph()
378 int constantConcat = addNodeToMatch("Constant");
379 int castConcat = addNodeToMatch("Cast", concatId);
380 int concat = addNodeToMatch("Concat", constantConcat, castConcat);
381 int constant = addNodeToMatch("Constant");
383 addNodeToMatch("Resize", input, constant, constant, concat);
384 setFusedNode("Upsample", input, scaleWNode, scaleHNode);
388 class BatchNormalizationSubgraphBase : public Subgraph
391 BatchNormalizationSubgraphBase()
393 input = addNodeToMatch("");
394 var = addNodeToMatch("");
395 mean = addNodeToMatch("");
396 weight = addNodeToMatch("");
397 bias = addNodeToMatch("");
398 A = addNodeToMatch("");
399 shape1 = addNodeToMatch("");
400 shape2 = addNodeToMatch("");
403 int input, var, mean, weight, bias, A, shape1, shape2;
406 class BatchNormalizationSubgraph1 : public BatchNormalizationSubgraphBase
409 BatchNormalizationSubgraph1()
411 int reshape1 = addNodeToMatch("Reshape", weight, shape1);
412 int reshape2 = addNodeToMatch("Reshape", bias, shape2);
413 int shape3 = addNodeToMatch("Constant");
414 int reshape3 = addNodeToMatch("Reshape", var, shape3);
415 int shape4 = addNodeToMatch("Constant");
416 int reshape4 = addNodeToMatch("Reshape", mean, shape4);
417 int sqrtNode = addNodeToMatch("Sqrt", reshape3);
418 int divNode = addNodeToMatch("Div", A, sqrtNode);
419 int mul1 = addNodeToMatch("Mul", reshape1, divNode);
420 int mul2 = addNodeToMatch("Mul", reshape4, mul1);
421 int sub = addNodeToMatch("Sub", reshape2, mul2);
422 int mul3 = addNodeToMatch("Mul", input, mul1);
423 addNodeToMatch("Add", mul3, sub);
424 setFusedNode("BatchNormalization", input, weight, bias, mean, var);
428 class BatchNormalizationSubgraph2 : public BatchNormalizationSubgraphBase
431 BatchNormalizationSubgraph2()
433 int sqrtNode = addNodeToMatch("Sqrt", var);
434 int divNode = addNodeToMatch("Div", A, sqrtNode);
435 int mul1 = addNodeToMatch("Mul", weight, divNode);
436 int reshape2 = addNodeToMatch("Reshape", mul1, shape2);
438 int mulMean = addNodeToMatch("Mul", mean, mul1);
439 int sub = addNodeToMatch("Sub", bias, mulMean);
440 int reshape1 = addNodeToMatch("Reshape", sub, shape1);
442 int mulInput = addNodeToMatch("Mul", input, reshape2);
443 addNodeToMatch("Add", mulInput, reshape1);
444 setFusedNode("BatchNormalization", input, weight, bias, mean, var);
448 void simplifySubgraphs(opencv_onnx::GraphProto& net)
450 std::vector<Ptr<Subgraph> > subgraphs;
451 subgraphs.push_back(makePtr<GatherCastSubgraph>());
452 subgraphs.push_back(makePtr<MulCastSubgraph>());
453 subgraphs.push_back(makePtr<UpsampleSubgraph>());
454 subgraphs.push_back(makePtr<ResizeSubgraph1>());
455 subgraphs.push_back(makePtr<ResizeSubgraph2>());
456 subgraphs.push_back(makePtr<SoftMaxSubgraph>());
457 subgraphs.push_back(makePtr<NormalizeSubgraph1>());
458 subgraphs.push_back(makePtr<NormalizeSubgraph2>());
459 subgraphs.push_back(makePtr<NormalizeSubgraph3>());
460 subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
461 subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
463 simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
466 Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
468 if (tensor_proto.raw_data().empty() && tensor_proto.float_data().empty() &&
469 tensor_proto.double_data().empty() && tensor_proto.int64_data().empty())
472 opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type();
474 std::vector<int> sizes;
475 for (int i = 0; i < tensor_proto.dims_size(); i++) {
476 sizes.push_back(tensor_proto.dims(i));
480 if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) {
482 if (!tensor_proto.float_data().empty()) {
483 const ::google::protobuf::RepeatedField<float> field = tensor_proto.float_data();
484 Mat(sizes, CV_32FC1, (void*)field.data()).copyTo(blob);
487 char* val = const_cast<char*>(tensor_proto.raw_data().c_str());
488 Mat(sizes, CV_32FC1, val).copyTo(blob);
491 else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
493 const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data();
494 CV_Assert(!field.empty());
495 Mat(sizes, CV_64FC1, (void*)field.data()).convertTo(blob, CV_32FC1);
497 else if (datatype == opencv_onnx::TensorProto_DataType_INT64)
499 blob.create(sizes, CV_32SC1);
500 int32_t* dst = reinterpret_cast<int32_t*>(blob.data);
502 if (!tensor_proto.int64_data().empty()) {
503 ::google::protobuf::RepeatedField< ::google::protobuf::int64> src = tensor_proto.int64_data();
504 convertInt64ToInt32(src, dst, blob.total());
508 const char* val = tensor_proto.raw_data().c_str();
509 #if CV_STRONG_ALIGNMENT
510 // Aligned pointer is required: https://github.com/opencv/opencv/issues/16373
511 // this doesn't work: typedef int64_t CV_DECL_ALIGNED(1) unaligned_int64_t;
512 AutoBuffer<int64_t, 16> aligned_val;
513 if (!isAligned<sizeof(int64_t)>(val))
515 size_t sz = tensor_proto.raw_data().size();
516 aligned_val.allocate(divUp(sz, sizeof(int64_t)));
517 memcpy(aligned_val.data(), val, sz);
518 val = (const char*)aligned_val.data();
521 const int64_t* src = reinterpret_cast<const int64_t*>(val);
522 convertInt64ToInt32(src, dst, blob.total());
526 CV_Error(Error::StsUnsupportedFormat, "Unsupported data type: " +
527 opencv_onnx::TensorProto_DataType_Name(datatype));
528 if (tensor_proto.dims_size() == 0)
529 blob.dims = 1; // To force 1-dimensional cv::Mat for scalars.
533 CV__DNN_INLINE_NS_END
534 }} // namespace cv::dnn