Merge remote-tracking branch 'upstream/3.4' into merge-3.4
[platform/upstream/opencv.git] / modules / dnn / src / onnx / onnx_importer.cpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7
8 #include "../precomp.hpp"
9 #include <opencv2/dnn/shape_utils.hpp>
10
11 #ifdef HAVE_PROTOBUF
12
13 #include <iostream>
14 #include <fstream>
15 #include <string>
16 #include <limits>
17 #include <algorithm>
18
19
20 #if defined(__GNUC__) && __GNUC__ >= 5
21 #pragma GCC diagnostic push
22 #pragma GCC diagnostic ignored "-Wsuggest-override"
23 #endif
24 #include "opencv-onnx.pb.h"
25 #if defined(__GNUC__) && __GNUC__ >= 5
26 #pragma GCC diagnostic pop
27 #endif
28
29 #include "onnx_graph_simplifier.hpp"
30
31 namespace cv {
32 namespace dnn {
33 CV__DNN_INLINE_NS_BEGIN
34
35
36 class ONNXImporter
37 {
38     opencv_onnx::ModelProto model_proto;
39     struct LayerInfo {
40         int layerId;
41         int outputId;
42         LayerInfo(int _layerId, int _outputId) : layerId(_layerId), outputId(_outputId) {}
43     };
44
45     std::map<std::string, Mat> getGraphTensors(
46                                     const opencv_onnx::GraphProto& graph_proto);
47     Mat getBlob(const opencv_onnx::NodeProto& node_proto, const std::map<std::string, Mat>& constBlobs, int index);
48
49     LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto);
50     bool isCeilMode(const LayerParams& layerParams);
51
52 public:
53
54     ONNXImporter(const char *onnxFile)
55     {
56         std::fstream input(onnxFile, std::ios::in | std::ios::binary);
57
58         if (!model_proto.ParseFromIstream(&input))
59             CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model");
60     }
61
62     ONNXImporter(const char* buffer, size_t sizeBuffer)
63     {
64         struct _Buf : public std::streambuf
65         {
66             _Buf(const char* buffer, size_t sizeBuffer)
67             {
68                 char* p = const_cast<char*>(buffer);
69                 setg(p, p, p + sizeBuffer);
70             }
71         };
72
73         _Buf buf(buffer, sizeBuffer);
74         std::istream input(&buf);
75
76         if (!model_proto.ParseFromIstream(&input))
77             CV_Error(Error::StsUnsupportedFormat, "Failed to parse onnx model from in-memory byte array.");
78     }
79
80     void populateNet(Net dstNet);
81 };
82
83 inline void replaceLayerParam(LayerParams& layerParams, const String& oldKey, const String& newKey)
84 {
85     if (layerParams.has(oldKey)) {
86         layerParams.set(newKey, layerParams.get(oldKey));
87         layerParams.erase(oldKey);
88     }
89 }
90
91 void releaseONNXTensor(opencv_onnx::TensorProto& tensor_proto)
92 {
93     if (!tensor_proto.raw_data().empty()) {
94         delete tensor_proto.release_raw_data();
95     }
96 }
97
98 template<typename T1, typename T2>
99 void convertInt64ToInt32(const T1& src, T2& dst, int size)
100 {
101     for (int i = 0; i < size; i++) {
102         if (src[i] < std::numeric_limits<int32_t>::min() || src[i] > std::numeric_limits<int32_t>::max()) {
103             CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
104         }
105         dst[i] = saturate_cast<int32_t>(src[i]);
106     }
107 }
108
109 Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
110 {
111     CV_Assert(!tensor_proto.raw_data().empty() || !tensor_proto.float_data().empty()
112                     || !tensor_proto.double_data().empty() || !tensor_proto.int64_data().empty());
113
114     opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type();
115     Mat blob;
116     std::vector<int> sizes;
117     for (int i = 0; i < tensor_proto.dims_size(); i++) {
118             sizes.push_back(tensor_proto.dims(i));
119     }
120     if (sizes.empty())
121         sizes.assign(1, 1);
122     if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) {
123
124         if (!tensor_proto.float_data().empty()) {
125             const ::google::protobuf::RepeatedField<float> field = tensor_proto.float_data();
126             Mat(sizes, CV_32FC1, (void*)field.data()).copyTo(blob);
127         }
128         else {
129             char* val = const_cast<char*>(tensor_proto.raw_data().c_str());
130             Mat(sizes, CV_32FC1, val).copyTo(blob);
131         }
132     }
133     else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
134     {
135         const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data();
136         CV_Assert(!field.empty());
137         Mat(sizes, CV_64FC1, (void*)field.data()).convertTo(blob, CV_32FC1);
138     }
139     else if (datatype == opencv_onnx::TensorProto_DataType_INT64)
140     {
141         blob.create(sizes, CV_32SC1);
142         int32_t* dst = reinterpret_cast<int32_t*>(blob.data);
143
144         if (!tensor_proto.int64_data().empty()) {
145             ::google::protobuf::RepeatedField< ::google::protobuf::int64> src = tensor_proto.int64_data();
146             convertInt64ToInt32(src, dst, blob.total());
147         }
148         else
149         {
150             const char* val = tensor_proto.raw_data().c_str();
151 #if CV_STRONG_ALIGNMENT
152             // Aligned pointer is required: https://github.com/opencv/opencv/issues/16373
153             // this doesn't work: typedef int64_t CV_DECL_ALIGNED(1) unaligned_int64_t;
154             AutoBuffer<int64_t, 16> aligned_val;
155             if (!isAligned<sizeof(int64_t)>(val))
156             {
157                 size_t sz = tensor_proto.raw_data().size();
158                 aligned_val.allocate(divUp(sz, sizeof(int64_t)));
159                 memcpy(aligned_val.data(), val, sz);
160                 val = (const char*)aligned_val.data();
161             }
162 #endif
163             const int64_t* src = reinterpret_cast<const int64_t*>(val);
164             convertInt64ToInt32(src, dst, blob.total());
165         }
166     }
167     else
168         CV_Error(Error::StsUnsupportedFormat, "Unsupported data type: " +
169                         opencv_onnx::TensorProto_DataType_Name(datatype));
170     if (tensor_proto.dims_size() == 0)
171         blob.dims = 1;  // To force 1-dimensional cv::Mat for scalars.
172     return blob;
173 }
174
175 void runLayer(LayerParams& params, const std::vector<Mat>& inputs,
176               std::vector<Mat>& outputs)
177 {
178     Ptr<Layer> layer = LayerFactory::createLayerInstance(params.type, params);
179     CV_Assert((bool)layer);
180
181     std::vector<MatShape> inpShapes(inputs.size());
182     int ddepth = CV_32F;
183     for (size_t i = 0; i < inputs.size(); ++i)
184     {
185         inpShapes[i] = shape(inputs[i]);
186         if (i > 0 && ddepth != inputs[i].depth())
187             CV_Error(Error::StsNotImplemented, "Mixed input data types.");
188         ddepth = inputs[i].depth();
189     }
190
191     std::vector<MatShape> outShapes, internalShapes;
192     layer->getMemoryShapes(inpShapes, 0, outShapes, internalShapes);
193
194     std::vector<Mat> internals(internalShapes.size());
195     outputs.resize(outShapes.size());
196     for (size_t i = 0; i < outShapes.size(); ++i)
197         outputs[i].create(outShapes[i], ddepth);
198     for (size_t i = 0; i < internalShapes.size(); ++i)
199         internals[i].create(internalShapes[i], ddepth);
200
201     layer->finalize(inputs, outputs);
202     layer->forward(inputs, outputs, internals);
203 }
204
205 std::map<std::string, Mat> ONNXImporter::getGraphTensors(
206                                         const opencv_onnx::GraphProto& graph_proto)
207 {
208   opencv_onnx::TensorProto tensor_proto;
209   std::map<std::string, Mat> layers_weights;
210
211   for (int i = 0; i < graph_proto.initializer_size(); i++)
212   {
213     tensor_proto = graph_proto.initializer(i);
214     Mat mat = getMatFromTensor(tensor_proto);
215     releaseONNXTensor(tensor_proto);
216     layers_weights.insert(std::make_pair(tensor_proto.name(), mat));
217   }
218   return layers_weights;
219 }
220
221 static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protobuf::int64>& src) {
222     std::vector<int32_t> dst(src.size());
223     convertInt64ToInt32(src, dst, src.size());
224     return DictValue::arrayInt(&dst[0], src.size());
225 }
226
227 LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto)
228 {
229     LayerParams lp;
230     for(int i = 0; i < node_proto.attribute_size(); i++)
231     {
232         opencv_onnx::AttributeProto attribute_proto = node_proto.attribute(i);
233         std::string attribute_name = attribute_proto.name();
234
235         if(attribute_name == "kernel_shape")
236         {
237             CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
238             lp.set("kernel_size", parse(attribute_proto.ints()));
239         }
240         else if(attribute_name == "strides")
241         {
242             CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
243             lp.set("stride", parse(attribute_proto.ints()));
244         }
245         else if(attribute_name == "pads")
246         {
247             if (node_proto.op_type() == "Pad")
248             {
249                 // Padding layer.
250                 // Paddings are in order begin0, begin1, .. beginN, end0, end1, ..., endN.
251                 // We need to shuffle it to begin0, end0, begin1, end1, ...
252                 CV_Assert(attribute_proto.ints_size() % 2 == 0);
253                 const int dims = attribute_proto.ints_size() / 2;
254                 std::vector<int32_t> paddings;
255                 paddings.reserve(attribute_proto.ints_size());
256                 for (int i = 0; i < dims; ++i)
257                 {
258                     paddings.push_back(attribute_proto.ints(i));
259                     paddings.push_back(attribute_proto.ints(dims + i));
260                 }
261                 lp.set("paddings", DictValue::arrayInt(&paddings[0], paddings.size()));
262             }
263             else
264             {
265                 // Convolution or pooling.
266                 CV_Assert(attribute_proto.ints_size() == 4 || attribute_proto.ints_size() == 6);
267                 lp.set("pad", parse(attribute_proto.ints()));
268             }
269         }
270         else if(attribute_name == "auto_pad")
271         {
272             if (attribute_proto.s() == "SAME_UPPER" || attribute_proto.s() == "SAME_LOWER") {
273                 lp.set("pad_mode",  "SAME");
274             }
275             else if (attribute_proto.s() == "VALID") {
276                 lp.set("pad_mode", "VALID");
277             }
278         }
279         else if(attribute_name == "dilations")
280         {
281             CV_Assert(attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3);
282             lp.set("dilation", parse(attribute_proto.ints()));
283         }
284         else if (attribute_proto.has_i())
285         {
286             ::google::protobuf::int64 src = attribute_proto.i();
287             if (src < std::numeric_limits<int32_t>::min() || src > std::numeric_limits<int32_t>::max())
288                 CV_Error(Error::StsOutOfRange, "Input is out of OpenCV 32S range");
289             else
290                 lp.set(attribute_name, saturate_cast<int32_t>(src));
291         }
292         else if (attribute_proto.has_f())
293         {
294             lp.set(attribute_name, attribute_proto.f());
295         }
296         else if (attribute_proto.has_s())
297         {
298             lp.set(attribute_name, attribute_proto.s());
299         }
300         else if (attribute_proto.floats_size() > 0)
301         {
302             lp.set(attribute_name, DictValue::arrayReal(
303                 attribute_proto.floats().data(), attribute_proto.floats_size()));
304         }
305         else if (attribute_proto.ints_size() > 0)
306         {
307             lp.set(attribute_proto.name(), parse(attribute_proto.ints()));
308         }
309         else if (attribute_proto.has_t())
310         {
311             opencv_onnx::TensorProto tensor = attribute_proto.t();
312             Mat blob = getMatFromTensor(tensor);
313             lp.blobs.push_back(blob);
314         }
315         else if (attribute_proto.has_g() || attribute_proto.strings_size() > 0 ||
316                     attribute_proto.tensors_size() > 0 || attribute_proto.graphs_size() > 0)
317         {
318                 CV_Error(Error::StsNotImplemented, "Unexpected attribute type");
319         }
320         else
321             CV_Error(Error::StsNotImplemented, "Unsupported attribute type");
322     }
323     return lp;
324 }
325
326 Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto,
327                     const std::map<std::string, Mat>& constBlobs, int index)
328 {
329     CV_Assert(index < node_proto.input_size());
330     std::map<std::string, Mat>::const_iterator constBlob;
331     constBlob = constBlobs.find(node_proto.input(index));
332     if (constBlob == constBlobs.end()) {
333         CV_Error(Error::StsObjectNotFound,
334              "Blob " + node_proto.input(index) + " not found in const blobs");
335     }
336     return constBlob->second;
337 }
338
339 void ONNXImporter::populateNet(Net dstNet)
340 {
341     CV_Assert(model_proto.has_graph());
342     opencv_onnx::GraphProto graph_proto = model_proto.graph();
343
344     simplifySubgraphs(graph_proto);
345
346     std::map<std::string, Mat> constBlobs = getGraphTensors(graph_proto);
347     // List of internal blobs shapes.
348     std::map<std::string, MatShape> outShapes;
349     // Add all the inputs shapes. It includes as constant blobs as network's inputs shapes.
350     for (int i = 0; i < graph_proto.input_size(); ++i)
351     {
352         opencv_onnx::ValueInfoProto valueInfoProto = graph_proto.input(i);
353         CV_Assert(valueInfoProto.has_type());
354         opencv_onnx::TypeProto typeProto = valueInfoProto.type();
355         CV_Assert(typeProto.has_tensor_type());
356         opencv_onnx::TypeProto::Tensor tensor = typeProto.tensor_type();
357         CV_Assert(tensor.has_shape());
358         opencv_onnx::TensorShapeProto tensorShape = tensor.shape();
359
360         MatShape inpShape(tensorShape.dim_size());
361         for (int j = 0; j < inpShape.size(); ++j)
362         {
363             inpShape[j] = tensorShape.dim(j).dim_value();
364         }
365         outShapes[valueInfoProto.name()] = inpShape;
366     }
367
368     std::string framework_name;
369     if (model_proto.has_producer_name()) {
370         framework_name = model_proto.producer_name();
371     }
372
373     // create map with network inputs (without const blobs)
374     std::map<std::string, LayerInfo> layer_id;
375     std::map<std::string, LayerInfo>::iterator layerId;
376     std::map<std::string, MatShape>::iterator shapeIt;
377     // fill map: push layer name, layer id and output id
378     std::vector<String> netInputs;
379     for (int j = 0; j < graph_proto.input_size(); j++)
380     {
381         const std::string& name = graph_proto.input(j).name();
382         if (constBlobs.find(name) == constBlobs.end()) {
383             netInputs.push_back(name);
384             layer_id.insert(std::make_pair(name, LayerInfo(0, netInputs.size() - 1)));
385         }
386     }
387     dstNet.setInputsNames(netInputs);
388
389     int layersSize = graph_proto.node_size();
390     LayerParams layerParams;
391     opencv_onnx::NodeProto node_proto;
392
393     for(int li = 0; li < layersSize; li++)
394     {
395         node_proto = graph_proto.node(li);
396         layerParams = getLayerParams(node_proto);
397         CV_Assert(node_proto.output_size() >= 1);
398         layerParams.name = node_proto.output(0);
399
400         std::string layer_type = node_proto.op_type();
401         layerParams.type = layer_type;
402
403
404         if (layer_type == "MaxPool")
405         {
406             layerParams.type = "Pooling";
407             layerParams.set("pool", "MAX");
408             layerParams.set("ceil_mode", layerParams.has("pad_mode"));
409         }
410         else if (layer_type == "AveragePool")
411         {
412             layerParams.type = "Pooling";
413             layerParams.set("pool", "AVE");
414             layerParams.set("ceil_mode", layerParams.has("pad_mode"));
415             layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
416         }
417         else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || layer_type == "ReduceMean")
418         {
419             CV_Assert(node_proto.input_size() == 1);
420             layerParams.type = "Pooling";
421             layerParams.set("pool", layer_type == "GlobalMaxPool"? "MAX" : "AVE");
422             layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
423
424             if (layer_type == "ReduceMean")
425             {
426                 if (layerParams.get<int>("keepdims") == 0 || !layerParams.has("axes"))
427                     CV_Error(Error::StsNotImplemented, "Unsupported mode of ReduceMean operation.");
428
429                 MatShape inpShape = outShapes[node_proto.input(0)];
430                 if (inpShape.size() != 4 && inpShape.size() != 5)
431                     CV_Error(Error::StsNotImplemented, "Unsupported input shape of reduce_mean operation.");
432
433                 DictValue axes = layerParams.get("axes");
434                 CV_Assert(axes.size() <= inpShape.size() - 2);
435                 std::vector<int> kernel_size(inpShape.size() - 2, 1);
436                 for (int i = 0; i < axes.size(); i++) {
437                     int axis = axes.get<int>(i);
438                     CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
439                     kernel_size[axis - 2] = inpShape[axis];
440                 }
441
442                 layerParams.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
443             }
444         }
445         else if (layer_type == "Slice")
446         {
447             if (layerParams.has("steps")) {
448                 DictValue steps = layerParams.get("steps");
449                 for (int i = 0; i < steps.size(); ++i) {
450                     if (steps.get<int>(i) != 1)
451                         CV_Error(Error::StsNotImplemented,
452                                  "Slice layer only supports steps = 1");
453                 }
454             }
455
456             int axis = 0;
457             if (layerParams.has("axes")) {
458                 DictValue axes = layerParams.get("axes");
459                 for (int i = 1; i < axes.size(); ++i) {
460                     CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
461                 }
462                 axis = axes.get<int>(0);
463             }
464             layerParams.set("axis", axis);
465
466             DictValue starts = layerParams.get("starts");
467             DictValue ends = layerParams.get("ends");
468             CV_Assert(starts.size() == ends.size());
469
470             std::vector<int> begin;
471             std::vector<int> end;
472             if (axis > 0) {
473                 begin.resize(axis, 0);
474                 end.resize(axis, -1);
475             }
476
477             for (int i = 0; i < starts.size(); ++i)
478             {
479                 begin.push_back(starts.get<int>(i));
480                 int finish = ends.get<int>(i);
481                 end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
482             }
483             layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
484             layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
485          }
486         else if (layer_type == "Split")
487         {
488             DictValue splits = layerParams.get("split");
489             const int numSplits = splits.size();
490             CV_Assert(numSplits > 1);
491
492             std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
493             for (int i = 1; i < splits.size() - 1; ++i)
494             {
495                 slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1);
496             }
497             layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
498             layerParams.type = "Slice";
499         }
500         else if (layer_type == "Add" || layer_type == "Sum")
501         {
502             if (layer_id.find(node_proto.input(1)) == layer_id.end())
503             {
504                 Mat blob = getBlob(node_proto, constBlobs, 1);
505                 blob = blob.reshape(1, 1);
506                 if (blob.total() == 1) {
507                     layerParams.type = "Power";
508                     layerParams.set("shift", blob.at<float>(0));
509                 }
510                 else {
511                     layerParams.type = "Scale";
512                     layerParams.set("bias_term", true);
513                     layerParams.blobs.push_back(blob);
514                 }
515             }
516             else {
517                 layerParams.type = "Eltwise";
518             }
519         }
520         else if (layer_type == "Max")
521         {
522             layerParams.type = "Eltwise";
523             layerParams.set("operation", "max");
524         }
525         else if (layer_type == "Sub")
526         {
527             Mat blob = getBlob(node_proto, constBlobs, 1);
528             if (blob.total() == 1) {
529                 layerParams.type = "Power";
530                 layerParams.set("shift", -blob.at<float>(0));
531             }
532             else {
533                 layerParams.type = "Scale";
534                 layerParams.set("has_bias", true);
535                 layerParams.blobs.push_back(-1.0f * blob.reshape(1, 1));
536             }
537         }
538         else if (layer_type == "Div")
539         {
540             if (constBlobs.find(node_proto.input(1)) == constBlobs.end())
541             {
542                 layerParams.type = "Eltwise";
543                 layerParams.set("operation", "div");
544             }
545             else
546             {
547                 Mat blob = getBlob(node_proto, constBlobs, 1);
548                 CV_Assert_N(blob.type() == CV_32F, blob.total());
549                 if (blob.total() == 1)
550                 {
551                     layerParams.set("scale", 1.0f / blob.at<float>(0));
552                     layerParams.type = "Power";
553                 }
554                 else
555                 {
556                     layerParams.type = "Scale";
557                     divide(1.0, blob, blob);
558                     layerParams.blobs.push_back(blob);
559                     layerParams.set("bias_term", false);
560                 }
561             }
562         }
563         else if (layer_type == "Neg")
564         {
565             layerParams.type = "Power";
566             layerParams.set("scale", -1);
567         }
568         else if (layer_type == "Constant")
569         {
570             CV_Assert(node_proto.input_size() == 0);
571             CV_Assert(layerParams.blobs.size() == 1);
572             constBlobs.insert(std::make_pair(layerParams.name, layerParams.blobs[0]));
573             continue;
574         }
575         else if (layer_type == "ImageScaler")
576         {
577             const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
578             layerParams.erase("scale");
579
580             if (layerParams.has("bias"))
581             {
582                 layerParams.type = "Scale";
583                 layerParams.blobs.push_back(
584                     Mat(Size(1,  layerParams.get("bias").size()), CV_32FC1, scale));
585
586                 layerParams.set("bias_term", true);
587                 Mat bias(1, layerParams.get("bias").size(), CV_32FC1);
588                 for (int j = 0; j < bias.total(); j++) {
589                     bias.at<float>(0, j) = layerParams.get("bias").getRealValue(j);
590                 }
591                 layerParams.blobs.push_back(bias);
592                 layerParams.erase("bias");
593             }
594             else {
595                 layerParams.set("scale", scale);
596                 layerParams.type = "Power";
597             }
598         }
599         else if (layer_type == "Clip")
600         {
601             layerParams.type = "ReLU6";
602             replaceLayerParam(layerParams, "min", "min_value");
603             replaceLayerParam(layerParams, "max", "max_value");
604
605         }
606         else if (layer_type == "LeakyRelu")
607         {
608             layerParams.type = "ReLU";
609             replaceLayerParam(layerParams, "alpha", "negative_slope");
610         }
611         else if (layer_type == "LRN")
612         {
613             replaceLayerParam(layerParams, "size", "local_size");
614         }
615         else if (layer_type == "InstanceNormalization")
616         {
617             if (node_proto.input_size() != 3)
618                 CV_Error(Error::StsNotImplemented,
619                          "Expected input, scale, bias");
620
621             layerParams.blobs.resize(4);
622             layerParams.blobs[2] = getBlob(node_proto, constBlobs, 1);  // weightData
623             layerParams.blobs[3] = getBlob(node_proto, constBlobs, 2);  // biasData
624             layerParams.set("has_bias", true);
625             layerParams.set("has_weight", true);
626
627             // Get number of channels in input
628             int size = layerParams.blobs[2].total();
629             layerParams.blobs[0] = Mat::zeros(size, 1, CV_32F); // mean
630             layerParams.blobs[1] = Mat::ones(size, 1, CV_32F); // std
631
632             LayerParams mvnParams;
633             mvnParams.name = layerParams.name + "/MVN";
634             mvnParams.type = "MVN";
635             mvnParams.set("eps", layerParams.get<float>("epsilon"));
636             layerParams.erase("epsilon");
637
638             //Create MVN layer
639             int id = dstNet.addLayer(mvnParams.name, mvnParams.type, mvnParams);
640             //Connect to input
641             layerId = layer_id.find(node_proto.input(0));
642             CV_Assert(layerId != layer_id.end());
643             dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0);
644             //Add shape
645             layer_id.insert(std::make_pair(mvnParams.name, LayerInfo(id, 0)));
646             outShapes[mvnParams.name] = outShapes[node_proto.input(0)];
647
648             //Replace Batch Norm's input to MVN
649             node_proto.set_input(0, mvnParams.name);
650             layerParams.type = "BatchNorm";
651         }
652         else if (layer_type == "BatchNormalization")
653         {
654             if (node_proto.input_size() != 5)
655                 CV_Error(Error::StsNotImplemented,
656                          "Expected input, scale, bias, mean and var");
657
658             layerParams.type = "BatchNorm";
659             replaceLayerParam(layerParams, "epsilon", "eps");
660             replaceLayerParam(layerParams, "spatial", "use_global_stats");
661
662             Mat meanData = getBlob(node_proto, constBlobs, 3);
663             Mat stdData =  getBlob(node_proto, constBlobs, 4);
664
665             layerParams.blobs.push_back(meanData);
666             layerParams.blobs.push_back(stdData);
667
668             if (!node_proto.input(1).empty()) {
669                 layerParams.set("has_weight", true);
670                 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 1));  // weightData
671             } else {
672                 layerParams.set("has_weight", false);
673             }
674
675             if (!node_proto.input(2).empty()) {
676                 layerParams.set("has_bias", true);
677                 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 2)); // biasData
678             } else {
679                 layerParams.set("has_bias", false);
680             }
681         }
682         else if (layer_type == "Gemm")
683         {
684             CV_Assert(node_proto.input_size() >= 2);
685             layerParams.type = "InnerProduct";
686             Mat weights = getBlob(node_proto, constBlobs, 1);
687             int ind_num_out = 0;
688             if (layerParams.has("transB") && !layerParams.get<int>("transB")) {
689                 transpose(weights, weights);
690                 ind_num_out = 1;
691             }
692             layerParams.blobs.push_back(weights);
693
694             if (node_proto.input_size() == 3) {
695                 Mat bias = getBlob(node_proto, constBlobs, 2);
696                 layerParams.blobs.push_back(bias);
697             }
698
699             layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
700             layerParams.set("bias_term", node_proto.input_size() == 3);
701         }
702         else if (layer_type == "MatMul")
703         {
704             CV_Assert(node_proto.input_size() == 2);
705             layerParams.type = "InnerProduct";
706             Mat blob = getBlob(node_proto, constBlobs, 1);
707             layerParams.blobs.push_back(blob.t());
708             layerParams.set("bias_term", false);
709             layerParams.set("num_output", layerParams.blobs[0].size[0]);
710         }
711         else if (layer_type == "Mul")
712         {
713             CV_Assert(node_proto.input_size() == 2);
714             if (layer_id.find(node_proto.input(1)) == layer_id.end()) {
715                 Mat blob = getBlob(node_proto, constBlobs, 1);
716                 blob = blob.reshape(1, 1);
717                 if (blob.total() == 1) {
718                     layerParams.set("scale", blob.at<float>(0));
719                     layerParams.type = "Power";
720                 }
721                 else {
722                     layerParams.blobs.push_back(blob);
723                     layerParams.type = "Scale";
724                 }
725             }
726             else {
727                 layerParams.type = "Eltwise";
728                 layerParams.set("operation", "prod");
729             }
730         }
731         else if (layer_type == "Conv")
732         {
733             CV_Assert(node_proto.input_size() >= 2);
734             layerParams.type = "Convolution";
735             for (int j = 1; j < node_proto.input_size(); j++) {
736                 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
737             }
738             layerParams.set("num_output", layerParams.blobs[0].size[0]);
739             layerParams.set("bias_term", node_proto.input_size() == 3);
740         }
741         else if (layer_type == "ConvTranspose")
742         {
743             CV_Assert(node_proto.input_size() >= 2);
744             layerParams.type = "Deconvolution";
745             for (int j = 1; j < node_proto.input_size(); j++) {
746                 layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
747             }
748             layerParams.set("num_output", layerParams.blobs[0].size[1] * layerParams.get<int>("group", 1));
749             layerParams.set("bias_term", node_proto.input_size() == 3);
750
751             if (!layerParams.has("kernel_size"))
752                 CV_Error(Error::StsNotImplemented,
753                          "Required attribute 'kernel_size' is not present.");
754
755             if (layerParams.has("output_shape"))
756             {
757                 const DictValue& outShape = layerParams.get("output_shape");
758                 DictValue strides = layerParams.get("stride");
759                 DictValue kernel = layerParams.get("kernel_size");
760
761                 String padMode;
762                 std::vector<int> adjust_pads;
763                 if (layerParams.has("pad_mode"))
764                 {
765                     padMode = toUpperCase(layerParams.get<String>("pad_mode"));
766                     if (padMode != "SAME" && padMode != "VALID")
767                         CV_Error(Error::StsError, "Unsupported padding mode " + padMode);
768
769                     for (int i = 0; i < strides.size(); i++)
770                     {
771                         int sz = outShape.get<int>(2 + i);
772                         int stride = strides.get<int>(i);
773                         adjust_pads.push_back(padMode == "SAME"? (sz - 1) % stride :
774                                                                  (sz - kernel.get<int>(i)) % stride);
775                     }
776                     layerParams.set("adj", DictValue::arrayInt(&adjust_pads[0], adjust_pads.size()));
777                 }
778             }
779             else if (layerParams.has("output_padding"))
780             {
781                 replaceLayerParam(layerParams, "output_padding", "adj");
782             }
783         }
784         else if (layer_type == "Transpose")
785         {
786             layerParams.type = "Permute";
787             replaceLayerParam(layerParams, "perm", "order");
788
789             CV_Assert(node_proto.input_size() == 1);
790             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
791             {
792                 std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), transposed;
793                 runLayer(layerParams, inputs, transposed);
794                 CV_Assert(transposed.size() == 1);
795                 constBlobs.insert(std::make_pair(layerParams.name, transposed[0]));
796                 continue;
797             }
798         }
799         else if (layer_type == "ReduceL2")
800         {
801             CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
802             CV_Assert(graph_proto.node_size() > li + 1 && graph_proto.node(li + 1).op_type() == "Div");
803             ++li;
804             node_proto = graph_proto.node(li);
805             layerParams.name = node_proto.output(0);
806             layerParams.type = "Normalize";
807
808             DictValue axes_dict = layerParams.get("axes");
809             if (axes_dict.size() != 1)
810                 CV_Error(Error::StsNotImplemented, "Multidimensional reduceL2");
811             int axis = axes_dict.getIntValue(0);
812             layerParams.set("axis",axis);
813             layerParams.set("end_axis", axis);
814         }
815         else if (layer_type == "Squeeze")
816         {
817             CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
818             DictValue axes_dict = layerParams.get("axes");
819             if (axes_dict.size() != 1)
820                 CV_Error(Error::StsNotImplemented, "Multidimensional squeeze");
821
822             int axis = axes_dict.getIntValue(0);
823             layerParams.set("axis", axis - 1);
824             layerParams.set("end_axis", axis);
825             layerParams.type = "Flatten";
826         }
827         else if (layer_type == "Unsqueeze")
828         {
829             CV_Assert(node_proto.input_size() == 1);
830             DictValue axes = layerParams.get("axes");
831             if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
832             {
833                 // Constant input.
834                 Mat input = getBlob(node_proto, constBlobs, 0);
835
836                 std::vector<int> dims;
837                 for (int j = 0; j < input.dims; j++) {
838                     dims.push_back(input.size[j]);
839                 }
840                 CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
841                 for (int j = 0; j < axes.size(); j++) {
842                     dims.insert(dims.begin() + axes.getIntValue(j), 1);
843                 }
844
845                 Mat out = input.reshape(0, dims);
846                 constBlobs.insert(std::make_pair(layerParams.name, out));
847                 continue;
848             }
849
850             // Variable input.
851             if (axes.size() != 1)
852                 CV_Error(Error::StsNotImplemented, "Multidimensional unsqueeze");
853
854             MatShape inpShape = outShapes[node_proto.input(0)];
855             int axis = axes.getIntValue(0);
856             CV_Assert(0 <= axis && axis <= inpShape.size());
857             std::vector<int> outShape = inpShape;
858             outShape.insert(outShape.begin() + axis, 1);
859             layerParams.type = "Reshape";
860             layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
861         }
862         else if (layer_type == "Reshape")
863         {
864             CV_Assert(node_proto.input_size() == 2 || layerParams.has("shape"));
865
866             if (node_proto.input_size() == 2) {
867                 Mat blob = getBlob(node_proto, constBlobs, 1);
868                 CV_Assert(blob.type() == CV_32SC1);
869
870                 layerParams.set("dim", DictValue::arrayInt<int*>(
871                             blob.ptr<int>(), blob.total() ));
872
873                 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
874                     std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), outputs;
875                     runLayer(layerParams, inputs, outputs);
876                     constBlobs.insert(std::make_pair(layerParams.name, outputs[0]));
877                     continue;
878                 }
879             }
880             else {
881                 DictValue shape = layerParams.get("shape");
882                 std::vector<int> dim;
883                 for (int j = 0; j < shape.size(); j++) {
884                     dim.push_back(shape.getIntValue(j));
885                 }
886
887                 if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
888                     Mat input = getBlob(node_proto, constBlobs, 0);
889                     Mat out = input.reshape(0, dim);
890                     constBlobs.insert(std::make_pair(layerParams.name, out));
891                     continue;
892                 }
893                 replaceLayerParam(layerParams, "shape", "dim");
894             }
895         }
896         else if (layer_type == "Pad")
897         {
898             layerParams.type = "Padding";
899         }
900         else if (layer_type == "Shape")
901         {
902             CV_Assert(node_proto.input_size() == 1);
903             shapeIt = outShapes.find(node_proto.input(0));
904             CV_Assert(shapeIt != outShapes.end());
905             MatShape inpShape = shapeIt->second;
906
907             Mat shapeMat(inpShape.size(), 1, CV_32S);
908             for (int j = 0; j < inpShape.size(); ++j)
909                 shapeMat.at<int>(j) = inpShape[j];
910             shapeMat.dims = 1;
911
912             constBlobs.insert(std::make_pair(layerParams.name, shapeMat));
913             continue;
914         }
915         else if (layer_type == "Gather")
916         {
917             CV_Assert(node_proto.input_size() == 2);
918             CV_Assert(layerParams.has("axis"));
919             Mat input = getBlob(node_proto, constBlobs, 0);
920             Mat indexMat = getBlob(node_proto, constBlobs, 1);
921             CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
922             int index = indexMat.at<int>(0);
923             int axis = layerParams.get<int>("axis");
924
925             std::vector<cv::Range> ranges(input.dims, Range::all());
926             ranges[axis] = Range(index, index + 1);
927
928             Mat out = input(ranges);
929             constBlobs.insert(std::make_pair(layerParams.name, out));
930             continue;
931         }
932         else if (layer_type == "Concat")
933         {
934             bool hasVariableInps = false;
935             for (int i = 0; i < node_proto.input_size(); ++i)
936             {
937                 if (layer_id.find(node_proto.input(i)) != layer_id.end())
938                 {
939                     hasVariableInps = true;
940                     break;
941                 }
942             }
943
944             if (!hasVariableInps)
945             {
946                 std::vector<Mat> inputs(node_proto.input_size()), concatenated;
947                 for (size_t i = 0; i < inputs.size(); ++i)
948                 {
949                     inputs[i] = getBlob(node_proto, constBlobs, i);
950                 }
951                 runLayer(layerParams, inputs, concatenated);
952
953                 CV_Assert(concatenated.size() == 1);
954                 constBlobs.insert(std::make_pair(layerParams.name, concatenated[0]));
955                 continue;
956             }
957         }
958         else if (layer_type == "Upsample")
959         {
960             layerParams.type = "Resize";
961             if (layerParams.has("scales"))
962             {
963                 // Pytorch layer
964                 DictValue scales = layerParams.get("scales");
965                 CV_Assert(scales.size() == 4);
966                 layerParams.set("zoom_factor_y", scales.getIntValue(2));
967                 layerParams.set("zoom_factor_x", scales.getIntValue(3));
968             }
969             else
970             {
971                 // Caffe2 layer
972                 replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
973                 replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
974             }
975             replaceLayerParam(layerParams, "mode", "interpolation");
976         }
977         else if (layer_type == "LogSoftmax")
978         {
979             layerParams.type = "Softmax";
980             layerParams.set("log_softmax", true);
981         }
982         else
983         {
984             for (int j = 0; j < node_proto.input_size(); j++) {
985                 if (layer_id.find(node_proto.input(j)) == layer_id.end())
986                     layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
987             }
988         }
989
990         int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
991         for (int i = 0; i < node_proto.output_size(); ++i)
992         {
993             layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
994         }
995
996         std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
997         for (int j = 0; j < node_proto.input_size(); j++) {
998             layerId = layer_id.find(node_proto.input(j));
999             if (layerId != layer_id.end()) {
1000                 dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, j);
1001                 // Collect input shapes.
1002                 shapeIt = outShapes.find(node_proto.input(j));
1003                 CV_Assert(shapeIt != outShapes.end());
1004                 layerInpShapes.push_back(shapeIt->second);
1005             }
1006         }
1007
1008         // Compute shape of output blob for this layer.
1009         Ptr<Layer> layer = dstNet.getLayer(id);
1010         layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
1011         for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
1012         {
1013             outShapes[node_proto.output(i)] = layerOutShapes[i];
1014         }
1015     }
1016 }
1017
1018 Net readNetFromONNX(const String& onnxFile)
1019 {
1020     ONNXImporter onnxImporter(onnxFile.c_str());
1021     Net net;
1022     onnxImporter.populateNet(net);
1023     return net;
1024 }
1025
1026 Net readNetFromONNX(const char* buffer, size_t sizeBuffer)
1027 {
1028     ONNXImporter onnxImporter(buffer, sizeBuffer);
1029     Net net;
1030     onnxImporter.populateNet(net);
1031     return net;
1032 }
1033
1034 Net readNetFromONNX(const std::vector<uchar>& buffer)
1035 {
1036     return readNetFromONNX(reinterpret_cast<const char*>(buffer.data()), buffer.size());
1037 }
1038
1039 Mat readTensorFromONNX(const String& path)
1040 {
1041     opencv_onnx::TensorProto tensor_proto = opencv_onnx::TensorProto();
1042     std::fstream input(path.c_str(), std::ios::in | std::ios::binary);
1043     if (!tensor_proto.ParseFromIstream(&input)) {
1044         CV_Error(Error::StsUnsupportedFormat, "Failed to parse data");
1045     }
1046     Mat mat = getMatFromTensor(tensor_proto);
1047     releaseONNXTensor(tensor_proto);
1048     return mat;
1049 }
1050
1051 CV__DNN_INLINE_NS_END
1052 }} // namespace
1053
1054 #endif