Reading net from std::ifstream
authorasciian <asciian@users.noreply.github.com>
Sun, 18 Mar 2018 02:21:58 +0000 (11:21 +0900)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 9 Jul 2018 07:02:05 +0000 (10:02 +0300)
Remove some assertions

Replace std::ifstream to std::istream

Add test for new importer

Remove constructor to load file

Rename cfgStream and darknetModelStream to ifile

Add error notification to inform pathname to user

Use FileStorage instead of std::istream

Use FileNode instead of FileStorage

Fix typo

modules/dnn/include/opencv2/dnn/dnn.hpp
modules/dnn/src/darknet/darknet_importer.cpp
modules/dnn/src/darknet/darknet_io.cpp
modules/dnn/src/darknet/darknet_io.hpp
modules/dnn/test/test_darknet_importer.cpp

index f65f503..68e1994 100644 (file)
@@ -644,6 +644,14 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
     */
     CV_EXPORTS_W Net readNetFromDarknet(const String &cfgFile, const String &darknetModel = String());
 
+   /** @brief Reads a network model stored in <a href="https://pjreddie.com/darknet/">Darknet</a> model files.
+    *  @param cfgFile      file node to the .cfg file with text description of the network architecture.
+    *  @param darknetModel file node to the .weights file with learned network.
+    *  @returns Network object that ready to do forward, throw an exception in failure cases.
+    *  @returns Net object.
+    */
+    CV_EXPORTS_W Net readNetFromDarknet(const FileNode &cfgFile, const FileNode &darknetModel = FileNode());
+
     /** @brief Reads a network model stored in <a href="http://caffe.berkeleyvision.org">Caffe</a> framework's format.
       * @param prototxt   path to the .prototxt file with text description of the network architecture.
       * @param caffeModel path to the .caffemodel file with learned network.
index 8bd64d0..17506c2 100644 (file)
@@ -44,6 +44,7 @@
 #include "../precomp.hpp"
 
 #include <iostream>
+#include <fstream>
 #include <algorithm>
 #include <vector>
 #include <map>
@@ -66,14 +67,19 @@ public:
 
     DarknetImporter() {}
 
-    DarknetImporter(const char *cfgFile, const char *darknetModel)
+    DarknetImporter(std::istream &cfgStream, std::istream &darknetModelStream)
     {
         CV_TRACE_FUNCTION();
 
-        ReadNetParamsFromCfgFileOrDie(cfgFile, &net);
+        ReadNetParamsFromCfgStreamOrDie(cfgStream, &net);
+        ReadNetParamsFromBinaryStreamOrDie(darknetModelStream, &net);
+    }
+
+    DarknetImporter(std::istream &cfgStream)
+    {
+        CV_TRACE_FUNCTION();
 
-        if (darknetModel && darknetModel[0])
-            ReadNetParamsFromBinaryFileOrDie(darknetModel, &net);
+        ReadNetParamsFromCfgStreamOrDie(cfgStream, &net);
     }
 
     struct BlobNote
@@ -179,7 +185,38 @@ public:
 
 Net readNetFromDarknet(const String &cfgFile, const String &darknetModel /*= String()*/)
 {
-    DarknetImporter darknetImporter(cfgFile.c_str(), darknetModel.c_str());
+    Net net;
+    std::ifstream cfgStream(cfgFile.c_str());
+    if(!cfgStream.is_open()) {
+        CV_Error(cv::Error::StsParseError, "Failed to parse NetParameter file: " + std::string(cfgFile));
+        return net;
+    }
+    DarknetImporter darknetImporter;
+    if (darknetModel != String()) {
+        std::ifstream darknetModelStream(darknetModel.c_str());
+        if(!darknetModelStream.is_open()){
+            CV_Error(cv::Error::StsParseError, "Failed to parse NetParameter file: " + std::string(darknetModel));
+            return net;
+        }
+        darknetImporter = DarknetImporter(cfgStream, darknetModelStream);
+    } else {
+        darknetImporter = DarknetImporter(cfgStream);
+    }
+    darknetImporter.populateNet(net);
+    return net;
+}
+
+Net readNetFromDarknet(const FileNode &cfgFile, const FileNode &darknetModel /*= FileNode()*/)
+{
+    DarknetImporter darknetImporter;
+    if(darknetModel.empty()){
+        std::istringstream cfgStream((std::string)cfgFile);
+        darknetImporter = DarknetImporter(cfgStream);
+    }else{
+        std::istringstream cfgStream((std::string)cfgFile);
+        std::istringstream darknetModelStream((std::string)darknetModel);
+        darknetImporter = DarknetImporter(cfgStream, darknetModelStream);
+    }
     Net net;
     darknetImporter.populateNet(net);
     return net;
index 03805dd..815b84f 100644 (file)
@@ -476,68 +476,61 @@ namespace cv {
                 return dst;
             }
 
-            bool ReadDarknetFromCfgFile(const char *cfgFile, NetParameter *net)
+            bool ReadDarknetFromCfgStream(std::istream &ifile, NetParameter *net)
             {
-                std::ifstream ifile;
-                ifile.open(cfgFile);
-                if (ifile.is_open())
-                {
-                    bool read_net = false;
-                    int layers_counter = -1;
-                    for (std::string line; std::getline(ifile, line);) {
-                        line = escapeString(line);
-                        if (line.empty()) continue;
-                        switch (line[0]) {
-                        case '\0': break;
-                        case '#': break;
-                        case ';': break;
-                        case '[':
-                            if (line == "[net]") {
-                                read_net = true;
-                            }
-                            else {
-                                // read section
-                                read_net = false;
-                                ++layers_counter;
-                                const size_t layer_type_size = line.find("]") - 1;
-                                CV_Assert(layer_type_size < line.size());
-                                std::string layer_type = line.substr(1, layer_type_size);
-                                net->layers_cfg[layers_counter]["type"] = layer_type;
-                            }
-                            break;
-                        default:
-                            // read entry
-                            const size_t separator_index = line.find('=');
-                            CV_Assert(separator_index < line.size());
-                            if (separator_index != std::string::npos) {
-                                std::string name = line.substr(0, separator_index);
-                                std::string value = line.substr(separator_index + 1, line.size() - (separator_index + 1));
-                                name = escapeString(name);
-                                value = escapeString(value);
-                                if (name.empty() || value.empty()) continue;
-                                if (read_net)
-                                    net->net_cfg[name] = value;
-                                else
-                                    net->layers_cfg[layers_counter][name] = value;
-                            }
+                bool read_net = false;
+                int layers_counter = -1;
+                for (std::string line; std::getline(ifile, line);) {
+                    line = escapeString(line);
+                    if (line.empty()) continue;
+                    switch (line[0]) {
+                    case '\0': break;
+                    case '#': break;
+                    case ';': break;
+                    case '[':
+                        if (line == "[net]") {
+                            read_net = true;
+                        }
+                        else {
+                            // read section
+                            read_net = false;
+                            ++layers_counter;
+                            const size_t layer_type_size = line.find("]") - 1;
+                            CV_Assert(layer_type_size < line.size());
+                            std::string layer_type = line.substr(1, layer_type_size);
+                            net->layers_cfg[layers_counter]["type"] = layer_type;
+                        }
+                        break;
+                    default:
+                        // read entry
+                        const size_t separator_index = line.find('=');
+                        CV_Assert(separator_index < line.size());
+                        if (separator_index != std::string::npos) {
+                            std::string name = line.substr(0, separator_index);
+                            std::string value = line.substr(separator_index + 1, line.size() - (separator_index + 1));
+                            name = escapeString(name);
+                            value = escapeString(value);
+                            if (name.empty() || value.empty()) continue;
+                            if (read_net)
+                                net->net_cfg[name] = value;
+                            else
+                                net->layers_cfg[layers_counter][name] = value;
                         }
                     }
-
-                    std::string anchors = net->layers_cfg[net->layers_cfg.size() - 1]["anchors"];
-                    std::vector<float> vec = getNumbers<float>(anchors);
-                    std::map<std::string, std::string> &net_params = net->net_cfg;
-                    net->width = getParam(net_params, "width", 416);
-                    net->height = getParam(net_params, "height", 416);
-                    net->channels = getParam(net_params, "channels", 3);
-                    CV_Assert(net->width > 0 && net->height > 0 && net->channels > 0);
                 }
-                else
-                    return false;
+
+                std::string anchors = net->layers_cfg[net->layers_cfg.size() - 1]["anchors"];
+                std::vector<float> vec = getNumbers<float>(anchors);
+                std::map<std::string, std::string> &net_params = net->net_cfg;
+                net->width = getParam(net_params, "width", 416);
+                net->height = getParam(net_params, "height", 416);
+                net->channels = getParam(net_params, "channels", 3);
+                CV_Assert(net->width > 0 && net->height > 0 && net->channels > 0);
 
                 int current_channels = net->channels;
                 net->out_channels_vec.resize(net->layers_cfg.size());
 
-                int layers_counter = -1;
+                layers_counter = -1;
 
                 setLayersParams setParams(net);
 
@@ -676,13 +669,8 @@ namespace cv {
                 return true;
             }
 
-
-            bool ReadDarknetFromWeightsFile(const char *darknetModel, NetParameter *net)
+            bool ReadDarknetFromWeightsStream(std::istream &ifile, NetParameter *net)
             {
-                std::ifstream ifile;
-                ifile.open(darknetModel, std::ios::binary);
-                CV_Assert(ifile.is_open());
-
                 int32_t major_ver, minor_ver, revision;
                 ifile.read(reinterpret_cast<char *>(&major_ver), sizeof(int32_t));
                 ifile.read(reinterpret_cast<char *>(&minor_ver), sizeof(int32_t));
@@ -778,19 +766,18 @@ namespace cv {
         }
 
 
-        void ReadNetParamsFromCfgFileOrDie(const char *cfgFile, darknet::NetParameter *net)
+        void ReadNetParamsFromCfgStreamOrDie(std::istream &ifile, darknet::NetParameter *net)
         {
-            if (!darknet::ReadDarknetFromCfgFile(cfgFile, net)) {
-                CV_Error(cv::Error::StsParseError, "Failed to parse NetParameter file: " + std::string(cfgFile));
+            if (!darknet::ReadDarknetFromCfgStream(ifile, net)) {
+                CV_Error(cv::Error::StsParseError, "Failed to parse NetParameter stream");
             }
         }
 
-        void ReadNetParamsFromBinaryFileOrDie(const char *darknetModel, darknet::NetParameter *net)
+        void ReadNetParamsFromBinaryStreamOrDie(std::istream &ifile, darknet::NetParameter *net)
         {
-            if (!darknet::ReadDarknetFromWeightsFile(darknetModel, net)) {
-                CV_Error(cv::Error::StsParseError, "Failed to parse NetParameter file: " + std::string(darknetModel));
+            if (!darknet::ReadDarknetFromWeightsStream(ifile, net)) {
+                CV_Error(cv::Error::StsParseError, "Failed to parse NetParameter stream");
             }
         }
-
     }
 }
index 5859f73..f783ca7 100644 (file)
@@ -109,10 +109,9 @@ namespace cv {
             };
         }
 
-        // Read parameters from a file into a NetParameter message.
-        void ReadNetParamsFromCfgFileOrDie(const char *cfgFile, darknet::NetParameter *net);
-        void ReadNetParamsFromBinaryFileOrDie(const char *darknetModel, darknet::NetParameter *net);
-
+        // Read parameters from a stream into a NetParameter message.
+        void ReadNetParamsFromCfgStreamOrDie(std::istream &ifile, darknet::NetParameter *net);
+        void ReadNetParamsFromBinaryStreamOrDie(std::istream &ifile, darknet::NetParameter *net);
     }
 }
 #endif
index 682213b..c585d40 100644 (file)
@@ -65,6 +65,18 @@ TEST(Test_Darknet, read_yolo_voc)
     ASSERT_FALSE(net.empty());
 }
 
+TEST(Test_Darknet, read_filestorage_yolo_voc)
+{
+    std::ifstream ifile(_tf("yolo-voc.cfg").c_str());
+    std::stringstream buffer;
+    buffer << " " << ifile.rdbuf(); // FIXME: FileStorage drops first character.
+    FileStorage ofs(".xml", FileStorage::WRITE | FileStorage::MEMORY);
+    ofs.write("cfgFile", buffer.str());
+    FileStorage ifs(ofs.releaseAndGetString(), FileStorage::READ | FileStorage::MEMORY | FileStorage::FORMAT_XML);
+    Net net = readNetFromDarknet(ifs["cfgFile"]);
+    ASSERT_FALSE(net.empty());
+}
+
 class Test_Darknet_layers : public DNNTestLayer
 {
 public: