modified onnx importer to concat const input blobs
[platform/upstream/opencv.git] / modules / dnn / src / onnx / onnx_importer.cpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7
8 #include "../precomp.hpp"
9 #include <opencv2/dnn/shape_utils.hpp>
10
11 #include <opencv2/core/utils/logger.defines.hpp>
12 #undef CV_LOG_STRIP_LEVEL
13 #define CV_LOG_STRIP_LEVEL CV_LOG_LEVEL_DEBUG + 1
14 #include <opencv2/core/utils/logger.hpp>
15
16 #ifdef HAVE_PROTOBUF
17
18 #include <iostream>
19 #include <fstream>
20 #include <string>
21 #include <limits>
22 #include <algorithm>
23
24
25 #if defined(__GNUC__) && __GNUC__ >= 5
26 #pragma GCC diagnostic push
27 #pragma GCC diagnostic ignored "-Wsuggest-override"
28 #endif
29 #include "opencv-onnx.pb.h"
30 #if defined(__GNUC__) && __GNUC__ >= 5
31 #pragma GCC diagnostic pop
32 #endif
33
34 #include "onnx_graph_simplifier.hpp"
35
36 namespace cv {
37 namespace dnn {
38 CV__DNN_EXPERIMENTAL_NS_BEGIN
39
40
41 class ONNXImporter
42 {
43     opencv_onnx::ModelProto model_proto;
44     struct LayerInfo {
45         int layerId;
46         int outputId;
47         LayerInfo(int _layerId = 0, int _outputId = 0) : layerId(_layerId), outputId(_outputId) {}
48     };
49
50     std::map<std::string, Mat> getGraphTensors(
51                                     const opencv_onnx::GraphProto& graph_proto);
52     Mat getBlob(const opencv_onnx::NodeProto& node_proto, int index);
53     Mat getBlob(const std::string& input_name);
54
55     LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto);
56     bool isCeilMode(const LayerParams& layerParams);
57
58     void addConstant(const std::string& name, const Mat& blob);
59     void addLayer(LayerParams& layerParams,
60                   const opencv_onnx::NodeProto& node_proto);
61
62 public:
63
64     ONNXImporter(Net& net, const char *onnxFile)
65         : dstNet(net)
66     {
67         hasDynamicShapes = false;
68         CV_Assert(onnxFile);
69         CV_LOG_DEBUG(NULL, "DNN/ONNX: processing ONNX model from file: " << onnxFile);
70
71         std::fstream input(onnxFile, std::ios::in | std::ios::binary);
72         if (!input)
73         {
74             CV_Error(Error::StsBadArg, cv::format("Can't read ONNX file: %s", onnxFile));
75         }
76
77         if (!model_proto.ParseFromIstream(&input))
78         {
79             CV_Error(Error::StsUnsupportedFormat, cv::format("Failed to parse ONNX model: %s", onnxFile));
80         }
81
82         populateNet();
83     }
84
85     ONNXImporter(Net& net, const char* buffer, size_t sizeBuffer)
86         : dstNet(net)
87     {
88         hasDynamicShapes = false;
89         CV_LOG_DEBUG(NULL, "DNN/ONNX: processing in-memory ONNX model (" << sizeBuffer << " bytes)");
90
91         struct _Buf : public std::streambuf
92         {
93             _Buf(const char* buffer, size_t sizeBuffer)
94             {
95                 char* p = const_cast<char*>(buffer);
96                 setg(p, p, p + sizeBuffer);
97             }
98         };
99
100         _Buf buf(buffer, sizeBuffer);
101         std::istream input(&buf);
102
103         if (!model_proto.ParseFromIstream(&input))
104             CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model from in-memory byte array.");
105
106         populateNet();
107     }
108
109     void populateNet();
110
111 protected:
112     Net& dstNet;
113
114     opencv_onnx::GraphProto graph_proto;
115     std::string framework_name;
116
117     std::map<std::string, Mat> constBlobs;
118
119     std::map<std::string, MatShape> outShapes;  // List of internal blobs shapes.
120     bool hasDynamicShapes;  // Whether the model has inputs with dynamic shapes
121     typedef std::map<std::string, MatShape>::iterator IterShape_t;
122
123     std::map<std::string, LayerInfo> layer_id;
124     typedef std::map<std::string, LayerInfo>::iterator IterLayerId_t;
125
126     void handleNode(const opencv_onnx::NodeProto& node_proto);
127 };
128
129 inline void replaceLayerParam(LayerParams& layerParams, const String& oldKey, const String& newKey)
130 {
131     if (layerParams.has(oldKey)) {
132         layerParams.set(newKey, layerParams.get(oldKey));
133         layerParams.erase(oldKey);
134     }
135 }
136
137 void releaseONNXTensor(opencv_onnx::TensorProto& tensor_proto)
138 {
139     if (!tensor_proto.raw_data().empty()) {
140         delete tensor_proto.release_raw_data();
141     }
142 }
143
144 void runLayer(LayerParams& params, const std::vector<Mat>& inputs,
145               std::vector<Mat>& outputs)
146 {
147     Ptr<Layer> layer = LayerFactory::createLayerInstance(params.type, params);
148     CV_Assert((bool)layer);
149
150     std::vector<MatShape> inpShapes(inputs.size());
151     int ddepth = CV_32F;
152     for (size_t i = 0; i < inputs.size(); ++i)
153     {
154         inpShapes[i] = shape(inputs[i]);
155         if (i > 0 && ddepth != inputs[i].depth())
156             CV_Error(Error::StsNotImplemented, "Mixed input data types.");
157         ddepth = inputs[i].depth();
158     }
159
160     std::vector<MatShape> outShapes, internalShapes;
161     layer->getMemoryShapes(inpShapes, 0, outShapes, internalShapes);
162
163     std::vector<Mat> internals(internalShapes.size());
164     outputs.resize(outShapes.size());
165     for (size_t i = 0; i < outShapes.size(); ++i)
166         outputs[i].create(outShapes[i], ddepth);
167     for (size_t i = 0; i < internalShapes.size(); ++i)
168         internals[i].create(internalShapes[i], ddepth);
169
170     layer->finalize(inputs, outputs);
171     layer->forward(inputs, outputs, internals);
172 }
173
174 std::map<std::string, Mat> ONNXImporter::getGraphTensors(
175                                         const opencv_onnx::GraphProto& graph_proto)
176 {
177   opencv_onnx::TensorProto tensor_proto;
178   std::map<std::string, Mat> layers_weights;
179
180   for (int i = 0; i < graph_proto.initializer_size(); i++)
181   {
182     tensor_proto = graph_proto.initializer(i);
183     Mat mat = getMatFromTensor(tensor_proto);
184     releaseONNXTensor(tensor_proto);
185     layers_weights.insert(std::make_pair(tensor_proto.name(), mat));
186   }
187   return layers_weights;
188 }
189
190 static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protobuf::int64>& src) {
191     std::vector<int32_t> dst(src.size());
192     convertInt64ToInt32(src, dst, src.size());
193     return DictValue::arrayInt(&dst[0], src.size());
194 }
195
196 LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto)
197 {
198     LayerParams lp;
199     for(int i = 0; i < node_proto.attribute_size(); i++)
200     {
201         opencv_onnx::AttributeProto attribute_proto = node_proto.attribute(i);
202         std::string attribute_name = attribute_proto.name();
203
204         if(attribute_name == "kernel_shape")
205         {
206             CV_Assert(attribute_proto.ints_size() == 1 || attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
207             lp.set("kernel_size", parse(attribute_proto.ints()));
208         }
209         else if(attribute_name == "strides")
210         {
211             CV_Assert(attribute_proto.ints_size() == 1 || attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
212             lp.set("stride", parse(attribute_proto.ints()));
213         }
214         else if(attribute_name == "pads")
215         {
216             if (node_proto.op_type() == "Pad")
217             {
218                 // Padding layer.
219                 // Paddings are in order begin0, begin1, .. beginN, end0, end1, ..., endN.
220                 // We need to shuffle it to begin0, end0, begin1, end1, ...
221                 CV_Assert(attribute_proto.ints_size() % 2 == 0);
222                 const int dims = attribute_proto.ints_size() / 2;
223                 std::vector<int32_t> paddings;
224                 paddings.reserve(attribute_proto.ints_size());
225                 for (int i = 0; i < dims; ++i)
226                 {
227                     paddings.push_back(attribute_proto.ints(i));
228                     paddings.push_back(attribute_proto.ints(dims + i));
229                 }
230                 lp.set("paddings", DictValue::arrayInt(&paddings[0], paddings.size()));
231             }
232             else
233             {
234                 // Convolution or pooling.
235                 CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 4 || attribute_proto.ints_size() == 6);
236                 lp.set("pad", parse(attribute_proto.ints()));
237             }
238         }
239         else if(attribute_name == "auto_pad")
240         {
241             if (attribute_proto.s() == "SAME_UPPER" || attribute_proto.s() == "SAME_LOWER") {
242                 lp.set("pad_mode",  "SAME");
243             }
244             else if (attribute_proto.s() == "VALID") {
245                 lp.set("pad_mode", "VALID");
246             }
247         }
248         else if(attribute_name == "dilations")
249         {
250             CV_Assert(attribute_proto.ints_size() == 1 || attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
251             lp.set("dilation", parse(attribute_proto.ints()));
252         }
253         else if (attribute_proto.has_i())
254         {
255             ::google::protobuf::int64 src = attribute_proto.i();
256             if (src < std::numeric_limits<int32_t>::min() || src > std::numeric_limits<int32_t>::max())
257                 CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
258             else
259                 lp.set(attribute_name, saturate_cast<int32_t>(src));
260         }
261         else if (attribute_proto.has_f())
262         {
263             lp.set(attribute_name, attribute_proto.f());
264         }
265         else if (attribute_proto.has_s())
266         {
267             lp.set(attribute_name, attribute_proto.s());
268         }
269         else if (attribute_proto.floats_size() > 0)
270         {
271             lp.set(attribute_name, DictValue::arrayReal(
272                 attribute_proto.floats().data(), attribute_proto.floats_size()));
273         }
274         else if (attribute_proto.ints_size() > 0)
275         {
276             lp.set(attribute_name, parse(attribute_proto.ints()));
277         }
278         else if (attribute_proto.has_t())
279         {
280             opencv_onnx::TensorProto tensor = attribute_proto.t();
281             Mat blob = getMatFromTensor(tensor);
282             lp.blobs.push_back(blob);
283         }
284         else if (attribute_proto.has_g())
285         {
286             CV_Error(Error::StsNotImplemented, cv::format("DNN/ONNX/Attribute[%s]: 'Graph' is not supported", attribute_name.c_str()));
287         }
288         else if (attribute_proto.graphs_size() > 0)
289         {
290             CV_Error(Error::StsNotImplemented,
291                     cv::format("DNN/ONNX/Attribute[%s]: 'Graphs' (%d) in attributes is not supported",
292                             attribute_name.c_str(), attribute_proto.graphs_size())
293             );
294         }
295         else if (attribute_proto.strings_size() > 0)
296         {
297             std::string msg = cv::format("DNN/ONNX/Attribute[%s]: 'Strings' (%d) are not supported",
298                     attribute_name.c_str(), attribute_proto.strings_size());
299             CV_LOG_ERROR(NULL, msg);
300             for (int i = 0; i < attribute_proto.strings_size(); i++)
301             {
302                 CV_LOG_ERROR(NULL, "    Attribute[" << attribute_name << "].string(" << i << ") = '" << attribute_proto.strings(i) << "'");
303             }
304             CV_Error(Error::StsNotImplemented, msg);
305         }
306         else if (attribute_proto.tensors_size() > 0)
307         {
308             CV_Error(Error::StsNotImplemented,
309                     cv::format("DNN/ONNX/Attribute[%s]: 'Tensors' (%d) in attributes are not supported",
310                             attribute_name.c_str(), attribute_proto.tensors_size())
311             );
312         }
313         else
314         {
315             CV_Error(Error::StsNotImplemented, cv::format("DNN/ONNX/Attribute[%s]: unsupported attribute format", attribute_name.c_str()));
316         }
317     }
318     return lp;
319 }
320
321 Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto, int index)
322 {
323     CV_Assert(index < node_proto.input_size());
324     const std::string& input_name = node_proto.input(index);
325     return getBlob(input_name);
326 }
327
328 Mat ONNXImporter::getBlob(const std::string& input_name)
329 {
330     std::map<std::string, Mat>::const_iterator constBlob = constBlobs.find(input_name);
331     if (constBlob == constBlobs.end())
332     {
333         CV_Error(Error::StsBadArg, std::string("Blob ") + input_name + " not found in const blobs");
334     }
335     return constBlob->second;
336 }
337
338 void ONNXImporter::addLayer(LayerParams& layerParams,
339                             const opencv_onnx::NodeProto& node_proto)
340 {
341     int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
342     for (int i = 0; i < node_proto.output_size(); ++i)
343     {
344         layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
345     }
346
347     std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
348     int inpNum = 0;
349     for (int j = 0; j < node_proto.input_size(); j++)
350     {
351         const std::string& input_name = node_proto.input(j);
352         IterLayerId_t layerId = layer_id.find(input_name);
353         if (layerId != layer_id.end()) {
354             dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, inpNum);
355             ++inpNum;
356             // Collect input shapes.
357             IterShape_t shapeIt = outShapes.find(input_name);
358             CV_Assert(shapeIt != outShapes.end());
359             layerInpShapes.push_back(shapeIt->second);
360         }
361     }
362     // Compute shape of output blob for this layer.
363     Ptr<Layer> layer = dstNet.getLayer(id);  // FIXIT: avoid instantiation of layers during the import stage
364     layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
365     for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
366     {
367         outShapes[node_proto.output(i)] = layerOutShapes[i];
368     }
369 }
370
371 void ONNXImporter::addConstant(const std::string& name, const Mat& blob)
372 {
373     constBlobs.insert(std::make_pair(name, blob));
374     outShapes.insert(std::make_pair(name, shape(blob)));
375 }
376
377 void ONNXImporter::populateNet()
378 {
379     CV_Assert(model_proto.has_graph());
380     graph_proto = model_proto.graph();
381
382     std::string framework_version;
383     if (model_proto.has_producer_name())
384         framework_name = model_proto.producer_name();
385     if (model_proto.has_producer_version())
386         framework_version = model_proto.producer_version();
387
388     CV_LOG_INFO(NULL, "DNN/ONNX: loading ONNX"
389             << (model_proto.has_ir_version() ? cv::format(" v%d", (int)model_proto.ir_version()) : cv::String())
390             << " model produced by '" << framework_name << "'"
391             << (framework_version.empty() ? cv::String() : cv::format(":%s", framework_version.c_str()))
392             << ". Number of nodes = " << graph_proto.node_size()
393             << ", inputs = " << graph_proto.input_size()
394             << ", outputs = " << graph_proto.output_size()
395             );
396
397     simplifySubgraphs(graph_proto);
398
399     const int layersSize = graph_proto.node_size();
400     CV_LOG_DEBUG(NULL, "DNN/ONNX: graph simplified to " << layersSize << " nodes");
401
402     constBlobs = getGraphTensors(graph_proto);
403     // Add all the inputs shapes. It includes as constant blobs as network's inputs shapes.
404     for (int i = 0; i < graph_proto.input_size(); ++i)
405     {
406         const opencv_onnx::ValueInfoProto& valueInfoProto = graph_proto.input(i);
407         CV_Assert(valueInfoProto.has_name());
408         CV_Assert(valueInfoProto.has_type());
409         opencv_onnx::TypeProto typeProto = valueInfoProto.type();
410         CV_Assert(typeProto.has_tensor_type());
411         opencv_onnx::TypeProto::Tensor tensor = typeProto.tensor_type();
412         CV_Assert(tensor.has_shape());
413         opencv_onnx::TensorShapeProto tensorShape = tensor.shape();
414
415         MatShape inpShape(tensorShape.dim_size());
416         for (int j = 0; j < inpShape.size(); ++j)
417         {
418             inpShape[j] = tensorShape.dim(j).dim_value();
419             if (!tensorShape.dim(j).dim_param().empty())
420                 hasDynamicShapes = true;
421         }
422         if (!inpShape.empty() && !hasDynamicShapes)
423         {
424             inpShape[0] = std::max(inpShape[0], 1); // It's OK to have undetermined batch size
425         }
426         outShapes[valueInfoProto.name()] = inpShape;
427     }
428
429     // create map with network inputs (without const blobs)
430     // fill map: push layer name, layer id and output id
431     std::vector<String> netInputs;
432     for (int j = 0; j < graph_proto.input_size(); j++)
433     {
434         const std::string& name = graph_proto.input(j).name();
435         if (constBlobs.find(name) == constBlobs.end()) {
436             netInputs.push_back(name);
437             layer_id.insert(std::make_pair(name, LayerInfo(0, netInputs.size() - 1)));
438         }
439     }
440     dstNet.setInputsNames(netInputs);
441
442     for(int li = 0; li < layersSize; li++)
443     {
444         const opencv_onnx::NodeProto& node_proto = graph_proto.node(li);
445         handleNode(node_proto);
446     }
447
448     CV_LOG_DEBUG(NULL, "DNN/ONNX: import completed!");
449 }
450
451 void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
452 {
453     opencv_onnx::NodeProto node_proto = node_proto_;  // TODO FIXIT
454
455     CV_Assert(node_proto.output_size() >= 1);
456     std::string name = node_proto.output(0);
457     std::string layer_type = node_proto.op_type();
458     CV_LOG_DEBUG(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
459             << cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
460     );
461
462     try
463     {
464         // FIXIT not all cases can be repacked into "LayerParams". Importer should handle such cases directly for each "layer_type"
465         LayerParams layerParams = getLayerParams(node_proto);
466
467         layerParams.name = name;
468         layerParams.type = layer_type;
469         layerParams.set("has_dynamic_shapes", hasDynamicShapes);
470
471         if (layer_type == "MaxPool")
472         {
473             layerParams.type = "Pooling";
474             layerParams.set("pool", "MAX");
475             layerParams.set("ceil_mode", layerParams.has("pad_mode"));
476         }
477         else if (layer_type == "AveragePool")
478         {
479             layerParams.type = "Pooling";
480             layerParams.set("pool", "AVE");
481             layerParams.set("ceil_mode", layerParams.has("pad_mode"));
482             layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
483         }
484         else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
485                 layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax")
486         {
487             CV_Assert(node_proto.input_size() == 1);
488             layerParams.type = "Pooling";
489             String pool;
490             if (layer_type == "GlobalMaxPool" || layer_type == "ReduceMax")
491                 pool = "MAX";
492             else if (layer_type == "ReduceSum")
493                 pool = "SUM";
494             else
495                 pool = "AVE";
496             layerParams.set("pool", pool);
497             layerParams.set("global_pooling", !layerParams.has("axes"));
498             if (layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
499             {
500                 MatShape inpShape = outShapes[node_proto.input(0)];
501                 DictValue axes = layerParams.get("axes");
502                 bool keepdims = layerParams.get<int>("keepdims");
503                 MatShape targetShape;
504                 std::vector<bool> shouldDelete(inpShape.size(), false);
505                 for (int i = 0; i < axes.size(); i++) {
506                     int axis = normalize_axis(axes.get<int>(i), inpShape.size());
507                     shouldDelete[axis] = true;
508                 }
509                 for (int axis = 0; axis < inpShape.size(); ++axis){
510                     if (!shouldDelete[axis])
511                         targetShape.push_back(inpShape[axis]);
512                     else if (keepdims)
513                         targetShape.push_back(1);
514                 }
515
516                 if (inpShape.size() == 3 && axes.size() <= 2)
517                 {
518                     int axis = normalize_axis(axes.get<int>(0), inpShape.size());
519                     CV_CheckNE(axis, 0, "");
520
521                     LayerParams reshapeLp;
522                     reshapeLp.name = layerParams.name + "/reshape";
523                     reshapeLp.type = "Reshape";
524                     CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
525                     reshapeLp.set("axis", 0);
526                     reshapeLp.set("num_axes", 1);
527                     int newShape[] = {1, -1};
528                     reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 2));
529
530                     opencv_onnx::NodeProto proto;
531                     proto.add_input(node_proto.input(0));
532                     proto.add_output(reshapeLp.name);
533                     addLayer(reshapeLp, proto);
534
535                     LayerParams avgLp;
536                     avgLp.name = layerParams.name + "/avg";
537                     avgLp.type = "Pooling";
538                     CV_Assert(layer_id.find(avgLp.name) == layer_id.end());
539                     avgLp.set("pool", pool);
540                     if (axes.size() == 2)
541                     {
542                         CV_CheckEQ(normalize_axis(axes.get<int>(0), inpShape.size()), 1, "Unsupported mode");
543                         CV_CheckEQ(normalize_axis(axes.get<int>(1), inpShape.size()), 2, "Unsupported mode");
544                         avgLp.set("global_pooling", true);
545                     }
546                     else
547                     {
548                         avgLp.set(axis == 2 ? "global_pooling_w" : "global_pooling_h", true);
549                         avgLp.set(axis == 2 ? "kernel_h" : "kernel_w", 1);
550                     }
551
552                     node_proto.set_input(0, reshapeLp.name);
553                     node_proto.set_output(0, avgLp.name);
554                     addLayer(avgLp, node_proto);
555                 }
556                 else
557                 {
558                     if (inpShape.size() != 4 && inpShape.size() != 5)
559                         CV_Error(Error::StsNotImplemented, "Unsupported input shape of " + layer_type + " operation.");
560
561                     CV_Assert(axes.size() <= inpShape.size() - 2);
562                     std::vector<int> kernel_size(inpShape.size() - 2, 1);
563                     if (axes.size() == 1 && (normalize_axis(axes.get<int>(0), inpShape.size()) <= 1))
564                     {
565                         int axis = normalize_axis(axes.get<int>(0), inpShape.size());
566                         MatShape newShape = inpShape;
567                         newShape[axis + 1] = total(newShape, axis + 1);
568                         newShape.resize(axis + 2);
569                         newShape.insert(newShape.begin(), 2 - axis, 1);
570
571                         LayerParams reshapeLp;
572                         reshapeLp.type = "Reshape";
573                         reshapeLp.name = layerParams.name + "/reshape";
574                         CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
575                         reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], newShape.size()));
576
577                         node_proto.set_output(0, reshapeLp.name);
578                         addLayer(reshapeLp, node_proto);
579
580                         kernel_size.resize(2);
581                         kernel_size[0] = inpShape[axis];
582                         node_proto.set_input(0, node_proto.output(0));
583                     }
584                     else
585                     {
586                         for (int i = 0; i < axes.size(); i++) {
587                             int axis = normalize_axis(axes.get<int>(i), inpShape.size());
588                             CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
589                             kernel_size[axis - 2] = inpShape[axis];
590                         }
591                     }
592
593                     LayerParams poolLp = layerParams;
594                     poolLp.name = layerParams.name + "/avg";
595                     CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
596                     poolLp.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
597
598                     node_proto.set_output(0, poolLp.name);
599                     addLayer(poolLp, node_proto);
600                 }
601
602                 layerParams.type = "Reshape";
603                 layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
604
605                 node_proto.set_input(0, node_proto.output(0));
606                 node_proto.set_output(0, layerParams.name);
607             }
608             else if (!layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
609             {
610                 CV_CheckEQ(layerParams.get<int>("keepdims"), 0, "layer only supports keepdims = false");
611                 LayerParams reshapeLp;
612                 reshapeLp.name = layerParams.name + "/reshape";
613                 reshapeLp.type = "Reshape";
614                 CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
615                 int newShape[] = {1, 1, 1, -1};
616                 reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 4));
617
618                 opencv_onnx::NodeProto proto;
619                 proto.add_input(node_proto.input(0));
620                 proto.add_output(reshapeLp.name);
621                 addLayer(reshapeLp, proto);
622
623                 LayerParams poolLp = layerParams;
624                 poolLp.name = layerParams.name + "/pool";
625                 CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
626
627                 node_proto.set_input(0, reshapeLp.name);
628                 node_proto.set_output(0, poolLp.name);
629                 addLayer(poolLp, node_proto);
630
631                 layerParams.type = "Reshape";
632                 int targetShape[] = {1};
633                 layerParams.set("dim", DictValue::arrayInt(&targetShape[0], 1));
634
635                 node_proto.set_input(0, node_proto.output(0));
636                 node_proto.set_output(0, layerParams.name);
637             }
638         }
639         else if (layer_type == "Slice")
640         {
641             int axis = 0;
642             std::vector<int> begin;
643             std::vector<int> end;
644             std::vector<int> steps;
645             int inp_size = node_proto.input_size();
646
647             if (inp_size == 1)
648             {
649                 if (layerParams.has("axes")) {
650                     DictValue axes = layerParams.get("axes");
651                     for (int i = 1; i < axes.size(); ++i) {
652                         CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
653                     }
654                     axis = axes.get<int>(0);
655                 }
656
657                 DictValue starts = layerParams.get("starts");
658                 DictValue ends = layerParams.get("ends");
659                 CV_Assert(starts.size() == ends.size());
660
661                 if (axis > 0) {
662                     begin.resize(axis, 0);
663                     end.resize(axis, -1);
664                 }
665                 for (int i = 0; i < starts.size(); ++i)
666                 {
667                     begin.push_back(starts.get<int>(i));
668                     int finish = ends.get<int>(i);
669                     end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
670                 }
671             } else { // inp_size > 1
672                 CV_Assert(inp_size >= 3);
673                 for (int i = 1; i < inp_size; i++) {
674                     CV_Assert(constBlobs.find(node_proto.input(i)) != constBlobs.end());
675                 }
676                 Mat start_blob = getBlob(node_proto, 1);
677                 Mat end_blob   = getBlob(node_proto, 2);
678                 CV_Assert(start_blob.total() == end_blob.total());
679
680                 if (inp_size > 3) {
681                     Mat axes_blob = getBlob(node_proto, 3);
682                     const int* axes = (int*)axes_blob.data;
683                     for (int i = 1; i < axes_blob.total(); ++i) {
684                         CV_Assert(axes[i - 1] == axes[i] - 1);
685                     }
686                     axis = axes[0];
687                 }
688
689                 const int* starts = start_blob.ptr<int>();
690                 const int* ends   = end_blob.ptr<int>();
691                 if (axis > 0) {
692                     begin.resize(axis, 0);
693                     end.resize(axis, -1);
694                 }
695                 std::copy(starts, starts + start_blob.total(), std::back_inserter(begin));
696                 for (int i = 0; i < end_blob.total(); ++i)
697                 {
698                     int finish = ends[i];
699                     end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
700                 }
701
702                 if (inp_size == 5) {
703                     CV_Assert(constBlobs.find(node_proto.input(4)) != constBlobs.end());
704                     Mat step_blob = getBlob(node_proto, 4);
705                     const int* steps_ptr = step_blob.ptr<int>();
706
707                     if (axis > 0)
708                         steps.resize(axis, 1);
709
710                     std::copy(steps_ptr, steps_ptr + step_blob.total(), std::back_inserter(steps));
711
712                     // Very strange application for Slice op with tensor reversing.
713                     // We just workaround it for 2d constants.
714                     if (constBlobs.find(node_proto.input(0)) != constBlobs.end() &&
715                         axis == 0 &&
716                         start_blob.at<int>(0) == -1 && step_blob.at<int>(0) == -1 &&
717                         end_blob.at<int>(0) == std::numeric_limits<int32_t>::min())
718                     {
719                         Mat inp = getBlob(node_proto, 0);
720                         if (inp.dims == 2)
721                         {
722                             Mat flipped;
723                             flip(inp, flipped, 0);
724                             addConstant(layerParams.name, flipped);
725                             return;
726                         }
727                     }
728                 }
729             }
730             layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
731             layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
732             layerParams.set("axis", axis);
733
734             if (!steps.empty())
735                 layerParams.set("steps", DictValue::arrayInt(&steps[0], steps.size()));
736
737             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
738             {
739                 Mat inp = getBlob(node_proto, 0);
740                 std::vector<Mat> inputs, sliced;
741                 inputs.push_back(inp);
742                 runLayer(layerParams, inputs, sliced);
743                 CV_Assert(sliced.size() == 1);
744                 addConstant(layerParams.name, sliced[0]);
745                 return;
746             }
747         }
748         else if (layer_type == "Split")
749         {
750             if (layerParams.has("split"))
751             {
752                 DictValue splits = layerParams.get("split");
753                 const int numSplits = splits.size();
754                 CV_Assert(numSplits > 1);
755
756                 std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
757                 for (int i = 1; i < splits.size() - 1; ++i)
758                 {
759                     slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1);
760                 }
761                 layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
762             }
763             else
764             {
765                 layerParams.set("num_split", node_proto.output_size());
766             }
767             layerParams.type = "Slice";
768         }
769         else if (layer_type == "Add" || layer_type == "Sum" || layer_type == "Sub")
770         {
771             bool isSub = layer_type == "Sub";
772             CV_CheckEQ(node_proto.input_size(), 2, "");
773             bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end();
774             bool is_const_1 = layer_id.find(node_proto.input(1)) == layer_id.end();
775             if (is_const_0 && is_const_1)
776             {
777                 Mat blob_0 = getBlob(node_proto, 0);
778                 Mat blob_1 = getBlob(node_proto, 1);
779                 CV_Assert(blob_0.size == blob_1.size);
780                 Mat output = isSub ? (blob_0 - blob_1) : (blob_0 + blob_1);
781                 addConstant(layerParams.name, output);
782                 return;
783             }
784             else if (is_const_0 || is_const_1)
785             {
786                 int const_blob_id = is_const_0 ? 0 : 1;
787                 Mat blob = getBlob(node_proto, const_blob_id);
788                 int blob_total = blob.total();
789                 if (blob_total == 1) {
790                     layerParams.type = "Power";
791                     layerParams.set("shift", (isSub ? -1 : 1) * blob.at<float>(0));
792                 }
793                 else {
794                     MatShape inpShape = outShapes[node_proto.input(1 - const_blob_id)];
795                     if (shape(blob) == inpShape)
796                     {
797                         LayerParams constParams;
798                         constParams.name = layerParams.name + "/const";
799                         constParams.type = "Const";
800                         constParams.blobs.push_back((isSub ? -1 : 1) * blob);
801                         int id = dstNet.addLayer(constParams.name, constParams.type, constParams);
802                         layer_id.insert(std::make_pair(constParams.name, LayerInfo(id, 0)));
803                         outShapes[constParams.name] = shape(blob);
804
805                         layerParams.type = "Eltwise";
806                         node_proto.set_input(const_blob_id, constParams.name);
807                     }
808                     else
809                     {
810                         layerParams.type = "Scale";
811                         layerParams.set("bias_term", true);
812                         int axis = 1;
813                         for (int i = 0; i < graph_proto.initializer_size(); i++)
814                         {
815                             opencv_onnx::TensorProto tensor_proto = graph_proto.initializer(i);
816                             if (tensor_proto.name() == node_proto.input(const_blob_id))
817                             {
818                                 axis = inpShape.size() - tensor_proto.dims_size();
819                                 break;
820                             }
821                         }
822                         layerParams.set("axis", axis);
823                         blob = blob.reshape(1, 1);
824                         layerParams.blobs.push_back((isSub ? -1 : 1) * blob);
825                     }
826                 }
827             }
828             else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)])
829             {
830                 layerParams.type = "Eltwise";
831                 if (isSub)
832                 {
833                     static float subCoeffs[] = {1.f, -1.f};
834                     layerParams.set("coeff", DictValue::arrayReal<float*>(subCoeffs, 2));
835                 }
836             }
837             else
838             {
839                 if (isSub)
840                 {
841                     LayerParams powerParams;
842                     powerParams.name = layerParams.name + "/neg";
843                     powerParams.type = "Power";
844                     powerParams.set("scale", -1);
845
846                     //Create Power layer
847                     int id = dstNet.addLayer(powerParams.name, powerParams.type, powerParams);
848                     //Connect to input
849                     IterLayerId_t layerId = layer_id.find(node_proto.input(1));
850                     CV_Assert(layerId != layer_id.end());
851                     dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
852                     //Add shape
853                     layer_id.insert(std::make_pair(powerParams.name, LayerInfo(id, 0)));
854                     outShapes[powerParams.name] = outShapes[node_proto.input(1)];
855
856                     //Replace input to Power
857                     node_proto.set_input(1, powerParams.name);
858                 }
859                 layerParams.type = "Scale";
860                 layerParams.set("bias_term", true);
861             }
862         }
863         else if (layer_type == "Pow")
864         {
865             if (layer_id.find(node_proto.input(1)) != layer_id.end())
866                 CV_Error(Error::StsNotImplemented, "Unsupported Pow op with variable power");
867
868             Mat blob = getBlob(node_proto, 1);
869             if (blob.total() != 1)
870                 CV_Error(Error::StsNotImplemented, "Pow op supports only scalar power");
871
872             blob.convertTo(blob, CV_32F);
873             layerParams.type = "Power";
874             layerParams.set("power", blob.at<float>(0));
875         }
876         else if (layer_type == "Max")
877         {
878             layerParams.type = "Eltwise";
879             layerParams.set("operation", "max");
880         }
881         else if (layer_type == "Neg")
882         {
883             layerParams.type = "Power";
884             layerParams.set("scale", -1);
885         }
886         else if (layer_type == "Constant")
887         {
888             CV_Assert(node_proto.input_size() == 0);
889             CV_Assert(layerParams.blobs.size() == 1);
890             addConstant(layerParams.name, layerParams.blobs[0]);
891             return;
892         }
893         else if (layer_type == "LSTM")
894         {
895             LayerParams lstmParams = layerParams;
896             lstmParams.name += "/lstm";
897
898             // https://pytorch.org/docs/stable/nn.html#lstm
899             CV_Assert(node_proto.input_size() == 7);
900             Mat Wx = getBlob(node_proto, 1);
901             Mat Wh = getBlob(node_proto, 2);
902             Mat b = getBlob(node_proto, 3);
903             CV_CheckEQ(countNonZero(getBlob(node_proto, 5)), 0, "Unsupported non zero initial_h");
904             CV_CheckEQ(countNonZero(getBlob(node_proto, 6)), 0, "Unsupported non zero initial_c");
905             b = b.reshape(1, b.size[0]);
906
907             const int numHidden = lstmParams.get<int>("hidden_size");
908             const int numDirs = Wx.size[0];  // Is 1 for forward only and 2 for bidirectional LSTM.
909             const int numFeatures = Wx.size[2];
910             Mat bx = b.colRange(0, b.cols / 2);
911             Mat bh = b.colRange(b.cols / 2, b.cols);
912             b = bx + bh;
913
914             // IFGO->IGFO
915             for (int k = 0; k < numDirs; ++k)
916             {
917                 float* WxData = Wx.ptr<float>(k);
918                 float* WhData = Wh.ptr<float>(k);
919                 float* biasData = b.ptr<float>(k);
920                 for (int j = 0; j < numHidden; ++j)
921                 {
922                     for (int i = 0; i < numFeatures; ++i)
923                     {
924                         std::swap(WxData[(numHidden + j) * numFeatures + i],
925                                   WxData[(numHidden * 2 + j) * numFeatures + i]);
926                     }
927                     for (int i = 0; i < numHidden; ++i)
928                     {
929                         std::swap(WhData[(numHidden + j) * numHidden + i],
930                                   WhData[(numHidden * 2 + j) * numHidden + i]);
931                     }
932                     std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]);
933                 }
934             }
935             Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
936             Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
937
938             lstmParams.blobs.resize(3);
939             lstmParams.blobs[0] = Wh;
940             lstmParams.blobs[1] = Wx;
941             lstmParams.blobs[2] = b;
942             lstmParams.set("bidirectional", lstmParams.get<String>("direction", "") == "bidirectional");
943
944             node_proto.set_output(0, lstmParams.name);  // set different name so output shapes will be registered on that name
945             addLayer(lstmParams, node_proto);
946
947             MatShape lstmShape = outShapes[node_proto.output(0)];
948
949             // Add fake 1 as it is done in ONNX
950             lstmShape.insert(lstmShape.begin() + 1, 1);
951
952             layerParams.type = "Reshape";
953             layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size()));
954             node_proto.set_input(0, lstmParams.name);  // redirect input to LSTM
955             node_proto.set_output(0, layerParams.name);  // keep origin LSTM's name
956         }
957         else if (layer_type == "ImageScaler")
958         {
959             const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
960             layerParams.erase("scale");
961
962             if (layerParams.has("bias"))
963             {
964                 layerParams.type = "Scale";
965                 layerParams.blobs.push_back(
966                     Mat(Size(1,  layerParams.get("bias").size()), CV_32FC1, scale));
967
968                 layerParams.set("bias_term", true);
969                 Mat bias(1, layerParams.get("bias").size(), CV_32FC1);
970                 for (int j = 0; j < bias.total(); j++) {
971                     bias.at<float>(0, j) = layerParams.get("bias").getRealValue(j);
972                 }
973                 layerParams.blobs.push_back(bias);
974                 layerParams.erase("bias");
975             }
976             else {
977                 layerParams.set("scale", scale);
978                 layerParams.type = "Power";
979             }
980         }
981         else if (layer_type == "Clip")
982         {
983             layerParams.type = "ReLU6";
984             replaceLayerParam(layerParams, "min", "min_value");
985             replaceLayerParam(layerParams, "max", "max_value");
986
987         }
988         else if (layer_type == "LeakyRelu")
989         {
990             layerParams.type = "ReLU";
991             replaceLayerParam(layerParams, "alpha", "negative_slope");
992         }
993         else if (layer_type == "Relu")
994         {
995             layerParams.type = "ReLU";
996         }
997         else if (layer_type == "Elu")
998         {
999             layerParams.type = "ELU";
1000         }
1001         else if (layer_type == "Tanh")
1002         {
1003             layerParams.type = "TanH";
1004         }
1005         else if (layer_type == "PRelu")
1006         {
1007             layerParams.type = "PReLU";
1008             layerParams.blobs.push_back(getBlob(node_proto, 1));
1009         }
1010         else if (layer_type == "LRN")
1011         {
1012             replaceLayerParam(layerParams, "size", "local_size");
1013         }
1014         else if (layer_type == "InstanceNormalization")
1015         {
1016             if (node_proto.input_size() != 3)
1017                 CV_Error(Error::StsNotImplemented,
1018                          "Expected input, scale, bias");
1019
1020             layerParams.blobs.resize(4);
1021             layerParams.blobs[2] = getBlob(node_proto, 1);  // weightData
1022             layerParams.blobs[3] = getBlob(node_proto, 2);  // biasData
1023             layerParams.set("has_bias", true);
1024             layerParams.set("has_weight", true);
1025
1026             // Get number of channels in input
1027             int size = layerParams.blobs[2].total();
1028             layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean
1029             layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std
1030
1031             LayerParams mvnParams;
1032             mvnParams.name = layerParams.name + "/MVN";
1033             mvnParams.type = "MVN";
1034             mvnParams.set("eps", layerParams.get<float>("epsilon"));
1035             layerParams.erase("epsilon");
1036
1037             //Create MVN layer
1038             int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams);
1039             //Connect to input
1040             IterLayerId_t layerId = layer_id.find(node_proto.input(0));
1041             CV_Assert(layerId != layer_id.end());
1042             dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
1043             //Add shape
1044             layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0)));
1045             outShapes[mvnParams.name] = outShapes[node_proto.input(0)];
1046
1047             //Replace Batch Norm's input to MVN
1048             node_proto.set_input(0, mvnParams.name);
1049             layerParams.type = "BatchNorm";
1050         }
1051         else if (layer_type == "BatchNormalization")
1052         {
1053             if (node_proto.input_size() != 5)
1054                 CV_Error(Error::StsNotImplemented,
1055                          "Expected input, scale, bias, mean and var");
1056
1057             layerParams.type = "BatchNorm";
1058             replaceLayerParam(layerParams, "epsilon", "eps");
1059             replaceLayerParam(layerParams, "spatial", "use_global_stats");
1060
1061             Mat meanData = getBlob(node_proto, 3);
1062             Mat stdData =  getBlob(node_proto, 4);
1063
1064             layerParams.blobs.push_back(meanData);
1065             layerParams.blobs.push_back(stdData);
1066
1067             if (!node_proto.input(1).empty()) {
1068                 layerParams.set("has_weight", true);
1069                 layerParams.blobs.push_back(getBlob(node_proto, 1));  // weightData
1070             } else {
1071                 layerParams.set("has_weight", false);
1072             }
1073
1074             if (!node_proto.input(2).empty()) {
1075                 layerParams.set("has_bias", true);
1076                 layerParams.blobs.push_back(getBlob(node_proto, 2)); // biasData
1077             } else {
1078                 layerParams.set("has_bias", false);
1079             }
1080         }
1081         else if (layer_type == "Gemm")
1082         {
1083             CV_Assert(node_proto.input_size() >= 2);
1084             layerParams.type = "InnerProduct";
1085             Mat weights = getBlob(node_proto, 1);
1086             int ind_num_out = 0;
1087             if (layerParams.has("transB") && !layerParams.get<int>("transB")) {
1088                 transpose(weights, weights);
1089                 ind_num_out = 1;
1090             }
1091             layerParams.blobs.push_back(weights);
1092
1093             if (node_proto.input_size() == 3) {
1094                 Mat bias = getBlob(node_proto, 2);
1095                 layerParams.blobs.push_back(bias);
1096             }
1097             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1098             {
1099                 Mat inputBuf = getBlob(node_proto, 0);
1100
1101                 LayerParams constParams;
1102                 constParams.name = node_proto.input(0);
1103                 constParams.type = "Const";
1104                 constParams.blobs.push_back(inputBuf);
1105
1106                 opencv_onnx::NodeProto proto;
1107                 proto.add_output(constParams.name);
1108                 addLayer(constParams, proto);
1109             }
1110
1111             layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
1112             layerParams.set("bias_term", node_proto.input_size() == 3);
1113         }
1114         else if (layer_type == "MatMul")
1115         {
1116             CV_Assert(node_proto.input_size() == 2);
1117             layerParams.type = "InnerProduct";
1118             layerParams.set("bias_term", false);
1119             CV_Assert(constBlobs.find(node_proto.input(0)) == constBlobs.end());
1120             int firstInpDims = outShapes[node_proto.input(0)].size();
1121             int secondInpDims;
1122
1123             if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
1124             {
1125                 Mat blob = getBlob(node_proto, 1);
1126                 secondInpDims = blob.dims;
1127                 layerParams.blobs.push_back(blob.t());
1128                 layerParams.set("num_output", layerParams.blobs[0].size[0]);
1129             } else {
1130                 secondInpDims = outShapes[node_proto.input(1)].size();
1131             }
1132             layerParams.set("axis", firstInpDims - secondInpDims + 1);
1133         }
1134         else if (layer_type == "Mul" || layer_type == "Div")
1135         {
1136             CV_Assert(node_proto.input_size() == 2);
1137
1138             bool isDiv = layer_type == "Div";
1139             int constId = -1;
1140             bool haveVariables = false;
1141             for (int i = 0; i < 2; ++i)
1142             {
1143                 if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
1144                     constId = i;
1145                 else
1146                     haveVariables = true;
1147             }
1148             if (constId != -1 && haveVariables)
1149             {
1150                 Mat blob = getBlob(node_proto, constId);
1151                 blob = blob.reshape(1, 1);
1152                 if (blob.total() == 1) {
1153                     float coeff = isDiv ? 1.0 / blob.at<float>(0) : blob.at<float>(0);
1154                     layerParams.set("scale", coeff);
1155                     layerParams.type = "Power";
1156                 }
1157                 else {
1158                     if (isDiv)
1159                         divide(1.0, blob, blob);
1160                     layerParams.blobs.push_back(blob);
1161                     layerParams.type = "Scale";
1162                 }
1163             }
1164             else if (!haveVariables)
1165             {
1166                 Mat inp0 = getBlob(node_proto, 0);
1167                 Mat inp1 = getBlob(node_proto, 1);
1168
1169                 if (inp0.size != inp1.size && (inp0.total() != 1 || inp1.total() != 1))
1170                     CV_Error_(Error::StsNotImplemented, ("Different shapes case is not supported with constant inputs: %s", layer_type.c_str()));
1171
1172                 if (inp0.total() == 1 && inp1.total() == 1 && inp0.dims != inp1.dims)
1173                 {
1174                     if (inp0.dims < inp1.dims)
1175                     {
1176                         inp0 = inp0.reshape(1, inp1.dims, inp1.size);
1177                         inp0.dims = inp1.dims;
1178                     }
1179                     else
1180                     {
1181                         inp1 = inp1.reshape(1, inp0.dims, inp0.size);
1182                         inp1.dims = inp0.dims;
1183                     }
1184                 }
1185
1186                 Mat out;
1187                 if (inp0.total() != inp1.total())
1188                 {
1189                     if (inp0.total() == 1)
1190                     {
1191                         float coeff = isDiv ? 1.0 / inp0.at<float>(0) : inp0.at<float>(0);
1192                         multiply(inp1, coeff, out);
1193                     }
1194                     else
1195                     {
1196                         float coeff = isDiv ? 1.0 / inp1.at<float>(0) : inp1.at<float>(0);
1197                         multiply(inp0, coeff, out);
1198                     }
1199
1200                 }
1201                 else
1202                 {
1203                     out = isDiv ? inp0 / inp1 : inp0.mul(inp1);
1204                 }
1205
1206                 if (inp0.dims == 1 && inp1.dims == 1)
1207                     out.dims = 1;  // to workaround dims == 1
1208                 addConstant(layerParams.name, out);
1209                 return;
1210             }
1211             else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)])
1212             {
1213                 layerParams.type = "Eltwise";
1214                 layerParams.set("operation", isDiv ? "div" : "prod");
1215             }
1216             else
1217             {
1218                 // Scale layer allocate output with the first input shape
1219                 if (total(outShapes[node_proto.input(0)]) < total(outShapes[node_proto.input(1)]))
1220                 {
1221                     opencv_onnx::NodeProto proto;
1222                     proto.add_input(node_proto.input(1));
1223                     proto.add_input(node_proto.input(0));
1224                     proto.add_output(layerParams.name);
1225                     node_proto = proto;
1226                 }
1227
1228                 if (isDiv)
1229                 {
1230                     LayerParams powerParams;
1231                     powerParams.name = layerParams.name + "/inv";
1232                     powerParams.type = "Power";
1233                     powerParams.set("power", -1);
1234
1235                     //Create Power layer
1236                     int id = dstNet.addLayer(powerParams.name, powerParams.type, powerParams);
1237                     //Connect to input
1238                     IterLayerId_t layerId = layer_id.find(node_proto.input(1));
1239                     CV_Assert(layerId != layer_id.end());
1240                     dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
1241                     //Add shape
1242                     layer_id.insert(std::make_pair(powerParams.name, LayerInfo(id, 0)));
1243                     outShapes[powerParams.name] = outShapes[node_proto.input(1)];
1244
1245                     //Replace input to Power
1246                     node_proto.set_input(1, powerParams.name);
1247                 }
1248                 layerParams.type = "Scale";
1249             }
1250         }
1251         else if (layer_type == "Conv")
1252         {
1253             CV_Assert(node_proto.input_size() >= 2);
1254             layerParams.type = "Convolution";
1255             for (int j = 1; j < node_proto.input_size(); j++) {
1256                 if (constBlobs.find(node_proto.input(j)) != constBlobs.end())
1257                 {
1258                     layerParams.blobs.push_back(getBlob(node_proto, j));
1259                 }
1260             }
1261             int outCn = layerParams.blobs.empty() ? outShapes[node_proto.input(1)][0] : layerParams.blobs[0].size[0];
1262             layerParams.set("num_output", outCn);
1263         }
1264         else if (layer_type == "ConvTranspose")
1265         {
1266             CV_Assert(node_proto.input_size() >= 2);
1267             layerParams.type = "Deconvolution";
1268             for (int j = 1; j < node_proto.input_size(); j++) {
1269                 layerParams.blobs.push_back(getBlob(node_proto, j));
1270             }
1271             layerParams.set("num_output", layerParams.blobs[0].size[1] * layerParams.get<int>("group", 1));
1272             layerParams.set("bias_term", node_proto.input_size() == 3);
1273
1274             if (!layerParams.has("kernel_size"))
1275                 CV_Error(Error::StsNotImplemented,
1276                          "Required attribute 'kernel_size' is not present.");
1277
1278             if (layerParams.has("output_shape"))
1279             {
1280                 const DictValue& outShape = layerParams.get("output_shape");
1281                 DictValue strides = layerParams.get("stride");
1282                 DictValue kernel = layerParams.get("kernel_size");
1283
1284                 String padMode;
1285                 std::vector<int> adjust_pads;
1286                 if (layerParams.has("pad_mode"))
1287                 {
1288                     padMode = toUpperCase(layerParams.get<String>("pad_mode"));
1289                     if (padMode != "SAME" && padMode != "VALID")
1290                         CV_Error(Error::StsError, "Unsupported padding mode " + padMode);
1291
1292                     for (int i = 0; i < strides.size(); i++)
1293                     {
1294                         int sz = outShape.get<int>(2 + i);
1295                         int stride = strides.get<int>(i);
1296                         adjust_pads.push_back(padMode == "SAME"? (sz - 1) % stride :
1297                                                                  (sz - kernel.get<int>(i)) % stride);
1298                     }
1299                     layerParams.set("adj", DictValue::arrayInt(&adjust_pads[0], adjust_pads.size()));
1300                 }
1301             }
1302             else if (layerParams.has("output_padding"))
1303             {
1304                 replaceLayerParam(layerParams, "output_padding", "adj");
1305             }
1306         }
1307         else if (layer_type == "Transpose")
1308         {
1309             layerParams.type = "Permute";
1310             replaceLayerParam(layerParams, "perm", "order");
1311
1312             CV_Assert(node_proto.input_size() == 1);
1313             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1314             {
1315                 std::vector<Mat> inputs(1, getBlob(node_proto, 0)), transposed;
1316                 runLayer(layerParams, inputs, transposed);
1317                 CV_Assert(transposed.size() == 1);
1318                 addConstant(layerParams.name, transposed[0]);
1319                 return;
1320             }
1321         }
1322         else if (layer_type == "Squeeze")
1323         {
1324             CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
1325             DictValue axes_dict = layerParams.get("axes");
1326             MatShape inpShape = outShapes[node_proto.input(0)];
1327
1328             std::vector<bool> maskedAxes(inpShape.size(), false);
1329             for (int i = 0; i < axes_dict.size(); ++i)
1330             {
1331                 int axis = axes_dict.getIntValue(i);
1332                 CV_CheckLE(axis, static_cast<int>(inpShape.size()), "Squeeze axis");
1333                 maskedAxes[axis] = inpShape[axis] == 1;
1334             }
1335             MatShape outShape;
1336             for (int i = 0; i < inpShape.size(); ++i)
1337             {
1338                 if (!maskedAxes[i])
1339                     outShape.push_back(inpShape[i]);
1340             }
1341             if (outShape.size() != inpShape.size())
1342             {
1343                 layerParams.type = "Reshape";
1344                 layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
1345                 if (hasDynamicShapes)
1346                 {
1347                     std::vector<int> dynamicAxes;
1348                     std::vector<int> inputIndices;
1349                     for (int index = 0; index < inpShape.size(); ++index)
1350                     {
1351                         if (!maskedAxes[index])
1352                             inputIndices.push_back(index);
1353                     }
1354                     for (int index = 0; index < outShape.size(); ++index)
1355                         dynamicAxes.push_back(index);
1356                     layerParams.set("dynamic_axes", DictValue::arrayInt(dynamicAxes.data(), dynamicAxes.size()));
1357                     layerParams.set("input_indices", DictValue::arrayInt(inputIndices.data(), inputIndices.size()));
1358                 }
1359             }
1360             else
1361                 layerParams.type = "Identity";
1362
1363             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1364             {
1365                 Mat inp = getBlob(node_proto, 0);
1366                 Mat out = inp.reshape(1, outShape);
1367                 out.dims = outShape.size();  // to workaround dims == 1
1368                 addConstant(layerParams.name, out);
1369                 return;
1370             }
1371         }
1372         else if (layer_type == "Flatten")
1373         {
1374             CV_CheckEQ(node_proto.input_size(), 1, "");
1375             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1376             {
1377                 Mat input = getBlob(node_proto, 0);
1378                 int axis = normalize_axis(layerParams.get<int>("axis", 1), input.dims);
1379
1380                 std::vector<int> out_size(&input.size[0], &input.size[0] + axis);
1381                 out_size.push_back(input.total(axis));
1382                 Mat output = input.reshape(1, out_size);
1383                 addConstant(layerParams.name, output);
1384                 return;
1385             }
1386         }
1387         else if (layer_type == "Unsqueeze")
1388         {
1389             CV_Assert(node_proto.input_size() == 1);
1390             DictValue axes = layerParams.get("axes");
1391             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1392             {
1393                 // Constant input.
1394                 Mat input = getBlob(node_proto, 0);
1395
1396                 std::vector<int> dims;
1397                 for (int j = 0; j < input.dims; j++) {
1398                     dims.push_back(input.size[j]);
1399                 }
1400                 CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
1401                 for (int j = 0; j < axes.size(); j++) {
1402                     dims.insert(dims.begin() + axes.getIntValue(j), 1);
1403                 }
1404
1405                 Mat out = input.reshape(0, dims);
1406                 addConstant(layerParams.name, out);
1407                 return;
1408             }
1409
1410             // Variable input.
1411             if (axes.size() != 1)
1412                 CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze");
1413
1414             MatShape inpShape = outShapes[node_proto.input(0)];
1415             int axis = axes.getIntValue(0);
1416             CV_Assert(0 <= axis && axis <= inpShape.size());
1417             std::vector<int> outShape = inpShape;
1418             outShape.insert(outShape.begin() + axis, 1);
1419             layerParams.type = "Reshape";
1420             layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
1421             if (hasDynamicShapes)
1422             {
1423                 std::vector<int> dynamicAxes;
1424                 std::vector<int> inputIndices;
1425                 for (int index = 0; index < outShape.size(); ++index) {
1426                     if (index != axis)
1427                         dynamicAxes.push_back(index);
1428                 }
1429                 for (int index = 0; index < inpShape.size(); ++index)
1430                     inputIndices.push_back(index);
1431                 layerParams.set("dynamic_axes", DictValue::arrayInt(dynamicAxes.data(), dynamicAxes.size()));
1432                 layerParams.set("input_indices", DictValue::arrayInt(inputIndices.data(), inputIndices.size()));
1433             }
1434         }
1435         else if (layer_type == "Expand")
1436         {
1437             CV_CheckEQ(node_proto.input_size(), 2, "");
1438             const std::string& input0 = node_proto.input(0);
1439             const std::string& input1 = node_proto.input(1);
1440             Mat newShapeMat = getBlob(input1);
1441             MatShape targetShape(newShapeMat.ptr<int>(), newShapeMat.ptr<int>() + newShapeMat.total());
1442
1443             MatShape inpShape;
1444             bool haveVariables = constBlobs.find(input0) == constBlobs.end();
1445             if (haveVariables)
1446             {
1447                 IterShape_t shapeIt = outShapes.find(input0);
1448                 CV_Assert(shapeIt != outShapes.end());
1449                 inpShape = shapeIt->second;
1450             }
1451             else
1452             {
1453                 inpShape = shape(getBlob(input0));
1454             }
1455
1456             String srcName = input0;
1457             // Unsqueeze and repeat along new axis
1458             if (targetShape.size() == inpShape.size() + 1)
1459             {
1460                 for (int i = 0; i < targetShape.size(); i++)
1461                 {
1462                     if (targetShape[i] == -1 && i < inpShape.size())
1463                         targetShape[i] = inpShape[i];
1464                     else if (i < inpShape.size() && targetShape[i] != inpShape[i])
1465                         inpShape.insert(inpShape.begin() + i, 1);
1466                 }
1467                 if (haveVariables)
1468                 {
1469                     LayerParams reshapeLp;
1470                     reshapeLp.name = layerParams.name + "/reshape";
1471                     reshapeLp.type = "Reshape";
1472                     CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
1473                     reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
1474
1475                     opencv_onnx::NodeProto proto;
1476                     proto.add_input(node_proto.input(0));
1477                     proto.add_output(reshapeLp.name);
1478                     addLayer(reshapeLp, proto);
1479                     srcName = reshapeLp.name;
1480                 }
1481             }
1482             CV_CheckEQ(inpShape.size(), targetShape.size(), "Unsupported Expand op with different dims");
1483
1484             std::vector<int> broadcast_axes;
1485             for (int i = 0; i < targetShape.size(); i++)
1486             {
1487                 if (targetShape[i] != inpShape[i])
1488                 {
1489                     if (inpShape[i] == 1)
1490                         broadcast_axes.push_back(i);
1491                     else
1492                         CV_Error(Error::StsError, format("Could not be broadcast by axis: %d", i));
1493                 }
1494             }
1495
1496             if (!haveVariables)
1497             {
1498                 if (broadcast_axes.size() != 1)
1499                     CV_Error(Error::StsNotImplemented, "Expand op doesn't support multiple axes for constant input");
1500
1501                 Mat input = getBlob(node_proto, 0);
1502                 input = input.reshape(0, total(inpShape, 0, broadcast_axes[0]));
1503                 Mat output = cv::repeat(input, 1, targetShape[broadcast_axes[0]]);
1504                 output = output.reshape(0, targetShape);
1505                 addConstant(layerParams.name, output);
1506                 return;
1507             }
1508
1509             if (broadcast_axes.size() == 2 &&
1510                 broadcast_axes[0] == broadcast_axes[1] - 1 && broadcast_axes[1] == inpShape.size() - 1)
1511             {
1512                 LayerParams constParams;
1513                 constParams.name = layerParams.name + "/const";
1514                 CV_Assert(layer_id.find(constParams.name) == layer_id.end());
1515                 constParams.type = "Const";
1516
1517                 Mat inp = Mat::ones(newShapeMat.total(), newShapeMat.ptr<int>(), CV_32F);
1518                 constParams.blobs.push_back(inp);
1519
1520                 opencv_onnx::NodeProto proto;
1521                 proto.add_output(constParams.name);
1522                 addLayer(constParams, proto);
1523
1524                 layerParams.type = "Scale";
1525                 layerParams.set("bias_term", false);
1526                 node_proto.set_input(0, constParams.name);
1527                 node_proto.set_input(1, srcName);
1528             }
1529             else if (broadcast_axes.size() == 1 && broadcast_axes[0] <= 1)
1530             {
1531                 String base_name = layerParams.name + "/copy_";
1532                 std::vector<std::string> input_names;
1533                 for (int j = 0; j < targetShape[broadcast_axes[0]]; j++)
1534                 {
1535                     std::ostringstream ss;
1536                     ss << j;
1537                     LayerParams copyLP;
1538                     copyLP.name = base_name + ss.str();
1539                     copyLP.type = "Identity";
1540                     CV_Assert(layer_id.find(copyLP.name) == layer_id.end());
1541                     input_names.push_back(copyLP.name);
1542
1543                     node_proto.set_input(0, srcName);
1544                     node_proto.set_output(0, copyLP.name);
1545                     addLayer(copyLP, node_proto);
1546                 }
1547                 node_proto.clear_input();
1548                 for (int i = 0; i < input_names.size(); i++)
1549                 {
1550                     node_proto.add_input(input_names[i]);
1551                 }
1552                 layerParams.set("axis", broadcast_axes[0]);
1553                 layerParams.type = "Concat";
1554                 node_proto.set_output(0, layerParams.name);
1555             }
1556             else
1557                 CV_Error(Error::StsNotImplemented, "Unsupported Expand op");
1558         }
1559         else if (layer_type == "Reshape")
1560         {
1561             CV_Assert(node_proto.input_size() == 2 || layerParams.has("shape"));
1562
1563             if (node_proto.input_size() == 2) {
1564                 Mat blob = getBlob(node_proto, 1);
1565                 CV_Assert(blob.type() == CV_32SC1);
1566
1567                 layerParams.set("dim", DictValue::arrayInt<int*>(
1568                             blob.ptr<int>(), blob.total() ));
1569
1570                 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
1571                     std::vector<Mat> inputs(1, getBlob(node_proto, 0)), outputs;
1572                     runLayer(layerParams, inputs, outputs);
1573                     addConstant(layerParams.name, outputs[0]);
1574                     return;
1575                 }
1576             }
1577             else {
1578                 DictValue shape = layerParams.get("shape");
1579                 std::vector<int> dim;
1580                 for (int j = 0; j < shape.size(); j++) {
1581                     dim.push_back(shape.getIntValue(j));
1582                 }
1583
1584                 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
1585                     Mat input = getBlob(node_proto, 0);
1586                     Mat out = input.reshape(0, dim);
1587                     addConstant(layerParams.name, out);
1588                     return;
1589                 }
1590                 replaceLayerParam(layerParams, "shape", "dim");
1591             }
1592         }
1593         else if (layer_type == "Pad")
1594         {
1595             layerParams.type = "Padding";
1596             replaceLayerParam(layerParams, "mode", "type");
1597             if (node_proto.input_size() == 3 || node_proto.input_size() == 2)
1598             {
1599                 // Paddings are in order begin0, begin1, .. beginN, end0, end1, ..., endN.
1600                 // We need to shuffle it to begin0, end0, begin1, end1, ...
1601                 Mat paddings = getBlob(node_proto, 1).reshape(1, 2);
1602                 paddings = paddings.t();
1603                 layerParams.set("paddings", DictValue::arrayInt(paddings.ptr<int>(), paddings.total()));
1604
1605                 if (node_proto.input_size() == 3)
1606                 {
1607                     Mat value = getBlob(node_proto, 2);
1608                     layerParams.set("value", value.at<float>(0));
1609                 }
1610             }
1611         }
1612         else if (layer_type == "Shape")
1613         {
1614             CV_Assert(node_proto.input_size() == 1);
1615             IterShape_t shapeIt = outShapes.find(node_proto.input(0));
1616             CV_Assert(shapeIt != outShapes.end());
1617             const MatShape& inpShape = shapeIt->second;
1618
1619             Mat shapeMat(inpShape.size(), 1, CV_32S);
1620             for (int j = 0; j < inpShape.size(); ++j)
1621                 shapeMat.at<int>(j) = inpShape[j];
1622             shapeMat.dims = 1;
1623
1624             addConstant(layerParams.name, shapeMat);
1625             return;
1626         }
1627         else if (layer_type == "Cast")
1628         {
1629             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1630             {
1631                 Mat blob = getBlob(node_proto, 0);
1632                 int type;
1633                 switch (layerParams.get<int>("to"))
1634                 {
1635                     case opencv_onnx::TensorProto_DataType_FLOAT:   type = CV_32F; break;
1636                     case opencv_onnx::TensorProto_DataType_UINT8:   type = CV_8U; break;
1637                     case opencv_onnx::TensorProto_DataType_UINT16:  type = CV_16U; break;
1638                     case opencv_onnx::TensorProto_DataType_FLOAT16: type = CV_16S; break;
1639                     case opencv_onnx::TensorProto_DataType_INT8:
1640                     case opencv_onnx::TensorProto_DataType_INT16:
1641                     case opencv_onnx::TensorProto_DataType_INT32:
1642                     case opencv_onnx::TensorProto_DataType_INT64:   type = CV_32S; break;
1643                     default: type = blob.type();
1644                 }
1645                 Mat dst;
1646                 blob.convertTo(dst, type);
1647                 dst.dims = blob.dims;
1648                 addConstant(layerParams.name, dst);
1649                 return;
1650             }
1651             else
1652                 layerParams.type = "Identity";
1653         }
1654         else if (layer_type == "ConstantOfShape" || layer_type == "ConstantFill")
1655         {
1656             int depth = CV_32F;
1657             float fill_value;
1658             if (!layerParams.blobs.empty())
1659             {
1660                 CV_Assert(!layerParams.has("value"));
1661                 depth = layerParams.blobs[0].depth();
1662                 Mat floats;
1663                 layerParams.blobs[0].convertTo(floats, CV_32F);
1664                 fill_value = floats.at<float>(0, 0);
1665             }
1666             else
1667                 fill_value = layerParams.get("value", 0);
1668
1669             MatShape inpShape = getBlob(node_proto, 0);
1670             for (int i = 0; i < inpShape.size(); i++)
1671                 CV_CheckGT(inpShape[i], 0, "");
1672             Mat tensor(inpShape.size(), &inpShape[0], depth, Scalar(fill_value));
1673             addConstant(layerParams.name, tensor);
1674             return;
1675         }
1676         else if (layer_type == "Gather")
1677         {
1678             CV_Assert(node_proto.input_size() == 2);
1679             Mat indexMat = getBlob(node_proto, 1);
1680             CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
1681             int index = indexMat.at<int>(0);
1682             int axis = layerParams.get<int>("axis", 0);
1683
1684             if ((constBlobs.find(node_proto.input(0)) != constBlobs.end()))
1685             {
1686                 Mat input = getBlob(node_proto, 0);
1687                 Mat out;
1688                 std::vector<cv::Range> ranges(input.dims, Range::all());
1689                 ranges[axis] = Range(index, index + 1);
1690
1691                 out = input(ranges);
1692                 MatShape outShape = shape(out);
1693                 if (outShape.size() > 1)
1694                 {
1695                     outShape.erase(outShape.begin() + axis);
1696                     out.reshape(0, outShape);
1697                 } else {
1698                     out.dims = 1;
1699                 }
1700                 addConstant(layerParams.name, out);
1701                 return;
1702             }
1703             else
1704             {
1705                 IterShape_t shapeIt = outShapes.find(node_proto.input(0));
1706                 CV_Assert(shapeIt != outShapes.end());
1707                 MatShape inpShape = shapeIt->second;
1708
1709                 LayerParams sliceLp;
1710                 sliceLp.type = "Slice";
1711                 sliceLp.name = inpShape.size() > 1 ? layerParams.name + "/slice" : layerParams.name;
1712                 std::vector<int> begin(inpShape.size(), 0);
1713                 std::vector<int> end(inpShape.size(), -1);
1714                 begin[axis] = index;
1715                 end[axis] = index + 1;
1716
1717                 cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt(begin.data(), begin.size());
1718                 cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt(end.data(), end.size());
1719                 sliceLp.set("begin", paramBegin);
1720                 sliceLp.set("end", paramEnd);
1721                 sliceLp.set("has_dynamic_shapes", hasDynamicShapes);
1722
1723                 if (inpShape.size() > 1)
1724                 {
1725                     opencv_onnx::NodeProto proto;
1726                     proto.add_input(node_proto.input(0));
1727                     proto.add_output(sliceLp.name);
1728                     addLayer(sliceLp, proto);
1729
1730                     inpShape.erase(inpShape.begin() + axis);
1731                     layerParams.type = "Reshape";
1732                     layerParams.set("axis", 0);
1733                     layerParams.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
1734                     if (hasDynamicShapes)
1735                     {
1736                         std::vector<int> dynamicAxes;
1737                         std::vector<int> inputIndices;
1738                         for (int index = 0; index < inpShape.size(); ++index)
1739                             dynamicAxes.push_back(index);
1740                         for (int index = 0; index < inpShape.size(); ++index)
1741                             inputIndices.push_back(index);
1742                         layerParams.set("dynamic_axes", DictValue::arrayInt(dynamicAxes.data(), dynamicAxes.size()));
1743                         layerParams.set("input_indices", DictValue::arrayInt(inputIndices.data(), inputIndices.size()));
1744                     }
1745                     node_proto.set_input(0, sliceLp.name);
1746                 }
1747                 else
1748                 {
1749                     layerParams = sliceLp;
1750                 }
1751             }
1752         }
1753         else if (layer_type == "Concat")
1754         {
1755             bool hasVariableInps = false;
1756             for (int i = 0; i < node_proto.input_size(); ++i)
1757             {
1758                 if (layer_id.find(node_proto.input(i)) != layer_id.end())
1759                 {
1760                     hasVariableInps = true;
1761                     break;
1762                 }
1763             }
1764
1765             if (!hasVariableInps)
1766             {
1767                 std::vector<Mat> inputs(node_proto.input_size()), concatenated;
1768                 // Due constant folding we can get inputs with different number of dimensions
1769                 // Insert the missing dimension to inputs
1770                 MatShape inputShape;
1771                 for (size_t i = 0; i < inputs.size(); ++i)
1772                 {
1773                     inputs[i] = getBlob(node_proto, i);
1774                     if (inputs[i].size.dims() > inputShape.size())
1775                     {
1776                         inputShape = shape(inputs[i]);
1777                     }
1778                 }
1779
1780                 // Concat-1 has default value for axis is 1: https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Concat-1
1781                 int axis = layerParams.get<int>("axis", 1);
1782                 for (size_t i = 0; i < inputs.size(); ++i)
1783                 {
1784                     MatShape targetShape = inputShape;
1785                     targetShape[axis] = shape(inputs[i])[axis];
1786                     CV_CheckEQ(total(targetShape), total(shape(inputs[i])), "");
1787                     inputs[i] = inputs[i].reshape(0, targetShape);
1788                 }
1789                 runLayer(layerParams, inputs, concatenated);
1790
1791                 CV_Assert(concatenated.size() == 1);
1792                 addConstant(layerParams.name, concatenated[0]);
1793                 return;
1794             }
1795             else
1796             {
1797                 for (int i = 0; i < node_proto.input_size(); ++i)
1798                 {
1799                     if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
1800                     {
1801                         LayerParams constParams;
1802                         constParams.name = node_proto.input(i);
1803                         constParams.type = "Const";
1804                         constParams.blobs.push_back(getBlob(node_proto, i));
1805
1806                         opencv_onnx::NodeProto proto;
1807                         proto.add_output(constParams.name);
1808                         addLayer(constParams, proto);
1809                     }
1810                 }
1811             }
1812         }
1813         else if (layer_type == "Resize")
1814         {
1815             for (int i = 1; i < node_proto.input_size(); i++)
1816                 CV_Assert(layer_id.find(node_proto.input(i)) == layer_id.end());
1817
1818             if (layerParams.has("coordinate_transformation_mode"))
1819             {
1820                 String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
1821                 CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");
1822
1823                 layerParams.set("align_corners", interp_mode == "align_corners");
1824                 if (layerParams.get<String>("mode") == "linear")
1825                 {
1826                     layerParams.set("mode", interp_mode == "pytorch_half_pixel" ?
1827                                             "opencv_linear" : "bilinear");
1828                 }
1829             }
1830             if (layerParams.get<String>("mode") == "linear" && framework_name == "pytorch")
1831                 layerParams.set("mode", "opencv_linear");
1832
1833             // input = [X, scales], [X, roi, scales] or [x, roi, scales, sizes]
1834             int foundScaleId = hasDynamicShapes ? node_proto.input_size() - 1
1835                                                 : node_proto.input_size() > 2 ? 2 : 1;
1836
1837             Mat scales = getBlob(node_proto, foundScaleId);
1838             if (scales.total() == 4)
1839             {
1840                 layerParams.set("zoom_factor_y", scales.at<float>(2));
1841                 layerParams.set("zoom_factor_x", scales.at<float>(3));
1842             }
1843             else
1844             {
1845                 const std::string& inputLast = node_proto.input(node_proto.input_size() - 1);
1846                 if (constBlobs.find(inputLast) != constBlobs.end())
1847                 {
1848                     Mat shapes = getBlob(inputLast);
1849                     CV_CheckEQ(shapes.size[0], 4, "");
1850                     CV_CheckEQ(shapes.size[1], 1, "");
1851                     CV_CheckDepth(shapes.depth(), shapes.depth() == CV_32S || shapes.depth() == CV_32F, "");
1852                     if (shapes.depth() == CV_32F)
1853                         shapes.convertTo(shapes, CV_32S);
1854                     layerParams.set("width", shapes.at<int>(3));
1855                     layerParams.set("height", shapes.at<int>(2));
1856                 }
1857             }
1858             replaceLayerParam(layerParams, "mode", "interpolation");
1859         }
1860         else if (layer_type == "Upsample")
1861         {
1862             //fused from Resize Subgraph
1863             if (layerParams.has("coordinate_transformation_mode"))
1864             {
1865                 String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
1866                 CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");
1867
1868                 layerParams.set("align_corners", interp_mode == "align_corners");
1869                 if (layerParams.get<String>("mode") == "linear")
1870                 {
1871                     layerParams.set("mode", interp_mode == "pytorch_half_pixel" ?
1872                                             "opencv_linear" : "bilinear");
1873                 }
1874             }
1875             if (layerParams.get<String>("mode") == "linear" && framework_name == "pytorch")
1876                 layerParams.set("mode", "opencv_linear");
1877
1878             layerParams.type = "Resize";
1879             if (layerParams.has("scales"))
1880             {
1881                 // Pytorch layer
1882                 DictValue scales = layerParams.get("scales");
1883                 CV_Assert(scales.size() == 4);
1884                 layerParams.set("zoom_factor_y", scales.getIntValue(2));
1885                 layerParams.set("zoom_factor_x", scales.getIntValue(3));
1886             }
1887             else if (layerParams.has("height_scale") && layerParams.has("width_scale"))
1888             {
1889                 // Caffe2 layer
1890                 replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
1891                 replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
1892             }
1893             else
1894             {
1895                 // scales as input
1896                 const std::string& input1 = node_proto.input(1);
1897                 if (constBlobs.find(input1) != constBlobs.end())
1898                 {
1899                     Mat scales = getBlob(input1);
1900                     CV_Assert(scales.total() == 4);
1901                     layerParams.set("zoom_factor_y", scales.at<float>(2));
1902                     layerParams.set("zoom_factor_x", scales.at<float>(3));
1903                 }
1904             }
1905             replaceLayerParam(layerParams, "mode", "interpolation");
1906         }
1907         else if (layer_type == "SoftMax" || layer_type == "LogSoftmax")
1908         {
1909             layerParams.type = "Softmax";
1910             layerParams.set("log_softmax", layer_type == "LogSoftmax");
1911         }
1912         else if (layer_type == "DetectionOutput")
1913         {
1914             CV_CheckEQ(node_proto.input_size(), 3, "");
1915             if (constBlobs.find(node_proto.input(2)) != constBlobs.end())
1916             {
1917                 Mat priors = getBlob(node_proto, 2);
1918
1919                 LayerParams constParams;
1920                 constParams.name = layerParams.name + "/priors";
1921                 constParams.type = "Const";
1922                 constParams.blobs.push_back(priors);
1923
1924                 opencv_onnx::NodeProto priorsProto;
1925                 priorsProto.add_output(constParams.name);
1926                 addLayer(constParams, priorsProto);
1927
1928                 node_proto.set_input(2, constParams.name);
1929             }
1930         }
1931         else
1932         {
1933             for (int j = 0; j < node_proto.input_size(); j++) {
1934                 if (layer_id.find(node_proto.input(j)) == layer_id.end())
1935                     layerParams.blobs.push_back(getBlob(node_proto, j));
1936             }
1937         }
1938         addLayer(layerParams, node_proto);
1939     }
1940     catch (const cv::Exception& e)
1941     {
1942         CV_LOG_ERROR(NULL, "DNN/ONNX: ERROR during processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
1943                 << cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
1944         );
1945         for (int i = 0; i < node_proto.input_size(); i++)
1946         {
1947             CV_LOG_INFO(NULL, "    Input[" << i << "] = '" << node_proto.input(i) << "'");
1948         }
1949         for (int i = 0; i < node_proto.output_size(); i++)
1950         {
1951             CV_LOG_INFO(NULL, "    Output[" << i << "] = '" << node_proto.output(i) << "'");
1952         }
1953         CV_Error(Error::StsError, cv::format("Node [%s]:(%s) parse error: %s", layer_type.c_str(), name.c_str(), e.what()));
1954     }
1955 }
1956
1957 Net readNetFromONNX(const String& onnxFile)
1958 {
1959     Net net;
1960     ONNXImporter onnxImporter(net, onnxFile.c_str());
1961     return net;
1962 }
1963
1964 Net readNetFromONNX(const char* buffer, size_t sizeBuffer)
1965 {
1966     Net net;
1967     ONNXImporter onnxImporter(net, buffer, sizeBuffer);
1968     return net;
1969 }
1970
1971 Net readNetFromONNX(const std::vector<uchar>& buffer)
1972 {
1973     return readNetFromONNX(reinterpret_cast<const char*>(buffer.data()), buffer.size());
1974 }
1975
1976 Mat readTensorFromONNX(const String& path)
1977 {
1978     std::fstream input(path.c_str(), std::ios::in | std::ios::binary);
1979     if (!input)
1980     {
1981         CV_Error(Error::StsBadArg, cv::format("Can't read ONNX file: %s", path.c_str()));
1982     }
1983
1984     opencv_onnx::TensorProto tensor_proto = opencv_onnx::TensorProto();
1985     if (!tensor_proto.ParseFromIstream(&input))
1986     {
1987         CV_Error(Error::StsUnsupportedFormat, cv::format("Failed to parse ONNX data: %s", path.c_str()));
1988     }
1989     Mat mat = getMatFromTensor(tensor_proto);
1990     releaseONNXTensor(tensor_proto);
1991     return mat;
1992 }
1993
1994 CV__DNN_EXPERIMENTAL_NS_END
1995 }} // namespace
1996
1997 #endif