refactor import network (#1871)
authorAnna Alberska <anna.alberska@intel.com>
Fri, 28 Aug 2020 10:48:48 +0000 (12:48 +0200)
committerGitHub <noreply@github.com>
Fri, 28 Aug 2020 10:48:48 +0000 (13:48 +0300)
inference-engine/src/gna_plugin/gna_executable_network.hpp
inference-engine/src/gna_plugin/gna_plugin.cpp
inference-engine/src/gna_plugin/gna_plugin.hpp
inference-engine/src/gna_plugin/gna_plugin_internal.hpp
inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp

index dbfe8ad..6fa2cb7 100644 (file)
@@ -21,12 +21,17 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe
     std::shared_ptr<GNAPlugin> plg;
 
  public:
-    GNAExecutableNetwork(const std::string &aotFileName, std::shared_ptr<GNAPlugin> plg)
-        : plg(plg) {
-        plg->ImportNetwork(aotFileName);
-        _networkInputs  = plg->GetInputs();
-        _networkOutputs = plg->GetOutputs();
-    }
+     GNAExecutableNetwork(const std::string& aotFileName, std::shared_ptr<GNAPlugin> plg)
+         : plg(plg) {
+         std::fstream inputStream(aotFileName, std::ios_base::in | std::ios_base::binary);
+         if (inputStream.fail()) {
+             THROW_GNA_EXCEPTION << "Cannot open file to import model: " << aotFileName;
+         }
+
+         plg->ImportNetwork(inputStream);
+         _networkInputs = plg->GetInputs();
+         _networkOutputs = plg->GetOutputs();
+     }
 
     GNAExecutableNetwork(std::istream& networkModel, std::shared_ptr<GNAPlugin> plg)
         : plg(plg) {
@@ -40,7 +45,7 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe
         plg->LoadNetwork(network);
     }
 
-    GNAExecutableNetwork(const std::string &aotFileName, const std::map<std::string, std::string> &config)
+    GNAExecutableNetwork(const std::string& aotFileName, const std::map<std::string, std::string>& config)
         : GNAExecutableNetwork(aotFileName, std::make_shared<GNAPlugin>(config)) {
     }
 
index 5120dbd..23179d7 100644 (file)
@@ -1134,16 +1134,6 @@ void GNAPlugin::SetName(const std::string & pluginName) noexcept {
     _pluginName = pluginName;
 }
 
-InferenceEngine::IExecutableNetwork::Ptr GNAPlugin::ImportNetwork(const std::string &modelFileName) {
-    // no need to return anything dueto weird design of internal base classes
-    std::fstream inputStream(modelFileName, ios_base::in | ios_base::binary);
-    if (inputStream.fail()) {
-        THROW_GNA_EXCEPTION << "Cannot open file to import model: " << modelFileName;
-    }
-
-    return ImportNetwork(inputStream);
-}
-
 InferenceEngine::IExecutableNetwork::Ptr GNAPlugin::ImportNetwork(std::istream& networkModel) {
     auto header = GNAModelSerial::ReadHeader(networkModel);
 
index 89e54c6..839d9ab 100644 (file)
@@ -146,7 +146,6 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
         THROW_GNA_EXCEPTION << "Not implemented";
     }
 
-    InferenceEngine::IExecutableNetwork::Ptr ImportNetwork(const std::string &modelFileName);
     InferenceEngine::IExecutableNetwork::Ptr ImportNetwork(std::istream& networkModel);
 
     /**
index a43f087..df42d03 100644 (file)
@@ -43,13 +43,14 @@ public:
         defaultConfig.UpdateFromMap(config);
     }
 
-    InferenceEngine::IExecutableNetwork::Ptr  ImportNetwork(
+    InferenceEngine::IExecutableNetwork::Ptr ImportNetwork(
                                                 const std::string &modelFileName,
                                                 const std::map<std::string, std::string> &config) override {
         Config updated_config(defaultConfig);
         updated_config.UpdateFromMap(config);
         auto plg = std::make_shared<GNAPlugin>(updated_config.key_config_map);
         plgPtr = plg;
+
         return make_executable_network(std::make_shared<GNAExecutableNetwork>(modelFileName, plg));
     }
 
index 573d178..6a71cb4 100644 (file)
@@ -201,7 +201,12 @@ void GNAPropagateMatcher :: match() {
         };
 
         auto loadNetworkFromAOT = [&] () {
-            auto sp = plugin.ImportNetwork(_env.importedModelFileName);
+            std::fstream inputStream(_env.importedModelFileName, std::ios_base::in | std::ios_base::binary);
+            if (inputStream.fail()) {
+                THROW_GNA_EXCEPTION << "Cannot open file to import model: " << _env.importedModelFileName;
+            }
+
+            auto sp = plugin.ImportNetwork(inputStream);
             inputsInfo = plugin.GetInputs();
             outputsInfo = plugin.GetOutputs();
         };
@@ -604,7 +609,12 @@ void GNADumpXNNMatcher::load(std::shared_ptr<GNAPlugin> & plugin) {
     };
 
     auto loadNetworkFromAOT = [&]() {
-        plugin->ImportNetwork(_env.importedModelFileName);
+        std::fstream inputStream(_env.importedModelFileName, std::ios_base::in | std::ios_base::binary);
+        if (inputStream.fail()) {
+            THROW_GNA_EXCEPTION << "Cannot open file to import model: " << _env.importedModelFileName;
+        }
+
+        plugin->ImportNetwork(inputStream);
     };
 
     auto loadNetwork = [&]() {