1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
8 #include "../precomp.hpp"
9 #include <opencv2/dnn/shape_utils.hpp>
20 #if defined(__GNUC__) && __GNUC__ >= 5
21 #pragma GCC diagnostic push
22 #pragma GCC diagnostic ignored "-Wsuggest-override"
24 #include "opencv-onnx.pb.h"
25 #if defined(__GNUC__) && __GNUC__ >= 5
26 #pragma GCC diagnostic pop
29 #include "onnx_graph_simplifier.hpp"
33 CV__DNN_INLINE_NS_BEGIN
38 opencv_onnx::ModelProto model_proto;
42 LayerInfo(int _layerId, int _outputId) : layerId(_layerId), outputId(_outputId) {}
45 std::map<std::string, Mat> getGraphTensors(
46 const opencv_onnx::GraphProto& graph_proto);
47 Mat getBlob(const opencv_onnx::NodeProto& node_proto, const std::map<std::string, Mat>& constBlobs, int index);
49 LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto);
50 bool isCeilMode(const LayerParams& layerParams);
54 ONNXImporter(const char *onnxFile)
56 std::fstream input(onnxFile, std::ios::in | std::ios::binary);
58 if (!model_proto.ParseFromIstream(&input))
59 CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model");
62 ONNXImporter(const char* buffer, size_t sizeBuffer)
64 struct _Buf : public std::streambuf
66 _Buf(const char* buffer, size_t sizeBuffer)
68 char* p = const_cast<char*>(buffer);
69 setg(p, p, p + sizeBuffer);
73 _Buf buf(buffer, sizeBuffer);
74 std::istream input(&buf);
76 if (!model_proto.ParseFromIstream(&input))
77 CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model from in-memory byte array.");
80 void populateNet(Net dstNet);
83 inline void replaceLayerParam(LayerParams& layerParams, const String& oldKey, const String& newKey)
85 if (layerParams.has(oldKey)) {
86 layerParams.set(newKey, layerParams.get(oldKey));
87 layerParams.erase(oldKey);
91 void releaseONNXTensor(opencv_onnx::TensorProto& tensor_proto)
93 if (!tensor_proto.raw_data().empty()) {
94 delete tensor_proto.release_raw_data();
98 template<typename T1, typename T2>
99 void convertInt64ToInt32(const T1& src, T2& dst, int size)
101 for (int i = 0; i < size; i++) {
102 if (src[i] < std::numeric_limits<int32_t>::min() || src[i] > std::numeric_limits<int32_t>::max()) {
103 CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
105 dst[i] = saturate_cast<int32_t>(src[i]);
109 Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
111 CV_Assert(!tensor_proto.raw_data().empty() || !tensor_proto.float_data().empty()
112 || !tensor_proto.double_data().empty() || !tensor_proto.int64_data().empty());
114 opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type();
116 std::vector<int> sizes;
117 for (int i = 0; i < tensor_proto.dims_size(); i++) {
118 sizes.push_back(tensor_proto.dims(i));
122 if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) {
124 if (!tensor_proto.float_data().empty()) {
125 const ::google::protobuf::RepeatedField<float> field = tensor_proto.float_data();
126 Mat(sizes, CV_32FC1, (void*)field.data()).copyTo(blob);
129 char* val = const_cast<char*>(tensor_proto.raw_data().c_str());
130 Mat(sizes, CV_32FC1, val).copyTo(blob);
133 else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
135 const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data();
136 CV_Assert(!field.empty());
137 Mat(sizes, CV_64FC1, (void*)field.data()).convertTo(blob, CV_32FC1);
139 else if (datatype == opencv_onnx::TensorProto_DataType_INT64)
141 blob.create(sizes, CV_32SC1);
142 int32_t* dst = reinterpret_cast<int32_t*>(blob.data);
144 if (!tensor_proto.int64_data().empty()) {
145 ::google::protobuf::RepeatedField< ::google::protobuf::int64> src = tensor_proto.int64_data();
146 convertInt64ToInt32(src, dst, blob.total());
150 const char* val = tensor_proto.raw_data().c_str();
151 // Aligned pointer is required: https://github.com/opencv/opencv/issues/16373
152 // this doesn't work: typedef int64_t CV_DECL_ALIGNED(1) unaligned_int64_t;
153 AutoBuffer<int64_t, 16> aligned_val;
154 if (!isAligned<sizeof(int64_t)>(val))
156 size_t sz = tensor_proto.raw_data().size();
157 aligned_val.allocate(divUp(sz, sizeof(int64_t)));
158 memcpy(aligned_val.data(), val, sz);
159 val = (const char*)aligned_val.data();
161 const int64_t* src = reinterpret_cast<const int64_t*>(val);
162 convertInt64ToInt32(src, dst, blob.total());
166 CV_Error(Error::StsUnsupportedFormat, "Unsupported data type: " +
167 opencv_onnx::TensorProto_DataType_Name(datatype));
168 if (tensor_proto.dims_size() == 0)
169 blob.dims = 1; // To force 1-dimensional cv::Mat for scalars.
173 void runLayer(LayerParams& params, const std::vector<Mat>& inputs,
174 std::vector<Mat>& outputs)
176 Ptr<Layer> layer = LayerFactory::createLayerInstance(params.type, params);
177 CV_Assert((bool)layer);
179 std::vector<MatShape> inpShapes(inputs.size());
181 for (size_t i = 0; i < inputs.size(); ++i)
183 inpShapes[i] = shape(inputs[i]);
184 if (i > 0 && ddepth != inputs[i].depth())
185 CV_Error(Error::StsNotImplemented, "Mixed input data types.");
186 ddepth = inputs[i].depth();
189 std::vector<MatShape> outShapes, internalShapes;
190 layer->getMemoryShapes(inpShapes, 0, outShapes, internalShapes);
192 std::vector<Mat> internals(internalShapes.size());
193 outputs.resize(outShapes.size());
194 for (size_t i = 0; i < outShapes.size(); ++i)
195 outputs[i].create(outShapes[i], ddepth);
196 for (size_t i = 0; i < internalShapes.size(); ++i)
197 internals[i].create(internalShapes[i], ddepth);
199 layer->finalize(inputs, outputs);
200 layer->forward(inputs, outputs, internals);
203 std::map<std::string, Mat> ONNXImporter::getGraphTensors(
204 const opencv_onnx::GraphProto& graph_proto)
206 opencv_onnx::TensorProto tensor_proto;
207 std::map<std::string, Mat> layers_weights;
209 for (int i = 0; i < graph_proto.initializer_size(); i++)
211 tensor_proto = graph_proto.initializer(i);
212 Mat mat = getMatFromTensor(tensor_proto);
213 releaseONNXTensor(tensor_proto);
214 layers_weights.insert(std::make_pair(tensor_proto.name(), mat));
216 return layers_weights;
219 static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protobuf::int64>& src) {
220 std::vector<int32_t> dst(src.size());
221 convertInt64ToInt32(src, dst, src.size());
222 return DictValue::arrayInt(&dst[0], src.size());
225 LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto)
228 for(int i = 0; i < node_proto.attribute_size(); i++)
230 opencv_onnx::AttributeProto attribute_proto = node_proto.attribute(i);
231 std::string attribute_name = attribute_proto.name();
233 if(attribute_name == "kernel_shape")
235 CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
236 lp.set("kernel_size", parse(attribute_proto.ints()));
238 else if(attribute_name == "strides")
240 CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
241 lp.set("stride", parse(attribute_proto.ints()));
243 else if(attribute_name == "pads")
245 if (node_proto.op_type() == "Pad")
248 // Paddings are in order begin0, begin1, .. beginN, end0, end1, ..., endN.
249 // We need to shuffle it to begin0, end0, begin1, end1, ...
250 CV_Assert(attribute_proto.ints_size() % 2 == 0);
251 const int dims = attribute_proto.ints_size() / 2;
252 std::vector<int32_t> paddings;
253 paddings.reserve(attribute_proto.ints_size());
254 for (int i = 0; i < dims; ++i)
256 paddings.push_back(attribute_proto.ints(i));
257 paddings.push_back(attribute_proto.ints(dims + i));
259 lp.set("paddings", DictValue::arrayInt(&paddings[0], paddings.size()));
263 // Convolution or pooling.
264 CV_Assert(attribute_proto.ints_size() == 4 || attribute_proto.ints_size() == 6);
265 lp.set("pad", parse(attribute_proto.ints()));
268 else if(attribute_name == "auto_pad")
270 if (attribute_proto.s() == "SAME_UPPER" || attribute_proto.s() == "SAME_LOWER") {
271 lp.set("pad_mode", "SAME");
273 else if (attribute_proto.s() == "VALID") {
274 lp.set("pad_mode", "VALID");
277 else if(attribute_name == "dilations")
279 CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
280 lp.set("dilation", parse(attribute_proto.ints()));
282 else if (attribute_proto.has_i())
284 ::google::protobuf::int64 src = attribute_proto.i();
285 if (src < std::numeric_limits<int32_t>::min() || src > std::numeric_limits<int32_t>::max())
286 CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
288 lp.set(attribute_name, saturate_cast<int32_t>(src));
290 else if (attribute_proto.has_f())
292 lp.set(attribute_name, attribute_proto.f());
294 else if (attribute_proto.has_s())
296 lp.set(attribute_name, attribute_proto.s());
298 else if (attribute_proto.floats_size() > 0)
300 lp.set(attribute_name, DictValue::arrayReal(
301 attribute_proto.floats().data(), attribute_proto.floats_size()));
303 else if (attribute_proto.ints_size() > 0)
305 lp.set(attribute_proto.name(), parse(attribute_proto.ints()));
307 else if (attribute_proto.has_t())
309 opencv_onnx::TensorProto tensor = attribute_proto.t();
310 Mat blob = getMatFromTensor(tensor);
311 lp.blobs.push_back(blob);
313 else if (attribute_proto.has_g() || attribute_proto.strings_size() > 0 ||
314 attribute_proto.tensors_size() > 0 || attribute_proto.graphs_size() > 0)
316 CV_Error(Error::StsNotImplemented, "Unexpected attribute type");
319 CV_Error(Error::StsNotImplemented, "Unsupported attribute type");
324 Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto,
325 const std::map<std::string, Mat>& constBlobs, int index)
327 CV_Assert(index < node_proto.input_size());
328 std::map<std::string, Mat>::const_iterator constBlob;
329 constBlob = constBlobs.find(node_proto.input(index));
330 if (constBlob == constBlobs.end()) {
331 CV_Error(Error::StsObjectNotFound,
332 "Blob " + node_proto.input(index) + " not found in const blobs");
334 return constBlob->second;
337 void ONNXImporter::populateNet(Net dstNet)
339 CV_Assert(model_proto.has_graph());
340 opencv_onnx::GraphProto graph_proto = model_proto.graph();
342 simplifySubgraphs(graph_proto);
344 std::map<std::string, Mat> constBlobs = getGraphTensors(graph_proto);
345 // List of internal blobs shapes.
346 std::map<std::string, MatShape> outShapes;
347 // Add all the inputs shapes. It includes as constant blobs as network's inputs shapes.
348 for (int i = 0; i < graph_proto.input_size(); ++i)
350 opencv_onnx::ValueInfoProto valueInfoProto = graph_proto.input(i);
351 CV_Assert(valueInfoProto.has_type());
352 opencv_onnx::TypeProto typeProto = valueInfoProto.type();
353 CV_Assert(typeProto.has_tensor_type());
354 opencv_onnx::TypeProto::Tensor tensor = typeProto.tensor_type();
355 CV_Assert(tensor.has_shape());
356 opencv_onnx::TensorShapeProto tensorShape = tensor.shape();
358 MatShape inpShape(tensorShape.dim_size());
359 for (int j = 0; j < inpShape.size(); ++j)
361 inpShape[j] = tensorShape.dim(j).dim_value();
363 outShapes[valueInfoProto.name()] = inpShape;
366 std::string framework_name;
367 if (model_proto.has_producer_name()) {
368 framework_name = model_proto.producer_name();
371 // create map with network inputs (without const blobs)
372 std::map<std::string, LayerInfo> layer_id;
373 std::map<std::string, LayerInfo>::iterator layerId;
374 std::map<std::string, MatShape>::iterator shapeIt;
375 // fill map: push layer name, layer id and output id
376 std::vector<String> netInputs;
377 for (int j = 0; j < graph_proto.input_size(); j++)
379 const std::string& name = graph_proto.input(j).name();
380 if (constBlobs.find(name) == constBlobs.end()) {
381 netInputs.push_back(name);
382 layer_id.insert(std::make_pair(name, LayerInfo(0, netInputs.size() - 1)));
385 dstNet.setInputsNames(netInputs);
387 int layersSize = graph_proto.node_size();
388 LayerParams layerParams;
389 opencv_onnx::NodeProto node_proto;
391 for(int li = 0; li < layersSize; li++)
393 node_proto = graph_proto.node(li);
394 layerParams = getLayerParams(node_proto);
395 CV_Assert(node_proto.output_size() >= 1);
396 layerParams.name = node_proto.output(0);
398 std::string layer_type = node_proto.op_type();
399 layerParams.type = layer_type;
402 if (layer_type == "MaxPool")
404 layerParams.type = "Pooling";
405 layerParams.set("pool", "MAX");
406 layerParams.set("ceil_mode", layerParams.has("pad_mode"));
408 else if (layer_type == "AveragePool")
410 layerParams.type = "Pooling";
411 layerParams.set("pool", "AVE");
412 layerParams.set("ceil_mode", layerParams.has("pad_mode"));
413 layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
415 else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || layer_type == "ReduceMean")
417 CV_Assert(node_proto.input_size() == 1);
418 layerParams.type = "Pooling";
419 layerParams.set("pool", layer_type == "GlobalMaxPool"? "MAX" : "AVE");
420 layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
422 if (layer_type == "ReduceMean")
424 if (layerParams.get<int>("keepdims") == 0 || !layerParams.has("axes"))
425 CV_Error(Error::StsNotImplemented, "Unsupported mode of ReduceMean operation.");
427 MatShape inpShape = outShapes[node_proto.input(0)];
428 if (inpShape.size() != 4 && inpShape.size() != 5)
429 CV_Error(Error::StsNotImplemented, "Unsupported input shape of reduce_mean operation.");
431 DictValue axes = layerParams.get("axes");
432 CV_Assert(axes.size() <= inpShape.size() - 2);
433 std::vector<int> kernel_size(inpShape.size() - 2, 1);
434 for (int i = 0; i < axes.size(); i++) {
435 int axis = axes.get<int>(i);
436 CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
437 kernel_size[axis - 2] = inpShape[axis];
440 layerParams.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
443 else if (layer_type == "Slice")
445 if (layerParams.has("steps")) {
446 DictValue steps = layerParams.get("steps");
447 for (int i = 0; i < steps.size(); ++i) {
448 if (steps.get<int>(i) != 1)
449 CV_Error(Error::StsNotImplemented,
450 "Slice layer only supports steps = 1");
455 if (layerParams.has("axes")) {
456 DictValue axes = layerParams.get("axes");
457 for (int i = 1; i < axes.size(); ++i) {
458 CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
460 axis = axes.get<int>(0);
462 layerParams.set("axis", axis);
464 DictValue starts = layerParams.get("starts");
465 DictValue ends = layerParams.get("ends");
466 CV_Assert(starts.size() == ends.size());
468 std::vector<int> begin;
469 std::vector<int> end;
471 begin.resize(axis, 0);
472 end.resize(axis, -1);
475 for (int i = 0; i < starts.size(); ++i)
477 begin.push_back(starts.get<int>(i));
478 int finish = ends.get<int>(i);
479 end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
481 layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
482 layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
484 else if (layer_type == "Split")
486 DictValue splits = layerParams.get("split");
487 const int numSplits = splits.size();
488 CV_Assert(numSplits > 1);
490 std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
491 for (int i = 1; i < splits.size() - 1; ++i)
493 slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1);
495 layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
496 layerParams.type = "Slice";
498 else if (layer_type == "Add" || layer_type == "Sum")
500 if (layer_id.find(node_proto.input(1)) == layer_id.end())
502 Mat blob = getBlob(node_proto, constBlobs, 1);
503 blob = blob.reshape(1, 1);
504 if (blob.total() == 1) {
505 layerParams.type = "Power";
506 layerParams.set("shift", blob.at<float>(0));
509 layerParams.type = "Scale";
510 layerParams.set("bias_term", true);
511 layerParams.blobs.push_back(blob);
515 layerParams.type = "Eltwise";
518 else if (layer_type == "Max")
520 layerParams.type = "Eltwise";
521 layerParams.set("operation", "max");
523 else if (layer_type == "Sub")
525 Mat blob = getBlob(node_proto, constBlobs, 1);
526 if (blob.total() == 1) {
527 layerParams.type = "Power";
528 layerParams.set("shift", -blob.at<float>(0));
531 layerParams.type = "Scale";
532 layerParams.set("has_bias", true);
533 layerParams.blobs.push_back(-1.0f * blob.reshape(1, 1));
536 else if (layer_type == "Div")
538 if (constBlobs.find(node_proto.input(1)) == constBlobs.end())
540 layerParams.type = "Eltwise";
541 layerParams.set("operation", "div");
545 Mat blob = getBlob(node_proto, constBlobs, 1);
546 CV_Assert_N(blob.type() == CV_32F, blob.total());
547 if (blob.total() == 1)
549 layerParams.set("scale", 1.0f / blob.at<float>(0));
550 layerParams.type = "Power";
554 layerParams.type = "Scale";
555 divide(1.0, blob, blob);
556 layerParams.blobs.push_back(blob);
557 layerParams.set("bias_term", false);
561 else if (layer_type == "Neg")
563 layerParams.type = "Power";
564 layerParams.set("scale", -1);
566 else if (layer_type == "Constant")
568 CV_Assert(node_proto.input_size() == 0);
569 CV_Assert(layerParams.blobs.size() == 1);
570 constBlobs.insert(std::make_pair(layerParams.name, layerParams.blobs[0]));
573 else if (layer_type == "ImageScaler")
575 const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
576 layerParams.erase("scale");
578 if (layerParams.has("bias"))
580 layerParams.type = "Scale";
581 layerParams.blobs.push_back(
582 Mat(Size(1, layerParams.get("bias").size()), CV_32FC1, scale));
584 layerParams.set("bias_term", true);
585 Mat bias(1, layerParams.get("bias").size(), CV_32FC1);
586 for (int j = 0; j < bias.total(); j++) {
587 bias.at<float>(0, j) = layerParams.get("bias").getRealValue(j);
589 layerParams.blobs.push_back(bias);
590 layerParams.erase("bias");
593 layerParams.set("scale", scale);
594 layerParams.type = "Power";
597 else if (layer_type == "Clip")
599 layerParams.type = "ReLU6";
600 replaceLayerParam(layerParams, "min", "min_value");
601 replaceLayerParam(layerParams, "max", "max_value");
604 else if (layer_type == "LeakyRelu")
606 layerParams.type = "ReLU";
607 replaceLayerParam(layerParams, "alpha", "negative_slope");
609 else if (layer_type == "LRN")
611 replaceLayerParam(layerParams, "size", "local_size");
613 else if (layer_type == "InstanceNormalization")
615 if (node_proto.input_size() != 3)
616 CV_Error(Error::StsNotImplemented,
617 "Expected input, scale, bias");
619 layerParams.blobs.resize(4);
620 layerParams.blobs[2] = getBlob(node_proto, constBlobs, 1); // weightData
621 layerParams.blobs[3] = getBlob(node_proto, constBlobs, 2); // biasData
622 layerParams.set("has_bias", true);
623 layerParams.set("has_weight", true);
625 // Get number of channels in input
626 int size = layerParams.blobs[2].total();
627 layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean
628 layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std
630 LayerParams mvnParams;
631 mvnParams.name = layerParams.name + "/MVN";
632 mvnParams.type = "MVN";
633 mvnParams.set("eps", layerParams.get<float>("epsilon"));
634 layerParams.erase("epsilon");
637 int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams);
639 layerId = layer_id.find(node_proto.input(0));
640 CV_Assert(layerId != layer_id.end());
641 dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
643 layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0)));
644 outShapes[mvnParams.name] = outShapes[node_proto.input(0)];
646 //Replace Batch Norm's input to MVN
647 node_proto.set_input(0, mvnParams.name);
648 layerParams.type = "BatchNorm";
650 else if (layer_type == "BatchNormalization")
652 if (node_proto.input_size() != 5)
653 CV_Error(Error::StsNotImplemented,
654 "Expected input, scale, bias, mean and var");
656 layerParams.type = "BatchNorm";
657 replaceLayerParam(layerParams, "epsilon", "eps");
658 replaceLayerParam(layerParams, "spatial", "use_global_stats");
660 Mat meanData = getBlob(node_proto, constBlobs, 3);
661 Mat stdData = getBlob(node_proto, constBlobs, 4);
663 layerParams.blobs.push_back(meanData);
664 layerParams.blobs.push_back(stdData);
666 if (!node_proto.input(1).empty()) {
667 layerParams.set("has_weight", true);
668 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 1)); // weightData
670 layerParams.set("has_weight", false);
673 if (!node_proto.input(2).empty()) {
674 layerParams.set("has_bias", true);
675 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 2)); // biasData
677 layerParams.set("has_bias", false);
680 else if (layer_type == "Gemm")
682 CV_Assert(node_proto.input_size() >= 2);
683 layerParams.type = "InnerProduct";
684 Mat weights = getBlob(node_proto, constBlobs, 1);
686 if (layerParams.has("transB") && !layerParams.get<int>("transB")) {
687 transpose(weights, weights);
690 layerParams.blobs.push_back(weights);
692 if (node_proto.input_size() == 3) {
693 Mat bias = getBlob(node_proto, constBlobs, 2);
694 layerParams.blobs.push_back(bias);
697 layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
698 layerParams.set("bias_term", node_proto.input_size() == 3);
700 else if (layer_type == "MatMul")
702 CV_Assert(node_proto.input_size() == 2);
703 layerParams.type = "InnerProduct";
704 Mat blob = getBlob(node_proto, constBlobs, 1);
705 layerParams.blobs.push_back(blob.t());
706 layerParams.set("bias_term", false);
707 layerParams.set("num_output", layerParams.blobs[0].size[0]);
709 else if (layer_type == "Mul")
711 CV_Assert(node_proto.input_size() == 2);
712 if (layer_id.find(node_proto.input(1)) == layer_id.end()) {
713 Mat blob = getBlob(node_proto, constBlobs, 1);
714 blob = blob.reshape(1, 1);
715 if (blob.total() == 1) {
716 layerParams.set("scale", blob.at<float>(0));
717 layerParams.type = "Power";
720 layerParams.blobs.push_back(blob);
721 layerParams.type = "Scale";
725 layerParams.type = "Eltwise";
726 layerParams.set("operation", "prod");
729 else if (layer_type == "Conv")
731 CV_Assert(node_proto.input_size() >= 2);
732 layerParams.type = "Convolution";
733 for (int j = 1; j < node_proto.input_size(); j++) {
734 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
736 layerParams.set("num_output", layerParams.blobs[0].size[0]);
737 layerParams.set("bias_term", node_proto.input_size() == 3);
739 else if (layer_type == "ConvTranspose")
741 CV_Assert(node_proto.input_size() >= 2);
742 layerParams.type = "Deconvolution";
743 for (int j = 1; j < node_proto.input_size(); j++) {
744 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
746 layerParams.set("num_output", layerParams.blobs[0].size[1] * layerParams.get<int>("group", 1));
747 layerParams.set("bias_term", node_proto.input_size() == 3);
749 if (!layerParams.has("kernel_size"))
750 CV_Error(Error::StsNotImplemented,
751 "Required attribute 'kernel_size' is not present.");
753 if (layerParams.has("output_shape"))
755 const DictValue& outShape = layerParams.get("output_shape");
756 DictValue strides = layerParams.get("stride");
757 DictValue kernel = layerParams.get("kernel_size");
760 std::vector<int> adjust_pads;
761 if (layerParams.has("pad_mode"))
763 padMode = toUpperCase(layerParams.get<String>("pad_mode"));
764 if (padMode != "SAME" && padMode != "VALID")
765 CV_Error(Error::StsError, "Unsupported padding mode " + padMode);
767 for (int i = 0; i < strides.size(); i++)
769 int sz = outShape.get<int>(2 + i);
770 int stride = strides.get<int>(i);
771 adjust_pads.push_back(padMode == "SAME"? (sz - 1) % stride :
772 (sz - kernel.get<int>(i)) % stride);
774 layerParams.set("adj", DictValue::arrayInt(&adjust_pads[0], adjust_pads.size()));
777 else if (layerParams.has("output_padding"))
779 replaceLayerParam(layerParams, "output_padding", "adj");
782 else if (layer_type == "Transpose")
784 layerParams.type = "Permute";
785 replaceLayerParam(layerParams, "perm", "order");
787 CV_Assert(node_proto.input_size() == 1);
788 if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
790 std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), transposed;
791 runLayer(layerParams, inputs, transposed);
792 CV_Assert(transposed.size() == 1);
793 constBlobs.insert(std::make_pair(layerParams.name, transposed[0]));
797 else if (layer_type == "ReduceL2")
799 CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
800 CV_Assert(graph_proto.node_size() > li + 1 && graph_proto.node(li + 1).op_type() == "Div");
802 node_proto = graph_proto.node(li);
803 layerParams.name = node_proto.output(0);
804 layerParams.type = "Normalize";
806 DictValue axes_dict = layerParams.get("axes");
807 if (axes_dict.size() != 1)
808 CV_Error(Error::StsNotImplemented, "Multidimensional reduceL2");
809 int axis = axes_dict.getIntValue(0);
810 layerParams.set("axis",axis);
811 layerParams.set("end_axis", axis);
813 else if (layer_type == "Squeeze")
815 CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
816 DictValue axes_dict = layerParams.get("axes");
817 if (axes_dict.size() != 1)
818 CV_Error(Error::StsNotImplemented, "Multidimensional squeeze");
820 int axis = axes_dict.getIntValue(0);
821 layerParams.set("axis", axis - 1);
822 layerParams.set("end_axis", axis);
823 layerParams.type = "Flatten";
825 else if (layer_type == "Unsqueeze")
827 CV_Assert(node_proto.input_size() == 1);
828 DictValue axes = layerParams.get("axes");
829 if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
832 Mat input = getBlob(node_proto, constBlobs, 0);
834 std::vector<int> dims;
835 for (int j = 0; j < input.dims; j++) {
836 dims.push_back(input.size[j]);
838 CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
839 for (int j = 0; j < axes.size(); j++) {
840 dims.insert(dims.begin() + axes.getIntValue(j), 1);
843 Mat out = input.reshape(0, dims);
844 constBlobs.insert(std::make_pair(layerParams.name, out));
849 if (axes.size() != 1)
850 CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze");
852 MatShape inpShape = outShapes[node_proto.input(0)];
853 int axis = axes.getIntValue(0);
854 CV_Assert(0 <= axis && axis <= inpShape.size());
855 std::vector<int> outShape = inpShape;
856 outShape.insert(outShape.begin() + axis, 1);
857 layerParams.type = "Reshape";
858 layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
860 else if (layer_type == "Reshape")
862 CV_Assert(node_proto.input_size() == 2 || layerParams.has("shape"));
864 if (node_proto.input_size() == 2) {
865 Mat blob = getBlob(node_proto, constBlobs, 1);
866 CV_Assert(blob.type() == CV_32SC1);
868 layerParams.set("dim", DictValue::arrayInt<int*>(
869 blob.ptr<int>(), blob.total() ));
871 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
872 std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), outputs;
873 runLayer(layerParams, inputs, outputs);
874 constBlobs.insert(std::make_pair(layerParams.name, outputs[0]));
879 DictValue shape = layerParams.get("shape");
880 std::vector<int> dim;
881 for (int j = 0; j < shape.size(); j++) {
882 dim.push_back(shape.getIntValue(j));
885 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
886 Mat input = getBlob(node_proto, constBlobs, 0);
887 Mat out = input.reshape(0, dim);
888 constBlobs.insert(std::make_pair(layerParams.name, out));
891 replaceLayerParam(layerParams, "shape", "dim");
894 else if (layer_type == "Pad")
896 layerParams.type = "Padding";
898 else if (layer_type == "Shape")
900 CV_Assert(node_proto.input_size() == 1);
901 shapeIt = outShapes.find(node_proto.input(0));
902 CV_Assert(shapeIt != outShapes.end());
903 MatShape inpShape = shapeIt->second;
905 Mat shapeMat(inpShape.size(), 1, CV_32S);
906 for (int j = 0; j < inpShape.size(); ++j)
907 shapeMat.at<int>(j) = inpShape[j];
910 constBlobs.insert(std::make_pair(layerParams.name, shapeMat));
913 else if (layer_type == "Gather")
915 CV_Assert(node_proto.input_size() == 2);
916 CV_Assert(layerParams.has("axis"));
917 Mat input = getBlob(node_proto, constBlobs, 0);
918 Mat indexMat = getBlob(node_proto, constBlobs, 1);
919 CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
920 int index = indexMat.at<int>(0);
921 int axis = layerParams.get<int>("axis");
923 std::vector<cv::Range> ranges(input.dims, Range::all());
924 ranges[axis] = Range(index, index + 1);
926 Mat out = input(ranges);
927 constBlobs.insert(std::make_pair(layerParams.name, out));
930 else if (layer_type == "Concat")
932 bool hasVariableInps = false;
933 for (int i = 0; i < node_proto.input_size(); ++i)
935 if (layer_id.find(node_proto.input(i)) != layer_id.end())
937 hasVariableInps = true;
942 if (!hasVariableInps)
944 std::vector<Mat> inputs(node_proto.input_size()), concatenated;
945 for (size_t i = 0; i < inputs.size(); ++i)
947 inputs[i] = getBlob(node_proto, constBlobs, i);
949 runLayer(layerParams, inputs, concatenated);
951 CV_Assert(concatenated.size() == 1);
952 constBlobs.insert(std::make_pair(layerParams.name, concatenated[0]));
956 else if (layer_type == "Upsample")
958 layerParams.type = "Resize";
959 if (layerParams.has("scales"))
962 DictValue scales = layerParams.get("scales");
963 CV_Assert(scales.size() == 4);
964 layerParams.set("zoom_factor_y", scales.getIntValue(2));
965 layerParams.set("zoom_factor_x", scales.getIntValue(3));
970 replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
971 replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
973 replaceLayerParam(layerParams, "mode", "interpolation");
975 else if (layer_type == "LogSoftmax")
977 layerParams.type = "Softmax";
978 layerParams.set("log_softmax", true);
982 for (int j = 0; j < node_proto.input_size(); j++) {
983 if (layer_id.find(node_proto.input(j)) == layer_id.end())
984 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
988 int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
989 for (int i = 0; i < node_proto.output_size(); ++i)
991 layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
994 std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
995 for (int j = 0; j < node_proto.input_size(); j++) {
996 layerId = layer_id.find(node_proto.input(j));
997 if (layerId != layer_id.end()) {
998 dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, j);
999 // Collect input shapes.
1000 shapeIt = outShapes.find(node_proto.input(j));
1001 CV_Assert(shapeIt != outShapes.end());
1002 layerInpShapes.push_back(shapeIt->second);
1006 // Compute shape of output blob for this layer.
1007 Ptr<Layer> layer = dstNet.getLayer(id);
1008 layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
1009 for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
1011 outShapes[node_proto.output(i)] = layerOutShapes[i];
1016 Net readNetFromONNX(const String& onnxFile)
1018 ONNXImporter onnxImporter(onnxFile.c_str());
1020 onnxImporter.populateNet(net);
1024 Net readNetFromONNX(const char* buffer, size_t sizeBuffer)
1026 ONNXImporter onnxImporter(buffer, sizeBuffer);
1028 onnxImporter.populateNet(net);
1032 Net readNetFromONNX(const std::vector<uchar>& buffer)
1034 return readNetFromONNX(reinterpret_cast<const char*>(buffer.data()), buffer.size());
1037 Mat readTensorFromONNX(const String& path)
1039 opencv_onnx::TensorProto tensor_proto = opencv_onnx::TensorProto();
1040 std::fstream input(path.c_str(), std::ios::in | std::ios::binary);
1041 if (!tensor_proto.ParseFromIstream(&input)) {
1042 CV_Error(Error::StsUnsupportedFormat, "Failed to parse data");
1044 Mat mat = getMatFromTensor(tensor_proto);
1045 releaseONNXTensor(tensor_proto);
1049 CV__DNN_INLINE_NS_END