Merge remote-tracking branch 'upstream/3.4' into merge-3.4
[platform/upstream/opencv.git] / modules / dnn / src / onnx / onnx_importer.cpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7
8 #include "../precomp.hpp"
9 #include <opencv2/dnn/shape_utils.hpp>
10
11 #include <opencv2/core/utils/logger.defines.hpp>
12 #undef CV_LOG_STRIP_LEVEL
13 #define CV_LOG_STRIP_LEVEL CV_LOG_LEVEL_DEBUG + 1
14 #include <opencv2/core/utils/logger.hpp>
15
16 #ifdef HAVE_PROTOBUF
17
18 #include <iostream>
19 #include <fstream>
20 #include <string>
21 #include <limits>
22 #include <algorithm>
23
24
25 #if defined(__GNUC__) && __GNUC__ >= 5
26 #pragma GCC diagnostic push
27 #pragma GCC diagnostic ignored "-Wsuggest-override"
28 #endif
29 #include "opencv-onnx.pb.h"
30 #if defined(__GNUC__) && __GNUC__ >= 5
31 #pragma GCC diagnostic pop
32 #endif
33
34 #include "onnx_graph_simplifier.hpp"
35
36 namespace cv {
37 namespace dnn {
38 CV__DNN_INLINE_NS_BEGIN
39
40
41 class ONNXImporter
42 {
43     opencv_onnx::ModelProto model_proto;
44     struct LayerInfo {
45         int layerId;
46         int outputId;
47         LayerInfo(int _layerId = 0, int _outputId = 0) : layerId(_layerId), outputId(_outputId) {}
48     };
49
50     std::map<std::string, Mat> getGraphTensors(
51                                     const opencv_onnx::GraphProto& graph_proto);
52     Mat getBlob(const opencv_onnx::NodeProto& node_proto, int index);
53     Mat getBlob(const std::string& input_name);
54
55     LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto);
56     bool isCeilMode(const LayerParams& layerParams);
57
58     void addConstant(const std::string& name, const Mat& blob);
59     void addLayer(LayerParams& layerParams,
60                   const opencv_onnx::NodeProto& node_proto);
61
62 public:
63
64     ONNXImporter(Net& net, const char *onnxFile)
65         : dstNet(net)
66     {
67         hasDynamicShapes = false;
68         CV_Assert(onnxFile);
69         CV_LOG_DEBUG(NULL, "DNN/ONNX: processing ONNX model from file: " << onnxFile);
70
71         std::fstream input(onnxFile, std::ios::in | std::ios::binary);
72         if (!input)
73         {
74             CV_Error(Error::StsBadArg, cv::format("Can't read ONNX file: %s", onnxFile));
75         }
76
77         if (!model_proto.ParseFromIstream(&input))
78         {
79             CV_Error(Error::StsUnsupportedFormat, cv::format("Failed to parse ONNX model: %s", onnxFile));
80         }
81
82         populateNet();
83     }
84
85     ONNXImporter(Net& net, const char* buffer, size_t sizeBuffer)
86         : dstNet(net)
87     {
88         hasDynamicShapes = false;
89         CV_LOG_DEBUG(NULL, "DNN/ONNX: processing in-memory ONNX model (" << sizeBuffer << " bytes)");
90
91         struct _Buf : public std::streambuf
92         {
93             _Buf(const char* buffer, size_t sizeBuffer)
94             {
95                 char* p = const_cast<char*>(buffer);
96                 setg(p, p, p + sizeBuffer);
97             }
98         };
99
100         _Buf buf(buffer, sizeBuffer);
101         std::istream input(&buf);
102
103         if (!model_proto.ParseFromIstream(&input))
104             CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model from in-memory byte array.");
105
106         populateNet();
107     }
108
109     void populateNet();
110
111 protected:
112     Net& dstNet;
113
114     opencv_onnx::GraphProto graph_proto;
115     std::string framework_name;
116
117     std::map<std::string, Mat> constBlobs;
118
119     std::map<std::string, MatShape> outShapes;  // List of internal blobs shapes.
120     bool hasDynamicShapes;  // Whether the model has inputs with dynamic shapes
121     typedef std::map<std::string, MatShape>::iterator IterShape_t;
122
123     std::map<std::string, LayerInfo> layer_id;
124     typedef std::map<std::string, LayerInfo>::iterator IterLayerId_t;
125
126     void handleNode(const opencv_onnx::NodeProto& node_proto);
127 };
128
129 inline void replaceLayerParam(LayerParams& layerParams, const String& oldKey, const String& newKey)
130 {
131     if (layerParams.has(oldKey)) {
132         layerParams.set(newKey, layerParams.get(oldKey));
133         layerParams.erase(oldKey);
134     }
135 }
136
137 void releaseONNXTensor(opencv_onnx::TensorProto& tensor_proto)
138 {
139     if (!tensor_proto.raw_data().empty()) {
140         delete tensor_proto.release_raw_data();
141     }
142 }
143
144 void runLayer(LayerParams& params, const std::vector<Mat>& inputs,
145               std::vector<Mat>& outputs)
146 {
147     Ptr<Layer> layer = LayerFactory::createLayerInstance(params.type, params);
148     CV_Assert((bool)layer);
149
150     std::vector<MatShape> inpShapes(inputs.size());
151     int ddepth = CV_32F;
152     for (size_t i = 0; i < inputs.size(); ++i)
153     {
154         inpShapes[i] = shape(inputs[i]);
155         if (i > 0 && ddepth != inputs[i].depth())
156             CV_Error(Error::StsNotImplemented, "Mixed input data types.");
157         ddepth = inputs[i].depth();
158     }
159
160     std::vector<MatShape> outShapes, internalShapes;
161     layer->getMemoryShapes(inpShapes, 0, outShapes, internalShapes);
162
163     std::vector<Mat> internals(internalShapes.size());
164     outputs.resize(outShapes.size());
165     for (size_t i = 0; i < outShapes.size(); ++i)
166         outputs[i].create(outShapes[i], ddepth);
167     for (size_t i = 0; i < internalShapes.size(); ++i)
168         internals[i].create(internalShapes[i], ddepth);
169
170     layer->finalize(inputs, outputs);
171     layer->forward(inputs, outputs, internals);
172 }
173
174 std::map<std::string, Mat> ONNXImporter::getGraphTensors(
175                                         const opencv_onnx::GraphProto& graph_proto)
176 {
177   opencv_onnx::TensorProto tensor_proto;
178   std::map<std::string, Mat> layers_weights;
179
180   for (int i = 0; i < graph_proto.initializer_size(); i++)
181   {
182     tensor_proto = graph_proto.initializer(i);
183     Mat mat = getMatFromTensor(tensor_proto);
184     releaseONNXTensor(tensor_proto);
185     layers_weights.insert(std::make_pair(tensor_proto.name(), mat));
186   }
187   return layers_weights;
188 }
189
190 static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protobuf::int64>& src) {
191     std::vector<int32_t> dst(src.size());
192     convertInt64ToInt32(src, dst, src.size());
193     return DictValue::arrayInt(&dst[0], src.size());
194 }
195
196 LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto)
197 {
198     LayerParams lp;
199     for(int i = 0; i < node_proto.attribute_size(); i++)
200     {
201         opencv_onnx::AttributeProto attribute_proto = node_proto.attribute(i);
202         std::string attribute_name = attribute_proto.name();
203
204         if(attribute_name == "kernel_shape")
205         {
206             CV_Assert(attribute_proto.ints_size() == 1 || attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
207             lp.set("kernel_size", parse(attribute_proto.ints()));
208         }
209         else if(attribute_name == "strides")
210         {
211             CV_Assert(attribute_proto.ints_size() == 1 || attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
212             lp.set("stride", parse(attribute_proto.ints()));
213         }
214         else if(attribute_name == "pads")
215         {
216             if (node_proto.op_type() == "Pad")
217             {
218                 // Padding layer.
219                 // Paddings are in order begin0, begin1, .. beginN, end0, end1, ..., endN.
220                 // We need to shuffle it to begin0, end0, begin1, end1, ...
221                 CV_Assert(attribute_proto.ints_size() % 2 == 0);
222                 const int dims = attribute_proto.ints_size() / 2;
223                 std::vector<int32_t> paddings;
224                 paddings.reserve(attribute_proto.ints_size());
225                 for (int i = 0; i < dims; ++i)
226                 {
227                     paddings.push_back(attribute_proto.ints(i));
228                     paddings.push_back(attribute_proto.ints(dims + i));
229                 }
230                 lp.set("paddings", DictValue::arrayInt(&paddings[0], paddings.size()));
231             }
232             else
233             {
234                 // Convolution or pooling.
235                 CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 4 || attribute_proto.ints_size() == 6);
236                 lp.set("pad", parse(attribute_proto.ints()));
237             }
238         }
239         else if(attribute_name == "auto_pad")
240         {
241             if (attribute_proto.s() == "SAME_UPPER" || attribute_proto.s() == "SAME_LOWER") {
242                 lp.set("pad_mode",  "SAME");
243             }
244             else if (attribute_proto.s() == "VALID") {
245                 lp.set("pad_mode", "VALID");
246             }
247         }
248         else if(attribute_name == "dilations")
249         {
250             CV_Assert(attribute_proto.ints_size() == 1 || attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
251             lp.set("dilation", parse(attribute_proto.ints()));
252         }
253         else if (attribute_proto.has_i())
254         {
255             ::google::protobuf::int64 src = attribute_proto.i();
256             if (src < std::numeric_limits<int32_t>::min() || src > std::numeric_limits<int32_t>::max())
257                 CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
258             else
259                 lp.set(attribute_name, saturate_cast<int32_t>(src));
260         }
261         else if (attribute_proto.has_f())
262         {
263             lp.set(attribute_name, attribute_proto.f());
264         }
265         else if (attribute_proto.has_s())
266         {
267             lp.set(attribute_name, attribute_proto.s());
268         }
269         else if (attribute_proto.floats_size() > 0)
270         {
271             lp.set(attribute_name, DictValue::arrayReal(
272                 attribute_proto.floats().data(), attribute_proto.floats_size()));
273         }
274         else if (attribute_proto.ints_size() > 0)
275         {
276             lp.set(attribute_name, parse(attribute_proto.ints()));
277         }
278         else if (attribute_proto.has_t())
279         {
280             opencv_onnx::TensorProto tensor = attribute_proto.t();
281             Mat blob = getMatFromTensor(tensor);
282             lp.blobs.push_back(blob);
283         }
284         else if (attribute_proto.has_g())
285         {
286             CV_Error(Error::StsNotImplemented, cv::format("DNN/ONNX/Attribute[%s]: 'Graph' is not supported", attribute_name.c_str()));
287         }
288         else if (attribute_proto.graphs_size() > 0)
289         {
290             CV_Error(Error::StsNotImplemented,
291                     cv::format("DNN/ONNX/Attribute[%s]: 'Graphs' (%d) in attributes is not supported",
292                             attribute_name.c_str(), attribute_proto.graphs_size())
293             );
294         }
295         else if (attribute_proto.strings_size() > 0)
296         {
297             std::string msg = cv::format("DNN/ONNX/Attribute[%s]: 'Strings' (%d) are not supported",
298                     attribute_name.c_str(), attribute_proto.strings_size());
299             CV_LOG_ERROR(NULL, msg);
300             for (int i = 0; i < attribute_proto.strings_size(); i++)
301             {
302                 CV_LOG_ERROR(NULL, "    Attribute[" << attribute_name << "].string(" << i << ") = '" << attribute_proto.strings(i) << "'");
303             }
304             CV_Error(Error::StsNotImplemented, msg);
305         }
306         else if (attribute_proto.tensors_size() > 0)
307         {
308             CV_Error(Error::StsNotImplemented,
309                     cv::format("DNN/ONNX/Attribute[%s]: 'Tensors' (%d) in attributes are not supported",
310                             attribute_name.c_str(), attribute_proto.tensors_size())
311             );
312         }
313         else
314         {
315             CV_Error(Error::StsNotImplemented, cv::format("DNN/ONNX/Attribute[%s]: unsupported attribute format", attribute_name.c_str()));
316         }
317     }
318     return lp;
319 }
320
321 Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto, int index)
322 {
323     CV_Assert(index < node_proto.input_size());
324     const std::string& input_name = node_proto.input(index);
325     return getBlob(input_name);
326 }
327
328 Mat ONNXImporter::getBlob(const std::string& input_name)
329 {
330     std::map<std::string, Mat>::const_iterator constBlob = constBlobs.find(input_name);
331     if (constBlob == constBlobs.end())
332     {
333         CV_Error(Error::StsBadArg, std::string("Blob ") + input_name + " not found in const blobs");
334     }
335     return constBlob->second;
336 }
337
338 void ONNXImporter::addLayer(LayerParams& layerParams,
339                             const opencv_onnx::NodeProto& node_proto)
340 {
341     int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
342     for (int i = 0; i < node_proto.output_size(); ++i)
343     {
344         layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
345     }
346
347     std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
348     int inpNum = 0;
349     for (int j = 0; j < node_proto.input_size(); j++)
350     {
351         const std::string& input_name = node_proto.input(j);
352         IterLayerId_t layerId = layer_id.find(input_name);
353         if (layerId != layer_id.end()) {
354             dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, inpNum);
355             ++inpNum;
356             // Collect input shapes.
357             IterShape_t shapeIt = outShapes.find(input_name);
358             CV_Assert(shapeIt != outShapes.end());
359             layerInpShapes.push_back(shapeIt->second);
360         }
361     }
362     // Compute shape of output blob for this layer.
363     Ptr<Layer> layer = dstNet.getLayer(id);  // FIXIT: avoid instantiation of layers during the import stage
364     layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
365     for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
366     {
367         outShapes[node_proto.output(i)] = layerOutShapes[i];
368     }
369 }
370
371 void ONNXImporter::addConstant(const std::string& name, const Mat& blob)
372 {
373     constBlobs.insert(std::make_pair(name, blob));
374     outShapes.insert(std::make_pair(name, shape(blob)));
375 }
376
377 void ONNXImporter::populateNet()
378 {
379     CV_Assert(model_proto.has_graph());
380     graph_proto = model_proto.graph();
381
382     std::string framework_version;
383     if (model_proto.has_producer_name())
384         framework_name = model_proto.producer_name();
385     if (model_proto.has_producer_version())
386         framework_version = model_proto.producer_version();
387
388     CV_LOG_INFO(NULL, "DNN/ONNX: loading ONNX"
389             << (model_proto.has_ir_version() ? cv::format(" v%d", (int)model_proto.ir_version()) : cv::String())
390             << " model produced by '" << framework_name << "'"
391             << (framework_version.empty() ? cv::String() : cv::format(":%s", framework_version.c_str()))
392             << ". Number of nodes = " << graph_proto.node_size()
393             << ", inputs = " << graph_proto.input_size()
394             << ", outputs = " << graph_proto.output_size()
395             );
396
397     simplifySubgraphs(graph_proto);
398
399     const int layersSize = graph_proto.node_size();
400     CV_LOG_DEBUG(NULL, "DNN/ONNX: graph simplified to " << layersSize << " nodes");
401
402     constBlobs = getGraphTensors(graph_proto);
403     // Add all the inputs shapes. It includes as constant blobs as network's inputs shapes.
404     for (int i = 0; i < graph_proto.input_size(); ++i)
405     {
406         const opencv_onnx::ValueInfoProto& valueInfoProto = graph_proto.input(i);
407         CV_Assert(valueInfoProto.has_name());
408         CV_Assert(valueInfoProto.has_type());
409         opencv_onnx::TypeProto typeProto = valueInfoProto.type();
410         CV_Assert(typeProto.has_tensor_type());
411         opencv_onnx::TypeProto::Tensor tensor = typeProto.tensor_type();
412         CV_Assert(tensor.has_shape());
413         opencv_onnx::TensorShapeProto tensorShape = tensor.shape();
414
415         MatShape inpShape(tensorShape.dim_size());
416         for (int j = 0; j < inpShape.size(); ++j)
417         {
418             inpShape[j] = tensorShape.dim(j).dim_value();
419             if (!tensorShape.dim(j).dim_param().empty())
420                 hasDynamicShapes = true;
421         }
422         if (!inpShape.empty() && !hasDynamicShapes)
423         {
424             inpShape[0] = std::max(inpShape[0], 1); // It's OK to have undetermined batch size
425         }
426         outShapes[valueInfoProto.name()] = inpShape;
427     }
428
429     // create map with network inputs (without const blobs)
430     // fill map: push layer name, layer id and output id
431     std::vector<String> netInputs;
432     for (int j = 0; j < graph_proto.input_size(); j++)
433     {
434         const std::string& name = graph_proto.input(j).name();
435         if (constBlobs.find(name) == constBlobs.end()) {
436             netInputs.push_back(name);
437             layer_id.insert(std::make_pair(name, LayerInfo(0, netInputs.size() - 1)));
438         }
439     }
440     dstNet.setInputsNames(netInputs);
441
442     for(int li = 0; li < layersSize; li++)
443     {
444         const opencv_onnx::NodeProto& node_proto = graph_proto.node(li);
445         handleNode(node_proto);
446     }
447
448     CV_LOG_DEBUG(NULL, "DNN/ONNX: import completed!");
449 }
450
451 void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
452 {
453     opencv_onnx::NodeProto node_proto = node_proto_;  // TODO FIXIT
454
455     CV_Assert(node_proto.output_size() >= 1);
456     std::string name = node_proto.output(0);
457     std::string layer_type = node_proto.op_type();
458     CV_LOG_DEBUG(NULL, "DNN/ONNX: processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
459             << cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
460     );
461
462     try
463     {
464         // FIXIT not all cases can be repacked into "LayerParams". Importer should handle such cases directly for each "layer_type"
465         LayerParams layerParams = getLayerParams(node_proto);
466
467         layerParams.name = name;
468         layerParams.type = layer_type;
469         layerParams.set("has_dynamic_shapes", hasDynamicShapes);
470
471         if (layer_type == "MaxPool")
472         {
473             layerParams.type = "Pooling";
474             layerParams.set("pool", "MAX");
475             layerParams.set("ceil_mode", layerParams.has("pad_mode"));
476         }
477         else if (layer_type == "AveragePool")
478         {
479             layerParams.type = "Pooling";
480             layerParams.set("pool", "AVE");
481             layerParams.set("ceil_mode", layerParams.has("pad_mode"));
482             layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
483         }
484         else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
485                 layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax")
486         {
487             CV_Assert(node_proto.input_size() == 1);
488             layerParams.type = "Pooling";
489             String pool;
490             if (layer_type == "GlobalMaxPool" || layer_type == "ReduceMax")
491                 pool = "MAX";
492             else if (layer_type == "ReduceSum")
493                 pool = "SUM";
494             else
495                 pool = "AVE";
496             layerParams.set("pool", pool);
497             layerParams.set("global_pooling", !layerParams.has("axes"));
498             if (layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
499             {
500                 MatShape inpShape = outShapes[node_proto.input(0)];
501                 DictValue axes = layerParams.get("axes");
502                 bool keepdims = layerParams.get<int>("keepdims");
503                 MatShape targetShape;
504                 std::vector<bool> shouldDelete(inpShape.size(), false);
505                 for (int i = 0; i < axes.size(); i++) {
506                     int axis = clamp(axes.get<int>(i), inpShape.size());
507                     shouldDelete[axis] = true;
508                 }
509                 for (int axis = 0; axis < inpShape.size(); ++axis){
510                     if (!shouldDelete[axis])
511                         targetShape.push_back(inpShape[axis]);
512                     else if (keepdims)
513                         targetShape.push_back(1);
514                 }
515
516                 if (inpShape.size() == 3 && axes.size() <= 2)
517                 {
518                     int axis = clamp(axes.get<int>(0), inpShape.size());
519                     CV_CheckNE(axis, 0, "");
520
521                     LayerParams reshapeLp;
522                     reshapeLp.name = layerParams.name + "/reshape";
523                     reshapeLp.type = "Reshape";
524                     CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
525                     reshapeLp.set("axis", 0);
526                     reshapeLp.set("num_axes", 1);
527                     int newShape[] = {1, -1};
528                     reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 2));
529
530                     opencv_onnx::NodeProto proto;
531                     proto.add_input(node_proto.input(0));
532                     proto.add_output(reshapeLp.name);
533                     addLayer(reshapeLp, proto);
534
535                     LayerParams avgLp;
536                     avgLp.name = layerParams.name + "/avg";
537                     avgLp.type = "Pooling";
538                     CV_Assert(layer_id.find(avgLp.name) == layer_id.end());
539                     avgLp.set("pool", pool);
540                     if (axes.size() == 2)
541                     {
542                         CV_CheckEQ(clamp(axes.get<int>(0), inpShape.size()), 1, "Unsupported mode");
543                         CV_CheckEQ(clamp(axes.get<int>(1), inpShape.size()), 2, "Unsupported mode");
544                         avgLp.set("global_pooling", true);
545                     }
546                     else
547                     {
548                         avgLp.set(axis == 2 ? "global_pooling_w" : "global_pooling_h", true);
549                         avgLp.set(axis == 2 ? "kernel_h" : "kernel_w", 1);
550                     }
551
552                     node_proto.set_input(0, reshapeLp.name);
553                     node_proto.set_output(0, avgLp.name);
554                     addLayer(avgLp, node_proto);
555                 }
556                 else
557                 {
558                     if (inpShape.size() != 4 && inpShape.size() != 5)
559                         CV_Error(Error::StsNotImplemented, "Unsupported input shape of " + layer_type + " operation.");
560
561                     CV_Assert(axes.size() <= inpShape.size() - 2);
562                     std::vector<int> kernel_size(inpShape.size() - 2, 1);
563                     if (axes.size() == 1 && (clamp(axes.get<int>(0), inpShape.size()) <= 1))
564                     {
565                         int axis = clamp(axes.get<int>(0), inpShape.size());
566                         MatShape newShape = inpShape;
567                         newShape[axis + 1] = total(newShape, axis + 1);
568                         newShape.resize(axis + 2);
569                         newShape.insert(newShape.begin(), 2 - axis, 1);
570
571                         LayerParams reshapeLp;
572                         reshapeLp.type = "Reshape";
573                         reshapeLp.name = layerParams.name + "/reshape";
574                         CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
575                         reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], newShape.size()));
576
577                         node_proto.set_output(0, reshapeLp.name);
578                         addLayer(reshapeLp, node_proto);
579
580                         kernel_size.resize(2);
581                         kernel_size[0] = inpShape[axis];
582                         node_proto.set_input(0, node_proto.output(0));
583                     }
584                     else
585                     {
586                         for (int i = 0; i < axes.size(); i++) {
587                             int axis = clamp(axes.get<int>(i), inpShape.size());
588                             CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
589                             kernel_size[axis - 2] = inpShape[axis];
590                         }
591                     }
592
593                     LayerParams poolLp = layerParams;
594                     poolLp.name = layerParams.name + "/avg";
595                     CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
596                     poolLp.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
597
598                     node_proto.set_output(0, poolLp.name);
599                     addLayer(poolLp, node_proto);
600                 }
601
602                 layerParams.type = "Reshape";
603                 layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
604
605                 node_proto.set_input(0, node_proto.output(0));
606                 node_proto.set_output(0, layerParams.name);
607             }
608             else if (!layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
609             {
610                 CV_CheckEQ(layerParams.get<int>("keepdims"), 0, "layer only supports keepdims = false");
611                 LayerParams reshapeLp;
612                 reshapeLp.name = layerParams.name + "/reshape";
613                 reshapeLp.type = "Reshape";
614                 CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
615                 int newShape[] = {1, 1, 1, -1};
616                 reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 4));
617
618                 opencv_onnx::NodeProto proto;
619                 proto.add_input(node_proto.input(0));
620                 proto.add_output(reshapeLp.name);
621                 addLayer(reshapeLp, proto);
622
623                 LayerParams poolLp = layerParams;
624                 poolLp.name = layerParams.name + "/pool";
625                 CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
626
627                 node_proto.set_input(0, reshapeLp.name);
628                 node_proto.set_output(0, poolLp.name);
629                 addLayer(poolLp, node_proto);
630
631                 layerParams.type = "Reshape";
632                 int targetShape[] = {1};
633                 layerParams.set("dim", DictValue::arrayInt(&targetShape[0], 1));
634
635                 node_proto.set_input(0, node_proto.output(0));
636                 node_proto.set_output(0, layerParams.name);
637             }
638         }
639         else if (layer_type == "Slice")
640         {
641             int axis = 0;
642             std::vector<int> begin;
643             std::vector<int> end;
644             int inp_size = node_proto.input_size();
645
646             if (inp_size == 1)
647             {
648                 if (layerParams.has("steps"))
649                 {
650                     DictValue steps = layerParams.get("steps");
651                     for (int i = 0; i < steps.size(); ++i)
652                     {
653                         if (steps.get<int>(i) != 1)
654                             CV_Error(Error::StsNotImplemented,
655                                 "Slice layer only supports steps = 1");
656                     }
657                 }
658                 if (layerParams.has("axes")) {
659                     DictValue axes = layerParams.get("axes");
660                     for (int i = 1; i < axes.size(); ++i) {
661                         CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
662                     }
663                     axis = axes.get<int>(0);
664                 }
665
666                 DictValue starts = layerParams.get("starts");
667                 DictValue ends = layerParams.get("ends");
668                 CV_Assert(starts.size() == ends.size());
669
670                 if (axis > 0) {
671                     begin.resize(axis, 0);
672                     end.resize(axis, -1);
673                 }
674                 for (int i = 0; i < starts.size(); ++i)
675                 {
676                     begin.push_back(starts.get<int>(i));
677                     int finish = ends.get<int>(i);
678                     end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
679                 }
680             } else {
681                 CV_Assert(inp_size >= 3);
682                 for (int i = 1; i < inp_size; i++) {
683                     CV_Assert(constBlobs.find(node_proto.input(i)) != constBlobs.end());
684                 }
685                 Mat start_blob = getBlob(node_proto, 1);
686                 Mat end_blob   = getBlob(node_proto, 2);
687                 CV_Assert(start_blob.total() == end_blob.total());
688
689                 if (inp_size > 3) {
690                     Mat axes_blob = getBlob(node_proto, 3);
691                     const int* axes = (int*)axes_blob.data;
692                     for (int i = 1; i < axes_blob.total(); ++i) {
693                         CV_Assert(axes[i - 1] == axes[i] - 1);
694                     }
695                     axis = axes[0];
696                 }
697
698                 const int* starts = start_blob.ptr<int>();
699                 const int* ends   = end_blob.ptr<int>();
700                 if (axis > 0) {
701                     begin.resize(axis, 0);
702                     end.resize(axis, -1);
703                 }
704                 std::copy(starts, starts + start_blob.total(), std::back_inserter(begin));
705                 for (int i = 0; i < end_blob.total(); ++i)
706                 {
707                     int finish = ends[i];
708                     end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
709                 }
710
711                 if (inp_size == 5) {
712                     CV_Assert(constBlobs.find(node_proto.input(4)) != constBlobs.end());
713                     Mat step_blob = getBlob(node_proto, 4);
714
715                     // Very strange application for Slice op with tensor reversing.
716                     // We just workaround it for 2d constants.
717                     if (constBlobs.find(node_proto.input(0)) != constBlobs.end() &&
718                         axis == 0 &&
719                         start_blob.at<int>(0) == -1 && step_blob.at<int>(0) == -1 &&
720                         end_blob.at<int>(0) == std::numeric_limits<int32_t>::min())
721                     {
722                         Mat inp = getBlob(node_proto, 0);
723                         if (inp.dims == 2)
724                         {
725                             Mat flipped;
726                             flip(inp, flipped, 0);
727                             addConstant(layerParams.name, flipped);
728                             return;
729                         }
730                     }
731                     CV_CheckEQ(countNonZero(step_blob != 1), 0, "Slice layer only supports steps = 1");
732                 }
733             }
734             layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
735             layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
736             layerParams.set("axis", axis);
737
738             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
739             {
740                 Mat inp = getBlob(node_proto, 0);
741                 std::vector<Mat> inputs, sliced;
742                 inputs.push_back(inp);
743                 runLayer(layerParams, inputs, sliced);
744                 CV_Assert(sliced.size() == 1);
745                 addConstant(layerParams.name, sliced[0]);
746                 return;
747             }
748         }
749         else if (layer_type == "Split")
750         {
751             if (layerParams.has("split"))
752             {
753                 DictValue splits = layerParams.get("split");
754                 const int numSplits = splits.size();
755                 CV_Assert(numSplits > 1);
756
757                 std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
758                 for (int i = 1; i < splits.size() - 1; ++i)
759                 {
760                     slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1);
761                 }
762                 layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
763             }
764             else
765             {
766                 layerParams.set("num_split", node_proto.output_size());
767             }
768             layerParams.type = "Slice";
769         }
770         else if (layer_type == "Add" || layer_type == "Sum" || layer_type == "Sub")
771         {
772             bool isSub = layer_type == "Sub";
773             CV_CheckEQ(node_proto.input_size(), 2, "");
774             bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end();
775             bool is_const_1 = layer_id.find(node_proto.input(1)) == layer_id.end();
776             if (is_const_0 && is_const_1)
777             {
778                 Mat blob_0 = getBlob(node_proto, 0);
779                 Mat blob_1 = getBlob(node_proto, 1);
780                 CV_Assert(blob_0.size == blob_1.size);
781                 Mat output = isSub ? (blob_0 - blob_1) : (blob_0 + blob_1);
782                 addConstant(layerParams.name, output);
783                 return;
784             }
785             else if (is_const_0 || is_const_1)
786             {
787                 int const_blob_id = is_const_0 ? 0 : 1;
788                 Mat blob = getBlob(node_proto, const_blob_id);
789                 int blob_total = blob.total();
790                 if (blob_total == 1) {
791                     layerParams.type = "Power";
792                     layerParams.set("shift", (isSub ? -1 : 1) * blob.at<float>(0));
793                 }
794                 else {
795                     MatShape inpShape = outShapes[node_proto.input(1 - const_blob_id)];
796                     if (shape(blob) == inpShape)
797                     {
798                         LayerParams constParams;
799                         constParams.name = layerParams.name + "/const";
800                         constParams.type = "Const";
801                         constParams.blobs.push_back((isSub ? -1 : 1) * blob);
802                         int id = dstNet.addLayer(constParams.name, constParams.type, constParams);
803                         layer_id.insert(std::make_pair(constParams.name, LayerInfo(id, 0)));
804                         outShapes[constParams.name] = shape(blob);
805
806                         layerParams.type = "Eltwise";
807                         node_proto.set_input(const_blob_id, constParams.name);
808                     }
809                     else
810                     {
811                         layerParams.type = "Scale";
812                         layerParams.set("bias_term", true);
813                         int axis = 1;
814                         for (int i = 0; i < graph_proto.initializer_size(); i++)
815                         {
816                             opencv_onnx::TensorProto tensor_proto = graph_proto.initializer(i);
817                             if (tensor_proto.name() == node_proto.input(const_blob_id))
818                             {
819                                 axis = inpShape.size() - tensor_proto.dims_size();
820                                 break;
821                             }
822                         }
823                         layerParams.set("axis", axis);
824                         blob = blob.reshape(1, 1);
825                         layerParams.blobs.push_back((isSub ? -1 : 1) * blob);
826                     }
827                 }
828             }
829             else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)])
830             {
831                 layerParams.type = "Eltwise";
832                 if (isSub)
833                 {
834                     static float subCoeffs[] = {1.f, -1.f};
835                     layerParams.set("coeff", DictValue::arrayReal<float*>(subCoeffs, 2));
836                 }
837             }
838             else
839             {
840                 if (isSub)
841                 {
842                     LayerParams powerParams;
843                     powerParams.name = layerParams.name + "/neg";
844                     powerParams.type = "Power";
845                     powerParams.set("scale", -1);
846
847                     //Create Power layer
848                     int id = dstNet.addLayer(powerParams.name, powerParams.type, powerParams);
849                     //Connect to input
850                     IterLayerId_t layerId = layer_id.find(node_proto.input(1));
851                     CV_Assert(layerId != layer_id.end());
852                     dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
853                     //Add shape
854                     layer_id.insert(std::make_pair(powerParams.name, LayerInfo(id, 0)));
855                     outShapes[powerParams.name] = outShapes[node_proto.input(1)];
856
857                     //Replace input to Power
858                     node_proto.set_input(1, powerParams.name);
859                 }
860                 layerParams.type = "Scale";
861                 layerParams.set("bias_term", true);
862             }
863         }
864         else if (layer_type == "Pow")
865         {
866             if (layer_id.find(node_proto.input(1)) != layer_id.end())
867                 CV_Error(Error::StsNotImplemented, "Unsupported Pow op with variable power");
868
869             Mat blob = getBlob(node_proto, 1);
870             if (blob.total() != 1)
871                 CV_Error(Error::StsNotImplemented, "Pow op supports only scalar power");
872
873             blob.convertTo(blob, CV_32F);
874             layerParams.type = "Power";
875             layerParams.set("power", blob.at<float>(0));
876         }
877         else if (layer_type == "Max")
878         {
879             layerParams.type = "Eltwise";
880             layerParams.set("operation", "max");
881         }
882         else if (layer_type == "Neg")
883         {
884             layerParams.type = "Power";
885             layerParams.set("scale", -1);
886         }
887         else if (layer_type == "Constant")
888         {
889             CV_Assert(node_proto.input_size() == 0);
890             CV_Assert(layerParams.blobs.size() == 1);
891             addConstant(layerParams.name, layerParams.blobs[0]);
892             return;
893         }
894         else if (layer_type == "LSTM")
895         {
896             LayerParams lstmParams = layerParams;
897             lstmParams.name += "/lstm";
898
899             // https://pytorch.org/docs/stable/nn.html#lstm
900             CV_Assert(node_proto.input_size() == 7);
901             Mat Wx = getBlob(node_proto, 1);
902             Mat Wh = getBlob(node_proto, 2);
903             Mat b = getBlob(node_proto, 3);
904             CV_CheckEQ(countNonZero(getBlob(node_proto, 5)), 0, "Unsupported non zero initial_h");
905             CV_CheckEQ(countNonZero(getBlob(node_proto, 6)), 0, "Unsupported non zero initial_c");
906             b = b.reshape(1, b.size[0]);
907
908             const int numHidden = lstmParams.get<int>("hidden_size");
909             const int numDirs = Wx.size[0];  // Is 1 for forward only and 2 for bidirectional LSTM.
910             const int numFeatures = Wx.size[2];
911             Mat bx = b.colRange(0, b.cols / 2);
912             Mat bh = b.colRange(b.cols / 2, b.cols);
913             b = bx + bh;
914
915             // IFGO->IGFO
916             for (int k = 0; k < numDirs; ++k)
917             {
918                 float* WxData = Wx.ptr<float>(k);
919                 float* WhData = Wh.ptr<float>(k);
920                 float* biasData = b.ptr<float>(k);
921                 for (int j = 0; j < numHidden; ++j)
922                 {
923                     for (int i = 0; i < numFeatures; ++i)
924                     {
925                         std::swap(WxData[(numHidden + j) * numFeatures + i],
926                                   WxData[(numHidden * 2 + j) * numFeatures + i]);
927                     }
928                     for (int i = 0; i < numHidden; ++i)
929                     {
930                         std::swap(WhData[(numHidden + j) * numHidden + i],
931                                   WhData[(numHidden * 2 + j) * numHidden + i]);
932                     }
933                     std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]);
934                 }
935             }
936             Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
937             Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
938
939             lstmParams.blobs.resize(3);
940             lstmParams.blobs[0] = Wh;
941             lstmParams.blobs[1] = Wx;
942             lstmParams.blobs[2] = b;
943             lstmParams.set("bidirectional", lstmParams.get<String>("direction", "") == "bidirectional");
944
945             node_proto.set_output(0, lstmParams.name);  // set different name so output shapes will be registered on that name
946             addLayer(lstmParams, node_proto);
947
948             MatShape lstmShape = outShapes[node_proto.output(0)];
949
950             // Add fake 1 as it is done in ONNX
951             lstmShape.insert(lstmShape.begin() + 1, 1);
952
953             layerParams.type = "Reshape";
954             layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size()));
955             node_proto.set_input(0, lstmParams.name);  // redirect input to LSTM
956             node_proto.set_output(0, layerParams.name);  // keep origin LSTM's name
957         }
958         else if (layer_type == "ImageScaler")
959         {
960             const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
961             layerParams.erase("scale");
962
963             if (layerParams.has("bias"))
964             {
965                 layerParams.type = "Scale";
966                 layerParams.blobs.push_back(
967                     Mat(Size(1,  layerParams.get("bias").size()), CV_32FC1, scale));
968
969                 layerParams.set("bias_term", true);
970                 Mat bias(1, layerParams.get("bias").size(), CV_32FC1);
971                 for (int j = 0; j < bias.total(); j++) {
972                     bias.at<float>(0, j) = layerParams.get("bias").getRealValue(j);
973                 }
974                 layerParams.blobs.push_back(bias);
975                 layerParams.erase("bias");
976             }
977             else {
978                 layerParams.set("scale", scale);
979                 layerParams.type = "Power";
980             }
981         }
982         else if (layer_type == "Clip")
983         {
984             layerParams.type = "ReLU6";
985             replaceLayerParam(layerParams, "min", "min_value");
986             replaceLayerParam(layerParams, "max", "max_value");
987
988         }
989         else if (layer_type == "LeakyRelu")
990         {
991             layerParams.type = "ReLU";
992             replaceLayerParam(layerParams, "alpha", "negative_slope");
993         }
994         else if (layer_type == "Relu")
995         {
996             layerParams.type = "ReLU";
997         }
998         else if (layer_type == "Elu")
999         {
1000             layerParams.type = "ELU";
1001         }
1002         else if (layer_type == "Tanh")
1003         {
1004             layerParams.type = "TanH";
1005         }
1006         else if (layer_type == "PRelu")
1007         {
1008             layerParams.type = "PReLU";
1009             layerParams.blobs.push_back(getBlob(node_proto, 1));
1010         }
1011         else if (layer_type == "LRN")
1012         {
1013             replaceLayerParam(layerParams, "size", "local_size");
1014         }
1015         else if (layer_type == "InstanceNormalization")
1016         {
1017             if (node_proto.input_size() != 3)
1018                 CV_Error(Error::StsNotImplemented,
1019                          "Expected input, scale, bias");
1020
1021             layerParams.blobs.resize(4);
1022             layerParams.blobs[2] = getBlob(node_proto, 1);  // weightData
1023             layerParams.blobs[3] = getBlob(node_proto, 2);  // biasData
1024             layerParams.set("has_bias", true);
1025             layerParams.set("has_weight", true);
1026
1027             // Get number of channels in input
1028             int size = layerParams.blobs[2].total();
1029             layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean
1030             layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std
1031
1032             LayerParams mvnParams;
1033             mvnParams.name = layerParams.name + "/MVN";
1034             mvnParams.type = "MVN";
1035             mvnParams.set("eps", layerParams.get<float>("epsilon"));
1036             layerParams.erase("epsilon");
1037
1038             //Create MVN layer
1039             int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams);
1040             //Connect to input
1041             IterLayerId_t layerId = layer_id.find(node_proto.input(0));
1042             CV_Assert(layerId != layer_id.end());
1043             dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
1044             //Add shape
1045             layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0)));
1046             outShapes[mvnParams.name] = outShapes[node_proto.input(0)];
1047
1048             //Replace Batch Norm's input to MVN
1049             node_proto.set_input(0, mvnParams.name);
1050             layerParams.type = "BatchNorm";
1051         }
1052         else if (layer_type == "BatchNormalization")
1053         {
1054             if (node_proto.input_size() != 5)
1055                 CV_Error(Error::StsNotImplemented,
1056                          "Expected input, scale, bias, mean and var");
1057
1058             layerParams.type = "BatchNorm";
1059             replaceLayerParam(layerParams, "epsilon", "eps");
1060             replaceLayerParam(layerParams, "spatial", "use_global_stats");
1061
1062             Mat meanData = getBlob(node_proto, 3);
1063             Mat stdData =  getBlob(node_proto, 4);
1064
1065             layerParams.blobs.push_back(meanData);
1066             layerParams.blobs.push_back(stdData);
1067
1068             if (!node_proto.input(1).empty()) {
1069                 layerParams.set("has_weight", true);
1070                 layerParams.blobs.push_back(getBlob(node_proto, 1));  // weightData
1071             } else {
1072                 layerParams.set("has_weight", false);
1073             }
1074
1075             if (!node_proto.input(2).empty()) {
1076                 layerParams.set("has_bias", true);
1077                 layerParams.blobs.push_back(getBlob(node_proto, 2)); // biasData
1078             } else {
1079                 layerParams.set("has_bias", false);
1080             }
1081         }
1082         else if (layer_type == "Gemm")
1083         {
1084             CV_Assert(node_proto.input_size() >= 2);
1085             layerParams.type = "InnerProduct";
1086             Mat weights = getBlob(node_proto, 1);
1087             int ind_num_out = 0;
1088             if (layerParams.has("transB") && !layerParams.get<int>("transB")) {
1089                 transpose(weights, weights);
1090                 ind_num_out = 1;
1091             }
1092             layerParams.blobs.push_back(weights);
1093
1094             if (node_proto.input_size() == 3) {
1095                 Mat bias = getBlob(node_proto, 2);
1096                 layerParams.blobs.push_back(bias);
1097             }
1098             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1099             {
1100                 Mat inputBuf = getBlob(node_proto, 0);
1101
1102                 LayerParams constParams;
1103                 constParams.name = node_proto.input(0);
1104                 constParams.type = "Const";
1105                 constParams.blobs.push_back(inputBuf);
1106
1107                 opencv_onnx::NodeProto proto;
1108                 proto.add_output(constParams.name);
1109                 addLayer(constParams, proto);
1110             }
1111
1112             layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
1113             layerParams.set("bias_term", node_proto.input_size() == 3);
1114         }
1115         else if (layer_type == "MatMul")
1116         {
1117             CV_Assert(node_proto.input_size() == 2);
1118             layerParams.type = "InnerProduct";
1119             layerParams.set("bias_term", false);
1120             CV_Assert(constBlobs.find(node_proto.input(0)) == constBlobs.end());
1121             int firstInpDims = outShapes[node_proto.input(0)].size();
1122             int secondInpDims;
1123
1124             if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
1125             {
1126                 Mat blob = getBlob(node_proto, 1);
1127                 secondInpDims = blob.dims;
1128                 layerParams.blobs.push_back(blob.t());
1129                 layerParams.set("num_output", layerParams.blobs[0].size[0]);
1130             } else {
1131                 secondInpDims = outShapes[node_proto.input(1)].size();
1132             }
1133             layerParams.set("axis", firstInpDims - secondInpDims + 1);
1134         }
1135         else if (layer_type == "Mul" || layer_type == "Div")
1136         {
1137             CV_Assert(node_proto.input_size() == 2);
1138
1139             bool isDiv = layer_type == "Div";
1140             int constId = -1;
1141             bool haveVariables = false;
1142             for (int i = 0; i < 2; ++i)
1143             {
1144                 if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
1145                     constId = i;
1146                 else
1147                     haveVariables = true;
1148             }
1149             if (constId != -1 && haveVariables)
1150             {
1151                 Mat blob = getBlob(node_proto, constId);
1152                 blob = blob.reshape(1, 1);
1153                 if (blob.total() == 1) {
1154                     float coeff = isDiv ? 1.0 / blob.at<float>(0) : blob.at<float>(0);
1155                     layerParams.set("scale", coeff);
1156                     layerParams.type = "Power";
1157                 }
1158                 else {
1159                     if (isDiv)
1160                         divide(1.0, blob, blob);
1161                     layerParams.blobs.push_back(blob);
1162                     layerParams.type = "Scale";
1163                 }
1164             }
1165             else if (!haveVariables)
1166             {
1167                 Mat inp0 = getBlob(node_proto, 0);
1168                 Mat inp1 = getBlob(node_proto, 1);
1169
1170                 if (inp0.size != inp1.size && (inp0.total() != 1 || inp1.total() != 1))
1171                     CV_Error_(Error::StsNotImplemented, ("Different shapes case is not supported with constant inputs: %s", layer_type.c_str()));
1172
1173                 if (inp0.total() == 1 && inp1.total() == 1 && inp0.dims != inp1.dims)
1174                 {
1175                     if (inp0.dims < inp1.dims)
1176                     {
1177                         inp0 = inp0.reshape(1, inp1.dims, inp1.size);
1178                         inp0.dims = inp1.dims;
1179                     }
1180                     else
1181                     {
1182                         inp1 = inp1.reshape(1, inp0.dims, inp0.size);
1183                         inp1.dims = inp0.dims;
1184                     }
1185                 }
1186
1187                 Mat out;
1188                 if (inp0.total() != inp1.total())
1189                 {
1190                     if (inp0.total() == 1)
1191                     {
1192                         float coeff = isDiv ? 1.0 / inp0.at<float>(0) : inp0.at<float>(0);
1193                         multiply(inp1, coeff, out);
1194                     }
1195                     else
1196                     {
1197                         float coeff = isDiv ? 1.0 / inp1.at<float>(0) : inp1.at<float>(0);
1198                         multiply(inp0, coeff, out);
1199                     }
1200
1201                 }
1202                 else
1203                 {
1204                     out = isDiv ? inp0 / inp1 : inp0.mul(inp1);
1205                 }
1206
1207                 if (inp0.dims == 1 && inp1.dims == 1)
1208                     out.dims = 1;  // to workaround dims == 1
1209                 addConstant(layerParams.name, out);
1210                 return;
1211             }
1212             else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)])
1213             {
1214                 layerParams.type = "Eltwise";
1215                 layerParams.set("operation", isDiv ? "div" : "prod");
1216             }
1217             else
1218             {
1219                 // Scale layer allocate output with the first input shape
1220                 if (total(outShapes[node_proto.input(0)]) < total(outShapes[node_proto.input(1)]))
1221                 {
1222                     opencv_onnx::NodeProto proto;
1223                     proto.add_input(node_proto.input(1));
1224                     proto.add_input(node_proto.input(0));
1225                     proto.add_output(layerParams.name);
1226                     node_proto = proto;
1227                 }
1228
1229                 if (isDiv)
1230                 {
1231                     LayerParams powerParams;
1232                     powerParams.name = layerParams.name + "/inv";
1233                     powerParams.type = "Power";
1234                     powerParams.set("power", -1);
1235
1236                     //Create Power layer
1237                     int id = dstNet.addLayer(powerParams.name, powerParams.type, powerParams);
1238                     //Connect to input
1239                     IterLayerId_t layerId = layer_id.find(node_proto.input(1));
1240                     CV_Assert(layerId != layer_id.end());
1241                     dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
1242                     //Add shape
1243                     layer_id.insert(std::make_pair(powerParams.name, LayerInfo(id, 0)));
1244                     outShapes[powerParams.name] = outShapes[node_proto.input(1)];
1245
1246                     //Replace input to Power
1247                     node_proto.set_input(1, powerParams.name);
1248                 }
1249                 layerParams.type = "Scale";
1250             }
1251         }
1252         else if (layer_type == "Conv")
1253         {
1254             CV_Assert(node_proto.input_size() >= 2);
1255             layerParams.type = "Convolution";
1256             for (int j = 1; j < node_proto.input_size(); j++) {
1257                 if (constBlobs.find(node_proto.input(j)) != constBlobs.end())
1258                 {
1259                     layerParams.blobs.push_back(getBlob(node_proto, j));
1260                 }
1261             }
1262             int outCn = layerParams.blobs.empty() ? outShapes[node_proto.input(1)][0] : layerParams.blobs[0].size[0];
1263             layerParams.set("num_output", outCn);
1264         }
1265         else if (layer_type == "ConvTranspose")
1266         {
1267             CV_Assert(node_proto.input_size() >= 2);
1268             layerParams.type = "Deconvolution";
1269             for (int j = 1; j < node_proto.input_size(); j++) {
1270                 layerParams.blobs.push_back(getBlob(node_proto, j));
1271             }
1272             layerParams.set("num_output", layerParams.blobs[0].size[1] * layerParams.get<int>("group", 1));
1273             layerParams.set("bias_term", node_proto.input_size() == 3);
1274
1275             if (!layerParams.has("kernel_size"))
1276                 CV_Error(Error::StsNotImplemented,
1277                          "Required attribute 'kernel_size' is not present.");
1278
1279             if (layerParams.has("output_shape"))
1280             {
1281                 const DictValue& outShape = layerParams.get("output_shape");
1282                 DictValue strides = layerParams.get("stride");
1283                 DictValue kernel = layerParams.get("kernel_size");
1284
1285                 String padMode;
1286                 std::vector<int> adjust_pads;
1287                 if (layerParams.has("pad_mode"))
1288                 {
1289                     padMode = toUpperCase(layerParams.get<String>("pad_mode"));
1290                     if (padMode != "SAME" && padMode != "VALID")
1291                         CV_Error(Error::StsError, "Unsupported padding mode " + padMode);
1292
1293                     for (int i = 0; i < strides.size(); i++)
1294                     {
1295                         int sz = outShape.get<int>(2 + i);
1296                         int stride = strides.get<int>(i);
1297                         adjust_pads.push_back(padMode == "SAME"? (sz - 1) % stride :
1298                                                                  (sz - kernel.get<int>(i)) % stride);
1299                     }
1300                     layerParams.set("adj", DictValue::arrayInt(&adjust_pads[0], adjust_pads.size()));
1301                 }
1302             }
1303             else if (layerParams.has("output_padding"))
1304             {
1305                 replaceLayerParam(layerParams, "output_padding", "adj");
1306             }
1307         }
1308         else if (layer_type == "Transpose")
1309         {
1310             layerParams.type = "Permute";
1311             replaceLayerParam(layerParams, "perm", "order");
1312
1313             CV_Assert(node_proto.input_size() == 1);
1314             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1315             {
1316                 std::vector<Mat> inputs(1, getBlob(node_proto, 0)), transposed;
1317                 runLayer(layerParams, inputs, transposed);
1318                 CV_Assert(transposed.size() == 1);
1319                 addConstant(layerParams.name, transposed[0]);
1320                 return;
1321             }
1322         }
1323         else if (layer_type == "Squeeze")
1324         {
1325             CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
1326             DictValue axes_dict = layerParams.get("axes");
1327             MatShape inpShape = outShapes[node_proto.input(0)];
1328
1329             std::vector<bool> maskedAxes(inpShape.size(), false);
1330             for (int i = 0; i < axes_dict.size(); ++i)
1331             {
1332                 int axis = axes_dict.getIntValue(i);
1333                 CV_CheckLE(axis, static_cast<int>(inpShape.size()), "Squeeze axis");
1334                 maskedAxes[axis] = inpShape[axis] == 1;
1335             }
1336             MatShape outShape;
1337             for (int i = 0; i < inpShape.size(); ++i)
1338             {
1339                 if (!maskedAxes[i])
1340                     outShape.push_back(inpShape[i]);
1341             }
1342             if (outShape.size() != inpShape.size())
1343             {
1344                 layerParams.type = "Reshape";
1345                 layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
1346                 if (hasDynamicShapes)
1347                 {
1348                     std::vector<int> dynamicAxes;
1349                     std::vector<int> inputIndices;
1350                     for (int index = 0; index < inpShape.size(); ++index)
1351                     {
1352                         if (!maskedAxes[index])
1353                             inputIndices.push_back(index);
1354                     }
1355                     for (int index = 0; index < outShape.size(); ++index)
1356                         dynamicAxes.push_back(index);
1357                     layerParams.set("dynamic_axes", DictValue::arrayInt(dynamicAxes.data(), dynamicAxes.size()));
1358                     layerParams.set("input_indices", DictValue::arrayInt(inputIndices.data(), inputIndices.size()));
1359                 }
1360             }
1361             else
1362                 layerParams.type = "Identity";
1363
1364             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1365             {
1366                 Mat inp = getBlob(node_proto, 0);
1367                 Mat out = inp.reshape(1, outShape);
1368                 out.dims = outShape.size();  // to workaround dims == 1
1369                 addConstant(layerParams.name, out);
1370                 return;
1371             }
1372         }
1373         else if (layer_type == "Flatten")
1374         {
1375             CV_CheckEQ(node_proto.input_size(), 1, "");
1376             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1377             {
1378                 Mat input = getBlob(node_proto, 0);
1379                 int axis = clamp(layerParams.get<int>("axis", 1), input.dims);
1380
1381                 std::vector<int> out_size(&input.size[0], &input.size[0] + axis);
1382                 out_size.push_back(input.total(axis));
1383                 Mat output = input.reshape(1, out_size);
1384                 addConstant(layerParams.name, output);
1385                 return;
1386             }
1387         }
1388         else if (layer_type == "Unsqueeze")
1389         {
1390             CV_Assert(node_proto.input_size() == 1);
1391             DictValue axes = layerParams.get("axes");
1392             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1393             {
1394                 // Constant input.
1395                 Mat input = getBlob(node_proto, 0);
1396
1397                 std::vector<int> dims;
1398                 for (int j = 0; j < input.dims; j++) {
1399                     dims.push_back(input.size[j]);
1400                 }
1401                 CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
1402                 for (int j = 0; j < axes.size(); j++) {
1403                     dims.insert(dims.begin() + axes.getIntValue(j), 1);
1404                 }
1405
1406                 Mat out = input.reshape(0, dims);
1407                 addConstant(layerParams.name, out);
1408                 return;
1409             }
1410
1411             // Variable input.
1412             if (axes.size() != 1)
1413                 CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze");
1414
1415             MatShape inpShape = outShapes[node_proto.input(0)];
1416             int axis = axes.getIntValue(0);
1417             CV_Assert(0 <= axis && axis <= inpShape.size());
1418             std::vector<int> outShape = inpShape;
1419             outShape.insert(outShape.begin() + axis, 1);
1420             layerParams.type = "Reshape";
1421             layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
1422             if (hasDynamicShapes)
1423             {
1424                 std::vector<int> dynamicAxes;
1425                 std::vector<int> inputIndices;
1426                 for (int index = 0; index < outShape.size(); ++index) {
1427                     if (index != axis)
1428                         dynamicAxes.push_back(index);
1429                 }
1430                 for (int index = 0; index < inpShape.size(); ++index)
1431                     inputIndices.push_back(index);
1432                 layerParams.set("dynamic_axes", DictValue::arrayInt(dynamicAxes.data(), dynamicAxes.size()));
1433                 layerParams.set("input_indices", DictValue::arrayInt(inputIndices.data(), inputIndices.size()));
1434             }
1435         }
1436         else if (layer_type == "Expand")
1437         {
1438             CV_CheckEQ(node_proto.input_size(), 2, "");
1439             const std::string& input0 = node_proto.input(0);
1440             const std::string& input1 = node_proto.input(1);
1441             Mat newShapeMat = getBlob(input1);
1442             MatShape targetShape(newShapeMat.ptr<int>(), newShapeMat.ptr<int>() + newShapeMat.total());
1443
1444             MatShape inpShape;
1445             bool haveVariables = constBlobs.find(input0) == constBlobs.end();
1446             if (haveVariables)
1447             {
1448                 IterShape_t shapeIt = outShapes.find(input0);
1449                 CV_Assert(shapeIt != outShapes.end());
1450                 inpShape = shapeIt->second;
1451             }
1452             else
1453             {
1454                 inpShape = shape(getBlob(input0));
1455             }
1456
1457             String srcName = input0;
1458             // Unsqueeze and repeat along new axis
1459             if (targetShape.size() == inpShape.size() + 1)
1460             {
1461                 for (int i = 0; i < targetShape.size(); i++)
1462                 {
1463                     if (targetShape[i] == -1 && i < inpShape.size())
1464                         targetShape[i] = inpShape[i];
1465                     else if (i < inpShape.size() && targetShape[i] != inpShape[i])
1466                         inpShape.insert(inpShape.begin() + i, 1);
1467                 }
1468                 if (haveVariables)
1469                 {
1470                     LayerParams reshapeLp;
1471                     reshapeLp.name = layerParams.name + "/reshape";
1472                     reshapeLp.type = "Reshape";
1473                     CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
1474                     reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
1475
1476                     opencv_onnx::NodeProto proto;
1477                     proto.add_input(node_proto.input(0));
1478                     proto.add_output(reshapeLp.name);
1479                     addLayer(reshapeLp, proto);
1480                     srcName = reshapeLp.name;
1481                 }
1482             }
1483             CV_CheckEQ(inpShape.size(), targetShape.size(), "Unsupported Expand op with different dims");
1484
1485             std::vector<int> broadcast_axes;
1486             for (int i = 0; i < targetShape.size(); i++)
1487             {
1488                 if (targetShape[i] != inpShape[i])
1489                 {
1490                     if (inpShape[i] == 1)
1491                         broadcast_axes.push_back(i);
1492                     else
1493                         CV_Error(Error::StsError, format("Could not be broadcast by axis: %d", i));
1494                 }
1495             }
1496
1497             if (!haveVariables)
1498             {
1499                 if (broadcast_axes.size() != 1)
1500                     CV_Error(Error::StsNotImplemented, "Expand op doesn't support multiple axes for constant input");
1501
1502                 Mat input = getBlob(node_proto, 0);
1503                 input = input.reshape(0, total(inpShape, 0, broadcast_axes[0]));
1504                 Mat output = cv::repeat(input, 1, targetShape[broadcast_axes[0]]);
1505                 output = output.reshape(0, targetShape);
1506                 addConstant(layerParams.name, output);
1507                 return;
1508             }
1509
1510             if (broadcast_axes.size() == 2 &&
1511                 broadcast_axes[0] == broadcast_axes[1] - 1 && broadcast_axes[1] == inpShape.size() - 1)
1512             {
1513                 LayerParams constParams;
1514                 constParams.name = layerParams.name + "/const";
1515                 CV_Assert(layer_id.find(constParams.name) == layer_id.end());
1516                 constParams.type = "Const";
1517
1518                 Mat inp = Mat::ones(newShapeMat.total(), newShapeMat.ptr<int>(), CV_32F);
1519                 constParams.blobs.push_back(inp);
1520
1521                 opencv_onnx::NodeProto proto;
1522                 proto.add_output(constParams.name);
1523                 addLayer(constParams, proto);
1524
1525                 layerParams.type = "Scale";
1526                 layerParams.set("bias_term", false);
1527                 node_proto.set_input(0, constParams.name);
1528                 node_proto.set_input(1, srcName);
1529             }
1530             else if (broadcast_axes.size() == 1 && broadcast_axes[0] <= 1)
1531             {
1532                 String base_name = layerParams.name + "/copy_";
1533                 std::vector<std::string> input_names;
1534                 for (int j = 0; j < targetShape[broadcast_axes[0]]; j++)
1535                 {
1536                     std::ostringstream ss;
1537                     ss << j;
1538                     LayerParams copyLP;
1539                     copyLP.name = base_name + ss.str();
1540                     copyLP.type = "Identity";
1541                     CV_Assert(layer_id.find(copyLP.name) == layer_id.end());
1542                     input_names.push_back(copyLP.name);
1543
1544                     node_proto.set_input(0, srcName);
1545                     node_proto.set_output(0, copyLP.name);
1546                     addLayer(copyLP, node_proto);
1547                 }
1548                 node_proto.clear_input();
1549                 for (int i = 0; i < input_names.size(); i++)
1550                 {
1551                     node_proto.add_input(input_names[i]);
1552                 }
1553                 layerParams.set("axis", broadcast_axes[0]);
1554                 layerParams.type = "Concat";
1555                 node_proto.set_output(0, layerParams.name);
1556             }
1557             else
1558                 CV_Error(Error::StsNotImplemented, "Unsupported Expand op");
1559         }
1560         else if (layer_type == "Reshape")
1561         {
1562             CV_Assert(node_proto.input_size() == 2 || layerParams.has("shape"));
1563
1564             if (node_proto.input_size() == 2) {
1565                 Mat blob = getBlob(node_proto, 1);
1566                 CV_Assert(blob.type() == CV_32SC1);
1567
1568                 layerParams.set("dim", DictValue::arrayInt<int*>(
1569                             blob.ptr<int>(), blob.total() ));
1570
1571                 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
1572                     std::vector<Mat> inputs(1, getBlob(node_proto, 0)), outputs;
1573                     runLayer(layerParams, inputs, outputs);
1574                     addConstant(layerParams.name, outputs[0]);
1575                     return;
1576                 }
1577             }
1578             else {
1579                 DictValue shape = layerParams.get("shape");
1580                 std::vector<int> dim;
1581                 for (int j = 0; j < shape.size(); j++) {
1582                     dim.push_back(shape.getIntValue(j));
1583                 }
1584
1585                 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
1586                     Mat input = getBlob(node_proto, 0);
1587                     Mat out = input.reshape(0, dim);
1588                     addConstant(layerParams.name, out);
1589                     return;
1590                 }
1591                 replaceLayerParam(layerParams, "shape", "dim");
1592             }
1593         }
1594         else if (layer_type == "Pad")
1595         {
1596             layerParams.type = "Padding";
1597             replaceLayerParam(layerParams, "mode", "type");
1598             if (node_proto.input_size() == 3 || node_proto.input_size() == 2)
1599             {
1600                 // Paddings are in order begin0, begin1, .. beginN, end0, end1, ..., endN.
1601                 // We need to shuffle it to begin0, end0, begin1, end1, ...
1602                 Mat paddings = getBlob(node_proto, 1).reshape(1, 2);
1603                 paddings = paddings.t();
1604                 layerParams.set("paddings", DictValue::arrayInt(paddings.ptr<int>(), paddings.total()));
1605
1606                 if (node_proto.input_size() == 3)
1607                 {
1608                     Mat value = getBlob(node_proto, 2);
1609                     layerParams.set("value", value.at<float>(0));
1610                 }
1611             }
1612         }
1613         else if (layer_type == "Shape")
1614         {
1615             CV_Assert(node_proto.input_size() == 1);
1616             IterShape_t shapeIt = outShapes.find(node_proto.input(0));
1617             CV_Assert(shapeIt != outShapes.end());
1618             const MatShape& inpShape = shapeIt->second;
1619
1620             Mat shapeMat(inpShape.size(), 1, CV_32S);
1621             for (int j = 0; j < inpShape.size(); ++j)
1622                 shapeMat.at<int>(j) = inpShape[j];
1623             shapeMat.dims = 1;
1624
1625             addConstant(layerParams.name, shapeMat);
1626             return;
1627         }
1628         else if (layer_type == "Cast")
1629         {
1630             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1631             {
1632                 Mat blob = getBlob(node_proto, 0);
1633                 int type;
1634                 switch (layerParams.get<int>("to"))
1635                 {
1636                     case opencv_onnx::TensorProto_DataType_FLOAT:   type = CV_32F; break;
1637                     case opencv_onnx::TensorProto_DataType_UINT8:   type = CV_8U; break;
1638                     case opencv_onnx::TensorProto_DataType_UINT16:  type = CV_16U; break;
1639                     case opencv_onnx::TensorProto_DataType_FLOAT16: type = CV_16S; break;
1640                     case opencv_onnx::TensorProto_DataType_INT8:
1641                     case opencv_onnx::TensorProto_DataType_INT16:
1642                     case opencv_onnx::TensorProto_DataType_INT32:
1643                     case opencv_onnx::TensorProto_DataType_INT64:   type = CV_32S; break;
1644                     default: type = blob.type();
1645                 }
1646                 Mat dst;
1647                 blob.convertTo(dst, type);
1648                 dst.dims = blob.dims;
1649                 addConstant(layerParams.name, dst);
1650                 return;
1651             }
1652             else
1653                 layerParams.type = "Identity";
1654         }
1655         else if (layer_type == "ConstantOfShape" || layer_type == "ConstantFill")
1656         {
1657             int depth = CV_32F;
1658             float fill_value;
1659             if (!layerParams.blobs.empty())
1660             {
1661                 CV_Assert(!layerParams.has("value"));
1662                 depth = layerParams.blobs[0].depth();
1663                 Mat floats;
1664                 layerParams.blobs[0].convertTo(floats, CV_32F);
1665                 fill_value = floats.at<float>(0, 0);
1666             }
1667             else
1668                 fill_value = layerParams.get("value", 0);
1669
1670             MatShape inpShape = getBlob(node_proto, 0);
1671             for (int i = 0; i < inpShape.size(); i++)
1672                 CV_CheckGT(inpShape[i], 0, "");
1673             Mat tensor(inpShape.size(), &inpShape[0], depth, Scalar(fill_value));
1674             addConstant(layerParams.name, tensor);
1675             return;
1676         }
1677         else if (layer_type == "Gather")
1678         {
1679             CV_Assert(node_proto.input_size() == 2);
1680             Mat indexMat = getBlob(node_proto, 1);
1681             CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
1682             int index = indexMat.at<int>(0);
1683             int axis = layerParams.get<int>("axis", 0);
1684
1685             if ((constBlobs.find(node_proto.input(0)) != constBlobs.end()))
1686             {
1687                 Mat input = getBlob(node_proto, 0);
1688                 Mat out;
1689                 std::vector<cv::Range> ranges(input.dims, Range::all());
1690                 ranges[axis] = Range(index, index + 1);
1691
1692                 out = input(ranges);
1693                 MatShape outShape = shape(out);
1694                 if (outShape.size() > 1)
1695                 {
1696                     outShape.erase(outShape.begin() + axis);
1697                     out.reshape(0, outShape);
1698                 } else {
1699                     out.dims = 1;
1700                 }
1701                 addConstant(layerParams.name, out);
1702                 return;
1703             }
1704             else
1705             {
1706                 IterShape_t shapeIt = outShapes.find(node_proto.input(0));
1707                 CV_Assert(shapeIt != outShapes.end());
1708                 MatShape inpShape = shapeIt->second;
1709
1710                 LayerParams sliceLp;
1711                 sliceLp.type = "Slice";
1712                 sliceLp.name = inpShape.size() > 1 ? layerParams.name + "/slice" : layerParams.name;
1713                 std::vector<int> begin(inpShape.size(), 0);
1714                 std::vector<int> end(inpShape.size(), -1);
1715                 begin[axis] = index;
1716                 end[axis] = index + 1;
1717
1718                 cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt(begin.data(), begin.size());
1719                 cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt(end.data(), end.size());
1720                 sliceLp.set("begin", paramBegin);
1721                 sliceLp.set("end", paramEnd);
1722                 sliceLp.set("has_dynamic_shapes", hasDynamicShapes);
1723
1724                 if (inpShape.size() > 1)
1725                 {
1726                     opencv_onnx::NodeProto proto;
1727                     proto.add_input(node_proto.input(0));
1728                     proto.add_output(sliceLp.name);
1729                     addLayer(sliceLp, proto);
1730
1731                     inpShape.erase(inpShape.begin() + axis);
1732                     layerParams.type = "Reshape";
1733                     layerParams.set("axis", 0);
1734                     layerParams.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
1735                     if (hasDynamicShapes)
1736                     {
1737                         std::vector<int> dynamicAxes;
1738                         std::vector<int> inputIndices;
1739                         for (int index = 0; index < inpShape.size(); ++index)
1740                             dynamicAxes.push_back(index);
1741                         for (int index = 0; index < inpShape.size(); ++index)
1742                             inputIndices.push_back(index);
1743                         layerParams.set("dynamic_axes", DictValue::arrayInt(dynamicAxes.data(), dynamicAxes.size()));
1744                         layerParams.set("input_indices", DictValue::arrayInt(inputIndices.data(), inputIndices.size()));
1745                     }
1746                     node_proto.set_input(0, sliceLp.name);
1747                 }
1748                 else
1749                 {
1750                     layerParams = sliceLp;
1751                 }
1752             }
1753         }
1754         else if (layer_type == "Concat")
1755         {
1756             bool hasVariableInps = false;
1757             for (int i = 0; i < node_proto.input_size(); ++i)
1758             {
1759                 if (layer_id.find(node_proto.input(i)) != layer_id.end())
1760                 {
1761                     hasVariableInps = true;
1762                     break;
1763                 }
1764             }
1765
1766             if (!hasVariableInps)
1767             {
1768                 std::vector<Mat> inputs(node_proto.input_size()), concatenated;
1769                 // Due constant folding we can get inputs with different number of dimensions
1770                 // Insert the missing dimension to inputs
1771                 MatShape inputShape;
1772                 for (size_t i = 0; i < inputs.size(); ++i)
1773                 {
1774                     inputs[i] = getBlob(node_proto, i);
1775                     if (inputs[i].size.dims() > inputShape.size())
1776                     {
1777                         inputShape = shape(inputs[i]);
1778                     }
1779                 }
1780
1781                 // Concat-1 has default value for axis is 1: https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Concat-1
1782                 int axis = layerParams.get<int>("axis", 1);
1783                 for (size_t i = 0; i < inputs.size(); ++i)
1784                 {
1785                     MatShape targetShape = inputShape;
1786                     targetShape[axis] = shape(inputs[i])[axis];
1787                     CV_CheckEQ(total(targetShape), total(shape(inputs[i])), "");
1788                     inputs[i] = inputs[i].reshape(0, targetShape);
1789                 }
1790                 runLayer(layerParams, inputs, concatenated);
1791
1792                 CV_Assert(concatenated.size() == 1);
1793                 addConstant(layerParams.name, concatenated[0]);
1794                 return;
1795             }
1796         }
1797         else if (layer_type == "Resize")
1798         {
1799             for (int i = 1; i < node_proto.input_size(); i++)
1800                 CV_Assert(layer_id.find(node_proto.input(i)) == layer_id.end());
1801
1802             if (layerParams.has("coordinate_transformation_mode"))
1803             {
1804                 String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
1805                 CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");
1806
1807                 layerParams.set("align_corners", interp_mode == "align_corners");
1808                 if (layerParams.get<String>("mode") == "linear")
1809                 {
1810                     layerParams.set("mode", interp_mode == "pytorch_half_pixel" ?
1811                                             "opencv_linear" : "bilinear");
1812                 }
1813             }
1814             if (layerParams.get<String>("mode") == "linear" && framework_name == "pytorch")
1815                 layerParams.set("mode", "opencv_linear");
1816
1817             // input = [X, scales], [X, roi, scales] or [x, roi, scales, sizes]
1818             int foundScaleId = hasDynamicShapes ? node_proto.input_size() - 1
1819                                                 : node_proto.input_size() > 2 ? 2 : 1;
1820
1821             Mat scales = getBlob(node_proto, foundScaleId);
1822             if (scales.total() == 4)
1823             {
1824                 layerParams.set("zoom_factor_y", scales.at<float>(2));
1825                 layerParams.set("zoom_factor_x", scales.at<float>(3));
1826             }
1827             else
1828             {
1829                 const std::string& inputLast = node_proto.input(node_proto.input_size() - 1);
1830                 if (constBlobs.find(inputLast) != constBlobs.end())
1831                 {
1832                     Mat shapes = getBlob(inputLast);
1833                     CV_CheckEQ(shapes.size[0], 4, "");
1834                     CV_CheckEQ(shapes.size[1], 1, "");
1835                     CV_CheckDepth(shapes.depth(), shapes.depth() == CV_32S || shapes.depth() == CV_32F, "");
1836                     if (shapes.depth() == CV_32F)
1837                         shapes.convertTo(shapes, CV_32S);
1838                     layerParams.set("width", shapes.at<int>(3));
1839                     layerParams.set("height", shapes.at<int>(2));
1840                 }
1841             }
1842             replaceLayerParam(layerParams, "mode", "interpolation");
1843         }
1844         else if (layer_type == "Upsample")
1845         {
1846             //fused from Resize Subgraph
1847             if (layerParams.has("coordinate_transformation_mode"))
1848             {
1849                 String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
1850                 CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");
1851
1852                 layerParams.set("align_corners", interp_mode == "align_corners");
1853                 if (layerParams.get<String>("mode") == "linear")
1854                 {
1855                     layerParams.set("mode", interp_mode == "pytorch_half_pixel" ?
1856                                             "opencv_linear" : "bilinear");
1857                 }
1858             }
1859             if (layerParams.get<String>("mode") == "linear" && framework_name == "pytorch")
1860                 layerParams.set("mode", "opencv_linear");
1861
1862             layerParams.type = "Resize";
1863             if (layerParams.has("scales"))
1864             {
1865                 // Pytorch layer
1866                 DictValue scales = layerParams.get("scales");
1867                 CV_Assert(scales.size() == 4);
1868                 layerParams.set("zoom_factor_y", scales.getIntValue(2));
1869                 layerParams.set("zoom_factor_x", scales.getIntValue(3));
1870             }
1871             else if (layerParams.has("height_scale") && layerParams.has("width_scale"))
1872             {
1873                 // Caffe2 layer
1874                 replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
1875                 replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
1876             }
1877             else
1878             {
1879                 // scales as input
1880                 const std::string& input1 = node_proto.input(1);
1881                 if (constBlobs.find(input1) != constBlobs.end())
1882                 {
1883                     Mat scales = getBlob(input1);
1884                     CV_Assert(scales.total() == 4);
1885                     layerParams.set("zoom_factor_y", scales.at<float>(2));
1886                     layerParams.set("zoom_factor_x", scales.at<float>(3));
1887                 }
1888             }
1889             replaceLayerParam(layerParams, "mode", "interpolation");
1890         }
1891         else if (layer_type == "SoftMax" || layer_type == "LogSoftmax")
1892         {
1893             layerParams.type = "Softmax";
1894             layerParams.set("log_softmax", layer_type == "LogSoftmax");
1895         }
1896         else if (layer_type == "DetectionOutput")
1897         {
1898             CV_CheckEQ(node_proto.input_size(), 3, "");
1899             if (constBlobs.find(node_proto.input(2)) != constBlobs.end())
1900             {
1901                 Mat priors = getBlob(node_proto, 2);
1902
1903                 LayerParams constParams;
1904                 constParams.name = layerParams.name + "/priors";
1905                 constParams.type = "Const";
1906                 constParams.blobs.push_back(priors);
1907
1908                 opencv_onnx::NodeProto priorsProto;
1909                 priorsProto.add_output(constParams.name);
1910                 addLayer(constParams, priorsProto);
1911
1912                 node_proto.set_input(2, constParams.name);
1913             }
1914         }
1915         else
1916         {
1917             for (int j = 0; j < node_proto.input_size(); j++) {
1918                 if (layer_id.find(node_proto.input(j)) == layer_id.end())
1919                     layerParams.blobs.push_back(getBlob(node_proto, j));
1920             }
1921         }
1922         addLayer(layerParams, node_proto);
1923     }
1924     catch (const cv::Exception& e)
1925     {
1926         CV_LOG_ERROR(NULL, "DNN/ONNX: ERROR during processing node with " << node_proto.input_size() << " inputs and " << node_proto.output_size() << " outputs: "
1927                 << cv::format("[%s]:(%s)", layer_type.c_str(), name.c_str())
1928         );
1929         for (int i = 0; i < node_proto.input_size(); i++)
1930         {
1931             CV_LOG_INFO(NULL, "    Input[" << i << "] = '" << node_proto.input(i) << "'");
1932         }
1933         for (int i = 0; i < node_proto.output_size(); i++)
1934         {
1935             CV_LOG_INFO(NULL, "    Output[" << i << "] = '" << node_proto.output(i) << "'");
1936         }
1937         CV_Error(Error::StsError, cv::format("Node [%s]:(%s) parse error: %s", layer_type.c_str(), name.c_str(), e.what()));
1938     }
1939 }
1940
1941 Net readNetFromONNX(const String& onnxFile)
1942 {
1943     Net net;
1944     ONNXImporter onnxImporter(net, onnxFile.c_str());
1945     return net;
1946 }
1947
1948 Net readNetFromONNX(const char* buffer, size_t sizeBuffer)
1949 {
1950     Net net;
1951     ONNXImporter onnxImporter(net, buffer, sizeBuffer);
1952     return net;
1953 }
1954
1955 Net readNetFromONNX(const std::vector<uchar>& buffer)
1956 {
1957     return readNetFromONNX(reinterpret_cast<const char*>(buffer.data()), buffer.size());
1958 }
1959
1960 Mat readTensorFromONNX(const String& path)
1961 {
1962     std::fstream input(path.c_str(), std::ios::in | std::ios::binary);
1963     if (!input)
1964     {
1965         CV_Error(Error::StsBadArg, cv::format("Can't read ONNX file: %s", path.c_str()));
1966     }
1967
1968     opencv_onnx::TensorProto tensor_proto = opencv_onnx::TensorProto();
1969     if (!tensor_proto.ParseFromIstream(&input))
1970     {
1971         CV_Error(Error::StsUnsupportedFormat, cv::format("Failed to parse ONNX data: %s", path.c_str()));
1972     }
1973     Mat mat = getMatFromTensor(tensor_proto);
1974     releaseONNXTensor(tensor_proto);
1975     return mat;
1976 }
1977
1978 CV__DNN_INLINE_NS_END
1979 }} // namespace
1980
1981 #endif