From 37357350ad5c103f2f8ea90e2bbffc15df9f886f Mon Sep 17 00:00:00 2001 From: Anna Alberska Date: Fri, 28 Aug 2020 12:48:48 +0200 Subject: [PATCH] refactor import network (#1871) --- .../src/gna_plugin/gna_executable_network.hpp | 19 ++++++++++++------- inference-engine/src/gna_plugin/gna_plugin.cpp | 10 ---------- inference-engine/src/gna_plugin/gna_plugin.hpp | 1 - .../src/gna_plugin/gna_plugin_internal.hpp | 3 ++- .../tests_deprecated/unit/engines/gna/gna_matcher.cpp | 14 ++++++++++++-- 5 files changed, 26 insertions(+), 21 deletions(-) diff --git a/inference-engine/src/gna_plugin/gna_executable_network.hpp b/inference-engine/src/gna_plugin/gna_executable_network.hpp index dbfe8ad..6fa2cb7 100644 --- a/inference-engine/src/gna_plugin/gna_executable_network.hpp +++ b/inference-engine/src/gna_plugin/gna_executable_network.hpp @@ -21,12 +21,17 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe std::shared_ptr plg; public: - GNAExecutableNetwork(const std::string &aotFileName, std::shared_ptr plg) - : plg(plg) { - plg->ImportNetwork(aotFileName); - _networkInputs = plg->GetInputs(); - _networkOutputs = plg->GetOutputs(); - } + GNAExecutableNetwork(const std::string& aotFileName, std::shared_ptr plg) + : plg(plg) { + std::fstream inputStream(aotFileName, std::ios_base::in | std::ios_base::binary); + if (inputStream.fail()) { + THROW_GNA_EXCEPTION << "Cannot open file to import model: " << aotFileName; + } + + plg->ImportNetwork(inputStream); + _networkInputs = plg->GetInputs(); + _networkOutputs = plg->GetOutputs(); + } GNAExecutableNetwork(std::istream& networkModel, std::shared_ptr plg) : plg(plg) { @@ -40,7 +45,7 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe plg->LoadNetwork(network); } - GNAExecutableNetwork(const std::string &aotFileName, const std::map &config) + GNAExecutableNetwork(const std::string& aotFileName, const std::map& config) : GNAExecutableNetwork(aotFileName, std::make_shared(config)) { } diff --git a/inference-engine/src/gna_plugin/gna_plugin.cpp b/inference-engine/src/gna_plugin/gna_plugin.cpp index 5120dbd..23179d7 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.cpp +++ b/inference-engine/src/gna_plugin/gna_plugin.cpp @@ -1134,16 +1134,6 @@ void GNAPlugin::SetName(const std::string & pluginName) noexcept { _pluginName = pluginName; } -InferenceEngine::IExecutableNetwork::Ptr GNAPlugin::ImportNetwork(const std::string &modelFileName) { - // no need to return anything dueto weird design of internal base classes - std::fstream inputStream(modelFileName, ios_base::in | ios_base::binary); - if (inputStream.fail()) { - THROW_GNA_EXCEPTION << "Cannot open file to import model: " << modelFileName; - } - - return ImportNetwork(inputStream); -} - InferenceEngine::IExecutableNetwork::Ptr GNAPlugin::ImportNetwork(std::istream& networkModel) { auto header = GNAModelSerial::ReadHeader(networkModel); diff --git a/inference-engine/src/gna_plugin/gna_plugin.hpp b/inference-engine/src/gna_plugin/gna_plugin.hpp index 89e54c6..839d9ab 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.hpp +++ b/inference-engine/src/gna_plugin/gna_plugin.hpp @@ -146,7 +146,6 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin { THROW_GNA_EXCEPTION << "Not implemented"; } - InferenceEngine::IExecutableNetwork::Ptr ImportNetwork(const std::string &modelFileName); InferenceEngine::IExecutableNetwork::Ptr ImportNetwork(std::istream& networkModel); /** diff --git a/inference-engine/src/gna_plugin/gna_plugin_internal.hpp b/inference-engine/src/gna_plugin/gna_plugin_internal.hpp index a43f087..df42d03 100644 --- a/inference-engine/src/gna_plugin/gna_plugin_internal.hpp +++ b/inference-engine/src/gna_plugin/gna_plugin_internal.hpp @@ -43,13 +43,14 @@ public: defaultConfig.UpdateFromMap(config); } - InferenceEngine::IExecutableNetwork::Ptr ImportNetwork( + InferenceEngine::IExecutableNetwork::Ptr ImportNetwork( const std::string &modelFileName, const std::map &config) override { Config updated_config(defaultConfig); updated_config.UpdateFromMap(config); auto plg = std::make_shared(updated_config.key_config_map); plgPtr = plg; + return make_executable_network(std::make_shared(modelFileName, plg)); } diff --git a/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp b/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp index 573d178..6a71cb4 100644 --- a/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp +++ b/inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp @@ -201,7 +201,12 @@ void GNAPropagateMatcher :: match() { }; auto loadNetworkFromAOT = [&] () { - auto sp = plugin.ImportNetwork(_env.importedModelFileName); + std::fstream inputStream(_env.importedModelFileName, std::ios_base::in | std::ios_base::binary); + if (inputStream.fail()) { + THROW_GNA_EXCEPTION << "Cannot open file to import model: " << _env.importedModelFileName; + } + + auto sp = plugin.ImportNetwork(inputStream); inputsInfo = plugin.GetInputs(); outputsInfo = plugin.GetOutputs(); }; @@ -604,7 +609,12 @@ void GNADumpXNNMatcher::load(std::shared_ptr & plugin) { }; auto loadNetworkFromAOT = [&]() { - plugin->ImportNetwork(_env.importedModelFileName); + std::fstream inputStream(_env.importedModelFileName, std::ios_base::in | std::ios_base::binary); + if (inputStream.fail()) { + THROW_GNA_EXCEPTION << "Cannot open file to import model: " << _env.importedModelFileName; + } + + plugin->ImportNetwork(inputStream); }; auto loadNetwork = [&]() { -- 2.7.4