1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
11 // For Open Source Computer Vision Library
13 // Copyright (C) 2013, OpenCV Foundation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
19 // * Redistribution's of source code must retain the above copyright notice,
20 // this list of conditions and the following disclaimer.
22 // * Redistribution's in binary form must reproduce the above copyright notice,
23 // this list of conditions and the following disclaimer in the documentation
24 // and/or other materials provided with the distribution.
26 // * The name of the copyright holders may not be used to endorse or promote products
27 // derived from this software without specific prior written permission.
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
42 #include "../precomp.hpp"
51 #include <google/protobuf/message.h>
52 #include <google/protobuf/text_format.h>
53 #include <google/protobuf/io/zero_copy_stream_impl.h>
54 #include "caffe_io.hpp"
59 CV__DNN_EXPERIMENTAL_NS_BEGIN
62 using ::google::protobuf::RepeatedField;
63 using ::google::protobuf::RepeatedPtrField;
64 using ::google::protobuf::Message;
65 using ::google::protobuf::Descriptor;
66 using ::google::protobuf::FieldDescriptor;
67 using ::google::protobuf::Reflection;
73 static cv::String toString(const T &v)
75 std::ostringstream ss;
80 class CaffeImporter : public Importer
82 caffe::NetParameter net;
83 caffe::NetParameter netBinary;
87 CaffeImporter(const char *pototxt, const char *caffeModel)
91 ReadNetParamsFromTextFileOrDie(pototxt, &net);
93 if (caffeModel && caffeModel[0])
94 ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
97 void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams ¶ms)
99 const Reflection *refl = msg.GetReflection();
100 int type = field->cpp_type();
101 bool isRepeated = field->is_repeated();
102 const std::string &name = field->name();
104 #define SET_UP_FILED(getter, arrayConstr, gtype) \
106 const RepeatedField<gtype> &v = refl->GetRepeatedField<gtype>(msg, field); \
107 params.set(name, DictValue::arrayConstr(v.begin(), (int)v.size())); \
110 params.set(name, refl->getter(msg, field)); \
115 case FieldDescriptor::CPPTYPE_INT32:
116 SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int32);
118 case FieldDescriptor::CPPTYPE_UINT32:
119 SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint32);
121 case FieldDescriptor::CPPTYPE_INT64:
122 SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int64);
124 case FieldDescriptor::CPPTYPE_UINT64:
125 SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint64);
127 case FieldDescriptor::CPPTYPE_BOOL:
128 SET_UP_FILED(GetBool, arrayInt, bool);
130 case FieldDescriptor::CPPTYPE_DOUBLE:
131 SET_UP_FILED(GetDouble, arrayReal, double);
133 case FieldDescriptor::CPPTYPE_FLOAT:
134 SET_UP_FILED(GetFloat, arrayReal, float);
136 case FieldDescriptor::CPPTYPE_STRING:
138 const RepeatedPtrField<std::string> &v = refl->GetRepeatedPtrField<std::string>(msg, field);
139 params.set(name, DictValue::arrayString(v.begin(), (int)v.size()));
142 params.set(name, refl->GetString(msg, field));
145 case FieldDescriptor::CPPTYPE_ENUM:
147 int size = refl->FieldSize(msg, field);
148 std::vector<cv::String> buf(size);
149 for (int i = 0; i < size; i++)
150 buf[i] = refl->GetRepeatedEnum(msg, field, i)->name();
151 params.set(name, DictValue::arrayString(buf.begin(), size));
154 params.set(name, refl->GetEnum(msg, field)->name());
158 CV_Error(Error::StsError, "Unknown type \"" + String(field->type_name()) + "\" in prototxt");
163 inline static bool ends_with_param(const std::string &str)
165 static const std::string _param("_param");
166 return (str.size() >= _param.size()) && str.compare(str.size() - _param.size(), _param.size(), _param) == 0;
169 void extractLayerParams(const Message &msg, cv::dnn::LayerParams ¶ms, bool isInternal = false)
171 const Descriptor *msgDesc = msg.GetDescriptor();
172 const Reflection *msgRefl = msg.GetReflection();
174 for (int fieldId = 0; fieldId < msgDesc->field_count(); fieldId++)
176 const FieldDescriptor *fd = msgDesc->field(fieldId);
178 if (!isInternal && !ends_with_param(fd->name()))
181 bool hasData = fd->is_required() ||
182 (fd->is_optional() && msgRefl->HasField(msg, fd)) ||
183 (fd->is_repeated() && msgRefl->FieldSize(msg, fd) > 0);
187 if (fd->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE)
189 if (fd->is_repeated()) //Extract only first item!
190 extractLayerParams(msgRefl->GetRepeatedMessage(msg, fd, 0), params, true);
192 extractLayerParams(msgRefl->GetMessage(msg, fd), params, true);
196 addParam(msg, fd, params);
201 void blobShapeFromProto(const caffe::BlobProto &pbBlob, MatShape& shape)
204 if (pbBlob.has_num() || pbBlob.has_channels() || pbBlob.has_height() || pbBlob.has_width())
206 shape.push_back(pbBlob.num());
207 shape.push_back(pbBlob.channels());
208 shape.push_back(pbBlob.height());
209 shape.push_back(pbBlob.width());
211 else if (pbBlob.has_shape())
213 const caffe::BlobShape &_shape = pbBlob.shape();
215 for (int i = 0; i < _shape.dim_size(); i++)
216 shape.push_back((int)_shape.dim(i));
219 CV_Error(Error::StsError, "Unknown shape of input blob");
222 void blobFromProto(const caffe::BlobProto &pbBlob, cv::Mat &dstBlob)
225 blobShapeFromProto(pbBlob, shape);
227 dstBlob.create((int)shape.size(), &shape[0], CV_32F);
228 float *dstData = dstBlob.ptr<float>();
229 if (pbBlob.data_size())
231 // Single precision floats.
232 CV_Assert(pbBlob.data_size() == (int)dstBlob.total());
234 CV_DbgAssert(pbBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT);
236 for (int i = 0; i < pbBlob.data_size(); i++)
237 dstData[i] = pbBlob.data(i);
241 // Half precision floats.
242 CV_Assert(pbBlob.raw_data_type() == caffe::FLOAT16);
243 std::string raw_data = pbBlob.raw_data();
245 CV_Assert(raw_data.size() / 2 == (int)dstBlob.total());
247 Mat halfs((int)shape.size(), &shape[0], CV_16SC1, (void*)raw_data.c_str());
248 convertFp16(halfs, dstBlob);
252 void extractBinaryLayerParms(const caffe::LayerParameter& layer, LayerParams& layerParams)
254 const std::string &name = layer.name();
257 for (li = 0; li != netBinary.layer_size(); li++)
259 if (netBinary.layer(li).name() == name)
263 if (li == netBinary.layer_size() || netBinary.layer(li).blobs_size() == 0)
266 const caffe::LayerParameter &binLayer = netBinary.layer(li);
267 layerParams.blobs.resize(binLayer.blobs_size());
268 for (int bi = 0; bi < binLayer.blobs_size(); bi++)
270 blobFromProto(binLayer.blobs(bi), layerParams.blobs[bi]);
276 BlobNote(const std::string &_name, int _layerId, int _outNum) :
277 name(_name.c_str()), layerId(_layerId), outNum(_outNum) {}
283 std::vector<BlobNote> addedBlobs;
284 std::map<String, int> layerCounter;
286 void populateNet(Net dstNet)
290 int layersSize = net.layer_size();
291 layerCounter.clear();
293 addedBlobs.reserve(layersSize + 1);
295 //setup input layer names
297 std::vector<String> netInputs(net.input_size());
298 for (int inNum = 0; inNum < net.input_size(); inNum++)
300 addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum));
301 netInputs[inNum] = net.input(inNum);
303 dstNet.setInputsNames(netInputs);
306 for (int li = 0; li < layersSize; li++)
308 const caffe::LayerParameter &layer = net.layer(li);
309 String name = layer.name();
310 String type = layer.type();
311 LayerParams layerParams;
313 extractLayerParams(layer, layerParams);
314 extractBinaryLayerParms(layer, layerParams);
316 int repetitions = layerCounter[name]++;
318 name += String("_") + toString(repetitions);
320 int id = dstNet.addLayer(name, type, layerParams);
322 for (int inNum = 0; inNum < layer.bottom_size(); inNum++)
323 addInput(layer.bottom(inNum), id, inNum, dstNet);
325 for (int outNum = 0; outNum < layer.top_size(); outNum++)
326 addOutput(layer, id, outNum);
332 void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum)
334 const std::string &name = layer.top(outNum);
336 bool haveDups = false;
337 for (int idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
339 if (addedBlobs[idx].name == name)
348 bool isInplace = layer.bottom_size() > outNum && layer.bottom(outNum) == name;
350 CV_Error(Error::StsBadArg, "Duplicate blobs produced by multiple sources");
353 addedBlobs.push_back(BlobNote(name, layerId, outNum));
356 void addInput(const std::string &name, int layerId, int inNum, Net &dstNet)
359 for (idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
361 if (addedBlobs[idx].name == name)
367 CV_Error(Error::StsObjectNotFound, "Can't find output blob \"" + name + "\"");
371 dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum);
383 Ptr<Importer> createCaffeImporter(const String &prototxt, const String &caffeModel)
385 return Ptr<Importer>(new CaffeImporter(prototxt.c_str(), caffeModel.c_str()));
388 Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String()*/)
390 CaffeImporter caffeImporter(prototxt.c_str(), caffeModel.c_str());
392 caffeImporter.populateNet(net);
396 #endif //HAVE_PROTOBUF
398 CV__DNN_EXPERIMENTAL_NS_END