Propogate DEVICE_ID for functions working with RemoteContext (#3109)
authorIlya Lavrenov <ilya.lavrenov@intel.com>
Fri, 13 Nov 2020 16:44:40 +0000 (19:44 +0300)
committerGitHub <noreply@github.com>
Fri, 13 Nov 2020 16:44:40 +0000 (19:44 +0300)
* Propogate DEVICE_ID for functions working with RemoteContext

* More fixes for RemoteContext

* Fixed tests compilation with VariableState

14 files changed:
inference-engine/include/gpu/gpu_ocl_wrapper.hpp
inference-engine/src/cldnn_engine/cldnn_engine.cpp
inference-engine/src/cldnn_engine/cldnn_engine.h
inference-engine/src/gna_plugin/gna_plugin.hpp
inference-engine/src/inference_engine/ie_core.cpp
inference-engine/src/inference_engine/ie_plugin_cpp.hpp
inference-engine/src/plugin_api/cpp_interfaces/impl/ie_plugin_internal.hpp
inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iplugin_internal.hpp
inference-engine/tests/functional/plugin/shared/src/subgraph_tests/basic_lstm.cpp
inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp
inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_eltwise_reshape_concat.cpp
inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp
inference-engine/tests/functional/plugin/shared/src/subgraph_tests/negative_memory_layer_offset.cpp
inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iinference_plugin.hpp

index 282dcd8..f51076e 100644 (file)
 /**
 * @brief Definitions required by Khronos headers
 */
 /**
 * @brief Definitions required by Khronos headers
 */
-#define CL_HPP_ENABLE_EXCEPTIONS
-#define CL_HPP_MINIMUM_OPENCL_VERSION 120
-#define CL_HPP_TARGET_OPENCL_VERSION 120
 
 
-#if defined __GNUC__
+#ifndef CL_HPP_ENABLE_EXCEPTIONS
+# define CL_HPP_ENABLE_EXCEPTIONS
+#endif
+
+#ifdef CL_HPP_MINIMUM_OPENCL_VERSION
+# if CL_HPP_MINIMUM_OPENCL_VERSION <= 120
+#  error "CL_HPP_MINIMUM_OPENCL_VERSION must be >= 120"
+# endif
+#else
+# define CL_HPP_MINIMUM_OPENCL_VERSION 120
+#endif
+
+#ifdef CL_HPP_TARGET_OPENCL_VERSION
+# if CL_HPP_TARGET_OPENCL_VERSION <= 120
+#  error "CL_HPP_TARGET_OPENCL_VERSION must be >= 120"
+# endif
+#else
+# define CL_HPP_TARGET_OPENCL_VERSION 120
+#endif
+
+#ifdef __GNUC__
 # pragma GCC diagnostic push
 # pragma GCC system_header
 #endif
 
 #include <CL/cl2.hpp>
 
 # pragma GCC diagnostic push
 # pragma GCC system_header
 #endif
 
 #include <CL/cl2.hpp>
 
-#if defined __GNUC__
+#ifdef __GNUC__
 # pragma GCC diagnostic pop
 #endif
 # pragma GCC diagnostic pop
 #endif
index 743ff3c..b53c39d 100644 (file)
@@ -363,9 +363,9 @@ RemoteContext::Ptr clDNNEngine::CreateContext(const ParamMap& params) {
     }
 }
 
     }
 }
 
-RemoteContext::Ptr clDNNEngine::GetDefaultContext() {
+RemoteContext::Ptr clDNNEngine::GetDefaultContext(const ParamMap& params) {
     if (nullptr == m_defaultContext) {
     if (nullptr == m_defaultContext) {
-        m_defaultContext.reset(new CLDNNRemoteCLContext(shared_from_this(), ParamMap(), _impl->m_config));
+        m_defaultContext.reset(new CLDNNRemoteCLContext(shared_from_this(), params, _impl->m_config));
     }
     return std::dynamic_pointer_cast<RemoteContext>(m_defaultContext);
 }
     }
     return std::dynamic_pointer_cast<RemoteContext>(m_defaultContext);
 }
