#include "ie_core.hpp"
#include <unordered_set>
-#include <fstream>
#include <functional>
#include <limits>
#include <map>
#include <memory>
-#include <sstream>
-#include <streambuf>
#include <string>
#include <utility>
#include <vector>
#include <istream>
#include <mutex>
-#include "ie_blob_stream.hpp"
-#include <ie_reader_ptr.hpp>
#include <ngraph/opsets/opset.hpp>
#include "cpp/ie_cnn_net_reader.h"
#include "cpp/ie_plugin_cpp.hpp"
#include "cpp_interfaces/base/ie_plugin_base.hpp"
#include "details/ie_exception_conversion.hpp"
#include "details/ie_so_pointer.hpp"
-#include "file_utils.h"
#include "ie_icore.hpp"
#include "ie_plugin.hpp"
#include "ie_plugin_config.hpp"
#include "ie_profiling.hpp"
#include "ie_util_internal.hpp"
+#include "ie_network_reader.hpp"
#include "multi-device/multi_device_config.hpp"
#include "xml_parse_utils.h"
} // namespace
-class Reader: public IReader {
-private:
- InferenceEngine::IReaderPtr ptr;
- std::once_flag readFlag;
- std::string name;
- std::string location;
-
- InferenceEngine::IReaderPtr getReaderPtr() {
- std::call_once(readFlag, [&] () {
- FileUtils::FilePath libraryName = FileUtils::toFilePath(location);
- FileUtils::FilePath readersLibraryPath = FileUtils::makeSharedLibraryName(getInferenceEngineLibraryPath(), libraryName);
-
- if (!FileUtils::fileExist(readersLibraryPath)) {
- THROW_IE_EXCEPTION << "Please, make sure that Inference Engine ONNX reader library "
- << FileUtils::fromFilePath(::FileUtils::makeSharedLibraryName({}, libraryName)) << " is in "
- << getIELibraryPath();
- }
- ptr = IReaderPtr(readersLibraryPath);
- });
-
- return ptr;
- }
-
- InferenceEngine::IReaderPtr getReaderPtr() const {
- return const_cast<Reader*>(this)->getReaderPtr();
- }
-
- void Release() noexcept override {
- delete this;
- }
-
-public:
- using Ptr = std::shared_ptr<Reader>;
- Reader(const std::string& name, const std::string location): name(name), location(location) {}
- bool supportModel(std::istream& model) const override {
- auto reader = getReaderPtr();
- return reader->supportModel(model);
- }
- CNNNetwork read(std::istream& model, const std::vector<IExtensionPtr>& exts) const override {
- auto reader = getReaderPtr();
- return reader->read(model, exts);
- }
- CNNNetwork read(std::istream& model, std::istream& weights, const std::vector<IExtensionPtr>& exts) const override {
- auto reader = getReaderPtr();
- return reader->read(model, weights, exts);
- }
- std::vector<std::string> getDataFileExtensions() const override {
- auto reader = getReaderPtr();
- return reader->getDataFileExtensions();
- }
- std::string getName() const {
- return name;
- }
-};
-
-namespace {
-
-// Extension to plugins creator
-std::multimap<std::string, Reader::Ptr> readers;
-
-void registerReaders() {
- static std::mutex readerMutex;
- std::lock_guard<std::mutex> lock(readerMutex);
- // TODO: Read readers info from XML
- auto onnxReader = std::make_shared<Reader>("ONNX", std::string("inference_engine_onnx_reader") + std::string(IE_BUILD_POSTFIX));
- readers.emplace("onnx", onnxReader);
- readers.emplace("prototxt", onnxReader);
- auto irReader = std::make_shared<Reader>("IR", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
- readers.emplace("xml", irReader);
-}
-
-} // namespace
-
CNNNetReaderPtr CreateCNNNetReaderPtr() noexcept {
auto loader = createCnnReaderLoader();
return CNNNetReaderPtr(loader);
CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath) const override {
IE_PROFILING_AUTO_SCOPE(Core::ReadNetwork)
-
- std::ifstream modelStream(modelPath, std::ios::binary);
- if (!modelStream.is_open())
- THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!";
-
- auto fileExt = modelPath.substr(modelPath.find_last_of(".") + 1);
- for (auto it = readers.lower_bound(fileExt); it != readers.upper_bound(fileExt); it++) {
- auto reader = it->second;
- if (reader->supportModel(modelStream)) {
- // Find weights
- std::string bPath = binPath;
- if (bPath.empty()) {
- auto pathWoExt = modelPath;
- auto pos = modelPath.rfind('.');
- if (pos != std::string::npos) pathWoExt = modelPath.substr(0, pos);
- for (const auto& ext : reader->getDataFileExtensions()) {
- bPath = pathWoExt + "." + ext;
- if (!FileUtils::fileExist(bPath)) {
- bPath.clear();
- } else {
- break;
- }
- }
- }
- if (!bPath.empty()) {
- std::ifstream binStream;
- binStream.open(bPath, std::ios::binary);
- if (!binStream.is_open())
- THROW_IE_EXCEPTION << "Weights file " << bPath << " cannot be opened!";
- return reader->read(modelStream, binStream, extensions);
- }
- return reader->read(modelStream, extensions);
- }
- }
- THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model: " << modelPath;
+ return details::ReadNetwork(modelPath, binPath, extensions);
}
CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights) const override {
IE_PROFILING_AUTO_SCOPE(Core::ReadNetwork)
- std::istringstream modelStream(model);
- details::BlobStream binStream(weights);
-
- for (auto it = readers.begin(); it != readers.end(); it++) {
- auto reader = it->second;
- if (reader->supportModel(modelStream)) {
- if (weights)
- return reader->read(modelStream, binStream, extensions);
- return reader->read(modelStream, extensions);
- }
- }
- THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model from string!";
+ return details::ReadNetwork(model, weights, extensions);
}
ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::string& deviceName,
opsetNames.insert("opset1");
opsetNames.insert("opset2");
opsetNames.insert("opset3");
- registerReaders();
}
Core::Impl::~Impl() {}
--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "ie_network_reader.hpp"
+
+#include <details/ie_so_pointer.hpp>
+#include <file_utils.h>
+#include <ie_blob_stream.hpp>
+#include <ie_profiling.hpp>
+#include <ie_reader.hpp>
+
+#include <fstream>
+#include <istream>
+#include <mutex>
+#include <map>
+
+namespace InferenceEngine {
+
+namespace details {
+
+/**
+ * @brief This class defines the name of the fabric for creating an IReader object in DLL
+ */
+template <>
+class SOCreatorTrait<IReader> {
+public:
+ /**
+ * @brief A name of the fabric for creating IReader object in DLL
+ */
+ static constexpr auto name = "CreateReader";
+};
+
+} // namespace details
+
+/**
+ * @brief This class is a wrapper for reader interfaces
+ */
+class Reader: public IReader {
+private:
+ InferenceEngine::details::SOPointer<IReader> ptr;
+ std::once_flag readFlag;
+ std::string name;
+ std::string location;
+
+ InferenceEngine::details::SOPointer<IReader> getReaderPtr() {
+ std::call_once(readFlag, [&] () {
+ FileUtils::FilePath libraryName = FileUtils::toFilePath(location);
+ FileUtils::FilePath readersLibraryPath = FileUtils::makeSharedLibraryName(getInferenceEngineLibraryPath(), libraryName);
+
+ if (!FileUtils::fileExist(readersLibraryPath)) {
+ THROW_IE_EXCEPTION << "Please, make sure that Inference Engine ONNX reader library "
+ << FileUtils::fromFilePath(::FileUtils::makeSharedLibraryName({}, libraryName)) << " is in "
+ << getIELibraryPath();
+ }
+ ptr = InferenceEngine::details::SOPointer<IReader>(readersLibraryPath);
+ });
+
+ return ptr;
+ }
+
+ InferenceEngine::details::SOPointer<IReader> getReaderPtr() const {
+ return const_cast<Reader*>(this)->getReaderPtr();
+ }
+
+ void Release() noexcept override {
+ delete this;
+ }
+
+public:
+ using Ptr = std::shared_ptr<Reader>;
+ Reader(const std::string& name, const std::string location): name(name), location(location) {}
+ bool supportModel(std::istream& model) const override {
+ auto reader = getReaderPtr();
+ return reader->supportModel(model);
+ }
+ CNNNetwork read(std::istream& model, const std::vector<IExtensionPtr>& exts) const override {
+ auto reader = getReaderPtr();
+ return reader->read(model, exts);
+ }
+ CNNNetwork read(std::istream& model, std::istream& weights, const std::vector<IExtensionPtr>& exts) const override {
+ auto reader = getReaderPtr();
+ return reader->read(model, weights, exts);
+ }
+ std::vector<std::string> getDataFileExtensions() const override {
+ auto reader = getReaderPtr();
+ return reader->getDataFileExtensions();
+ }
+ std::string getName() const {
+ return name;
+ }
+};
+
+namespace {
+
+// Extension to plugins creator
+std::multimap<std::string, Reader::Ptr> readers;
+
+void registerReaders() {
+ IE_PROFILING_AUTO_SCOPE(details::registerReaders)
+ static bool initialized = false;
+ static std::mutex readerMutex;
+ std::lock_guard<std::mutex> lock(readerMutex);
+ if (initialized) return;
+ // TODO: Read readers info from XML
+ auto onnxReader = std::make_shared<Reader>("ONNX", std::string("inference_engine_onnx_reader") + std::string(IE_BUILD_POSTFIX));
+ readers.emplace("onnx", onnxReader);
+ readers.emplace("prototxt", onnxReader);
+ auto irReader = std::make_shared<Reader>("IR", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
+ readers.emplace("xml", irReader);
+ initialized = true;
+}
+
+} // namespace
+
+CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string& binPath, const std::vector<IExtensionPtr>& exts) {
+ IE_PROFILING_AUTO_SCOPE(details::ReadNetwork)
+ // Register readers if it is needed
+ registerReaders();
+
+ // Fix unicode name
+#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
+ std::wstring model_path = InferenceEngine::details::multiByteCharToWString(modelPath.c_str());
+#else
+ std::string model_path = modelPath;
+#endif
+ // Try to open model file
+ std::ifstream modelStream(model_path, std::ios::binary);
+ if (!modelStream.is_open())
+ THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!";
+
+ // Find reader for model extension
+ auto fileExt = modelPath.substr(modelPath.find_last_of(".") + 1);
+ for (auto it = readers.lower_bound(fileExt); it != readers.upper_bound(fileExt); it++) {
+ auto reader = it->second;
+ // Check that reader supports the model
+ if (reader->supportModel(modelStream)) {
+ // Find weights
+ std::string bPath = binPath;
+ if (bPath.empty()) {
+ auto pathWoExt = modelPath;
+ auto pos = modelPath.rfind('.');
+ if (pos != std::string::npos) pathWoExt = modelPath.substr(0, pos);
+ for (const auto& ext : reader->getDataFileExtensions()) {
+ bPath = pathWoExt + "." + ext;
+ if (!FileUtils::fileExist(bPath)) {
+ bPath.clear();
+ } else {
+ break;
+ }
+ }
+ }
+ if (!bPath.empty()) {
+ // Open weights file
+#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
+ std::wstring weights_path = InferenceEngine::details::multiByteCharToWString(bPath.c_str());
+#else
+ std::string weights_path = bPath;
+#endif
+ std::ifstream binStream;
+ binStream.open(weights_path, std::ios::binary);
+ if (!binStream.is_open())
+ THROW_IE_EXCEPTION << "Weights file " << bPath << " cannot be opened!";
+
+ // read model with weights
+ return reader->read(modelStream, binStream, exts);
+ }
+ // read model without weights
+ return reader->read(modelStream, exts);
+ }
+ }
+ THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model: " << modelPath;
+}
+
+CNNNetwork details::ReadNetwork(const std::string& model, const Blob::CPtr& weights, const std::vector<IExtensionPtr>& exts) {
+ IE_PROFILING_AUTO_SCOPE(details::ReadNetwork)
+ // Register readers if it is needed
+ registerReaders();
+ std::istringstream modelStream(model);
+ details::BlobStream binStream(weights);
+
+ for (auto it = readers.begin(); it != readers.end(); it++) {
+ auto reader = it->second;
+ if (reader->supportModel(modelStream)) {
+ if (weights)
+ return reader->read(modelStream, binStream, exts);
+ return reader->read(modelStream, exts);
+ }
+ }
+ THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model from string!";
+}
+
+} // namespace InferenceEngine
--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <cpp/ie_cnn_network.h>
+#include <ie_blob.h>
+#include <string>
+
+namespace InferenceEngine {
+namespace details {
+
+/**
+ * @brief Reads IR xml and bin files
+ * @param modelPath path to IR file
+ * @param binPath path to bin file, if path is empty, will try to read bin file with the same name as xml and
+ * if bin file with the same name was not found, will load IR without weights.
+ * @param exts vector with extensions
+ * @return CNNNetwork
+ */
+CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath, const std::vector<IExtensionPtr>& exts);
+/**
+ * @brief Reads IR xml and bin (with the same name) files
+ * @param model string with IR
+ * @param weights shared pointer to constant blob with weights
+ * @param exts vector with extensions
+ * @return CNNNetwork
+ */
+CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights, const std::vector<IExtensionPtr>& exts);
+
+} // namespace details
+} // namespace InferenceEngine
+++ /dev/null
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#pragma once
-
-#include <string>
-
-#include <details/ie_so_pointer.hpp>
-#include "ie_reader.hpp"
-
-namespace InferenceEngine {
-namespace details {
-
-/**
- * @brief This class defines the name of the fabric for creating an IReader object in DLL
- */
-template <>
-class SOCreatorTrait<IReader> {
-public:
- /**
- * @brief A name of the fabric for creating IReader object in DLL
- */
- static constexpr auto name = "CreateReader";
-};
-
-} // namespace details
-
-/**
- * @brief A C++ helper to work with objects created by the plugin.
- *
- * Implements different interfaces.
- */
-using IReaderPtr = InferenceEngine::details::SOPointer<IReader>;
-
-} // namespace InferenceEngine
#ifdef ENABLE_UNICODE_PATH_SUPPORT
-TEST_P(NetReaderTest, DISABLED_ReadCorrectModelWithWeightsUnicodePath) {
+TEST_P(NetReaderTest, ReadCorrectModelWithWeightsUnicodePath) {
GTEST_COUT << "params.modelPath: '" << _modelPath << "'" << std::endl;
GTEST_COUT << "params.weightsPath: '" << _weightsPath << "'" << std::endl;
GTEST_COUT << "params.netPrc: '" << _netPrc.name() << "'" << std::endl;