1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
8 #include "../precomp.hpp"
9 #include <opencv2/dnn/shape_utils.hpp>
20 #if defined(__GNUC__) && __GNUC__ >= 5
21 #pragma GCC diagnostic push
22 #pragma GCC diagnostic ignored "-Wsuggest-override"
24 #include "opencv-onnx.pb.h"
25 #if defined(__GNUC__) && __GNUC__ >= 5
26 #pragma GCC diagnostic pop
31 CV__DNN_INLINE_NS_BEGIN
36 opencv_onnx::ModelProto model_proto;
40 LayerInfo(int _layerId, int _outputId) : layerId(_layerId), outputId(_outputId) {}
43 std::map<std::string, Mat> getGraphTensors(
44 const opencv_onnx::GraphProto& graph_proto);
45 Mat getBlob(const opencv_onnx::NodeProto& node_proto, const std::map<std::string, Mat>& constBlobs, int index);
47 LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto);
48 bool isCeilMode(const LayerParams& layerParams);
52 ONNXImporter(const char *onnxFile)
54 std::fstream input(onnxFile, std::ios::in | std::ios::binary);
56 if (!model_proto.ParseFromIstream(&input))
57 CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model");
60 ONNXImporter(const char* buffer, size_t sizeBuffer)
62 struct _Buf : public std::streambuf
64 _Buf(const char* buffer, size_t sizeBuffer)
66 char* p = const_cast<char*>(buffer);
67 setg(p, p, p + sizeBuffer);
71 _Buf buf(buffer, sizeBuffer);
72 std::istream input(&buf);
74 if (!model_proto.ParseFromIstream(&input))
75 CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model from in-memory byte array.");
78 void populateNet(Net dstNet);
81 inline void replaceLayerParam(LayerParams& layerParams, const String& oldKey, const String& newKey)
83 if (layerParams.has(oldKey)) {
84 layerParams.set(newKey, layerParams.get(oldKey));
85 layerParams.erase(oldKey);
89 void releaseONNXTensor(opencv_onnx::TensorProto& tensor_proto)
91 if (!tensor_proto.raw_data().empty()) {
92 delete tensor_proto.release_raw_data();
96 template<typename T1, typename T2>
97 void convertInt64ToInt32(const T1& src, T2& dst, int size)
99 for (int i = 0; i < size; i++) {
100 if (src[i] < std::numeric_limits<int32_t>::min() || src[i] > std::numeric_limits<int32_t>::max()) {
101 CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
103 dst[i] = saturate_cast<int32_t>(src[i]);
107 Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
109 CV_Assert(!tensor_proto.raw_data().empty() || !tensor_proto.float_data().empty()
110 || !tensor_proto.double_data().empty() || !tensor_proto.int64_data().empty());
112 opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type();
114 std::vector<int> sizes;
115 for (int i = 0; i < tensor_proto.dims_size(); i++) {
116 sizes.push_back(tensor_proto.dims(i));
120 if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) {
122 if (!tensor_proto.float_data().empty()) {
123 const ::google::protobuf::RepeatedField<float> field = tensor_proto.float_data();
124 Mat(sizes, CV_32FC1, (void*)field.data()).copyTo(blob);
127 char* val = const_cast<char*>(tensor_proto.raw_data().c_str());
128 Mat(sizes, CV_32FC1, val).copyTo(blob);
131 else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
133 const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data();
134 CV_Assert(!field.empty());
135 Mat(sizes, CV_64FC1, (void*)field.data()).convertTo(blob, CV_32FC1);
137 else if (datatype == opencv_onnx::TensorProto_DataType_INT64)
139 blob.create(sizes, CV_32SC1);
140 int32_t* dst = reinterpret_cast<int32_t*>(blob.data);
142 if (!tensor_proto.int64_data().empty()) {
143 ::google::protobuf::RepeatedField< ::google::protobuf::int64> src = tensor_proto.int64_data();
144 convertInt64ToInt32(src, dst, blob.total());
148 char* val = const_cast<char*>(tensor_proto.raw_data().c_str());
149 int64_t* src = reinterpret_cast<int64_t*>(val);
150 convertInt64ToInt32(src, dst, blob.total());
154 CV_Error(Error::StsUnsupportedFormat, "Unsupported data type: " +
155 opencv_onnx::TensorProto_DataType_Name(datatype));
156 if (tensor_proto.dims_size() == 0)
157 blob.dims = 1; // To force 1-dimensional cv::Mat for scalars.
161 void runLayer(LayerParams& params, const std::vector<Mat>& inputs,
162 std::vector<Mat>& outputs)
164 Ptr<Layer> layer = LayerFactory::createLayerInstance(params.type, params);
165 CV_Assert((bool)layer);
167 std::vector<MatShape> inpShapes(inputs.size());
169 for (size_t i = 0; i < inputs.size(); ++i)
171 inpShapes[i] = shape(inputs[i]);
172 if (i > 0 && ddepth != inputs[i].depth())
173 CV_Error(Error::StsNotImplemented, "Mixed input data types.");
174 ddepth = inputs[i].depth();
177 std::vector<MatShape> outShapes, internalShapes;
178 layer->getMemoryShapes(inpShapes, 0, outShapes, internalShapes);
180 std::vector<Mat> internals(internalShapes.size());
181 outputs.resize(outShapes.size());
182 for (size_t i = 0; i < outShapes.size(); ++i)
183 outputs[i].create(outShapes[i], ddepth);
184 for (size_t i = 0; i < internalShapes.size(); ++i)
185 internals[i].create(internalShapes[i], ddepth);
187 layer->finalize(inputs, outputs);
188 layer->forward(inputs, outputs, internals);
191 std::map<std::string, Mat> ONNXImporter::getGraphTensors(
192 const opencv_onnx::GraphProto& graph_proto)
194 opencv_onnx::TensorProto tensor_proto;
195 std::map<std::string, Mat> layers_weights;
197 for (int i = 0; i < graph_proto.initializer_size(); i++)
199 tensor_proto = graph_proto.initializer(i);
200 Mat mat = getMatFromTensor(tensor_proto);
201 releaseONNXTensor(tensor_proto);
202 layers_weights.insert(std::make_pair(tensor_proto.name(), mat));
204 return layers_weights;
207 static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protobuf::int64>& src) {
208 std::vector<int32_t> dst(src.size());
209 convertInt64ToInt32(src, dst, src.size());
210 return DictValue::arrayInt(&dst[0], src.size());
213 LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto)
216 for(int i = 0; i < node_proto.attribute_size(); i++)
218 opencv_onnx::AttributeProto attribute_proto = node_proto.attribute(i);
219 std::string attribute_name = attribute_proto.name();
221 if(attribute_name == "kernel_shape")
223 CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
224 lp.set("kernel_size", parse(attribute_proto.ints()));
226 else if(attribute_name == "strides")
228 CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
229 lp.set("stride", parse(attribute_proto.ints()));
231 else if(attribute_name == "pads")
233 if (node_proto.op_type() == "Pad")
236 // Paddings are in order begin0, begin1, .. beginN, end0, end1, ..., endN.
237 // We need to shuffle it to begin0, end0, begin1, end1, ...
238 CV_Assert(attribute_proto.ints_size() % 2 == 0);
239 const int dims = attribute_proto.ints_size() / 2;
240 std::vector<int32_t> paddings;
241 paddings.reserve(attribute_proto.ints_size());
242 for (int i = 0; i < dims; ++i)
244 paddings.push_back(attribute_proto.ints(i));
245 paddings.push_back(attribute_proto.ints(dims + i));
247 lp.set("paddings", DictValue::arrayInt(&paddings[0], paddings.size()));
251 // Convolution or pooling.
252 CV_Assert(attribute_proto.ints_size() == 4 || attribute_proto.ints_size() == 6);
253 lp.set("pad", parse(attribute_proto.ints()));
256 else if(attribute_name == "auto_pad")
258 if (attribute_proto.s() == "SAME_UPPER" || attribute_proto.s() == "SAME_LOWER") {
259 lp.set("pad_mode", "SAME");
261 else if (attribute_proto.s() == "VALID") {
262 lp.set("pad_mode", "VALID");
265 else if(attribute_name == "dilations")
267 CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
268 lp.set("dilation", parse(attribute_proto.ints()));
270 else if (attribute_proto.has_i())
272 ::google::protobuf::int64 src = attribute_proto.i();
273 if (src < std::numeric_limits<int32_t>::min() || src > std::numeric_limits<int32_t>::max())
274 CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
276 lp.set(attribute_name, saturate_cast<int32_t>(src));
278 else if (attribute_proto.has_f())
280 lp.set(attribute_name, attribute_proto.f());
282 else if (attribute_proto.has_s())
284 lp.set(attribute_name, attribute_proto.s());
286 else if (attribute_proto.floats_size() > 0)
288 lp.set(attribute_name, DictValue::arrayReal(
289 attribute_proto.floats().data(), attribute_proto.floats_size()));
291 else if (attribute_proto.ints_size() > 0)
293 lp.set(attribute_proto.name(), parse(attribute_proto.ints()));
295 else if (attribute_proto.has_t())
297 opencv_onnx::TensorProto tensor = attribute_proto.t();
298 Mat blob = getMatFromTensor(tensor);
299 lp.blobs.push_back(blob);
301 else if (attribute_proto.has_g() || attribute_proto.strings_size() > 0 ||
302 attribute_proto.tensors_size() > 0 || attribute_proto.graphs_size() > 0)
304 CV_Error(Error::StsNotImplemented, "Unexpected attribute type");
307 CV_Error(Error::StsNotImplemented, "Unsupported attribute type");
312 Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto,
313 const std::map<std::string, Mat>& constBlobs, int index)
315 CV_Assert(index < node_proto.input_size());
316 std::map<std::string, Mat>::const_iterator constBlob;
317 constBlob = constBlobs.find(node_proto.input(index));
318 if (constBlob == constBlobs.end()) {
319 CV_Error(Error::StsObjectNotFound,
320 "Blob " + node_proto.input(index) + " not found in const blobs");
322 return constBlob->second;
325 void ONNXImporter::populateNet(Net dstNet)
327 CV_Assert(model_proto.has_graph());
328 opencv_onnx::GraphProto graph_proto = model_proto.graph();
329 std::map<std::string, Mat> constBlobs = getGraphTensors(graph_proto);
330 // List of internal blobs shapes.
331 std::map<std::string, MatShape> outShapes;
332 // Add all the inputs shapes. It includes as constant blobs as network's inputs shapes.
333 for (int i = 0; i < graph_proto.input_size(); ++i)
335 opencv_onnx::ValueInfoProto valueInfoProto = graph_proto.input(i);
336 CV_Assert(valueInfoProto.has_type());
337 opencv_onnx::TypeProto typeProto = valueInfoProto.type();
338 CV_Assert(typeProto.has_tensor_type());
339 opencv_onnx::TypeProto::Tensor tensor = typeProto.tensor_type();
340 CV_Assert(tensor.has_shape());
341 opencv_onnx::TensorShapeProto tensorShape = tensor.shape();
343 MatShape inpShape(tensorShape.dim_size());
344 for (int j = 0; j < inpShape.size(); ++j)
346 inpShape[j] = tensorShape.dim(j).dim_value();
348 outShapes[valueInfoProto.name()] = inpShape;
351 std::string framework_name;
352 if (model_proto.has_producer_name()) {
353 framework_name = model_proto.producer_name();
356 // create map with network inputs (without const blobs)
357 std::map<std::string, LayerInfo> layer_id;
358 std::map<std::string, LayerInfo>::iterator layerId;
359 std::map<std::string, MatShape>::iterator shapeIt;
360 // fill map: push layer name, layer id and output id
361 std::vector<String> netInputs;
362 for (int j = 0; j < graph_proto.input_size(); j++)
364 const std::string& name = graph_proto.input(j).name();
365 if (constBlobs.find(name) == constBlobs.end()) {
366 netInputs.push_back(name);
367 layer_id.insert(std::make_pair(name, LayerInfo(0, netInputs.size() - 1)));
370 dstNet.setInputsNames(netInputs);
372 int layersSize = graph_proto.node_size();
373 LayerParams layerParams;
374 opencv_onnx::NodeProto node_proto;
376 for(int li = 0; li < layersSize; li++)
378 node_proto = graph_proto.node(li);
379 layerParams = getLayerParams(node_proto);
380 CV_Assert(node_proto.output_size() >= 1);
381 layerParams.name = node_proto.output(0);
383 std::string layer_type = node_proto.op_type();
384 layerParams.type = layer_type;
387 if (layer_type == "MaxPool")
389 layerParams.type = "Pooling";
390 layerParams.set("pool", "MAX");
391 layerParams.set("ceil_mode", layerParams.has("pad_mode"));
393 else if (layer_type == "AveragePool")
395 layerParams.type = "Pooling";
396 layerParams.set("pool", "AVE");
397 layerParams.set("ceil_mode", layerParams.has("pad_mode"));
398 layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
400 else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || layer_type == "ReduceMean")
402 CV_Assert(node_proto.input_size() == 1);
403 layerParams.type = "Pooling";
404 layerParams.set("pool", layer_type == "GlobalMaxPool"? "MAX" : "AVE");
405 layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
407 if (layer_type == "ReduceMean")
409 if (layerParams.get<int>("keepdims") == 0 || !layerParams.has("axes"))
410 CV_Error(Error::StsNotImplemented, "Unsupported mode of ReduceMean operation.");
412 MatShape inpShape = outShapes[node_proto.input(0)];
413 if (inpShape.size() != 4 && inpShape.size() != 5)
414 CV_Error(Error::StsNotImplemented, "Unsupported input shape of reduce_mean operation.");
416 DictValue axes = layerParams.get("axes");
417 CV_Assert(axes.size() <= inpShape.size() - 2);
418 std::vector<int> kernel_size(inpShape.size() - 2, 1);
419 for (int i = 0; i < axes.size(); i++) {
420 int axis = axes.get<int>(i);
421 CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
422 kernel_size[axis - 2] = inpShape[axis];
425 layerParams.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
428 else if (layer_type == "Slice")
430 if (layerParams.has("steps")) {
431 DictValue steps = layerParams.get("steps");
432 for (int i = 0; i < steps.size(); ++i) {
433 if (steps.get<int>(i) != 1)
434 CV_Error(Error::StsNotImplemented,
435 "Slice layer only supports steps = 1");
440 if (layerParams.has("axes")) {
441 DictValue axes = layerParams.get("axes");
442 for (int i = 1; i < axes.size(); ++i) {
443 CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
445 axis = axes.get<int>(0);
447 layerParams.set("axis", axis);
449 DictValue starts = layerParams.get("starts");
450 DictValue ends = layerParams.get("ends");
451 CV_Assert(starts.size() == ends.size());
453 std::vector<int> begin;
454 std::vector<int> end;
456 begin.resize(axis, 0);
457 end.resize(axis, -1);
460 for (int i = 0; i < starts.size(); ++i)
462 begin.push_back(starts.get<int>(i));
463 int finish = ends.get<int>(i);
464 end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
466 layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
467 layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
469 else if (layer_type == "Split")
471 DictValue splits = layerParams.get("split");
472 const int numSplits = splits.size();
473 CV_Assert(numSplits > 1);
475 std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
476 for (int i = 1; i < splits.size() - 1; ++i)
478 slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1);
480 layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
481 layerParams.type = "Slice";
483 else if (layer_type == "Add" || layer_type == "Sum")
485 if (layer_id.find(node_proto.input(1)) == layer_id.end())
487 Mat blob = getBlob(node_proto, constBlobs, 1);
488 blob = blob.reshape(1, 1);
489 if (blob.total() == 1) {
490 layerParams.type = "Power";
491 layerParams.set("shift", blob.at<float>(0));
494 layerParams.type = "Scale";
495 layerParams.set("bias_term", true);
496 layerParams.blobs.push_back(blob);
500 layerParams.type = "Eltwise";
503 else if (layer_type == "Max")
505 layerParams.type = "Eltwise";
506 layerParams.set("operation", "max");
508 else if (layer_type == "Sub")
510 Mat blob = getBlob(node_proto, constBlobs, 1);
511 if (blob.total() == 1) {
512 layerParams.type = "Power";
513 layerParams.set("shift", -blob.at<float>(0));
516 layerParams.type = "Scale";
517 layerParams.set("has_bias", true);
518 layerParams.blobs.push_back(-1.0f * blob.reshape(1, 1));
521 else if (layer_type == "Div")
523 Mat blob = getBlob(node_proto, constBlobs, 1);
524 CV_Assert_N(blob.type() == CV_32F, blob.total());
525 if (blob.total() == 1)
527 layerParams.set("scale", 1.0f / blob.at<float>(0));
528 layerParams.type = "Power";
532 layerParams.type = "Scale";
533 divide(1.0, blob, blob);
534 layerParams.blobs.push_back(blob);
535 layerParams.set("bias_term", false);
538 else if (layer_type == "Neg")
540 layerParams.type = "Power";
541 layerParams.set("scale", -1);
543 else if (layer_type == "Constant")
545 CV_Assert(node_proto.input_size() == 0);
546 CV_Assert(layerParams.blobs.size() == 1);
547 constBlobs.insert(std::make_pair(layerParams.name, layerParams.blobs[0]));
550 else if (layer_type == "ImageScaler")
552 const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
553 layerParams.erase("scale");
555 if (layerParams.has("bias"))
557 layerParams.type = "Scale";
558 layerParams.blobs.push_back(
559 Mat(Size(1, layerParams.get("bias").size()), CV_32FC1, scale));
561 layerParams.set("bias_term", true);
562 Mat bias(1, layerParams.get("bias").size(), CV_32FC1);
563 for (int j = 0; j < bias.total(); j++) {
564 bias.at<float>(0, j) = layerParams.get("bias").getRealValue(j);
566 layerParams.blobs.push_back(bias);
567 layerParams.erase("bias");
570 layerParams.set("scale", scale);
571 layerParams.type = "Power";
574 else if (layer_type == "Clip")
576 layerParams.type = "ReLU6";
577 replaceLayerParam(layerParams, "min", "min_value");
578 replaceLayerParam(layerParams, "max", "max_value");
581 else if (layer_type == "LeakyRelu")
583 layerParams.type = "ReLU";
584 replaceLayerParam(layerParams, "alpha", "negative_slope");
586 else if (layer_type == "LRN")
588 replaceLayerParam(layerParams, "size", "local_size");
590 else if (layer_type == "InstanceNormalization")
592 if (node_proto.input_size() != 3)
593 CV_Error(Error::StsNotImplemented,
594 "Expected input, scale, bias");
596 layerParams.blobs.resize(4);
597 layerParams.blobs[2] = getBlob(node_proto, constBlobs, 1); // weightData
598 layerParams.blobs[3] = getBlob(node_proto, constBlobs, 2); // biasData
599 layerParams.set("has_bias", true);
600 layerParams.set("has_weight", true);
602 // Get number of channels in input
603 int size = layerParams.blobs[2].total();
604 layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean
605 layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std
607 LayerParams mvnParams;
608 mvnParams.name = layerParams.name + "/MVN";
609 mvnParams.type = "MVN";
610 mvnParams.set("eps", layerParams.get<float>("epsilon"));
611 layerParams.erase("epsilon");
614 int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams);
616 layerId = layer_id.find(node_proto.input(0));
617 CV_Assert(layerId != layer_id.end());
618 dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
620 layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0)));
621 outShapes[mvnParams.name] = outShapes[node_proto.input(0)];
623 //Replace Batch Norm's input to MVN
624 node_proto.set_input(0, mvnParams.name);
625 layerParams.type = "BatchNorm";
627 else if (layer_type == "BatchNormalization")
629 if (node_proto.input_size() != 5)
630 CV_Error(Error::StsNotImplemented,
631 "Expected input, scale, bias, mean and var");
633 layerParams.type = "BatchNorm";
634 replaceLayerParam(layerParams, "epsilon", "eps");
635 replaceLayerParam(layerParams, "spatial", "use_global_stats");
637 Mat meanData = getBlob(node_proto, constBlobs, 3);
638 Mat stdData = getBlob(node_proto, constBlobs, 4);
640 layerParams.blobs.push_back(meanData);
641 layerParams.blobs.push_back(stdData);
643 if (!node_proto.input(1).empty()) {
644 layerParams.set("has_weight", true);
645 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 1)); // weightData
647 layerParams.set("has_weight", false);
650 if (!node_proto.input(2).empty()) {
651 layerParams.set("has_bias", true);
652 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 2)); // biasData
654 layerParams.set("has_bias", false);
657 else if (layer_type == "Gemm")
659 CV_Assert(node_proto.input_size() >= 2);
660 layerParams.type = "InnerProduct";
661 Mat weights = getBlob(node_proto, constBlobs, 1);
663 if (layerParams.has("transB") && !layerParams.get<int>("transB")) {
664 transpose(weights, weights);
667 layerParams.blobs.push_back(weights);
669 if (node_proto.input_size() == 3) {
670 Mat bias = getBlob(node_proto, constBlobs, 2);
671 layerParams.blobs.push_back(bias);
674 layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
675 layerParams.set("bias_term", node_proto.input_size() == 3);
677 else if (layer_type == "MatMul")
679 CV_Assert(node_proto.input_size() == 2);
680 layerParams.type = "InnerProduct";
681 Mat blob = getBlob(node_proto, constBlobs, 1);
682 layerParams.blobs.push_back(blob.t());
683 layerParams.set("bias_term", false);
684 layerParams.set("num_output", layerParams.blobs[0].size[0]);
686 else if (layer_type == "Mul")
688 CV_Assert(node_proto.input_size() == 2);
689 if (layer_id.find(node_proto.input(1)) == layer_id.end()) {
690 Mat blob = getBlob(node_proto, constBlobs, 1);
691 blob = blob.reshape(1, 1);
692 if (blob.total() == 1) {
693 layerParams.set("scale", blob.at<float>(0));
694 layerParams.type = "Power";
697 layerParams.blobs.push_back(blob);
698 layerParams.type = "Scale";
702 layerParams.type = "Eltwise";
703 layerParams.set("operation", "prod");
706 else if (layer_type == "Conv")
708 CV_Assert(node_proto.input_size() >= 2);
709 layerParams.type = "Convolution";
710 for (int j = 1; j < node_proto.input_size(); j++) {
711 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
713 layerParams.set("num_output", layerParams.blobs[0].size[0]);
714 layerParams.set("bias_term", node_proto.input_size() == 3);
716 else if (layer_type == "ConvTranspose")
718 CV_Assert(node_proto.input_size() >= 2);
719 layerParams.type = "Deconvolution";
720 for (int j = 1; j < node_proto.input_size(); j++) {
721 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
723 layerParams.set("num_output", layerParams.blobs[0].size[1] * layerParams.get<int>("group", 1));
724 layerParams.set("bias_term", node_proto.input_size() == 3);
726 if (!layerParams.has("kernel_size"))
727 CV_Error(Error::StsNotImplemented,
728 "Required attribute 'kernel_size' is not present.");
730 if (layerParams.has("output_shape"))
732 const DictValue& outShape = layerParams.get("output_shape");
733 DictValue strides = layerParams.get("stride");
734 DictValue kernel = layerParams.get("kernel_size");
737 std::vector<int> adjust_pads;
738 if (layerParams.has("pad_mode"))
740 padMode = toUpperCase(layerParams.get<String>("pad_mode"));
741 if (padMode != "SAME" && padMode != "VALID")
742 CV_Error(Error::StsError, "Unsupported padding mode " + padMode);
744 for (int i = 0; i < strides.size(); i++)
746 int sz = outShape.get<int>(2 + i);
747 int stride = strides.get<int>(i);
748 adjust_pads.push_back(padMode == "SAME"? (sz - 1) % stride :
749 (sz - kernel.get<int>(i)) % stride);
751 layerParams.set("adj", DictValue::arrayInt(&adjust_pads[0], adjust_pads.size()));
754 else if (layerParams.has("output_padding"))
756 replaceLayerParam(layerParams, "output_padding", "adj");
759 else if (layer_type == "Transpose")
761 layerParams.type = "Permute";
762 replaceLayerParam(layerParams, "perm", "order");
764 CV_Assert(node_proto.input_size() == 1);
765 if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
767 std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), transposed;
768 runLayer(layerParams, inputs, transposed);
769 CV_Assert(transposed.size() == 1);
770 constBlobs.insert(std::make_pair(layerParams.name, transposed[0]));
774 else if (layer_type == "Unsqueeze")
776 CV_Assert(node_proto.input_size() == 1);
777 DictValue axes = layerParams.get("axes");
778 if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
781 Mat input = getBlob(node_proto, constBlobs, 0);
783 std::vector<int> dims;
784 for (int j = 0; j < input.dims; j++) {
785 dims.push_back(input.size[j]);
787 CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
788 for (int j = 0; j < axes.size(); j++) {
789 dims.insert(dims.begin() + axes.getIntValue(j), 1);
792 Mat out = input.reshape(0, dims);
793 constBlobs.insert(std::make_pair(layerParams.name, out));
798 if (axes.size() != 1)
799 CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze");
801 MatShape inpShape = outShapes[node_proto.input(0)];
802 int axis = axes.getIntValue(0);
803 CV_Assert(0 <= axis && axis <= inpShape.size());
804 std::vector<int> outShape = inpShape;
805 outShape.insert(outShape.begin() + axis, 1);
806 layerParams.type = "Reshape";
807 layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
809 else if (layer_type == "Reshape")
811 CV_Assert(node_proto.input_size() == 2 || layerParams.has("shape"));
813 if (node_proto.input_size() == 2) {
814 Mat blob = getBlob(node_proto, constBlobs, 1);
815 CV_Assert(blob.type() == CV_32SC1);
817 layerParams.set("dim", DictValue::arrayInt<int*>(
818 blob.ptr<int>(), blob.total() ));
820 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
821 std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), outputs;
822 runLayer(layerParams, inputs, outputs);
823 constBlobs.insert(std::make_pair(layerParams.name, outputs[0]));
828 DictValue shape = layerParams.get("shape");
829 std::vector<int> dim;
830 for (int j = 0; j < shape.size(); j++) {
831 dim.push_back(shape.getIntValue(j));
834 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
835 Mat input = getBlob(node_proto, constBlobs, 0);
836 Mat out = input.reshape(0, dim);
837 constBlobs.insert(std::make_pair(layerParams.name, out));
840 replaceLayerParam(layerParams, "shape", "dim");
843 else if (layer_type == "Pad")
845 layerParams.type = "Padding";
847 else if (layer_type == "Shape")
849 CV_Assert(node_proto.input_size() == 1);
850 shapeIt = outShapes.find(node_proto.input(0));
851 CV_Assert(shapeIt != outShapes.end());
852 MatShape inpShape = shapeIt->second;
854 Mat shapeMat(inpShape.size(), 1, CV_32S);
855 for (int j = 0; j < inpShape.size(); ++j)
856 shapeMat.at<int>(j) = inpShape[j];
859 constBlobs.insert(std::make_pair(layerParams.name, shapeMat));
862 else if (layer_type == "Gather")
864 CV_Assert(node_proto.input_size() == 2);
865 CV_Assert(layerParams.has("axis"));
866 Mat input = getBlob(node_proto, constBlobs, 0);
867 Mat indexMat = getBlob(node_proto, constBlobs, 1);
868 CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
869 int index = indexMat.at<int>(0);
870 int axis = layerParams.get<int>("axis");
872 std::vector<cv::Range> ranges(input.dims, Range::all());
873 ranges[axis] = Range(index, index + 1);
875 Mat out = input(ranges);
876 constBlobs.insert(std::make_pair(layerParams.name, out));
879 else if (layer_type == "Concat")
881 bool hasVariableInps = false;
882 for (int i = 0; i < node_proto.input_size(); ++i)
884 if (layer_id.find(node_proto.input(i)) != layer_id.end())
886 hasVariableInps = true;
891 if (!hasVariableInps)
893 std::vector<Mat> inputs(node_proto.input_size()), concatenated;
894 for (size_t i = 0; i < inputs.size(); ++i)
896 inputs[i] = getBlob(node_proto, constBlobs, i);
898 runLayer(layerParams, inputs, concatenated);
900 CV_Assert(concatenated.size() == 1);
901 constBlobs.insert(std::make_pair(layerParams.name, concatenated[0]));
905 else if (layer_type == "Upsample")
907 layerParams.type = "Resize";
908 if (layerParams.has("scales"))
911 DictValue scales = layerParams.get("scales");
912 CV_Assert(scales.size() == 4);
913 layerParams.set("zoom_factor_y", scales.getIntValue(2));
914 layerParams.set("zoom_factor_x", scales.getIntValue(3));
919 replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
920 replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
922 replaceLayerParam(layerParams, "mode", "interpolation");
924 else if (layer_type == "LogSoftmax")
926 layerParams.type = "Softmax";
927 layerParams.set("log_softmax", true);
931 for (int j = 0; j < node_proto.input_size(); j++) {
932 if (layer_id.find(node_proto.input(j)) == layer_id.end())
933 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
937 int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
938 for (int i = 0; i < node_proto.output_size(); ++i)
940 layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
943 std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
944 for (int j = 0; j < node_proto.input_size(); j++) {
945 layerId = layer_id.find(node_proto.input(j));
946 if (layerId != layer_id.end()) {
947 dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, j);
948 // Collect input shapes.
949 shapeIt = outShapes.find(node_proto.input(j));
950 CV_Assert(shapeIt != outShapes.end());
951 layerInpShapes.push_back(shapeIt->second);
955 // Compute shape of output blob for this layer.
956 Ptr<Layer> layer = dstNet.getLayer(id);
957 layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
958 for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
960 outShapes[node_proto.output(i)] = layerOutShapes[i];
965 Net readNetFromONNX(const String& onnxFile)
967 ONNXImporter onnxImporter(onnxFile.c_str());
969 onnxImporter.populateNet(net);
973 Net readNetFromONNX(const char* buffer, size_t sizeBuffer)
975 ONNXImporter onnxImporter(buffer, sizeBuffer);
977 onnxImporter.populateNet(net);
981 Net readNetFromONNX(const std::vector<uchar>& buffer)
983 return readNetFromONNX(reinterpret_cast<const char*>(buffer.data()), buffer.size());
986 Mat readTensorFromONNX(const String& path)
988 opencv_onnx::TensorProto tensor_proto = opencv_onnx::TensorProto();
989 std::fstream input(path.c_str(), std::ios::in | std::ios::binary);
990 if (!tensor_proto.ParseFromIstream(&input)) {
991 CV_Error(Error::StsUnsupportedFormat, "Failed to parse data");
993 Mat mat = getMatFromTensor(tensor_proto);
994 releaseONNXTensor(tensor_proto);
998 CV__DNN_INLINE_NS_END