index 84b37d5..b546982 100644 (file)
@@ -46,7 +46,7 @@ public:
                                                      const std::map<std::string, std::string>& config) const override;
 
     InferenceEngine::RemoteContext::Ptr CreateContext(const InferenceEngine::ParamMap& params) override;
                                                      const std::map<std::string, std::string>& config) const override;
 
     InferenceEngine::RemoteContext::Ptr CreateContext(const InferenceEngine::ParamMap& params) override;
-    InferenceEngine::RemoteContext::Ptr GetDefaultContext() override;
+    InferenceEngine::RemoteContext::Ptr GetDefaultContext(const ParamMap& params) override;
 };
 
 };  // namespace CLDNNPlugin
 };
 
 };  // namespace CLDNNPlugin
index dbe98fd..838e904 100644 (file)
@@ -123,7 +123,7 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
     InferenceEngine::Parameter GetMetric(const std::string& name,
                                          const std::map<std::string, InferenceEngine::Parameter> & options) const override;
     InferenceEngine::RemoteContext::Ptr CreateContext(const InferenceEngine::ParamMap& params) override { THROW_GNA_EXCEPTION << "Not implemented"; }
     InferenceEngine::Parameter GetMetric(const std::string& name,
                                          const std::map<std::string, InferenceEngine::Parameter> & options) const override;
     InferenceEngine::RemoteContext::Ptr CreateContext(const InferenceEngine::ParamMap& params) override { THROW_GNA_EXCEPTION << "Not implemented"; }
-    InferenceEngine::RemoteContext::Ptr GetDefaultContext() override { THROW_GNA_EXCEPTION << "Not implemented"; }
+    InferenceEngine::RemoteContext::Ptr GetDefaultContext(const InferenceEngine::ParamMap&) override { THROW_GNA_EXCEPTION << "Not implemented"; }
 
     void Wait(uint32_t sync, InferenceEngine::Blob &result) { THROW_GNA_EXCEPTION << "Not implemented"; }
 
 
     void Wait(uint32_t sync, InferenceEngine::Blob &result) { THROW_GNA_EXCEPTION << "Not implemented"; }
 
