Merge remote-tracking branch 'upstream/3.4' into merge-3.4
[platform/upstream/opencv.git] / modules / dnn / src / onnx / onnx_importer.cpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7
8 #include "../precomp.hpp"
9 #include <opencv2/dnn/shape_utils.hpp>
10
11 #ifdef HAVE_PROTOBUF
12
13 #include <iostream>
14 #include <fstream>
15 #include <string>
16 #include <limits>
17 #include <algorithm>
18
19
20 #if defined(__GNUC__) && __GNUC__ >= 5
21 #pragma GCC diagnostic push
22 #pragma GCC diagnostic ignored "-Wsuggest-override"
23 #endif
24 #include "opencv-onnx.pb.h"
25 #if defined(__GNUC__) && __GNUC__ >= 5
26 #pragma GCC diagnostic pop
27 #endif
28
29 #include "onnx_graph_simplifier.hpp"
30
31 namespace cv {
32 namespace dnn {
33 CV__DNN_INLINE_NS_BEGIN
34
35
36 class ONNXImporter
37 {
38     opencv_onnx::ModelProto model_proto;
39     struct LayerInfo {
40         int layerId;
41         int outputId;
42         LayerInfo(int _layerId = 0, int _outputId = 0) : layerId(_layerId), outputId(_outputId) {}
43     };
44
45     std::map<std::string, Mat> getGraphTensors(
46                                     const opencv_onnx::GraphProto& graph_proto);
47     Mat getBlob(const opencv_onnx::NodeProto& node_proto, const std::map<std::string, Mat>& constBlobs, int index);
48
49     LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto);
50     bool isCeilMode(const LayerParams& layerParams);
51
52     void addLayer(Net& dstNet, LayerParams& layerParams,
53                   const opencv_onnx::NodeProto& node_proto,
54                   std::map<std::string, LayerInfo>& layer_id,
55                   std::map<std::string, MatShape>& outShapes);
56
57 public:
58
59     ONNXImporter(const char *onnxFile)
60     {
61         std::fstream input(onnxFile, std::ios::in | std::ios::binary);
62
63         if (!model_proto.ParseFromIstream(&input))
64             CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model");
65     }
66
67     ONNXImporter(const char* buffer, size_t sizeBuffer)
68     {
69         struct _Buf : public std::streambuf
70         {
71             _Buf(const char* buffer, size_t sizeBuffer)
72             {
73                 char* p = const_cast<char*>(buffer);
74                 setg(p, p, p + sizeBuffer);
75             }
76         };
77
78         _Buf buf(buffer, sizeBuffer);
79         std::istream input(&buf);
80
81         if (!model_proto.ParseFromIstream(&input))
82             CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model from in-memory byte array.");
83     }
84
85     void populateNet(Net dstNet);
86 };
87
88 inline void replaceLayerParam(LayerParams& layerParams, const String& oldKey, const String& newKey)
89 {
90     if (layerParams.has(oldKey)) {
91         layerParams.set(newKey, layerParams.get(oldKey));
92         layerParams.erase(oldKey);
93     }
94 }
95
96 void releaseONNXTensor(opencv_onnx::TensorProto& tensor_proto)
97 {
98     if (!tensor_proto.raw_data().empty()) {
99         delete tensor_proto.release_raw_data();
100     }
101 }
102
103 void runLayer(LayerParams& params, const std::vector<Mat>& inputs,
104               std::vector<Mat>& outputs)
105 {
106     Ptr<Layer> layer = LayerFactory::createLayerInstance(params.type, params);
107     CV_Assert((bool)layer);
108
109     std::vector<MatShape> inpShapes(inputs.size());
110     int ddepth = CV_32F;
111     for (size_t i = 0; i < inputs.size(); ++i)
112     {
113         inpShapes[i] = shape(inputs[i]);
114         if (i > 0 && ddepth != inputs[i].depth())
115             CV_Error(Error::StsNotImplemented, "Mixed input data types.");
116         ddepth = inputs[i].depth();
117     }
118
119     std::vector<MatShape> outShapes, internalShapes;
120     layer->getMemoryShapes(inpShapes, 0, outShapes, internalShapes);
121
122     std::vector<Mat> internals(internalShapes.size());
123     outputs.resize(outShapes.size());
124     for (size_t i = 0; i < outShapes.size(); ++i)
125         outputs[i].create(outShapes[i], ddepth);
126     for (size_t i = 0; i < internalShapes.size(); ++i)
127         internals[i].create(internalShapes[i], ddepth);
128
129     layer->finalize(inputs, outputs);
130     layer->forward(inputs, outputs, internals);
131 }
132
133 std::map<std::string, Mat> ONNXImporter::getGraphTensors(
134                                         const opencv_onnx::GraphProto& graph_proto)
135 {
136   opencv_onnx::TensorProto tensor_proto;
137   std::map<std::string, Mat> layers_weights;
138
139   for (int i = 0; i < graph_proto.initializer_size(); i++)
140   {
141     tensor_proto = graph_proto.initializer(i);
142     Mat mat = getMatFromTensor(tensor_proto);
143     releaseONNXTensor(tensor_proto);
144     layers_weights.insert(std::make_pair(tensor_proto.name(), mat));
145   }
146   return layers_weights;
147 }
148
149 static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protobuf::int64>& src) {
150     std::vector<int32_t> dst(src.size());
151     convertInt64ToInt32(src, dst, src.size());
152     return DictValue::arrayInt(&dst[0], src.size());
153 }
154
155 LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto)
156 {
157     LayerParams lp;
158     for(int i = 0; i < node_proto.attribute_size(); i++)
159     {
160         opencv_onnx::AttributeProto attribute_proto = node_proto.attribute(i);
161         std::string attribute_name = attribute_proto.name();
162
163         if(attribute_name == "kernel_shape")
164         {
165             CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
166             lp.set("kernel_size", parse(attribute_proto.ints()));
167         }
168         else if(attribute_name == "strides")
169         {
170             CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
171             lp.set("stride", parse(attribute_proto.ints()));
172         }
173         else if(attribute_name == "pads")
174         {
175             if (node_proto.op_type() == "Pad")
176             {
177                 // Padding layer.
178                 // Paddings are in order begin0, begin1, .. beginN, end0, end1, ..., endN.
179                 // We need to shuffle it to begin0, end0, begin1, end1, ...
180                 CV_Assert(attribute_proto.ints_size() % 2 == 0);
181                 const int dims = attribute_proto.ints_size() / 2;
182                 std::vector<int32_t> paddings;
183                 paddings.reserve(attribute_proto.ints_size());
184                 for (int i = 0; i < dims; ++i)
185                 {
186                     paddings.push_back(attribute_proto.ints(i));
187                     paddings.push_back(attribute_proto.ints(dims + i));
188                 }
189                 lp.set("paddings", DictValue::arrayInt(&paddings[0], paddings.size()));
190             }
191             else
192             {
193                 // Convolution or pooling.
194                 CV_Assert(attribute_proto.ints_size() == 4 || attribute_proto.ints_size() == 6);
195                 lp.set("pad", parse(attribute_proto.ints()));
196             }
197         }
198         else if(attribute_name == "auto_pad")
199         {
200             if (attribute_proto.s() == "SAME_UPPER" || attribute_proto.s() == "SAME_LOWER") {
201                 lp.set("pad_mode",  "SAME");
202             }
203             else if (attribute_proto.s() == "VALID") {
204                 lp.set("pad_mode", "VALID");
205             }
206         }
207         else if(attribute_name == "dilations")
208         {
209             CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
210             lp.set("dilation", parse(attribute_proto.ints()));
211         }
212         else if (attribute_proto.has_i())
213         {
214             ::google::protobuf::int64 src = attribute_proto.i();
215             if (src < std::numeric_limits<int32_t>::min() || src > std::numeric_limits<int32_t>::max())
216                 CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
217             else
218                 lp.set(attribute_name, saturate_cast<int32_t>(src));
219         }
220         else if (attribute_proto.has_f())
221         {
222             lp.set(attribute_name, attribute_proto.f());
223         }
224         else if (attribute_proto.has_s())
225         {
226             lp.set(attribute_name, attribute_proto.s());
227         }
228         else if (attribute_proto.floats_size() > 0)
229         {
230             lp.set(attribute_name, DictValue::arrayReal(
231                 attribute_proto.floats().data(), attribute_proto.floats_size()));
232         }
233         else if (attribute_proto.ints_size() > 0)
234         {
235             lp.set(attribute_proto.name(), parse(attribute_proto.ints()));
236         }
237         else if (attribute_proto.has_t())
238         {
239             opencv_onnx::TensorProto tensor = attribute_proto.t();
240             Mat blob = getMatFromTensor(tensor);
241             lp.blobs.push_back(blob);
242         }
243         else if (attribute_proto.has_g() || attribute_proto.strings_size() > 0 ||
244                     attribute_proto.tensors_size() > 0 || attribute_proto.graphs_size() > 0)
245         {
246                 CV_Error(Error::StsNotImplemented, "Unexpected attribute type");
247         }
248         else
249             CV_Error(Error::StsNotImplemented, "Unsupported attribute type");
250     }
251     return lp;
252 }
253
254 Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto,
255                     const std::map<std::string, Mat>& constBlobs, int index)
256 {
257     CV_Assert(index < node_proto.input_size());
258     std::map<std::string, Mat>::const_iterator constBlob;
259     constBlob = constBlobs.find(node_proto.input(index));
260     if (constBlob == constBlobs.end()) {
261         CV_Error(Error::StsObjectNotFound,
262              "Blob " + node_proto.input(index) + " not found in const blobs");
263     }
264     return constBlob->second;
265 }
266
267 void ONNXImporter::addLayer(Net& dstNet, LayerParams& layerParams,
268                             const opencv_onnx::NodeProto& node_proto,
269                             std::map<std::string, LayerInfo>& layer_id,
270                             std::map<std::string, MatShape>& outShapes)
271 {
272     std::map<std::string, LayerInfo>::iterator layerId;
273     std::map<std::string, MatShape>::iterator shapeIt;
274
275     int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
276     for (int i = 0; i < node_proto.output_size(); ++i)
277     {
278         layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
279     }
280
281     std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
282     int inpNum = 0;
283     for (int j = 0; j < node_proto.input_size(); j++) {
284         layerId = layer_id.find(node_proto.input(j));
285         if (layerId != layer_id.end()) {
286             dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, inpNum);
287             ++inpNum;
288             // Collect input shapes.
289             shapeIt = outShapes.find(node_proto.input(j));
290             CV_Assert(shapeIt != outShapes.end());
291             layerInpShapes.push_back(shapeIt->second);
292         }
293     }
294     // Compute shape of output blob for this layer.
295     Ptr<Layer> layer = dstNet.getLayer(id);
296     layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
297     for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
298     {
299         outShapes[node_proto.output(i)] = layerOutShapes[i];
300     }
301 }
302
303 static void addConstant(const std::string& name,
304                         const Mat& blob,
305                         std::map<std::string, Mat>& constBlobs,
306                         std::map<std::string, MatShape>& outShapes)
307 {
308     constBlobs.insert(std::make_pair(name, blob));
309     outShapes.insert(std::make_pair(name, shape(blob)));
310 }
311
312 void ONNXImporter::populateNet(Net dstNet)
313 {
314     CV_Assert(model_proto.has_graph());
315     opencv_onnx::GraphProto graph_proto = model_proto.graph();
316
317     simplifySubgraphs(graph_proto);
318
319     std::map<std::string, Mat> constBlobs = getGraphTensors(graph_proto);
320     // List of internal blobs shapes.
321     std::map<std::string, MatShape> outShapes;
322     // Add all the inputs shapes. It includes as constant blobs as network's inputs shapes.
323     for (int i = 0; i < graph_proto.input_size(); ++i)
324     {
325         opencv_onnx::ValueInfoProto valueInfoProto = graph_proto.input(i);
326         CV_Assert(valueInfoProto.has_type());
327         opencv_onnx::TypeProto typeProto = valueInfoProto.type();
328         CV_Assert(typeProto.has_tensor_type());
329         opencv_onnx::TypeProto::Tensor tensor = typeProto.tensor_type();
330         CV_Assert(tensor.has_shape());
331         opencv_onnx::TensorShapeProto tensorShape = tensor.shape();
332
333         MatShape inpShape(tensorShape.dim_size());
334         for (int j = 0; j < inpShape.size(); ++j)
335         {
336             inpShape[j] = tensorShape.dim(j).dim_value();
337         }
338         outShapes[valueInfoProto.name()] = inpShape;
339     }
340
341     std::string framework_name;
342     if (model_proto.has_producer_name()) {
343         framework_name = model_proto.producer_name();
344     }
345
346     // create map with network inputs (without const blobs)
347     std::map<std::string, LayerInfo> layer_id;
348     std::map<std::string, LayerInfo>::iterator layerId;
349     std::map<std::string, MatShape>::iterator shapeIt;
350     // fill map: push layer name, layer id and output id
351     std::vector<String> netInputs;
352     for (int j = 0; j < graph_proto.input_size(); j++)
353     {
354         const std::string& name = graph_proto.input(j).name();
355         if (constBlobs.find(name) == constBlobs.end()) {
356             netInputs.push_back(name);
357             layer_id.insert(std::make_pair(name, LayerInfo(0, netInputs.size() - 1)));
358         }
359     }
360     dstNet.setInputsNames(netInputs);
361
362     int layersSize = graph_proto.node_size();
363     LayerParams layerParams;
364     opencv_onnx::NodeProto node_proto;
365
366     for(int li = 0; li < layersSize; li++)
367     {
368         node_proto = graph_proto.node(li);
369         layerParams = getLayerParams(node_proto);
370         CV_Assert(node_proto.output_size() >= 1);
371         layerParams.name = node_proto.output(0);
372
373         std::string layer_type = node_proto.op_type();
374         layerParams.type = layer_type;
375
376
377         if (layer_type == "MaxPool")
378         {
379             layerParams.type = "Pooling";
380             layerParams.set("pool", "MAX");
381             layerParams.set("ceil_mode", layerParams.has("pad_mode"));
382         }
383         else if (layer_type == "AveragePool")
384         {
385             layerParams.type = "Pooling";
386             layerParams.set("pool", "AVE");
387             layerParams.set("ceil_mode", layerParams.has("pad_mode"));
388             layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
389         }
390         else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || layer_type == "ReduceMean")
391         {
392             CV_Assert(node_proto.input_size() == 1);
393             layerParams.type = "Pooling";
394             layerParams.set("pool", layer_type == "GlobalMaxPool"? "MAX" : "AVE");
395             layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
396
397             if (layer_type == "ReduceMean")
398             {
399                 if (layerParams.get<int>("keepdims") == 0 || !layerParams.has("axes"))
400                     CV_Error(Error::StsNotImplemented, "Unsupported mode of ReduceMean operation.");
401
402                 MatShape inpShape = outShapes[node_proto.input(0)];
403                 DictValue axes = layerParams.get("axes");
404                 if (inpShape.size() == 3 && axes.size() <= 2)
405                 {
406                     int axis = axes.get<int>(0);
407                     CV_CheckNE(axis, 0, "");
408                     outShapes[layerParams.name] = inpShape;
409                     outShapes[layerParams.name][axis] = 1;
410
411                     LayerParams reshapeLp;
412                     reshapeLp.name = layerParams.name + "/reshape";
413                     reshapeLp.type = "Reshape";
414                     CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
415                     reshapeLp.set("axis", 0);
416                     reshapeLp.set("num_axes", 1);
417                     int newShape[] = {1, -1};
418                     reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 2));
419
420                     opencv_onnx::NodeProto proto;
421                     proto.add_input(node_proto.input(0));
422                     proto.add_output(reshapeLp.name);
423                     addLayer(dstNet, reshapeLp, proto, layer_id, outShapes);
424
425                     LayerParams avgLp;
426                     avgLp.name = layerParams.name + "/avg";
427                     avgLp.type = "Pooling";
428                     CV_Assert(layer_id.find(avgLp.name) == layer_id.end());
429                     avgLp.set("pool", "ave");
430                     if (axes.size() == 2)
431                     {
432                         CV_CheckEQ(axes.get<int>(0), 1, "Unsupported ReduceMean mode");
433                         CV_CheckEQ(axes.get<int>(1), 2, "Unsupported ReduceMean mode");
434                         avgLp.set("global_pooling", true);
435                         outShapes[layerParams.name][axes.get<int>(1)] = 1;
436                     }
437                     else
438                     {
439                         avgLp.set(axis == 2 ? "global_pooling_w" : "global_pooling_h", true);
440                         avgLp.set(axis == 2 ? "kernel_h" : "kernel_w", 1);
441                     }
442
443                     node_proto.set_input(0, reshapeLp.name);
444                     node_proto.set_output(0, avgLp.name);
445                     addLayer(dstNet, avgLp, node_proto, layer_id, outShapes);
446
447                     layerParams.type = "Flatten";
448                     layerParams.set("axis", 0);
449                     layerParams.set("end_axis", 1);
450
451                     node_proto.set_input(0, avgLp.name);
452                     node_proto.set_output(0, layerParams.name);
453                 }
454                 else
455                 {
456                     if (inpShape.size() != 4 && inpShape.size() != 5)
457                     CV_Error(Error::StsNotImplemented, "Unsupported input shape of reduce_mean operation.");
458
459                     CV_Assert(axes.size() <= inpShape.size() - 2);
460                     std::vector<int> kernel_size(inpShape.size() - 2, 1);
461                     for (int i = 0; i < axes.size(); i++) {
462                         int axis = axes.get<int>(i);
463                         CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
464                         kernel_size[axis - 2] = inpShape[axis];
465                     }
466                     layerParams.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
467                 }
468             }
469         }
470         else if (layer_type == "Slice")
471         {
472             int axis = 0;
473             std::vector<int> begin;
474             std::vector<int> end;
475             int inp_size = node_proto.input_size();
476
477             if (inp_size == 1)
478             {
479                 if (layerParams.has("steps"))
480                 {
481                     DictValue steps = layerParams.get("steps");
482                     for (int i = 0; i < steps.size(); ++i)
483                     {
484                         if (steps.get<int>(i) != 1)
485                             CV_Error(Error::StsNotImplemented,
486                                 "Slice layer only supports steps = 1");
487                     }
488                 }
489                 if (layerParams.has("axes")) {
490                     DictValue axes = layerParams.get("axes");
491                     for (int i = 1; i < axes.size(); ++i) {
492                         CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
493                     }
494                     axis = axes.get<int>(0);
495                 }
496
497                 DictValue starts = layerParams.get("starts");
498                 DictValue ends = layerParams.get("ends");
499                 CV_Assert(starts.size() == ends.size());
500
501                 if (axis > 0) {
502                     begin.resize(axis, 0);
503                     end.resize(axis, -1);
504                 }
505                 for (int i = 0; i < starts.size(); ++i)
506                 {
507                     begin.push_back(starts.get<int>(i));
508                     int finish = ends.get<int>(i);
509                     end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
510                 }
511             } else {
512                 CV_Assert(inp_size >= 3);
513                 for (int i = 1; i < inp_size; i++) {
514                     CV_Assert(constBlobs.find(node_proto.input(i)) != constBlobs.end());
515                 }
516                 Mat start_blob = getBlob(node_proto, constBlobs, 1);
517                 Mat end_blob   = getBlob(node_proto, constBlobs, 2);
518                 CV_Assert(start_blob.total() == end_blob.total());
519
520                 if (inp_size > 3) {
521                     Mat axes_blob = getBlob(node_proto, constBlobs, 3);
522                     const int* axes = (int*)axes_blob.data;
523                     for (int i = 1; i < axes_blob.total(); ++i) {
524                         CV_Assert(axes[i - 1] == axes[i] - 1);
525                     }
526                     axis = axes[0];
527                 }
528
529                 const int* starts = start_blob.ptr<int>();
530                 const int* ends   = end_blob.ptr<int>();
531                 if (axis > 0) {
532                     begin.resize(axis, 0);
533                     end.resize(axis, -1);
534                 }
535                 std::copy(starts, starts + start_blob.total(), std::back_inserter(begin));
536                 for (int i = 0; i < end_blob.total(); ++i)
537                 {
538                     int finish = ends[i];
539                     end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
540                 }
541
542                 if (inp_size == 5) {
543                     CV_Assert(constBlobs.find(node_proto.input(4)) != constBlobs.end());
544                     Mat step_blob = getBlob(node_proto, constBlobs, 4);
545
546                     // Very strange application for Slice op with tensor reversing.
547                     // We just workaround it for 2d constants.
548                     if (constBlobs.find(node_proto.input(0)) != constBlobs.end() &&
549                         axis == 0 &&
550                         start_blob.at<int>(0) == -1 && step_blob.at<int>(0) == -1 &&
551                         end_blob.at<int>(0) == std::numeric_limits<int32_t>::min())
552                     {
553                         Mat inp = getBlob(node_proto, constBlobs, 0);
554                         if (inp.dims == 2)
555                         {
556                             Mat flipped;
557                             flip(inp, flipped, 0);
558                             addConstant(layerParams.name, flipped, constBlobs, outShapes);
559                             continue;
560                         }
561                     }
562                     CV_CheckEQ(countNonZero(step_blob != 1), 0, "Slice layer only supports steps = 1");
563                 }
564             }
565             layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
566             layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
567             layerParams.set("axis", axis);
568
569             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
570             {
571                 Mat inp = getBlob(node_proto, constBlobs, 0);
572                 std::vector<Mat> inputs, sliced;
573                 inputs.push_back(inp);
574                 runLayer(layerParams, inputs, sliced);
575                 CV_Assert(sliced.size() == 1);
576                 addConstant(layerParams.name, sliced[0], constBlobs, outShapes);
577                 continue;
578             }
579         }
580         else if (layer_type == "Split")
581         {
582             if (layerParams.has("split"))
583             {
584                 DictValue splits = layerParams.get("split");
585                 const int numSplits = splits.size();
586                 CV_Assert(numSplits > 1);
587
588                 std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
589                 for (int i = 1; i < splits.size() - 1; ++i)
590                 {
591                     slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1);
592                 }
593                 layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
594             }
595             else
596             {
597                 layerParams.set("num_split", node_proto.output_size());
598             }
599             layerParams.type = "Slice";
600         }
601         else if (layer_type == "Add" || layer_type == "Sum" || layer_type == "Sub")
602         {
603             bool isSub = layer_type == "Sub";
604             CV_CheckEQ(node_proto.input_size(), 2, "");
605             bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end();
606             bool is_const_1 = layer_id.find(node_proto.input(1)) == layer_id.end();
607             if (is_const_0 && is_const_1)
608             {
609                 Mat blob_0 = getBlob(node_proto, constBlobs, 0);
610                 Mat blob_1 = getBlob(node_proto, constBlobs, 1);
611                 CV_Assert(blob_0.size == blob_1.size);
612                 Mat output = isSub ? (blob_0 - blob_1) : (blob_0 + blob_1);
613                 addConstant(layerParams.name, output, constBlobs, outShapes);
614                 continue;
615             }
616             else if (is_const_0 || is_const_1)
617             {
618                 int const_blob_id = is_const_0 ? 0 : 1;
619                 Mat blob = getBlob(node_proto, constBlobs, const_blob_id);
620                 int blob_total = blob.total();
621                 if (blob_total == 1) {
622                     layerParams.type = "Power";
623                     layerParams.set("shift", (isSub ? -1 : 1) * blob.at<float>(0));
624                 }
625                 else {
626                     MatShape inpShape = outShapes[node_proto.input(1 - const_blob_id)];
627                     if (shape(blob) == inpShape)
628                     {
629                         LayerParams constParams;
630                         constParams.name = layerParams.name + "/const";
631                         constParams.type = "Const";
632                         constParams.blobs.push_back(blob);
633                         int id = dstNet.addLayer(constParams.name, constParams.type, constParams);
634                         layer_id.insert(std::make_pair(constParams.name, LayerInfo(id, 0)));
635                         outShapes[constParams.name] = shape(blob);
636
637                         layerParams.type = "Eltwise";
638                         node_proto.set_input(const_blob_id, constParams.name);
639                     }
640                     else
641                     {
642                         layerParams.type = "Scale";
643                         layerParams.set("bias_term", true);
644                         blob = blob.reshape(1, 1);
645                         layerParams.blobs.push_back((isSub ? -1 : 1) * blob);
646                     }
647                 }
648             }
649             else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)])
650             {
651                 layerParams.type = "Eltwise";
652                 if (isSub)
653                 {
654                     static float subCoeffs[] = {1.f, -1.f};
655                     layerParams.set("coeff", DictValue::arrayReal<float*>(subCoeffs, 2));
656                 }
657             }
658             else
659             {
660                 if (isSub)
661                 {
662                     LayerParams powerParams;
663                     powerParams.name = layerParams.name + "/neg";
664                     powerParams.type = "Power";
665                     powerParams.set("scale", -1);
666
667                     //Create Power layer
668                     int id = dstNet.addLayer(powerParams.name, powerParams.type, powerParams);
669                     //Connect to input
670                     layerId = layer_id.find(node_proto.input(1));
671                     CV_Assert(layerId != layer_id.end());
672                     dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
673                     //Add shape
674                     layer_id.insert(std::make_pair(powerParams.name, LayerInfo(id, 0)));
675                     outShapes[powerParams.name] = outShapes[node_proto.input(1)];
676
677                     //Replace input to Power
678                     node_proto.set_input(1, powerParams.name);
679                 }
680                 layerParams.type = "Scale";
681                 layerParams.set("bias_term", true);
682             }
683         }
684         else if (layer_type == "Max")
685         {
686             layerParams.type = "Eltwise";
687             layerParams.set("operation", "max");
688         }
689         else if (layer_type == "Neg")
690         {
691             layerParams.type = "Power";
692             layerParams.set("scale", -1);
693         }
694         else if (layer_type == "Constant")
695         {
696             CV_Assert(node_proto.input_size() == 0);
697             CV_Assert(layerParams.blobs.size() == 1);
698             addConstant(layerParams.name, layerParams.blobs[0], constBlobs, outShapes);
699             continue;
700         }
701         else if (layer_type == "LSTM")
702         {
703             LayerParams lstmParams = layerParams;
704             lstmParams.name += "/lstm";
705
706             // https://pytorch.org/docs/stable/nn.html#lstm
707             CV_Assert(node_proto.input_size() == 7);
708             Mat Wx = getBlob(node_proto, constBlobs, 1);
709             Mat Wh = getBlob(node_proto, constBlobs, 2);
710             Mat b = getBlob(node_proto, constBlobs, 3);
711             CV_CheckEQ(countNonZero(getBlob(node_proto, constBlobs, 5)), 0, "Unsupported non zero initial_h");
712             CV_CheckEQ(countNonZero(getBlob(node_proto, constBlobs, 6)), 0, "Unsupported non zero initial_c");
713             b = b.reshape(1, b.size[0]);
714
715             const int numHidden = lstmParams.get<int>("hidden_size");
716             const int numDirs = Wx.size[0];  // Is 1 for forward only and 2 for bidirectional LSTM.
717             const int numFeatures = Wx.size[2];
718             Mat bx = b.colRange(0, b.cols / 2);
719             Mat bh = b.colRange(b.cols / 2, b.cols);
720             b = bx + bh;
721
722             // IFGO->IGFO
723             for (int k = 0; k < numDirs; ++k)
724             {
725                 float* WxData = Wx.ptr<float>(k);
726                 float* WhData = Wh.ptr<float>(k);
727                 float* biasData = b.ptr<float>(k);
728                 for (int j = 0; j < numHidden; ++j)
729                 {
730                     for (int i = 0; i < numFeatures; ++i)
731                     {
732                         std::swap(WxData[(numHidden + j) * numFeatures + i],
733                                   WxData[(numHidden * 2 + j) * numFeatures + i]);
734                     }
735                     for (int i = 0; i < numHidden; ++i)
736                     {
737                         std::swap(WhData[(numHidden + j) * numHidden + i],
738                                   WhData[(numHidden * 2 + j) * numHidden + i]);
739                     }
740                     std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]);
741                 }
742             }
743             Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
744             Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
745
746             lstmParams.blobs.resize(3);
747             lstmParams.blobs[0] = Wh;
748             lstmParams.blobs[1] = Wx;
749             lstmParams.blobs[2] = b;
750             lstmParams.set("bidirectional", lstmParams.get<String>("direction", "") == "bidirectional");
751
752             node_proto.set_output(0, lstmParams.name);  // set different name so output shapes will be registered on that name
753             addLayer(dstNet, lstmParams, node_proto, layer_id, outShapes);
754
755             MatShape lstmShape = outShapes[node_proto.output(0)];
756
757             // Add fake 1 as it is done in ONNX
758             lstmShape.insert(lstmShape.begin() + 1, 1);
759
760             layerParams.type = "Reshape";
761             layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size()));
762             node_proto.set_input(0, lstmParams.name);  // redirect input to LSTM
763             node_proto.set_output(0, layerParams.name);  // keep origin LSTM's name
764         }
765         else if (layer_type == "ImageScaler")
766         {
767             const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
768             layerParams.erase("scale");
769
770             if (layerParams.has("bias"))
771             {
772                 layerParams.type = "Scale";
773                 layerParams.blobs.push_back(
774                     Mat(Size(1,  layerParams.get("bias").size()), CV_32FC1, scale));
775
776                 layerParams.set("bias_term", true);
777                 Mat bias(1, layerParams.get("bias").size(), CV_32FC1);
778                 for (int j = 0; j < bias.total(); j++) {
779                     bias.at<float>(0, j) = layerParams.get("bias").getRealValue(j);
780                 }
781                 layerParams.blobs.push_back(bias);
782                 layerParams.erase("bias");
783             }
784             else {
785                 layerParams.set("scale", scale);
786                 layerParams.type = "Power";
787             }
788         }
789         else if (layer_type == "Clip")
790         {
791             layerParams.type = "ReLU6";
792             replaceLayerParam(layerParams, "min", "min_value");
793             replaceLayerParam(layerParams, "max", "max_value");
794
795         }
796         else if (layer_type == "LeakyRelu")
797         {
798             layerParams.type = "ReLU";
799             replaceLayerParam(layerParams, "alpha", "negative_slope");
800         }
801         else if (layer_type == "Relu")
802         {
803             layerParams.type = "ReLU";
804         }
805         else if (layer_type == "Elu")
806         {
807             layerParams.type = "ELU";
808         }
809         else if (layer_type == "PRelu")
810         {
811             layerParams.type = "PReLU";
812             layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 1));
813         }
814         else if (layer_type == "LRN")
815         {
816             replaceLayerParam(layerParams, "size", "local_size");
817         }
818         else if (layer_type == "InstanceNormalization")
819         {
820             if (node_proto.input_size() != 3)
821                 CV_Error(Error::StsNotImplemented,
822                          "Expected input, scale, bias");
823
824             layerParams.blobs.resize(4);
825             layerParams.blobs[2] = getBlob(node_proto, constBlobs, 1);  // weightData
826             layerParams.blobs[3] = getBlob(node_proto, constBlobs, 2);  // biasData
827             layerParams.set("has_bias", true);
828             layerParams.set("has_weight", true);
829
830             // Get number of channels in input
831             int size = layerParams.blobs[2].total();
832             layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean
833             layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std
834
835             LayerParams mvnParams;
836             mvnParams.name = layerParams.name + "/MVN";
837             mvnParams.type = "MVN";
838             mvnParams.set("eps", layerParams.get<float>("epsilon"));
839             layerParams.erase("epsilon");
840
841             //Create MVN layer
842             int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams);
843             //Connect to input
844             layerId = layer_id.find(node_proto.input(0));
845             CV_Assert(layerId != layer_id.end());
846             dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
847             //Add shape
848             layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0)));
849             outShapes[mvnParams.name] = outShapes[node_proto.input(0)];
850
851             //Replace Batch Norm's input to MVN
852             node_proto.set_input(0, mvnParams.name);
853             layerParams.type = "BatchNorm";
854         }
855         else if (layer_type == "BatchNormalization")
856         {
857             if (node_proto.input_size() != 5)
858                 CV_Error(Error::StsNotImplemented,
859                          "Expected input, scale, bias, mean and var");
860
861             layerParams.type = "BatchNorm";
862             replaceLayerParam(layerParams, "epsilon", "eps");
863             replaceLayerParam(layerParams, "spatial", "use_global_stats");
864
865             Mat meanData = getBlob(node_proto, constBlobs, 3);
866             Mat stdData =  getBlob(node_proto, constBlobs, 4);
867
868             layerParams.blobs.push_back(meanData);
869             layerParams.blobs.push_back(stdData);
870
871             if (!node_proto.input(1).empty()) {
872                 layerParams.set("has_weight", true);
873                 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 1));  // weightData
874             } else {
875                 layerParams.set("has_weight", false);
876             }
877
878             if (!node_proto.input(2).empty()) {
879                 layerParams.set("has_bias", true);
880                 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 2)); // biasData
881             } else {
882                 layerParams.set("has_bias", false);
883             }
884         }
885         else if (layer_type == "Gemm")
886         {
887             CV_Assert(node_proto.input_size() >= 2);
888             layerParams.type = "InnerProduct";
889             Mat weights = getBlob(node_proto, constBlobs, 1);
890             int ind_num_out = 0;
891             if (layerParams.has("transB") && !layerParams.get<int>("transB")) {
892                 transpose(weights, weights);
893                 ind_num_out = 1;
894             }
895             layerParams.blobs.push_back(weights);
896
897             if (node_proto.input_size() == 3) {
898                 Mat bias = getBlob(node_proto, constBlobs, 2);
899                 layerParams.blobs.push_back(bias);
900             }
901
902             layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
903             layerParams.set("bias_term", node_proto.input_size() == 3);
904         }
905         else if (layer_type == "MatMul")
906         {
907             CV_Assert(node_proto.input_size() == 2);
908             layerParams.type = "InnerProduct";
909             layerParams.set("bias_term", false);
910
911             if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
912             {
913                 Mat blob = getBlob(node_proto, constBlobs, 1);
914                 layerParams.blobs.push_back(blob.t());
915                 layerParams.set("num_output", layerParams.blobs[0].size[0]);
916             }
917         }
918         else if (layer_type == "Mul" || layer_type == "Div")
919         {
920             CV_Assert(node_proto.input_size() == 2);
921
922             bool isDiv = layer_type == "Div";
923             int constId = -1;
924             bool haveVariables = false;
925             for (int i = 0; i < 2; ++i)
926             {
927                 if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
928                     constId = i;
929                 else
930                     haveVariables = true;
931             }
932             if (constId != -1 && haveVariables)
933             {
934                 Mat blob = getBlob(node_proto, constBlobs, constId);
935                 blob = blob.reshape(1, 1);
936                 if (blob.total() == 1) {
937                     float coeff = isDiv ? 1.0 / blob.at<float>(0) : blob.at<float>(0);
938                     layerParams.set("scale", coeff);
939                     layerParams.type = "Power";
940                 }
941                 else {
942                     if (isDiv)
943                         divide(1.0, blob, blob);
944                     layerParams.blobs.push_back(blob);
945                     layerParams.type = "Scale";
946                 }
947             }
948             else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)])
949             {
950                 layerParams.type = "Eltwise";
951                 layerParams.set("operation", isDiv ? "div" : "prod");
952             }
953             else
954             {
955                 if (isDiv)
956                 {
957                     LayerParams powerParams;
958                     powerParams.name = layerParams.name + "/inv";
959                     powerParams.type = "Power";
960                     powerParams.set("power", -1);
961
962                     //Create Power layer
963                     int id = dstNet.addLayer(powerParams.name, powerParams.type, powerParams);
964                     //Connect to input
965                     layerId = layer_id.find(node_proto.input(1));
966                     CV_Assert(layerId != layer_id.end());
967                     dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
968                     //Add shape
969                     layer_id.insert(std::make_pair(powerParams.name, LayerInfo(id, 0)));
970                     outShapes[powerParams.name] = outShapes[node_proto.input(1)];
971
972                     //Replace input to Power
973                     node_proto.set_input(1, powerParams.name);
974                 }
975                 layerParams.type = "Scale";
976             }
977
978             if (!haveVariables)
979             {
980                 Mat inp0 = getBlob(node_proto, constBlobs, 0);
981                 Mat inp1 = getBlob(node_proto, constBlobs, 1);
982                 if (inp0.size != inp1.size)
983                     CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes");
984
985                 Mat out;
986                 if (isDiv)
987                     divide(inp0, inp1, out);
988                 else
989                     multiply(inp0, inp1, out);
990
991                 out = out.reshape(1, inp0.dims, inp0.size);
992                 out.dims = inp0.dims;  // to workaround dims == 1
993                 addConstant(layerParams.name, out, constBlobs, outShapes);
994                 continue;
995             }
996         }
997         else if (layer_type == "Conv")
998         {
999             CV_Assert(node_proto.input_size() >= 2);
1000             layerParams.type = "Convolution";
1001             for (int j = 1; j < node_proto.input_size(); j++) {
1002                 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
1003             }
1004             layerParams.set("num_output", layerParams.blobs[0].size[0]);
1005             layerParams.set("bias_term", node_proto.input_size() == 3);
1006         }
1007         else if (layer_type == "ConvTranspose")
1008         {
1009             CV_Assert(node_proto.input_size() >= 2);
1010             layerParams.type = "Deconvolution";
1011             for (int j = 1; j < node_proto.input_size(); j++) {
1012                 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
1013             }
1014             layerParams.set("num_output", layerParams.blobs[0].size[1] * layerParams.get<int>("group", 1));
1015             layerParams.set("bias_term", node_proto.input_size() == 3);
1016
1017             if (!layerParams.has("kernel_size"))
1018                 CV_Error(Error::StsNotImplemented,
1019                          "Required attribute 'kernel_size' is not present.");
1020
1021             if (layerParams.has("output_shape"))
1022             {
1023                 const DictValue& outShape = layerParams.get("output_shape");
1024                 DictValue strides = layerParams.get("stride");
1025                 DictValue kernel = layerParams.get("kernel_size");
1026
1027                 String padMode;
1028                 std::vector<int> adjust_pads;
1029                 if (layerParams.has("pad_mode"))
1030                 {
1031                     padMode = toUpperCase(layerParams.get<String>("pad_mode"));
1032                     if (padMode != "SAME" && padMode != "VALID")
1033                         CV_Error(Error::StsError, "Unsupported padding mode " + padMode);
1034
1035                     for (int i = 0; i < strides.size(); i++)
1036                     {
1037                         int sz = outShape.get<int>(2 + i);
1038                         int stride = strides.get<int>(i);
1039                         adjust_pads.push_back(padMode == "SAME"? (sz - 1) % stride :
1040                                                                  (sz - kernel.get<int>(i)) % stride);
1041                     }
1042                     layerParams.set("adj", DictValue::arrayInt(&adjust_pads[0], adjust_pads.size()));
1043                 }
1044             }
1045             else if (layerParams.has("output_padding"))
1046             {
1047                 replaceLayerParam(layerParams, "output_padding", "adj");
1048             }
1049         }
1050         else if (layer_type == "Transpose")
1051         {
1052             layerParams.type = "Permute";
1053             replaceLayerParam(layerParams, "perm", "order");
1054
1055             CV_Assert(node_proto.input_size() == 1);
1056             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1057             {
1058                 std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), transposed;
1059                 runLayer(layerParams, inputs, transposed);
1060                 CV_Assert(transposed.size() == 1);
1061                 addConstant(layerParams.name, transposed[0], constBlobs, outShapes);
1062                 continue;
1063             }
1064         }
1065         else if (layer_type == "Squeeze")
1066         {
1067             CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
1068             DictValue axes_dict = layerParams.get("axes");
1069             MatShape inpShape = outShapes[node_proto.input(0)];
1070
1071             std::vector<bool> maskedAxes(inpShape.size(), false);
1072             for (int i = 0; i < axes_dict.size(); ++i)
1073             {
1074                 int axis = axes_dict.getIntValue(i);
1075                 CV_CheckLE(axis, static_cast<int>(inpShape.size()), "Squeeze axis");
1076                 maskedAxes[axis] = inpShape[axis] == 1;
1077             }
1078             MatShape outShape;
1079             for (int i = 0; i < inpShape.size(); ++i)
1080             {
1081                 if (!maskedAxes[i])
1082                     outShape.push_back(inpShape[i]);
1083             }
1084             if (outShape.size() != inpShape.size())
1085             {
1086                 layerParams.type = "Reshape";
1087                 layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
1088             }
1089             else
1090                 layerParams.type = "Identity";
1091
1092             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1093             {
1094                 Mat inp = getBlob(node_proto, constBlobs, 0);
1095                 Mat out = inp.reshape(1, outShape);
1096                 out.dims = outShape.size();  // to workaround dims == 1
1097                 addConstant(layerParams.name, out, constBlobs, outShapes);
1098                 continue;
1099             }
1100         }
1101         else if (layer_type == "Flatten")
1102         {
1103             CV_CheckEQ(node_proto.input_size(), 1, "");
1104             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1105             {
1106                 Mat input = getBlob(node_proto, constBlobs, 0);
1107                 int axis = clamp(layerParams.get<int>("axis", 1), input.dims);
1108
1109                 std::vector<int> out_size(&input.size[0], &input.size[0] + axis);
1110                 out_size.push_back(input.total(axis));
1111                 Mat output = input.reshape(1, out_size);
1112                 addConstant(layerParams.name, output, constBlobs, outShapes);
1113                 continue;
1114             }
1115         }
1116         else if (layer_type == "Unsqueeze")
1117         {
1118             CV_Assert(node_proto.input_size() == 1);
1119             DictValue axes = layerParams.get("axes");
1120             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1121             {
1122                 // Constant input.
1123                 Mat input = getBlob(node_proto, constBlobs, 0);
1124
1125                 std::vector<int> dims;
1126                 for (int j = 0; j < input.dims; j++) {
1127                     dims.push_back(input.size[j]);
1128                 }
1129                 CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
1130                 for (int j = 0; j < axes.size(); j++) {
1131                     dims.insert(dims.begin() + axes.getIntValue(j), 1);
1132                 }
1133
1134                 Mat out = input.reshape(0, dims);
1135                 addConstant(layerParams.name, out, constBlobs, outShapes);
1136                 continue;
1137             }
1138
1139             // Variable input.
1140             if (axes.size() != 1)
1141                 CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze");
1142
1143             MatShape inpShape = outShapes[node_proto.input(0)];
1144             int axis = axes.getIntValue(0);
1145             CV_Assert(0 <= axis && axis <= inpShape.size());
1146             std::vector<int> outShape = inpShape;
1147             outShape.insert(outShape.begin() + axis, 1);
1148             layerParams.type = "Reshape";
1149             layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
1150         }
1151         else if (layer_type == "Expand")
1152         {
1153             CV_CheckEQ(node_proto.input_size(), 2, "");
1154             CV_Assert(constBlobs.find(node_proto.input(1)) != constBlobs.end());
1155             Mat newShapeMat = getBlob(node_proto, constBlobs, 1);
1156             MatShape targetShape(newShapeMat.ptr<int>(), newShapeMat.ptr<int>() + newShapeMat.total());
1157
1158             shapeIt = outShapes.find(node_proto.input(0));
1159             CV_Assert(shapeIt != outShapes.end());
1160             MatShape inpShape = shapeIt->second;
1161             CV_CheckEQ(inpShape.size(), targetShape.size(), "Unsupported Expand op with different dims");
1162
1163             std::vector<int> broadcast_axes;
1164             for (int i = 0; i < targetShape.size(); i++)
1165             {
1166                 if (targetShape[i] != inpShape[i])
1167                 {
1168                     if (inpShape[i] == 1)
1169                         broadcast_axes.push_back(i);
1170                     else
1171                         CV_Error(Error::StsError, format("Could not be broadcast by axis: %d", i));
1172                 }
1173             }
1174
1175             if (broadcast_axes.size() == 2 &&
1176                 broadcast_axes[0] == broadcast_axes[1] - 1 && broadcast_axes[1] == inpShape.size() - 1)
1177             {
1178                 LayerParams constParams;
1179                 constParams.name = layerParams.name + "/const";
1180                 CV_Assert(layer_id.find(constParams.name) == layer_id.end());
1181                 constParams.type = "Const";
1182
1183                 Mat inp = Mat::ones(newShapeMat.total(), newShapeMat.ptr<int>(), CV_32F);
1184                 constParams.blobs.push_back(inp);
1185
1186                 opencv_onnx::NodeProto proto;
1187                 proto.add_output(constParams.name);
1188                 addLayer(dstNet, constParams, proto, layer_id, outShapes);
1189
1190                 layerParams.type = "Scale";
1191                 layerParams.set("bias_term", false);
1192                 node_proto.set_input(0, constParams.name);
1193                 node_proto.set_input(1, shapeIt->first);
1194             }
1195             else if (broadcast_axes.size() == 1 && broadcast_axes[0] <= 1)
1196             {
1197                 String base_name = layerParams.name + "/copy_";
1198                 std::vector<std::string> input_names;
1199                 for (int j = 0; j < targetShape[broadcast_axes[0]]; j++)
1200                 {
1201                     std::ostringstream ss;
1202                     ss << j;
1203                     LayerParams copyLP;
1204                     copyLP.name = base_name + ss.str();
1205                     copyLP.type = "Identity";
1206                     CV_Assert(layer_id.find(copyLP.name) == layer_id.end());
1207                     input_names.push_back(copyLP.name);
1208
1209                     node_proto.set_output(0, copyLP.name);
1210                     addLayer(dstNet, copyLP, node_proto, layer_id, outShapes);
1211                 }
1212                 node_proto.clear_input();
1213                 for (int i = 0; i < input_names.size(); i++)
1214                 {
1215                     node_proto.add_input(input_names[i]);
1216                 }
1217                 layerParams.set("axis", broadcast_axes[0]);
1218                 layerParams.type = "Concat";
1219             }
1220             else
1221                 CV_Error(Error::StsNotImplemented, "Unsupported Expand op");
1222         }
1223         else if (layer_type == "Reshape")
1224         {
1225             CV_Assert(node_proto.input_size() == 2 || layerParams.has("shape"));
1226
1227             if (node_proto.input_size() == 2) {
1228                 Mat blob = getBlob(node_proto, constBlobs, 1);
1229                 CV_Assert(blob.type() == CV_32SC1);
1230
1231                 layerParams.set("dim", DictValue::arrayInt<int*>(
1232                             blob.ptr<int>(), blob.total() ));
1233
1234                 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
1235                     std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), outputs;
1236                     runLayer(layerParams, inputs, outputs);
1237                     addConstant(layerParams.name, outputs[0], constBlobs, outShapes);
1238                     continue;
1239                 }
1240             }
1241             else {
1242                 DictValue shape = layerParams.get("shape");
1243                 std::vector<int> dim;
1244                 for (int j = 0; j < shape.size(); j++) {
1245                     dim.push_back(shape.getIntValue(j));
1246                 }
1247
1248                 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
1249                     Mat input = getBlob(node_proto, constBlobs, 0);
1250                     Mat out = input.reshape(0, dim);
1251                     addConstant(layerParams.name, out, constBlobs, outShapes);
1252                     continue;
1253                 }
1254                 replaceLayerParam(layerParams, "shape", "dim");
1255             }
1256         }
1257         else if (layer_type == "Pad")
1258         {
1259             layerParams.type = "Padding";
1260             replaceLayerParam(layerParams, "mode", "type");
1261             if (node_proto.input_size() == 3 || node_proto.input_size() == 2)
1262             {
1263                 // Paddings are in order begin0, begin1, .. beginN, end0, end1, ..., endN.
1264                 // We need to shuffle it to begin0, end0, begin1, end1, ...
1265                 Mat paddings = getBlob(node_proto, constBlobs, 1).reshape(1, 2);
1266                 paddings = paddings.t();
1267                 layerParams.set("paddings", DictValue::arrayInt(paddings.ptr<int>(), paddings.total()));
1268
1269                 if (node_proto.input_size() == 3)
1270                 {
1271                     Mat value = getBlob(node_proto, constBlobs, 2);
1272                     layerParams.set("value", value.at<float>(0));
1273                 }
1274             }
1275         }
1276         else if (layer_type == "Shape")
1277         {
1278             CV_Assert(node_proto.input_size() == 1);
1279             shapeIt = outShapes.find(node_proto.input(0));
1280             CV_Assert(shapeIt != outShapes.end());
1281             MatShape inpShape = shapeIt->second;
1282
1283             Mat shapeMat(inpShape.size(), 1, CV_32S);
1284             for (int j = 0; j < inpShape.size(); ++j)
1285                 shapeMat.at<int>(j) = inpShape[j];
1286             shapeMat.dims = 1;
1287
1288             addConstant(layerParams.name, shapeMat, constBlobs, outShapes);
1289             continue;
1290         }
1291         else if (layer_type == "Cast")
1292         {
1293             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
1294             {
1295                 Mat blob = getBlob(node_proto, constBlobs, 0);
1296                 int type;
1297                 switch (layerParams.get<int>("to"))
1298                 {
1299                     case opencv_onnx::TensorProto_DataType_FLOAT:   type = CV_32F; break;
1300                     case opencv_onnx::TensorProto_DataType_UINT8:   type = CV_8U; break;
1301                     case opencv_onnx::TensorProto_DataType_UINT16:  type = CV_16U; break;
1302                     case opencv_onnx::TensorProto_DataType_FLOAT16: type = CV_16S; break;
1303                     case opencv_onnx::TensorProto_DataType_INT8:
1304                     case opencv_onnx::TensorProto_DataType_INT16:
1305                     case opencv_onnx::TensorProto_DataType_INT32:
1306                     case opencv_onnx::TensorProto_DataType_INT64:   type = CV_32S; break;
1307                     default: type = blob.type();
1308                 }
1309                 blob.convertTo(blob, type);
1310                 addConstant(layerParams.name, blob, constBlobs, outShapes);
1311                 continue;
1312             }
1313             else
1314                 layerParams.type = "Identity";
1315         }
1316         else if (layer_type == "ConstantOfShape" || layer_type == "ConstantFill")
1317         {
1318             int depth = CV_32F;
1319             float fill_value;
1320             if (!layerParams.blobs.empty())
1321             {
1322                 CV_Assert(!layerParams.has("value"));
1323                 depth = layerParams.blobs[0].depth();
1324                 Mat floats;
1325                 layerParams.blobs[0].convertTo(floats, CV_32F);
1326                 fill_value = floats.at<float>(0, 0);
1327             }
1328             else
1329                 fill_value = layerParams.get("value", 0);
1330
1331             MatShape inpShape = getBlob(node_proto, constBlobs, 0);
1332             for (int i = 0; i < inpShape.size(); i++)
1333                 CV_CheckGT(inpShape[i], 0, "");
1334             Mat tensor(inpShape.size(), &inpShape[0], depth, Scalar(fill_value));
1335             addConstant(layerParams.name, tensor, constBlobs, outShapes);
1336             continue;
1337         }
1338         else if (layer_type == "Gather")
1339         {
1340             CV_Assert(node_proto.input_size() == 2);
1341             Mat input = getBlob(node_proto, constBlobs, 0);
1342             Mat indexMat = getBlob(node_proto, constBlobs, 1);
1343             CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
1344             int index = indexMat.at<int>(0);
1345
1346             Mat out;
1347             if (layerParams.has("axis"))
1348             {
1349                 int axis = layerParams.get<int>("axis");
1350
1351                 std::vector<cv::Range> ranges(input.dims, Range::all());
1352                 ranges[axis] = Range(index, index + 1);
1353
1354                 out = input(ranges);
1355             }
1356             else
1357             {
1358                 CV_Assert(index < input.total());
1359                 const int dims = input.dims;
1360                 input = input.reshape(1, 1);
1361                 input.dims = 2;
1362                 out = input.reshape(1, 1).colRange(index, index + 1);
1363                 out.dims = dims;
1364             }
1365             addConstant(layerParams.name, out, constBlobs, outShapes);
1366             continue;
1367         }
1368         else if (layer_type == "Concat")
1369         {
1370             bool hasVariableInps = false;
1371             for (int i = 0; i < node_proto.input_size(); ++i)
1372             {
1373                 if (layer_id.find(node_proto.input(i)) != layer_id.end())
1374                 {
1375                     hasVariableInps = true;
1376                     break;
1377                 }
1378             }
1379
1380             if (!hasVariableInps)
1381             {
1382                 std::vector<Mat> inputs(node_proto.input_size()), concatenated;
1383                 for (size_t i = 0; i < inputs.size(); ++i)
1384                 {
1385                     inputs[i] = getBlob(node_proto, constBlobs, i);
1386                 }
1387                 runLayer(layerParams, inputs, concatenated);
1388
1389                 CV_Assert(concatenated.size() == 1);
1390                 addConstant(layerParams.name, concatenated[0], constBlobs, outShapes);
1391                 continue;
1392             }
1393         }
1394         else if (layer_type == "Resize")
1395         {
1396             for (int i = 1; i < node_proto.input_size(); i++)
1397                 CV_Assert(layer_id.find(node_proto.input(i)) == layer_id.end());
1398
1399             String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
1400             CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");
1401
1402             layerParams.set("align_corners", interp_mode == "align_corners");
1403             Mat shapes = getBlob(node_proto, constBlobs, node_proto.input_size() - 1);
1404             CV_CheckEQ(shapes.size[0], 4, "");
1405             CV_CheckEQ(shapes.size[1], 1, "");
1406             CV_CheckTypeEQ(shapes.depth(), CV_32S, "");
1407             int height = shapes.at<int>(2);
1408             int width  = shapes.at<int>(3);
1409             if (node_proto.input_size() == 3)
1410             {
1411                 shapeIt = outShapes.find(node_proto.input(0));
1412                 CV_Assert(shapeIt != outShapes.end());
1413                 MatShape scales = shapeIt->second;
1414                 height *= scales[2];
1415                 width  *= scales[3];
1416             }
1417             layerParams.set("width", width);
1418             layerParams.set("height", height);
1419
1420             if (layerParams.get<String>("mode") == "linear") {
1421                 layerParams.set("mode", interp_mode == "pytorch_half_pixel" ?
1422                                         "opencv_linear" : "bilinear");
1423             }
1424             replaceLayerParam(layerParams, "mode", "interpolation");
1425         }
1426         else if (layer_type == "Upsample")
1427         {
1428             //fused from Resize Subgraph
1429             if (layerParams.has("coordinate_transformation_mode"))
1430             {
1431                 String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
1432                 CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");
1433
1434                 layerParams.set("align_corners", interp_mode == "align_corners");
1435                 if (layerParams.get<String>("mode") == "linear")
1436                 {
1437                     layerParams.set("mode", interp_mode == "pytorch_half_pixel" ?
1438                                             "opencv_linear" : "bilinear");
1439                 }
1440             }
1441             if (layerParams.get<String>("mode") == "linear" && framework_name == "pytorch")
1442                 layerParams.set("mode", "opencv_linear");
1443
1444             layerParams.type = "Resize";
1445             if (layerParams.has("scales"))
1446             {
1447                 // Pytorch layer
1448                 DictValue scales = layerParams.get("scales");
1449                 CV_Assert(scales.size() == 4);
1450                 layerParams.set("zoom_factor_y", scales.getIntValue(2));
1451                 layerParams.set("zoom_factor_x", scales.getIntValue(3));
1452             }
1453             else if (layerParams.has("height_scale") && layerParams.has("width_scale"))
1454             {
1455                 // Caffe2 layer
1456                 replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
1457                 replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
1458             }
1459             else
1460             {
1461                 // scales as input
1462                 Mat scales = getBlob(node_proto, constBlobs, 1);
1463                 CV_Assert(scales.total() == 4);
1464                 layerParams.set("zoom_factor_y", scales.at<float>(2));
1465                 layerParams.set("zoom_factor_x", scales.at<float>(3));
1466             }
1467             replaceLayerParam(layerParams, "mode", "interpolation");
1468         }
1469         else if (layer_type == "SoftMax" || layer_type == "LogSoftmax")
1470         {
1471             layerParams.type = "Softmax";
1472             layerParams.set("log_softmax", layer_type == "LogSoftmax");
1473         }
1474         else if (layer_type == "DetectionOutput")
1475         {
1476             CV_CheckEQ(node_proto.input_size(), 3, "");
1477             if (constBlobs.find(node_proto.input(2)) != constBlobs.end())
1478             {
1479                 Mat priors = getBlob(node_proto, constBlobs, 2);
1480
1481                 LayerParams constParams;
1482                 constParams.name = layerParams.name + "/priors";
1483                 constParams.type = "Const";
1484                 constParams.blobs.push_back(priors);
1485
1486                 opencv_onnx::NodeProto priorsProto;
1487                 priorsProto.add_output(constParams.name);
1488                 addLayer(dstNet, constParams, priorsProto, layer_id, outShapes);
1489
1490                 node_proto.set_input(2, constParams.name);
1491             }
1492         }
1493         else
1494         {
1495             for (int j = 0; j < node_proto.input_size(); j++) {
1496                 if (layer_id.find(node_proto.input(j)) == layer_id.end())
1497                     layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
1498             }
1499         }
1500         addLayer(dstNet, layerParams, node_proto, layer_id, outShapes);
1501     }
1502 }
1503
1504 Net readNetFromONNX(const String& onnxFile)
1505 {
1506     ONNXImporter onnxImporter(onnxFile.c_str());
1507     Net net;
1508     onnxImporter.populateNet(net);
1509     return net;
1510 }
1511
1512 Net readNetFromONNX(const char* buffer, size_t sizeBuffer)
1513 {
1514     ONNXImporter onnxImporter(buffer, sizeBuffer);
1515     Net net;
1516     onnxImporter.populateNet(net);
1517     return net;
1518 }
1519
1520 Net readNetFromONNX(const std::vector<uchar>& buffer)
1521 {
1522     return readNetFromONNX(reinterpret_cast<const char*>(buffer.data()), buffer.size());
1523 }
1524
1525 Mat readTensorFromONNX(const String& path)
1526 {
1527     opencv_onnx::TensorProto tensor_proto = opencv_onnx::TensorProto();
1528     std::fstream input(path.c_str(), std::ios::in | std::ios::binary);
1529     if (!tensor_proto.ParseFromIstream(&input)) {
1530         CV_Error(Error::StsUnsupportedFormat, "Failed to parse data");
1531     }
1532     Mat mat = getMatFromTensor(tensor_proto);
1533     releaseONNXTensor(tensor_proto);
1534     return mat;
1535 }
1536
1537 CV__DNN_INLINE_NS_END
1538 }} // namespace
1539
1540 #endif