1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
8 #include "../precomp.hpp"
9 #include <opencv2/dnn/shape_utils.hpp>
20 #if defined(__GNUC__) && __GNUC__ >= 5
21 #pragma GCC diagnostic push
22 #pragma GCC diagnostic ignored "-Wsuggest-override"
24 #include "opencv-onnx.pb.h"
25 #if defined(__GNUC__) && __GNUC__ >= 5
26 #pragma GCC diagnostic pop
29 #include "onnx_graph_simplifier.hpp"
33 CV__DNN_INLINE_NS_BEGIN
38 opencv_onnx::ModelProto model_proto;
42 LayerInfo(int _layerId, int _outputId) : layerId(_layerId), outputId(_outputId) {}
45 std::map<std::string, Mat> getGraphTensors(
46 const opencv_onnx::GraphProto& graph_proto);
47 Mat getBlob(const opencv_onnx::NodeProto& node_proto, const std::map<std::string, Mat>& constBlobs, int index);
49 LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto);
50 bool isCeilMode(const LayerParams& layerParams);
54 ONNXImporter(const char *onnxFile)
56 std::fstream input(onnxFile, std::ios::in | std::ios::binary);
58 if (!model_proto.ParseFromIstream(&input))
59 CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model");
62 ONNXImporter(const char* buffer, size_t sizeBuffer)
64 struct _Buf : public std::streambuf
66 _Buf(const char* buffer, size_t sizeBuffer)
68 char* p = const_cast<char*>(buffer);
69 setg(p, p, p + sizeBuffer);
73 _Buf buf(buffer, sizeBuffer);
74 std::istream input(&buf);
76 if (!model_proto.ParseFromIstream(&input))
77 CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model from in-memory byte array.");
80 void populateNet(Net dstNet);
83 inline void replaceLayerParam(LayerParams& layerParams, const String& oldKey, const String& newKey)
85 if (layerParams.has(oldKey)) {
86 layerParams.set(newKey, layerParams.get(oldKey));
87 layerParams.erase(oldKey);
91 void releaseONNXTensor(opencv_onnx::TensorProto& tensor_proto)
93 if (!tensor_proto.raw_data().empty()) {
94 delete tensor_proto.release_raw_data();
98 void runLayer(LayerParams& params, const std::vector<Mat>& inputs,
99 std::vector<Mat>& outputs)
101 Ptr<Layer> layer = LayerFactory::createLayerInstance(params.type, params);
102 CV_Assert((bool)layer);
104 std::vector<MatShape> inpShapes(inputs.size());
106 for (size_t i = 0; i < inputs.size(); ++i)
108 inpShapes[i] = shape(inputs[i]);
109 if (i > 0 && ddepth != inputs[i].depth())
110 CV_Error(Error::StsNotImplemented, "Mixed input data types.");
111 ddepth = inputs[i].depth();
114 std::vector<MatShape> outShapes, internalShapes;
115 layer->getMemoryShapes(inpShapes, 0, outShapes, internalShapes);
117 std::vector<Mat> internals(internalShapes.size());
118 outputs.resize(outShapes.size());
119 for (size_t i = 0; i < outShapes.size(); ++i)
120 outputs[i].create(outShapes[i], ddepth);
121 for (size_t i = 0; i < internalShapes.size(); ++i)
122 internals[i].create(internalShapes[i], ddepth);
124 layer->finalize(inputs, outputs);
125 layer->forward(inputs, outputs, internals);
128 std::map<std::string, Mat> ONNXImporter::getGraphTensors(
129 const opencv_onnx::GraphProto& graph_proto)
131 opencv_onnx::TensorProto tensor_proto;
132 std::map<std::string, Mat> layers_weights;
134 for (int i = 0; i < graph_proto.initializer_size(); i++)
136 tensor_proto = graph_proto.initializer(i);
137 Mat mat = getMatFromTensor(tensor_proto);
138 releaseONNXTensor(tensor_proto);
139 layers_weights.insert(std::make_pair(tensor_proto.name(), mat));
141 return layers_weights;
144 static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protobuf::int64>& src) {
145 std::vector<int32_t> dst(src.size());
146 convertInt64ToInt32(src, dst, src.size());
147 return DictValue::arrayInt(&dst[0], src.size());
150 LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto)
153 for(int i = 0; i < node_proto.attribute_size(); i++)
155 opencv_onnx::AttributeProto attribute_proto = node_proto.attribute(i);
156 std::string attribute_name = attribute_proto.name();
158 if(attribute_name == "kernel_shape")
160 CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
161 lp.set("kernel_size", parse(attribute_proto.ints()));
163 else if(attribute_name == "strides")
165 CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
166 lp.set("stride", parse(attribute_proto.ints()));
168 else if(attribute_name == "pads")
170 if (node_proto.op_type() == "Pad")
173 // Paddings are in order begin0, begin1, .. beginN, end0, end1, ..., endN.
174 // We need to shuffle it to begin0, end0, begin1, end1, ...
175 CV_Assert(attribute_proto.ints_size() % 2 == 0);
176 const int dims = attribute_proto.ints_size() / 2;
177 std::vector<int32_t> paddings;
178 paddings.reserve(attribute_proto.ints_size());
179 for (int i = 0; i < dims; ++i)
181 paddings.push_back(attribute_proto.ints(i));
182 paddings.push_back(attribute_proto.ints(dims + i));
184 lp.set("paddings", DictValue::arrayInt(&paddings[0], paddings.size()));
188 // Convolution or pooling.
189 CV_Assert(attribute_proto.ints_size() == 4 || attribute_proto.ints_size() == 6);
190 lp.set("pad", parse(attribute_proto.ints()));
193 else if(attribute_name == "auto_pad")
195 if (attribute_proto.s() == "SAME_UPPER" || attribute_proto.s() == "SAME_LOWER") {
196 lp.set("pad_mode", "SAME");
198 else if (attribute_proto.s() == "VALID") {
199 lp.set("pad_mode", "VALID");
202 else if(attribute_name == "dilations")
204 CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
205 lp.set("dilation", parse(attribute_proto.ints()));
207 else if (attribute_proto.has_i())
209 ::google::protobuf::int64 src = attribute_proto.i();
210 if (src < std::numeric_limits<int32_t>::min() || src > std::numeric_limits<int32_t>::max())
211 CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
213 lp.set(attribute_name, saturate_cast<int32_t>(src));
215 else if (attribute_proto.has_f())
217 lp.set(attribute_name, attribute_proto.f());
219 else if (attribute_proto.has_s())
221 lp.set(attribute_name, attribute_proto.s());
223 else if (attribute_proto.floats_size() > 0)
225 lp.set(attribute_name, DictValue::arrayReal(
226 attribute_proto.floats().data(), attribute_proto.floats_size()));
228 else if (attribute_proto.ints_size() > 0)
230 lp.set(attribute_proto.name(), parse(attribute_proto.ints()));
232 else if (attribute_proto.has_t())
234 opencv_onnx::TensorProto tensor = attribute_proto.t();
235 Mat blob = getMatFromTensor(tensor);
236 lp.blobs.push_back(blob);
238 else if (attribute_proto.has_g() || attribute_proto.strings_size() > 0 ||
239 attribute_proto.tensors_size() > 0 || attribute_proto.graphs_size() > 0)
241 CV_Error(Error::StsNotImplemented, "Unexpected attribute type");
244 CV_Error(Error::StsNotImplemented, "Unsupported attribute type");
249 Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto,
250 const std::map<std::string, Mat>& constBlobs, int index)
252 CV_Assert(index < node_proto.input_size());
253 std::map<std::string, Mat>::const_iterator constBlob;
254 constBlob = constBlobs.find(node_proto.input(index));
255 if (constBlob == constBlobs.end()) {
256 CV_Error(Error::StsObjectNotFound,
257 "Blob " + node_proto.input(index) + " not found in const blobs");
259 return constBlob->second;
262 void ONNXImporter::populateNet(Net dstNet)
264 CV_Assert(model_proto.has_graph());
265 opencv_onnx::GraphProto graph_proto = model_proto.graph();
267 simplifySubgraphs(graph_proto);
269 std::map<std::string, Mat> constBlobs = getGraphTensors(graph_proto);
270 // List of internal blobs shapes.
271 std::map<std::string, MatShape> outShapes;
272 // Add all the inputs shapes. It includes as constant blobs as network's inputs shapes.
273 for (int i = 0; i < graph_proto.input_size(); ++i)
275 opencv_onnx::ValueInfoProto valueInfoProto = graph_proto.input(i);
276 CV_Assert(valueInfoProto.has_type());
277 opencv_onnx::TypeProto typeProto = valueInfoProto.type();
278 CV_Assert(typeProto.has_tensor_type());
279 opencv_onnx::TypeProto::Tensor tensor = typeProto.tensor_type();
280 CV_Assert(tensor.has_shape());
281 opencv_onnx::TensorShapeProto tensorShape = tensor.shape();
283 MatShape inpShape(tensorShape.dim_size());
284 for (int j = 0; j < inpShape.size(); ++j)
286 inpShape[j] = tensorShape.dim(j).dim_value();
288 outShapes[valueInfoProto.name()] = inpShape;
291 std::string framework_name;
292 if (model_proto.has_producer_name()) {
293 framework_name = model_proto.producer_name();
296 // create map with network inputs (without const blobs)
297 std::map<std::string, LayerInfo> layer_id;
298 std::map<std::string, LayerInfo>::iterator layerId;
299 std::map<std::string, MatShape>::iterator shapeIt;
300 // fill map: push layer name, layer id and output id
301 std::vector<String> netInputs;
302 for (int j = 0; j < graph_proto.input_size(); j++)
304 const std::string& name = graph_proto.input(j).name();
305 if (constBlobs.find(name) == constBlobs.end()) {
306 netInputs.push_back(name);
307 layer_id.insert(std::make_pair(name, LayerInfo(0, netInputs.size() - 1)));
310 dstNet.setInputsNames(netInputs);
312 int layersSize = graph_proto.node_size();
313 LayerParams layerParams;
314 opencv_onnx::NodeProto node_proto;
316 for(int li = 0; li < layersSize; li++)
318 node_proto = graph_proto.node(li);
319 layerParams = getLayerParams(node_proto);
320 CV_Assert(node_proto.output_size() >= 1);
321 layerParams.name = node_proto.output(0);
323 std::string layer_type = node_proto.op_type();
324 layerParams.type = layer_type;
327 if (layer_type == "MaxPool")
329 layerParams.type = "Pooling";
330 layerParams.set("pool", "MAX");
331 layerParams.set("ceil_mode", layerParams.has("pad_mode"));
333 else if (layer_type == "AveragePool")
335 layerParams.type = "Pooling";
336 layerParams.set("pool", "AVE");
337 layerParams.set("ceil_mode", layerParams.has("pad_mode"));
338 layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
340 else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || layer_type == "ReduceMean")
342 CV_Assert(node_proto.input_size() == 1);
343 layerParams.type = "Pooling";
344 layerParams.set("pool", layer_type == "GlobalMaxPool"? "MAX" : "AVE");
345 layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
347 if (layer_type == "ReduceMean")
349 if (layerParams.get<int>("keepdims") == 0 || !layerParams.has("axes"))
350 CV_Error(Error::StsNotImplemented, "Unsupported mode of ReduceMean operation.");
352 MatShape inpShape = outShapes[node_proto.input(0)];
353 if (inpShape.size() != 4 && inpShape.size() != 5)
354 CV_Error(Error::StsNotImplemented, "Unsupported input shape of reduce_mean operation.");
356 DictValue axes = layerParams.get("axes");
357 CV_Assert(axes.size() <= inpShape.size() - 2);
358 std::vector<int> kernel_size(inpShape.size() - 2, 1);
359 for (int i = 0; i < axes.size(); i++) {
360 int axis = axes.get<int>(i);
361 CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
362 kernel_size[axis - 2] = inpShape[axis];
365 layerParams.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
368 else if (layer_type == "Slice")
370 if (layerParams.has("steps")) {
371 DictValue steps = layerParams.get("steps");
372 for (int i = 0; i < steps.size(); ++i) {
373 if (steps.get<int>(i) != 1)
374 CV_Error(Error::StsNotImplemented,
375 "Slice layer only supports steps = 1");
380 if (layerParams.has("axes")) {
381 DictValue axes = layerParams.get("axes");
382 for (int i = 1; i < axes.size(); ++i) {
383 CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
385 axis = axes.get<int>(0);
387 layerParams.set("axis", axis);
389 DictValue starts = layerParams.get("starts");
390 DictValue ends = layerParams.get("ends");
391 CV_Assert(starts.size() == ends.size());
393 std::vector<int> begin;
394 std::vector<int> end;
396 begin.resize(axis, 0);
397 end.resize(axis, -1);
400 for (int i = 0; i < starts.size(); ++i)
402 begin.push_back(starts.get<int>(i));
403 int finish = ends.get<int>(i);
404 end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
406 layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
407 layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
409 else if (layer_type == "Split")
411 if (layerParams.has("split"))
413 DictValue splits = layerParams.get("split");
414 const int numSplits = splits.size();
415 CV_Assert(numSplits > 1);
417 std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
418 for (int i = 1; i < splits.size() - 1; ++i)
420 slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1);
422 layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
426 layerParams.set("num_split", node_proto.output_size());
428 layerParams.type = "Slice";
430 else if (layer_type == "Add" || layer_type == "Sum")
432 if (layer_id.find(node_proto.input(1)) == layer_id.end())
434 Mat blob = getBlob(node_proto, constBlobs, 1);
435 blob = blob.reshape(1, 1);
436 if (blob.total() == 1) {
437 layerParams.type = "Power";
438 layerParams.set("shift", blob.at<float>(0));
441 layerParams.type = "Scale";
442 layerParams.set("bias_term", true);
443 layerParams.blobs.push_back(blob);
447 layerParams.type = "Eltwise";
450 else if (layer_type == "Max")
452 layerParams.type = "Eltwise";
453 layerParams.set("operation", "max");
455 else if (layer_type == "Sub")
457 Mat blob = getBlob(node_proto, constBlobs, 1);
458 if (blob.total() == 1) {
459 layerParams.type = "Power";
460 layerParams.set("shift", -blob.at<float>(0));
463 layerParams.type = "Scale";
464 layerParams.set("has_bias", true);
465 layerParams.blobs.push_back(-1.0f * blob.reshape(1, 1));
468 else if (layer_type == "Neg")
470 layerParams.type = "Power";
471 layerParams.set("scale", -1);
473 else if (layer_type == "Constant")
475 CV_Assert(node_proto.input_size() == 0);
476 CV_Assert(layerParams.blobs.size() == 1);
477 constBlobs.insert(std::make_pair(layerParams.name, layerParams.blobs[0]));
480 else if (layer_type == "ImageScaler")
482 const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
483 layerParams.erase("scale");
485 if (layerParams.has("bias"))
487 layerParams.type = "Scale";
488 layerParams.blobs.push_back(
489 Mat(Size(1, layerParams.get("bias").size()), CV_32FC1, scale));
491 layerParams.set("bias_term", true);
492 Mat bias(1, layerParams.get("bias").size(), CV_32FC1);
493 for (int j = 0; j < bias.total(); j++) {
494 bias.at<float>(0, j) = layerParams.get("bias").getRealValue(j);
496 layerParams.blobs.push_back(bias);
497 layerParams.erase("bias");
500 layerParams.set("scale", scale);
501 layerParams.type = "Power";
504 else if (layer_type == "Clip")
506 layerParams.type = "ReLU6";
507 replaceLayerParam(layerParams, "min", "min_value");
508 replaceLayerParam(layerParams, "max", "max_value");
511 else if (layer_type == "LeakyRelu")
513 layerParams.type = "ReLU";
514 replaceLayerParam(layerParams, "alpha", "negative_slope");
516 else if (layer_type == "LRN")
518 replaceLayerParam(layerParams, "size", "local_size");
520 else if (layer_type == "InstanceNormalization")
522 if (node_proto.input_size() != 3)
523 CV_Error(Error::StsNotImplemented,
524 "Expected input, scale, bias");
526 layerParams.blobs.resize(4);
527 layerParams.blobs[2] = getBlob(node_proto, constBlobs, 1); // weightData
528 layerParams.blobs[3] = getBlob(node_proto, constBlobs, 2); // biasData
529 layerParams.set("has_bias", true);
530 layerParams.set("has_weight", true);
532 // Get number of channels in input
533 int size = layerParams.blobs[2].total();
534 layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean
535 layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std
537 LayerParams mvnParams;
538 mvnParams.name = layerParams.name + "/MVN";
539 mvnParams.type = "MVN";
540 mvnParams.set("eps", layerParams.get<float>("epsilon"));
541 layerParams.erase("epsilon");
544 int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams);
546 layerId = layer_id.find(node_proto.input(0));
547 CV_Assert(layerId != layer_id.end());
548 dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
550 layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0)));
551 outShapes[mvnParams.name] = outShapes[node_proto.input(0)];
553 //Replace Batch Norm's input to MVN
554 node_proto.set_input(0, mvnParams.name);
555 layerParams.type = "BatchNorm";
557 else if (layer_type == "BatchNormalization")
559 if (node_proto.input_size() != 5)
560 CV_Error(Error::StsNotImplemented,
561 "Expected input, scale, bias, mean and var");
563 layerParams.type = "BatchNorm";
564 replaceLayerParam(layerParams, "epsilon", "eps");
565 replaceLayerParam(layerParams, "spatial", "use_global_stats");
567 Mat meanData = getBlob(node_proto, constBlobs, 3);
568 Mat stdData = getBlob(node_proto, constBlobs, 4);
570 layerParams.blobs.push_back(meanData);
571 layerParams.blobs.push_back(stdData);
573 if (!node_proto.input(1).empty()) {
574 layerParams.set("has_weight", true);
575 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 1)); // weightData
577 layerParams.set("has_weight", false);
580 if (!node_proto.input(2).empty()) {
581 layerParams.set("has_bias", true);
582 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 2)); // biasData
584 layerParams.set("has_bias", false);
587 else if (layer_type == "Gemm")
589 CV_Assert(node_proto.input_size() >= 2);
590 layerParams.type = "InnerProduct";
591 Mat weights = getBlob(node_proto, constBlobs, 1);
593 if (layerParams.has("transB") && !layerParams.get<int>("transB")) {
594 transpose(weights, weights);
597 layerParams.blobs.push_back(weights);
599 if (node_proto.input_size() == 3) {
600 Mat bias = getBlob(node_proto, constBlobs, 2);
601 layerParams.blobs.push_back(bias);
604 layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
605 layerParams.set("bias_term", node_proto.input_size() == 3);
607 else if (layer_type == "MatMul")
609 CV_Assert(node_proto.input_size() == 2);
610 layerParams.type = "InnerProduct";
611 Mat blob = getBlob(node_proto, constBlobs, 1);
612 layerParams.blobs.push_back(blob.t());
613 layerParams.set("bias_term", false);
614 layerParams.set("num_output", layerParams.blobs[0].size[0]);
616 else if (layer_type == "Mul" || layer_type == "Div")
618 CV_Assert(node_proto.input_size() == 2);
620 bool isDiv = layer_type == "Div";
622 bool haveVariables = false;
623 for (int i = 0; i < 2; ++i)
625 if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
628 haveVariables = true;
630 if (constId != -1 && haveVariables)
632 Mat blob = getBlob(node_proto, constBlobs, constId);
633 blob = blob.reshape(1, 1);
634 if (blob.total() == 1) {
635 float coeff = isDiv ? 1.0 / blob.at<float>(0) : blob.at<float>(0);
636 layerParams.set("scale", coeff);
637 layerParams.type = "Power";
641 divide(1.0, blob, blob);
642 layerParams.blobs.push_back(blob);
643 layerParams.type = "Scale";
647 layerParams.type = "Eltwise";
648 layerParams.set("operation", isDiv ? "div" : "prod");
653 Mat inp0 = getBlob(node_proto, constBlobs, 0);
654 Mat inp1 = getBlob(node_proto, constBlobs, 1);
655 if (inp0.size != inp1.size)
656 CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes");
660 divide(inp0, inp1, out);
662 multiply(inp0, inp1, out);
664 out = out.reshape(1, inp0.dims, inp0.size);
665 out.dims = inp0.dims; // to workaround dims == 1
666 constBlobs.insert(std::make_pair(layerParams.name, out));
670 else if (layer_type == "Conv")
672 CV_Assert(node_proto.input_size() >= 2);
673 layerParams.type = "Convolution";
674 for (int j = 1; j < node_proto.input_size(); j++) {
675 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
677 layerParams.set("num_output", layerParams.blobs[0].size[0]);
678 layerParams.set("bias_term", node_proto.input_size() == 3);
680 else if (layer_type == "ConvTranspose")
682 CV_Assert(node_proto.input_size() >= 2);
683 layerParams.type = "Deconvolution";
684 for (int j = 1; j < node_proto.input_size(); j++) {
685 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
687 layerParams.set("num_output", layerParams.blobs[0].size[1] * layerParams.get<int>("group", 1));
688 layerParams.set("bias_term", node_proto.input_size() == 3);
690 if (!layerParams.has("kernel_size"))
691 CV_Error(Error::StsNotImplemented,
692 "Required attribute 'kernel_size' is not present.");
694 if (layerParams.has("output_shape"))
696 const DictValue& outShape = layerParams.get("output_shape");
697 DictValue strides = layerParams.get("stride");
698 DictValue kernel = layerParams.get("kernel_size");
701 std::vector<int> adjust_pads;
702 if (layerParams.has("pad_mode"))
704 padMode = toUpperCase(layerParams.get<String>("pad_mode"));
705 if (padMode != "SAME" && padMode != "VALID")
706 CV_Error(Error::StsError, "Unsupported padding mode " + padMode);
708 for (int i = 0; i < strides.size(); i++)
710 int sz = outShape.get<int>(2 + i);
711 int stride = strides.get<int>(i);
712 adjust_pads.push_back(padMode == "SAME"? (sz - 1) % stride :
713 (sz - kernel.get<int>(i)) % stride);
715 layerParams.set("adj", DictValue::arrayInt(&adjust_pads[0], adjust_pads.size()));
718 else if (layerParams.has("output_padding"))
720 replaceLayerParam(layerParams, "output_padding", "adj");
723 else if (layer_type == "Transpose")
725 layerParams.type = "Permute";
726 replaceLayerParam(layerParams, "perm", "order");
728 CV_Assert(node_proto.input_size() == 1);
729 if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
731 std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), transposed;
732 runLayer(layerParams, inputs, transposed);
733 CV_Assert(transposed.size() == 1);
734 constBlobs.insert(std::make_pair(layerParams.name, transposed[0]));
738 else if (layer_type == "ReduceL2")
740 CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
741 CV_Assert(graph_proto.node_size() > li + 1 && graph_proto.node(li + 1).op_type() == "Div");
743 node_proto = graph_proto.node(li);
744 layerParams.name = node_proto.output(0);
745 layerParams.type = "Normalize";
747 DictValue axes_dict = layerParams.get("axes");
748 if (axes_dict.size() != 1)
749 CV_Error(Error::StsNotImplemented, "Multidimensional reduceL2");
750 int axis = axes_dict.getIntValue(0);
751 layerParams.set("axis",axis);
752 layerParams.set("end_axis", axis);
754 else if (layer_type == "Squeeze")
756 CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
757 DictValue axes_dict = layerParams.get("axes");
758 if (axes_dict.size() != 1)
759 CV_Error(Error::StsNotImplemented, "Multidimensional squeeze");
761 int axis = axes_dict.getIntValue(0);
762 layerParams.set("axis", axis - 1);
763 layerParams.set("end_axis", axis);
764 layerParams.type = "Flatten";
766 else if (layer_type == "Unsqueeze")
768 CV_Assert(node_proto.input_size() == 1);
769 DictValue axes = layerParams.get("axes");
770 if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
773 Mat input = getBlob(node_proto, constBlobs, 0);
775 std::vector<int> dims;
776 for (int j = 0; j < input.dims; j++) {
777 dims.push_back(input.size[j]);
779 CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
780 for (int j = 0; j < axes.size(); j++) {
781 dims.insert(dims.begin() + axes.getIntValue(j), 1);
784 Mat out = input.reshape(0, dims);
785 constBlobs.insert(std::make_pair(layerParams.name, out));
790 if (axes.size() != 1)
791 CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze");
793 MatShape inpShape = outShapes[node_proto.input(0)];
794 int axis = axes.getIntValue(0);
795 CV_Assert(0 <= axis && axis <= inpShape.size());
796 std::vector<int> outShape = inpShape;
797 outShape.insert(outShape.begin() + axis, 1);
798 layerParams.type = "Reshape";
799 layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
801 else if (layer_type == "Reshape")
803 CV_Assert(node_proto.input_size() == 2 || layerParams.has("shape"));
805 if (node_proto.input_size() == 2) {
806 Mat blob = getBlob(node_proto, constBlobs, 1);
807 CV_Assert(blob.type() == CV_32SC1);
809 layerParams.set("dim", DictValue::arrayInt<int*>(
810 blob.ptr<int>(), blob.total() ));
812 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
813 std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), outputs;
814 runLayer(layerParams, inputs, outputs);
815 constBlobs.insert(std::make_pair(layerParams.name, outputs[0]));
820 DictValue shape = layerParams.get("shape");
821 std::vector<int> dim;
822 for (int j = 0; j < shape.size(); j++) {
823 dim.push_back(shape.getIntValue(j));
826 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
827 Mat input = getBlob(node_proto, constBlobs, 0);
828 Mat out = input.reshape(0, dim);
829 constBlobs.insert(std::make_pair(layerParams.name, out));
832 replaceLayerParam(layerParams, "shape", "dim");
835 else if (layer_type == "Pad")
837 layerParams.type = "Padding";
839 else if (layer_type == "Shape")
841 CV_Assert(node_proto.input_size() == 1);
842 shapeIt = outShapes.find(node_proto.input(0));
843 CV_Assert(shapeIt != outShapes.end());
844 MatShape inpShape = shapeIt->second;
846 Mat shapeMat(inpShape.size(), 1, CV_32S);
847 for (int j = 0; j < inpShape.size(); ++j)
848 shapeMat.at<int>(j) = inpShape[j];
851 constBlobs.insert(std::make_pair(layerParams.name, shapeMat));
854 else if (layer_type == "Gather")
856 CV_Assert(node_proto.input_size() == 2);
857 CV_Assert(layerParams.has("axis"));
858 Mat input = getBlob(node_proto, constBlobs, 0);
859 Mat indexMat = getBlob(node_proto, constBlobs, 1);
860 CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
861 int index = indexMat.at<int>(0);
862 int axis = layerParams.get<int>("axis");
864 std::vector<cv::Range> ranges(input.dims, Range::all());
865 ranges[axis] = Range(index, index + 1);
867 Mat out = input(ranges);
868 constBlobs.insert(std::make_pair(layerParams.name, out));
871 else if (layer_type == "Concat")
873 bool hasVariableInps = false;
874 for (int i = 0; i < node_proto.input_size(); ++i)
876 if (layer_id.find(node_proto.input(i)) != layer_id.end())
878 hasVariableInps = true;
883 if (!hasVariableInps)
885 std::vector<Mat> inputs(node_proto.input_size()), concatenated;
886 for (size_t i = 0; i < inputs.size(); ++i)
888 inputs[i] = getBlob(node_proto, constBlobs, i);
890 runLayer(layerParams, inputs, concatenated);
892 CV_Assert(concatenated.size() == 1);
893 constBlobs.insert(std::make_pair(layerParams.name, concatenated[0]));
897 else if (layer_type == "Upsample")
899 layerParams.type = "Resize";
900 if (layerParams.has("scales"))
903 DictValue scales = layerParams.get("scales");
904 CV_Assert(scales.size() == 4);
905 layerParams.set("zoom_factor_y", scales.getIntValue(2));
906 layerParams.set("zoom_factor_x", scales.getIntValue(3));
911 replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
912 replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
914 replaceLayerParam(layerParams, "mode", "interpolation");
916 if (layerParams.get<String>("interpolation") == "linear" && framework_name == "pytorch") {
917 layerParams.type = "Resize";
918 Mat scales = getBlob(node_proto, constBlobs, 1);
919 CV_Assert(scales.total() == 4);
920 layerParams.set("interpolation", "opencv_linear");
921 layerParams.set("zoom_factor_y", scales.at<float>(2));
922 layerParams.set("zoom_factor_x", scales.at<float>(3));
925 else if (layer_type == "LogSoftmax")
927 layerParams.type = "Softmax";
928 layerParams.set("log_softmax", true);
932 for (int j = 0; j < node_proto.input_size(); j++) {
933 if (layer_id.find(node_proto.input(j)) == layer_id.end())
934 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
938 int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
939 for (int i = 0; i < node_proto.output_size(); ++i)
941 layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
944 std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
945 for (int j = 0; j < node_proto.input_size(); j++) {
946 layerId = layer_id.find(node_proto.input(j));
947 if (layerId != layer_id.end()) {
948 dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, j);
949 // Collect input shapes.
950 shapeIt = outShapes.find(node_proto.input(j));
951 CV_Assert(shapeIt != outShapes.end());
952 layerInpShapes.push_back(shapeIt->second);
956 // Compute shape of output blob for this layer.
957 Ptr<Layer> layer = dstNet.getLayer(id);
958 layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
959 for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
961 outShapes[node_proto.output(i)] = layerOutShapes[i];
966 Net readNetFromONNX(const String& onnxFile)
968 ONNXImporter onnxImporter(onnxFile.c_str());
970 onnxImporter.populateNet(net);
974 Net readNetFromONNX(const char* buffer, size_t sizeBuffer)
976 ONNXImporter onnxImporter(buffer, sizeBuffer);
978 onnxImporter.populateNet(net);
982 Net readNetFromONNX(const std::vector<uchar>& buffer)
984 return readNetFromONNX(reinterpret_cast<const char*>(buffer.data()), buffer.size());
987 Mat readTensorFromONNX(const String& path)
989 opencv_onnx::TensorProto tensor_proto = opencv_onnx::TensorProto();
990 std::fstream input(path.c_str(), std::ios::in | std::ios::binary);
991 if (!tensor_proto.ParseFromIstream(&input)) {
992 CV_Error(Error::StsUnsupportedFormat, "Failed to parse data");
994 Mat mat = getMatFromTensor(tensor_proto);
995 releaseONNXTensor(tensor_proto);
999 CV__DNN_INLINE_NS_END