index c22eeb5..ddce658 100644 (file)
@@ -598,45 +598,37 @@ void Core::AddExtension(const IExtensionPtr& extension) {
 ExecutableNetwork Core::LoadNetwork(const CNNNetwork& network, RemoteContext::Ptr context,
                                     const std::map<std::string, std::string>& config) {
     OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::LoadNetwork");
 ExecutableNetwork Core::LoadNetwork(const CNNNetwork& network, RemoteContext::Ptr context,
                                     const std::map<std::string, std::string>& config) {
     OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::LoadNetwork");
-    std::map<std::string, std::string> config_ = config;
 
     if (context == nullptr) {
         THROW_IE_EXCEPTION << "Remote context is null";
     }
 
 
     if (context == nullptr) {
         THROW_IE_EXCEPTION << "Remote context is null";
     }
 
-    std::string deviceName_ = context->getDeviceName();
-    DeviceIDParser device(deviceName_);
-    std::string deviceName = device.getDeviceName();
-
-    return _impl->GetCPPPluginByName(deviceName).LoadNetwork(network, config_, context);
+    auto parsed = parseDeviceNameIntoConfig(context->getDeviceName(), config);
+    return _impl->GetCPPPluginByName(parsed._deviceName).LoadNetwork(network, parsed._config, context);
 }
 
 }
 
-RemoteContext::Ptr Core::CreateContext(const std::string& deviceName_, const ParamMap& params) {
-    if (deviceName_.find("HETERO") == 0) {
-        THROW_IE_EXCEPTION << "HETERO device does not support remote contexts";
+RemoteContext::Ptr Core::CreateContext(const std::string& deviceName, const ParamMap& params) {
+    if (deviceName.find("HETERO") == 0) {
+        THROW_IE_EXCEPTION << "HETERO device does not support remote context";
     }
     }
-    if (deviceName_.find("MULTI") == 0) {
-        THROW_IE_EXCEPTION << "MULTI device does not support remote contexts";
+    if (deviceName.find("MULTI") == 0) {
+        THROW_IE_EXCEPTION << "MULTI device does not support remote context";
     }
 
     }
 
-    DeviceIDParser device(deviceName_);
-    std::string deviceName = device.getDeviceName();
-
-    return _impl->GetCPPPluginByName(deviceName).CreateContext(params);
+    auto parsed = parseDeviceNameIntoConfig(deviceName, params);
+    return _impl->GetCPPPluginByName(parsed._deviceName).CreateContext(parsed._config);
 }
 
 }
 
-RemoteContext::Ptr Core::GetDefaultContext(const std::string& deviceName_) {
-    if (deviceName_.find("HETERO") == 0) {
-        THROW_IE_EXCEPTION << "HETERO device does not support remote contexts";
+RemoteContext::Ptr Core::GetDefaultContext(const std::string& deviceName) {
+    if (deviceName.find("HETERO") == 0) {
+        THROW_IE_EXCEPTION << "HETERO device does not support remote context";
     }
     }
-    if (deviceName_.find("MULTI") == 0) {
-        THROW_IE_EXCEPTION << "MULTI device does not support remote contexts";
+    if (deviceName.find("MULTI") == 0) {
+        THROW_IE_EXCEPTION << "MULTI device does not support remote context";
     }
 
     }
 
-    DeviceIDParser device(deviceName_);
-    std::string deviceName = device.getDeviceName();
-
-    return _impl->GetCPPPluginByName(deviceName).GetDefaultContext();
+    auto parsed = parseDeviceNameIntoConfig(deviceName, ParamMap());
+    return _impl->GetCPPPluginByName(parsed._deviceName).GetDefaultContext(parsed._config);
 }
 
 void Core::AddExtension(IExtensionPtr extension, const std::string& deviceName_) {
 }
 
 void Core::AddExtension(IExtensionPtr extension, const std::string& deviceName_) {
index 9b3be1f..ec8f893 100644 (file)
@@ -118,8 +118,8 @@ public:
         CALL_STATEMENT(return actual->CreateContext(params));
     }
 
         CALL_STATEMENT(return actual->CreateContext(params));
     }
 
-    RemoteContext::Ptr GetDefaultContext() {
-        CALL_STATEMENT(return actual->GetDefaultContext());
+    RemoteContext::Ptr GetDefaultContext(const ParamMap& params) {
+        CALL_STATEMENT(return actual->GetDefaultContext(params));
     }
 
     ExecutableNetwork ImportNetwork(std::istream& networkModel,
     }
 
     ExecutableNetwork ImportNetwork(std::istream& networkModel,
index 9069132..dec31fa 100644 (file)
@@ -150,7 +150,7 @@ public:
         THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
     }
 
         THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
     }
 
-    RemoteContext::Ptr GetDefaultContext() override {
+    RemoteContext::Ptr GetDefaultContext(const ParamMap& /*params*/) override {
         THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
     }
 
         THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
     }
 
index 67fac3d..d949e2d 100644 (file)
@@ -214,9 +214,10 @@ public:
 
     /**
      * @brief      Provides a default remote context instance if supported by a plugin
 
     /**
      * @brief      Provides a default remote context instance if supported by a plugin
+     * @param[in]  params  The map of parameters
      * @return     The default context.
      */
      * @return     The default context.
      */
-    virtual RemoteContext::Ptr GetDefaultContext() = 0;
+    virtual RemoteContext::Ptr GetDefaultContext(const ParamMap& params) = 0;
 
     /**
      * @deprecated Use ImportNetwork(std::istream& networkModel, const std::map<std::string, std::string>& config)
 
     /**
      * @deprecated Use ImportNetwork(std::istream& networkModel, const std::map<std::string, std::string>& config)
index b06cee1..ebb66d7 100644 (file)
@@ -200,6 +200,7 @@ TEST_P(Basic_LSTM_S, CompareWithRefImpl_LowLatencyTransformation) {
     manager.register_pass<ngraph::pass::LowLatency>(); // LowLatency enables UnrollTI
     manager.run_passes(function);
     LoadNetwork();
     manager.register_pass<ngraph::pass::LowLatency>(); // LowLatency enables UnrollTI
     manager.run_passes(function);
     LoadNetwork();
+    IE_SUPPRESS_DEPRECATED_START
     auto states = executableNetwork.QueryState();
     for (auto& state : states) {
         auto name = state.GetName();
     auto states = executableNetwork.QueryState();
     for (auto& state : states) {
         auto name = state.GetName();
@@ -215,6 +216,7 @@ TEST_P(Basic_LSTM_S, CompareWithRefImpl_LowLatencyTransformation) {
             GTEST_FAIL() << "unknown memory state";
         }
     }
             GTEST_FAIL() << "unknown memory state";
         }
     }
+    IE_SUPPRESS_DEPRECATED_END
     // Run and compare
     Infer();
     const auto& actualOutputs = GetOutputs();
     // Run and compare
     Infer();
     const auto& actualOutputs = GetOutputs();
index 93a8837..4542c8f 100644 (file)
@@ -260,6 +260,7 @@ namespace SubgraphTestsDefinitions {
     void MemoryLSTMCellTest::Run() {
         SKIP_IF_CURRENT_TEST_IS_DISABLED()
 
     void MemoryLSTMCellTest::Run() {
         SKIP_IF_CURRENT_TEST_IS_DISABLED()
 
+        IE_SUPPRESS_DEPRECATED_START
         LoadNetwork();
         auto states = executableNetwork.QueryState();
         for (auto& state : states) {
         LoadNetwork();
         auto states = executableNetwork.QueryState();
         for (auto& state : states) {
@@ -276,6 +277,7 @@ namespace SubgraphTestsDefinitions {
                 GTEST_FAIL() << "unknown memory state";
             }
         }
                 GTEST_FAIL() << "unknown memory state";
             }
         }
+        IE_SUPPRESS_DEPRECATED_END
         Infer();
         switchToNgraphFriendlyModel();
         Validate();
         Infer();
         switchToNgraphFriendlyModel();
         Validate();
@@ -297,6 +299,7 @@ namespace SubgraphTestsDefinitions {
             manager.run_passes(function);
             LoadNetwork();
         }
             manager.run_passes(function);
             LoadNetwork();
         }
+        IE_SUPPRESS_DEPRECATED_START
         auto states = executableNetwork.QueryState();
         for (auto& state : states) {
             auto name = state.GetName();
         auto states = executableNetwork.QueryState();
         for (auto& state : states) {
             auto name = state.GetName();
@@ -312,6 +315,7 @@ namespace SubgraphTestsDefinitions {
                 GTEST_FAIL() << "unknown memory state";
             }
         }
                 GTEST_FAIL() << "unknown memory state";
             }
         }
+        IE_SUPPRESS_DEPRECATED_END
         Infer();
 
         CreatePureTensorIteratorModel();
         Infer();
 
         CreatePureTensorIteratorModel();
index 7b9a13f..a1754b0 100644 (file)
@@ -135,10 +135,12 @@ void MemoryEltwiseReshapeConcatTest::Run() {
                                                   InferenceEngine::SizeVector({1, inputSize * concatSize}),
                                                   InferenceEngine::Layout::NC);
 
                                                   InferenceEngine::SizeVector({1, inputSize * concatSize}),
                                                   InferenceEngine::Layout::NC);
 
+    IE_SUPPRESS_DEPRECATED_START
     auto states = executableNetwork.QueryState();
     auto state_values_blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description,
                                                                             memory_init.data(), memory_init.size());
     states[0].SetState(state_values_blob);
     auto states = executableNetwork.QueryState();
     auto state_values_blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description,
                                                                             memory_init.data(), memory_init.size());
     states[0].SetState(state_values_blob);
+    IE_SUPPRESS_DEPRECATED_END
     Infer();
     initNgraphFriendlyModel();
     Validate();
     Infer();
     initNgraphFriendlyModel();
     Validate();
index 9463031..0197077 100644 (file)
@@ -402,6 +402,7 @@ void MultipleLSTMCellTest::Run() {
                                                   InferenceEngine::SizeVector({1, hiddenSize}),
                                                   InferenceEngine::Layout::NC);
     LoadNetwork();
                                                   InferenceEngine::SizeVector({1, hiddenSize}),
                                                   InferenceEngine::Layout::NC);
     LoadNetwork();
+    IE_SUPPRESS_DEPRECATED_START
     auto states = executableNetwork.QueryState();
     for (auto& state : states) {
         auto name = state.GetName();
     auto states = executableNetwork.QueryState();
     for (auto& state : states) {
         auto name = state.GetName();
@@ -425,6 +426,7 @@ void MultipleLSTMCellTest::Run() {
             GTEST_FAIL() << "unknown memory state";
         }
     }
             GTEST_FAIL() << "unknown memory state";
         }
     }
+    IE_SUPPRESS_DEPRECATED_END
     Infer();
     switchToNgraphFriendlyModel();
     Validate();
     Infer();
     switchToNgraphFriendlyModel();
     Validate();
@@ -450,6 +452,7 @@ void MultipleLSTMCellTest::RunLowLatency(bool regular_api) {
         manager.run_passes(function);
         LoadNetwork();
     }
         manager.run_passes(function);
         LoadNetwork();
     }
+    IE_SUPPRESS_DEPRECATED_START
     auto states = executableNetwork.QueryState();
     for (auto& state : states) {
         auto name = state.GetName();
     auto states = executableNetwork.QueryState();
     for (auto& state : states) {
         auto name = state.GetName();
@@ -473,6 +476,7 @@ void MultipleLSTMCellTest::RunLowLatency(bool regular_api) {
             GTEST_FAIL() << "unknown memory state";
         }
     }
             GTEST_FAIL() << "unknown memory state";
         }
     }
+    IE_SUPPRESS_DEPRECATED_END
     Infer();
 
     // Calculate ref values for Unrolled TI
     Infer();
 
     // Calculate ref values for Unrolled TI
index 781e92f..d1efd0e 100644 (file)
@@ -79,6 +79,7 @@ namespace LayerTestsDefinitions {
         SKIP_IF_CURRENT_TEST_IS_DISABLED()
 
         LoadNetwork();
         SKIP_IF_CURRENT_TEST_IS_DISABLED()
 
         LoadNetwork();
+        IE_SUPPRESS_DEPRECATED_START
         auto states = executableNetwork.QueryState();
         for (auto& state : states) {
             auto name = state.GetName();
         auto states = executableNetwork.QueryState();
         for (auto& state : states) {
             auto name = state.GetName();
@@ -90,6 +91,7 @@ namespace LayerTestsDefinitions {
                 GTEST_FAIL() << "unknown memory state";
             }
         }
                 GTEST_FAIL() << "unknown memory state";
             }
         }
+        IE_SUPPRESS_DEPRECATED_END
         Infer();
         switchToNgraphFriendlyModel();
         Validate();
         Infer();
         switchToNgraphFriendlyModel();
         Validate();
index 8e31e8e..a0cfd45 100644 (file)
@@ -29,7 +29,7 @@ public:
                 const std::string&, const std::map<std::string, InferenceEngine::Parameter>&));
     MOCK_METHOD1(CreateContext,
                 InferenceEngine::RemoteContext::Ptr(const InferenceEngine::ParamMap&));
                 const std::string&, const std::map<std::string, InferenceEngine::Parameter>&));
     MOCK_METHOD1(CreateContext,
                 InferenceEngine::RemoteContext::Ptr(const InferenceEngine::ParamMap&));
-    MOCK_METHOD0(GetDefaultContext, InferenceEngine::RemoteContext::Ptr(void));
+    MOCK_METHOD1(GetDefaultContext, InferenceEngine::RemoteContext::Ptr(const InferenceEngine::ParamMap&));
     MOCK_METHOD3(LoadNetwork, InferenceEngine::ExecutableNetwork(
                 const InferenceEngine::ICNNNetwork&, const std::map<std::string, std::string>&,
                 InferenceEngine::RemoteContext::Ptr));
     MOCK_METHOD3(LoadNetwork, InferenceEngine::ExecutableNetwork(
                 const InferenceEngine::ICNNNetwork&, const std::map<std::string, std::string>&,
                 InferenceEngine::RemoteContext::Ptr));