Enabled Unit tests and remove IReaderPtr (#653)
authorIlya Churaev <ilya.churaev@intel.com>
Thu, 28 May 2020 19:40:20 +0000 (22:40 +0300)
committerGitHub <noreply@github.com>
Thu, 28 May 2020 19:40:20 +0000 (22:40 +0300)
* Enabled Unit tests and remove IReaderPtr

* Fixed unicode tests for Windows

* Fixed typo

inference-engine/src/inference_engine/ie_core.cpp
inference-engine/src/inference_engine/ie_network_reader.cpp [new file with mode: 0644]
inference-engine/src/inference_engine/ie_network_reader.hpp [new file with mode: 0644]
inference-engine/src/readers/reader_api/ie_reader_ptr.hpp [deleted file]
inference-engine/tests/functional/inference_engine/net_reader_test.cpp

index 0aab921..754e530 100644 (file)
@@ -5,33 +5,28 @@
 #include "ie_core.hpp"
 
 #include <unordered_set>
-#include <fstream>
 #include <functional>
 #include <limits>
 #include <map>
 #include <memory>
-#include <sstream>
-#include <streambuf>
 #include <string>
 #include <utility>
 #include <vector>
 #include <istream>
 #include <mutex>
 
-#include "ie_blob_stream.hpp"
-#include <ie_reader_ptr.hpp>
 #include <ngraph/opsets/opset.hpp>
 #include "cpp/ie_cnn_net_reader.h"
 #include "cpp/ie_plugin_cpp.hpp"
 #include "cpp_interfaces/base/ie_plugin_base.hpp"
 #include "details/ie_exception_conversion.hpp"
 #include "details/ie_so_pointer.hpp"
-#include "file_utils.h"
 #include "ie_icore.hpp"
 #include "ie_plugin.hpp"
 #include "ie_plugin_config.hpp"
 #include "ie_profiling.hpp"
 #include "ie_util_internal.hpp"
+#include "ie_network_reader.hpp"
 #include "multi-device/multi_device_config.hpp"
 #include "xml_parse_utils.h"
 
@@ -133,79 +128,6 @@ Parameter copyParameterValue(const Parameter & value) {
 
 }  // namespace
 
-class Reader: public IReader {
-private:
-    InferenceEngine::IReaderPtr ptr;
-    std::once_flag readFlag;
-    std::string name;
-    std::string location;
-
-    InferenceEngine::IReaderPtr getReaderPtr() {
-        std::call_once(readFlag, [&] () {
-            FileUtils::FilePath libraryName = FileUtils::toFilePath(location);
-            FileUtils::FilePath readersLibraryPath = FileUtils::makeSharedLibraryName(getInferenceEngineLibraryPath(), libraryName);
-
-            if (!FileUtils::fileExist(readersLibraryPath)) {
-                THROW_IE_EXCEPTION << "Please, make sure that Inference Engine ONNX reader library "
-                    << FileUtils::fromFilePath(::FileUtils::makeSharedLibraryName({}, libraryName)) << " is in "
-                    << getIELibraryPath();
-            }
-            ptr = IReaderPtr(readersLibraryPath);
-        });
-
-        return ptr;
-    }
-
-    InferenceEngine::IReaderPtr getReaderPtr() const {
-        return const_cast<Reader*>(this)->getReaderPtr();
-    }
-
-    void Release() noexcept override {
-        delete this;
-    }
-
-public:
-    using Ptr = std::shared_ptr<Reader>;
-    Reader(const std::string& name, const std::string location): name(name), location(location) {}
-    bool supportModel(std::istream& model) const override {
-        auto reader = getReaderPtr();
-        return reader->supportModel(model);
-    }
-    CNNNetwork read(std::istream& model, const std::vector<IExtensionPtr>& exts) const override {
-        auto reader = getReaderPtr();
-        return reader->read(model, exts);
-    }
-    CNNNetwork read(std::istream& model, std::istream& weights, const std::vector<IExtensionPtr>& exts) const override {
-        auto reader = getReaderPtr();
-        return reader->read(model, weights, exts);
-    }
-    std::vector<std::string> getDataFileExtensions() const override {
-        auto reader = getReaderPtr();
-        return reader->getDataFileExtensions();
-    }
-    std::string getName() const {
-        return name;
-    }
-};
-
-namespace {
-
-// Extension to plugins creator
-std::multimap<std::string, Reader::Ptr> readers;
-
-void registerReaders() {
-    static std::mutex readerMutex;
-    std::lock_guard<std::mutex> lock(readerMutex);
-    // TODO: Read readers info from XML
-    auto onnxReader = std::make_shared<Reader>("ONNX", std::string("inference_engine_onnx_reader") + std::string(IE_BUILD_POSTFIX));
-    readers.emplace("onnx", onnxReader);
-    readers.emplace("prototxt", onnxReader);
-    auto irReader = std::make_shared<Reader>("IR", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
-    readers.emplace("xml", irReader);
-}
-
-}  // namespace
-
 CNNNetReaderPtr CreateCNNNetReaderPtr() noexcept {
     auto loader = createCnnReaderLoader();
     return CNNNetReaderPtr(loader);
@@ -374,57 +296,12 @@ public:
 
     CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath) const override {
         IE_PROFILING_AUTO_SCOPE(Core::ReadNetwork)
-
-        std::ifstream modelStream(modelPath, std::ios::binary);
-        if (!modelStream.is_open())
-            THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!";
-
-        auto fileExt = modelPath.substr(modelPath.find_last_of(".") + 1);
-        for (auto it = readers.lower_bound(fileExt); it != readers.upper_bound(fileExt); it++) {
-            auto reader = it->second;
-            if (reader->supportModel(modelStream)) {
-                // Find weights
-                std::string bPath = binPath;
-                if (bPath.empty()) {
-                    auto pathWoExt = modelPath;
-                    auto pos = modelPath.rfind('.');
-                    if (pos != std::string::npos) pathWoExt = modelPath.substr(0, pos);
-                    for (const auto& ext : reader->getDataFileExtensions()) {
-                        bPath = pathWoExt + "." + ext;
-                        if (!FileUtils::fileExist(bPath)) {
-                            bPath.clear();
-                        } else {
-                            break;
-                        }
-                    }
-                }
-                if (!bPath.empty()) {
-                    std::ifstream binStream;
-                    binStream.open(bPath, std::ios::binary);
-                    if (!binStream.is_open())
-                        THROW_IE_EXCEPTION << "Weights file " << bPath << " cannot be opened!";
-                    return reader->read(modelStream, binStream, extensions);
-                }
-                return reader->read(modelStream, extensions);
-            }
-        }
-        THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model: " << modelPath;
+        return details::ReadNetwork(modelPath, binPath, extensions);
     }
 
     CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights) const override {
         IE_PROFILING_AUTO_SCOPE(Core::ReadNetwork)
-        std::istringstream modelStream(model);
-        details::BlobStream binStream(weights);
-
-        for (auto it = readers.begin(); it != readers.end(); it++) {
-            auto reader = it->second;
-            if (reader->supportModel(modelStream)) {
-                if (weights)
-                    return reader->read(modelStream, binStream, extensions);
-                return reader->read(modelStream, extensions);
-            }
-        }
-        THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model from string!";
+        return details::ReadNetwork(model, weights, extensions);
     }
 
     ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::string& deviceName,
@@ -704,7 +581,6 @@ Core::Impl::Impl() {
     opsetNames.insert("opset1");
     opsetNames.insert("opset2");
     opsetNames.insert("opset3");
-    registerReaders();
 }
 
 Core::Impl::~Impl() {}
diff --git a/inference-engine/src/inference_engine/ie_network_reader.cpp b/inference-engine/src/inference_engine/ie_network_reader.cpp
new file mode 100644 (file)
index 0000000..9d739b6
--- /dev/null
@@ -0,0 +1,193 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "ie_network_reader.hpp"
+
+#include <details/ie_so_pointer.hpp>
+#include <file_utils.h>
+#include <ie_blob_stream.hpp>
+#include <ie_profiling.hpp>
+#include <ie_reader.hpp>
+
+#include <fstream>
+#include <istream>
+#include <mutex>
+#include <map>
+
+namespace InferenceEngine {
+
+namespace details {
+
+/**
+ * @brief This class defines the name of the fabric for creating an IReader object in DLL
+ */
+template <>
+class SOCreatorTrait<IReader> {
+public:
+    /**
+     * @brief A name of the fabric for creating IReader object in DLL
+     */
+    static constexpr auto name = "CreateReader";
+};
+
+}  // namespace details
+
+/**
+ * @brief This class is a wrapper for reader interfaces
+ */
+class Reader: public IReader {
+private:
+    InferenceEngine::details::SOPointer<IReader> ptr;
+    std::once_flag readFlag;
+    std::string name;
+    std::string location;
+
+    InferenceEngine::details::SOPointer<IReader> getReaderPtr() {
+        std::call_once(readFlag, [&] () {
+            FileUtils::FilePath libraryName = FileUtils::toFilePath(location);
+            FileUtils::FilePath readersLibraryPath = FileUtils::makeSharedLibraryName(getInferenceEngineLibraryPath(), libraryName);
+
+            if (!FileUtils::fileExist(readersLibraryPath)) {
+                THROW_IE_EXCEPTION << "Please, make sure that Inference Engine ONNX reader library "
+                    << FileUtils::fromFilePath(::FileUtils::makeSharedLibraryName({}, libraryName)) << " is in "
+                    << getIELibraryPath();
+            }
+            ptr = InferenceEngine::details::SOPointer<IReader>(readersLibraryPath);
+        });
+
+        return ptr;
+    }
+
+    InferenceEngine::details::SOPointer<IReader> getReaderPtr() const {
+        return const_cast<Reader*>(this)->getReaderPtr();
+    }
+
+    void Release() noexcept override {
+        delete this;
+    }
+
+public:
+    using Ptr = std::shared_ptr<Reader>;
+    Reader(const std::string& name, const std::string location): name(name), location(location) {}
+    bool supportModel(std::istream& model) const override {
+        auto reader = getReaderPtr();
+        return reader->supportModel(model);
+    }
+    CNNNetwork read(std::istream& model, const std::vector<IExtensionPtr>& exts) const override {
+        auto reader = getReaderPtr();
+        return reader->read(model, exts);
+    }
+    CNNNetwork read(std::istream& model, std::istream& weights, const std::vector<IExtensionPtr>& exts) const override {
+        auto reader = getReaderPtr();
+        return reader->read(model, weights, exts);
+    }
+    std::vector<std::string> getDataFileExtensions() const override {
+        auto reader = getReaderPtr();
+        return reader->getDataFileExtensions();
+    }
+    std::string getName() const {
+        return name;
+    }
+};
+
+namespace {
+
+// Extension to plugins creator
+std::multimap<std::string, Reader::Ptr> readers;
+
+void registerReaders() {
+    IE_PROFILING_AUTO_SCOPE(details::registerReaders)
+    static bool initialized = false;
+    static std::mutex readerMutex;
+    std::lock_guard<std::mutex> lock(readerMutex);
+    if (initialized) return;
+    // TODO: Read readers info from XML
+    auto onnxReader = std::make_shared<Reader>("ONNX", std::string("inference_engine_onnx_reader") + std::string(IE_BUILD_POSTFIX));
+    readers.emplace("onnx", onnxReader);
+    readers.emplace("prototxt", onnxReader);
+    auto irReader = std::make_shared<Reader>("IR", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
+    readers.emplace("xml", irReader);
+    initialized = true;
+}
+
+}  // namespace
+
+CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string& binPath, const std::vector<IExtensionPtr>& exts) {
+    IE_PROFILING_AUTO_SCOPE(details::ReadNetwork)
+    // Register readers if it is needed
+    registerReaders();
+
+    // Fix unicode name
+#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
+    std::wstring model_path = InferenceEngine::details::multiByteCharToWString(modelPath.c_str());
+#else
+    std::string model_path = modelPath;
+#endif
+    // Try to open model file
+    std::ifstream modelStream(model_path, std::ios::binary);
+    if (!modelStream.is_open())
+        THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!";
+
+    // Find reader for model extension
+    auto fileExt = modelPath.substr(modelPath.find_last_of(".") + 1);
+    for (auto it = readers.lower_bound(fileExt); it != readers.upper_bound(fileExt); it++) {
+        auto reader = it->second;
+        // Check that reader supports the model
+        if (reader->supportModel(modelStream)) {
+            // Find weights
+            std::string bPath = binPath;
+            if (bPath.empty()) {
+                auto pathWoExt = modelPath;
+                auto pos = modelPath.rfind('.');
+                if (pos != std::string::npos) pathWoExt = modelPath.substr(0, pos);
+                for (const auto& ext : reader->getDataFileExtensions()) {
+                    bPath = pathWoExt + "." + ext;
+                    if (!FileUtils::fileExist(bPath)) {
+                        bPath.clear();
+                    } else {
+                        break;
+                    }
+                }
+            }
+            if (!bPath.empty()) {
+                // Open weights file
+#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
+                std::wstring weights_path = InferenceEngine::details::multiByteCharToWString(bPath.c_str());
+#else
+                std::string weights_path = bPath;
+#endif
+                std::ifstream binStream;
+                binStream.open(weights_path, std::ios::binary);
+                if (!binStream.is_open())
+                    THROW_IE_EXCEPTION << "Weights file " << bPath << " cannot be opened!";
+
+                // read model with weights
+                return reader->read(modelStream, binStream, exts);
+            }
+            // read model without weights
+            return reader->read(modelStream, exts);
+        }
+    }
+    THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model: " << modelPath;
+}
+
+CNNNetwork details::ReadNetwork(const std::string& model, const Blob::CPtr& weights, const std::vector<IExtensionPtr>& exts) {
+    IE_PROFILING_AUTO_SCOPE(details::ReadNetwork)
+    // Register readers if it is needed
+    registerReaders();
+    std::istringstream modelStream(model);
+    details::BlobStream binStream(weights);
+
+    for (auto it = readers.begin(); it != readers.end(); it++) {
+        auto reader = it->second;
+        if (reader->supportModel(modelStream)) {
+            if (weights)
+                return reader->read(modelStream, binStream, exts);
+            return reader->read(modelStream, exts);
+        }
+    }
+    THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model from string!";
+}
+
+}  // namespace InferenceEngine
diff --git a/inference-engine/src/inference_engine/ie_network_reader.hpp b/inference-engine/src/inference_engine/ie_network_reader.hpp
new file mode 100644 (file)
index 0000000..2d8ea63
--- /dev/null
@@ -0,0 +1,33 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <cpp/ie_cnn_network.h>
+#include <ie_blob.h>
+#include <string>
+
+namespace InferenceEngine {
+namespace details {
+
+/**
+ * @brief Reads IR xml and bin files
+ * @param modelPath path to IR file
+ * @param binPath path to bin file, if path is empty, will try to read bin file with the same name as xml and
+ * if bin file with the same name was not found, will load IR without weights.
+ * @param exts vector with extensions
+ * @return CNNNetwork
+ */
+CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath, const std::vector<IExtensionPtr>& exts);
+/**
+ * @brief Reads IR xml and bin (with the same name) files
+ * @param model string with IR
+ * @param weights shared pointer to constant blob with weights
+ * @param exts vector with extensions
+ * @return CNNNetwork
+ */
+CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights, const std::vector<IExtensionPtr>& exts);
+
+}  // namespace details
+}  // namespace InferenceEngine
diff --git a/inference-engine/src/readers/reader_api/ie_reader_ptr.hpp b/inference-engine/src/readers/reader_api/ie_reader_ptr.hpp
deleted file mode 100644 (file)
index 9c3aee3..0000000
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#pragma once
-
-#include <string>
-
-#include <details/ie_so_pointer.hpp>
-#include "ie_reader.hpp"
-
-namespace InferenceEngine {
-namespace details {
-
-/**
- * @brief This class defines the name of the fabric for creating an IReader object in DLL
- */
-template <>
-class SOCreatorTrait<IReader> {
-public:
-    /**
-     * @brief A name of the fabric for creating IReader object in DLL
-     */
-    static constexpr auto name = "CreateReader";
-};
-
-}  // namespace details
-
-/**
- * @brief A C++ helper to work with objects created by the plugin.
- *
- * Implements different interfaces.
- */
-using IReaderPtr = InferenceEngine::details::SOPointer<IReader>;
-
-}  // namespace InferenceEngine
index 4df6db8..2d008bb 100644 (file)
@@ -107,7 +107,7 @@ TEST_P(NetReaderTest, ReadNetworkTwiceSeparately) {
 
 #ifdef ENABLE_UNICODE_PATH_SUPPORT
 
-TEST_P(NetReaderTest, DISABLED_ReadCorrectModelWithWeightsUnicodePath) {
+TEST_P(NetReaderTest, ReadCorrectModelWithWeightsUnicodePath) {
     GTEST_COUT << "params.modelPath: '" << _modelPath << "'" << std::endl;
     GTEST_COUT << "params.weightsPath: '" << _weightsPath << "'" << std::endl;
     GTEST_COUT << "params.netPrc: '" << _netPrc.name() << "'" << std::endl;