Merge pull request #9305 from dkurt:public_dnn_importer_is_deprecated
[platform/upstream/opencv.git] / modules / dnn / src / caffe / caffe_importer.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2013, OpenCV Foundation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of the copyright holders may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "../precomp.hpp"
43
44 #ifdef HAVE_PROTOBUF
45 #include "caffe.pb.h"
46
47 #include <iostream>
48 #include <fstream>
49 #include <sstream>
50 #include <algorithm>
51 #include <google/protobuf/message.h>
52 #include <google/protobuf/text_format.h>
53 #include <google/protobuf/io/zero_copy_stream_impl.h>
54 #include "caffe_io.hpp"
55 #endif
56
57 namespace cv {
58 namespace dnn {
59 CV__DNN_EXPERIMENTAL_NS_BEGIN
60
61 #ifdef HAVE_PROTOBUF
62 using ::google::protobuf::RepeatedField;
63 using ::google::protobuf::RepeatedPtrField;
64 using ::google::protobuf::Message;
65 using ::google::protobuf::Descriptor;
66 using ::google::protobuf::FieldDescriptor;
67 using ::google::protobuf::Reflection;
68
69 namespace
70 {
71
72 template<typename T>
73 static cv::String toString(const T &v)
74 {
75     std::ostringstream ss;
76     ss << v;
77     return ss.str();
78 }
79
80 class CaffeImporter : public Importer
81 {
82     caffe::NetParameter net;
83     caffe::NetParameter netBinary;
84
85 public:
86
87     CaffeImporter(const char *pototxt, const char *caffeModel)
88     {
89         CV_TRACE_FUNCTION();
90
91         ReadNetParamsFromTextFileOrDie(pototxt, &net);
92
93         if (caffeModel && caffeModel[0])
94             ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
95     }
96
97     void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams &params)
98     {
99         const Reflection *refl = msg.GetReflection();
100         int type = field->cpp_type();
101         bool isRepeated = field->is_repeated();
102         const std::string &name = field->name();
103
104         #define SET_UP_FILED(getter, arrayConstr, gtype)                                    \
105             if (isRepeated) {                                                               \
106                 const RepeatedField<gtype> &v = refl->GetRepeatedField<gtype>(msg, field);  \
107                 params.set(name, DictValue::arrayConstr(v.begin(), (int)v.size()));                  \
108             }                                                                               \
109             else {                                                                          \
110                 params.set(name, refl->getter(msg, field));                               \
111             }
112
113         switch (type)
114         {
115         case FieldDescriptor::CPPTYPE_INT32:
116             SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int32);
117             break;
118         case FieldDescriptor::CPPTYPE_UINT32:
119             SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint32);
120             break;
121         case FieldDescriptor::CPPTYPE_INT64:
122             SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int64);
123             break;
124         case FieldDescriptor::CPPTYPE_UINT64:
125             SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint64);
126             break;
127         case FieldDescriptor::CPPTYPE_BOOL:
128             SET_UP_FILED(GetBool, arrayInt, bool);
129             break;
130         case FieldDescriptor::CPPTYPE_DOUBLE:
131             SET_UP_FILED(GetDouble, arrayReal, double);
132             break;
133         case FieldDescriptor::CPPTYPE_FLOAT:
134             SET_UP_FILED(GetFloat, arrayReal, float);
135             break;
136         case FieldDescriptor::CPPTYPE_STRING:
137             if (isRepeated) {
138                 const RepeatedPtrField<std::string> &v = refl->GetRepeatedPtrField<std::string>(msg, field);
139                 params.set(name, DictValue::arrayString(v.begin(), (int)v.size()));
140             }
141             else {
142                 params.set(name, refl->GetString(msg, field));
143             }
144             break;
145         case FieldDescriptor::CPPTYPE_ENUM:
146             if (isRepeated) {
147                 int size = refl->FieldSize(msg, field);
148                 std::vector<cv::String> buf(size);
149                 for (int i = 0; i < size; i++)
150                     buf[i] = refl->GetRepeatedEnum(msg, field, i)->name();
151                 params.set(name, DictValue::arrayString(buf.begin(), size));
152             }
153             else {
154                 params.set(name, refl->GetEnum(msg, field)->name());
155             }
156             break;
157         default:
158             CV_Error(Error::StsError, "Unknown type \"" + String(field->type_name()) + "\" in prototxt");
159             break;
160         }
161     }
162
163     inline static bool ends_with_param(const std::string &str)
164     {
165         static const std::string _param("_param");
166         return (str.size() >= _param.size()) && str.compare(str.size() - _param.size(), _param.size(), _param) == 0;
167     }
168
169     void extractLayerParams(const Message &msg, cv::dnn::LayerParams &params, bool isInternal = false)
170     {
171         const Descriptor *msgDesc = msg.GetDescriptor();
172         const Reflection *msgRefl = msg.GetReflection();
173
174         for (int fieldId = 0; fieldId < msgDesc->field_count(); fieldId++)
175         {
176             const FieldDescriptor *fd = msgDesc->field(fieldId);
177
178             if (!isInternal && !ends_with_param(fd->name()))
179                 continue;
180
181             bool hasData =  fd->is_required() ||
182                             (fd->is_optional() && msgRefl->HasField(msg, fd)) ||
183                             (fd->is_repeated() && msgRefl->FieldSize(msg, fd) > 0);
184             if (!hasData)
185                 continue;
186
187             if (fd->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE)
188             {
189                 if (fd->is_repeated()) //Extract only first item!
190                     extractLayerParams(msgRefl->GetRepeatedMessage(msg, fd, 0), params, true);
191                 else
192                     extractLayerParams(msgRefl->GetMessage(msg, fd), params, true);
193             }
194             else
195             {
196                 addParam(msg, fd, params);
197             }
198         }
199     }
200
201     void blobShapeFromProto(const caffe::BlobProto &pbBlob, MatShape& shape)
202     {
203         shape.clear();
204         if (pbBlob.has_num() || pbBlob.has_channels() || pbBlob.has_height() || pbBlob.has_width())
205         {
206             shape.push_back(pbBlob.num());
207             shape.push_back(pbBlob.channels());
208             shape.push_back(pbBlob.height());
209             shape.push_back(pbBlob.width());
210         }
211         else if (pbBlob.has_shape())
212         {
213             const caffe::BlobShape &_shape = pbBlob.shape();
214
215             for (int i = 0; i < _shape.dim_size(); i++)
216                 shape.push_back((int)_shape.dim(i));
217         }
218         else
219             CV_Error(Error::StsError, "Unknown shape of input blob");
220     }
221
222     void blobFromProto(const caffe::BlobProto &pbBlob, cv::Mat &dstBlob)
223     {
224         MatShape shape;
225         blobShapeFromProto(pbBlob, shape);
226
227         dstBlob.create((int)shape.size(), &shape[0], CV_32F);
228         float *dstData = dstBlob.ptr<float>();
229         if (pbBlob.data_size())
230         {
231             // Single precision floats.
232             CV_Assert(pbBlob.data_size() == (int)dstBlob.total());
233
234             CV_DbgAssert(pbBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT);
235
236             for (int i = 0; i < pbBlob.data_size(); i++)
237                 dstData[i] = pbBlob.data(i);
238         }
239         else
240         {
241             // Half precision floats.
242             CV_Assert(pbBlob.raw_data_type() == caffe::FLOAT16);
243             std::string raw_data = pbBlob.raw_data();
244
245             CV_Assert(raw_data.size() / 2 == (int)dstBlob.total());
246
247             Mat halfs((int)shape.size(), &shape[0], CV_16SC1, (void*)raw_data.c_str());
248             convertFp16(halfs, dstBlob);
249         }
250     }
251
252     void extractBinaryLayerParms(const caffe::LayerParameter& layer, LayerParams& layerParams)
253     {
254         const std::string &name = layer.name();
255
256         int li;
257         for (li = 0; li != netBinary.layer_size(); li++)
258         {
259             if (netBinary.layer(li).name() == name)
260                 break;
261         }
262
263         if (li == netBinary.layer_size() || netBinary.layer(li).blobs_size() == 0)
264             return;
265
266         const caffe::LayerParameter &binLayer = netBinary.layer(li);
267         layerParams.blobs.resize(binLayer.blobs_size());
268         for (int bi = 0; bi < binLayer.blobs_size(); bi++)
269         {
270             blobFromProto(binLayer.blobs(bi), layerParams.blobs[bi]);
271         }
272     }
273
274     struct BlobNote
275     {
276         BlobNote(const std::string &_name, int _layerId, int _outNum) :
277             name(_name.c_str()), layerId(_layerId), outNum(_outNum) {}
278
279         const char *name;
280         int layerId, outNum;
281     };
282
283     std::vector<BlobNote> addedBlobs;
284     std::map<String, int> layerCounter;
285
286     void populateNet(Net dstNet)
287     {
288         CV_TRACE_FUNCTION();
289
290         int layersSize = net.layer_size();
291         layerCounter.clear();
292         addedBlobs.clear();
293         addedBlobs.reserve(layersSize + 1);
294
295         //setup input layer names
296         {
297             std::vector<String> netInputs(net.input_size());
298             for (int inNum = 0; inNum < net.input_size(); inNum++)
299             {
300                 addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum));
301                 netInputs[inNum] = net.input(inNum);
302             }
303             dstNet.setInputsNames(netInputs);
304         }
305
306         for (int li = 0; li < layersSize; li++)
307         {
308             const caffe::LayerParameter &layer = net.layer(li);
309             String name = layer.name();
310             String type = layer.type();
311             LayerParams layerParams;
312
313             extractLayerParams(layer, layerParams);
314             extractBinaryLayerParms(layer, layerParams);
315
316             int repetitions = layerCounter[name]++;
317             if (repetitions)
318                 name += String("_") + toString(repetitions);
319
320             int id = dstNet.addLayer(name, type, layerParams);
321
322             for (int inNum = 0; inNum < layer.bottom_size(); inNum++)
323                 addInput(layer.bottom(inNum), id, inNum, dstNet);
324
325             for (int outNum = 0; outNum < layer.top_size(); outNum++)
326                 addOutput(layer, id, outNum);
327         }
328
329         addedBlobs.clear();
330     }
331
332     void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum)
333     {
334         const std::string &name = layer.top(outNum);
335
336         bool haveDups = false;
337         for (int idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
338         {
339             if (addedBlobs[idx].name == name)
340             {
341                 haveDups = true;
342                 break;
343             }
344         }
345
346         if (haveDups)
347         {
348             bool isInplace = layer.bottom_size() > outNum && layer.bottom(outNum) == name;
349             if (!isInplace)
350                 CV_Error(Error::StsBadArg, "Duplicate blobs produced by multiple sources");
351         }
352
353         addedBlobs.push_back(BlobNote(name, layerId, outNum));
354     }
355
356     void addInput(const std::string &name, int layerId, int inNum, Net &dstNet)
357     {
358         int idx;
359         for (idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
360         {
361             if (addedBlobs[idx].name == name)
362                 break;
363         }
364
365         if (idx < 0)
366         {
367             CV_Error(Error::StsObjectNotFound, "Can't find output blob \"" + name + "\"");
368             return;
369         }
370
371         dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum);
372     }
373
374     ~CaffeImporter()
375     {
376
377     }
378
379 };
380
381 }
382
383 Ptr<Importer> createCaffeImporter(const String &prototxt, const String &caffeModel)
384 {
385     return Ptr<Importer>(new CaffeImporter(prototxt.c_str(), caffeModel.c_str()));
386 }
387
388 Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String()*/)
389 {
390     CaffeImporter caffeImporter(prototxt.c_str(), caffeModel.c_str());
391     Net net;
392     caffeImporter.populateNet(net);
393     return net;
394 }
395
396 #endif //HAVE_PROTOBUF
397
398 CV__DNN_EXPERIMENTAL_NS_END
399 }} // namespace