Move QueryState from ExecutableNetwork to InferRequest (#2818)
authorSvetlana Dolinina <svetlana.a.dolinina@intel.com>
Thu, 12 Nov 2020 09:40:43 +0000 (12:40 +0300)
committerGitHub <noreply@github.com>
Thu, 12 Nov 2020 09:40:43 +0000 (12:40 +0300)
* QueryState moved to InferRequest

* deprecate ExecutableNetwork::QueryState,chaged tests (without any check yet)

* fix build

* review fixes + build fix

* build fix + review changes

* remove blank line

* style fixes

* test build fixes

* style fix

* style fix

* fixed build of tests

* fix

* mac build fix

* hddl plugin build fix

* clean up unneeded implementation for method

* fixed tests build

* add implementation for getstate, correct getName for MklDNN

* fixed description of state API in comments

* lint fixes

* Rename MemoryState to VariableState

* added tests for cpu for VariableStates, several small fixes in tests and code

* merge fix

* lint fix

* remove whitespaces

* spaces fix

* fix in test to make it workable for all plugins

* fix typo

* fix test for gna

* remove extra comment

* fix test for gna

51 files changed:
inference-engine/include/cpp/ie_executable_network.hpp
inference-engine/include/cpp/ie_infer_request.hpp
inference-engine/include/cpp/ie_memory_state.hpp
inference-engine/include/ie_iexecutable_network.hpp
inference-engine/include/ie_iinfer_request.hpp
inference-engine/include/ie_imemory_state.hpp
inference-engine/samples/speech_sample/main.cpp
inference-engine/src/gna_plugin/gna_executable_network.hpp
inference-engine/src/gna_plugin/gna_infer_request.hpp
inference-engine/src/gna_plugin/gna_plugin.cpp
inference-engine/src/gna_plugin/gna_plugin.hpp
inference-engine/src/gna_plugin/memory/gna_memory_state.cpp
inference-engine/src/gna_plugin/memory/gna_memory_state.hpp
inference-engine/src/mkldnn_plugin/mkldnn_exec_network.cpp
inference-engine/src/mkldnn_plugin/mkldnn_exec_network.h
inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp
inference-engine/src/mkldnn_plugin/mkldnn_infer_request.h
inference-engine/src/mkldnn_plugin/mkldnn_memory_state.cpp
inference-engine/src/mkldnn_plugin/mkldnn_memory_state.h
inference-engine/src/plugin_api/cpp_interfaces/base/ie_executable_network_base.hpp
inference-engine/src/plugin_api/cpp_interfaces/base/ie_infer_async_request_base.hpp
inference-engine/src/plugin_api/cpp_interfaces/base/ie_memory_state_base.hpp
inference-engine/src/plugin_api/cpp_interfaces/impl/ie_executable_network_internal.hpp
inference-engine/src/plugin_api/cpp_interfaces/impl/ie_infer_async_request_thread_safe_default.hpp
inference-engine/src/plugin_api/cpp_interfaces/impl/ie_infer_request_internal.hpp
inference-engine/src/plugin_api/cpp_interfaces/impl/ie_memory_state_internal.hpp
inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iexecutable_network_internal.hpp
inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iinfer_request_internal.hpp
inference-engine/src/plugin_api/cpp_interfaces/interface/ie_imemory_state_internal.hpp
inference-engine/tests/functional/inference_engine/async_infer_request_test.cpp
inference-engine/tests/functional/inference_engine/executable_network.cpp
inference-engine/tests/functional/plugin/cpu/shared_tests_instances/behavior/cpp_holders.cpp
inference-engine/tests/functional/plugin/cpu/shared_tests_instances/behavior/memory_states.cpp [new file with mode: 0644]
inference-engine/tests/functional/plugin/gna/shared_tests_instances/behavior/cpp_holders.cpp
inference-engine/tests/functional/plugin/gna/shared_tests_instances/behavior/memory_states.cpp
inference-engine/tests/functional/plugin/shared/include/behavior/memory_states.hpp
inference-engine/tests/functional/plugin/shared/src/behavior/cpp_holders.cpp
inference-engine/tests/functional/plugin/shared/src/behavior/memory_states.cpp
inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/impl/mock_async_infer_request_thread_safe_internal.hpp
inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp
inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp
inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_iinfer_request_internal.hpp
inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_imemory_state_internal.hpp
inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_ie_imemory_state.hpp
inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_iexecutable_network.hpp
inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_iinfer_request.hpp
inference-engine/tests/ngraph_functions/include/ngraph_functions/subgraph_builders.hpp
inference-engine/tests/unit/inference_engine/cpp_interfaces/ie_memory_state_internal_test.cpp
inference-engine/tests/unit/inference_engine/ie_executable_network_test.cpp
inference-engine/tests_deprecated/functional/ie_tests/src/custom_matcher.cpp
inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp

index 31d246c..d748820 100644 (file)
@@ -175,11 +175,13 @@ public:
      * Wraps IExecutableNetwork::QueryState
      * @return A vector of Memory State objects
      */
-    std::vector<MemoryState> QueryState() {
+    INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
+    std::vector<VariableState> QueryState() {
+        IE_SUPPRESS_DEPRECATED_START
         if (actual == nullptr) THROW_IE_EXCEPTION << "ExecutableNetwork was not initialized.";
-        IMemoryState::Ptr pState = nullptr;
+        IVariableState::Ptr pState = nullptr;
         auto res = OK;
-        std::vector<MemoryState> controller;
+        std::vector<VariableState> controller;
         for (size_t idx = 0; res == OK; ++idx) {
             ResponseDesc resp;
             res = actual->QueryState(pState, idx, &resp);
@@ -187,10 +189,11 @@ public:
                 THROW_IE_EXCEPTION << resp.msg;
             }
             if (res != OUT_OF_BOUNDS) {
-                controller.push_back(MemoryState(pState));
+                controller.push_back(VariableState(pState, plg));
             }
         }
 
+        IE_SUPPRESS_DEPRECATED_END
         return controller;
     }
 
index 18daaf7..c750a5d 100644 (file)
@@ -13,6 +13,7 @@
 #include <memory>
 #include <string>
 
+#include "cpp/ie_memory_state.hpp"
 #include "ie_iinfer_request.hpp"
 #include "details/ie_exception_conversion.hpp"
 #include "details/ie_so_loader.h"
@@ -251,6 +252,31 @@ public:
     }
 
     /**
+     * @copybrief IExecutableNetwork::QueryState
+     *
+     * Wraps IExecutableNetwork::QueryState
+     * @return A vector of Memory State objects
+     */
+    std::vector<VariableState> QueryState() {
+        if (actual == nullptr) THROW_IE_EXCEPTION << "ExecutableNetwork was not initialized.";
+        IVariableState::Ptr pState = nullptr;
+        auto res = OK;
+        std::vector<VariableState> controller;
+        for (size_t idx = 0; res == OK; ++idx) {
+            ResponseDesc resp;
+            res = actual->QueryState(pState, idx, &resp);
+            if (res != OK && res != OUT_OF_BOUNDS) {
+                THROW_IE_EXCEPTION << resp.msg;
+            }
+            if (res != OUT_OF_BOUNDS) {
+                controller.push_back(VariableState(pState, plg));
+            }
+        }
+
+        return controller;
+    }
+
+    /**
      * @brief  IInferRequest pointer to be used directly in CreateInferRequest functions
      * @return A shared pointer to underlying IInferRequest interface
      */
index 3fe79d0..24fd4d7 100644 (file)
 #include <string>
 
 #include "ie_imemory_state.hpp"
+#include "details/ie_exception_conversion.hpp"
+#include "details/ie_so_loader.h"
 
 namespace InferenceEngine {
 
 /**
- * @brief C++ exception based error reporting wrapper of API class IMemoryState
+ * @brief C++ exception based error reporting wrapper of API class IVariableState
  */
-class MemoryState {
-    IMemoryState::Ptr actual = nullptr;
+class VariableState {
+    IVariableState::Ptr actual = nullptr;
+    details::SharedObjectLoader::Ptr plugin = {};
 
 public:
     /**
-     * constructs MemoryState from the initialized shared_pointer
+     * constructs VariableState from the initialized shared_pointer
      * @param pState Initialized shared pointer
      */
-    explicit MemoryState(IMemoryState::Ptr pState): actual(pState) {
+    explicit VariableState(IVariableState::Ptr pState, details::SharedObjectLoader::Ptr plg = {}) : actual(pState), plugin(plg) {
         if (actual == nullptr) {
-            THROW_IE_EXCEPTION << "MemoryState wrapper was not initialized.";
+            THROW_IE_EXCEPTION << "VariableState wrapper was not initialized.";
         }
     }
 
     /**
-     * @copybrief IMemoryState::Reset
+     * @copybrief IVariableState::Reset
      *
-     * Wraps IMemoryState::Reset
+     * Wraps IVariableState::Reset
      */
     void Reset() {
         CALL_STATUS_FNC_NO_ARGS(Reset);
     }
 
     /**
-     * @copybrief IMemoryState::GetName
+     * @copybrief IVariableState::GetName
      *
-     * Wraps IMemoryState::GetName
+     * Wraps IVariableState::GetName
      * @return A string representing a state name
      */
     std::string GetName() const {
@@ -53,21 +56,26 @@ public:
     }
 
     /**
-     * @copybrief IMemoryState::GetLastState
+     * @copybrief IVariableState::GetState
      *
-     * Wraps IMemoryState::GetLastState
+     * Wraps IVariableState::GetState
      * @return A blob representing a last state 
      */
-    Blob::CPtr GetLastState() const {
+    Blob::CPtr GetState() const {
         Blob::CPtr stateBlob;
-        CALL_STATUS_FNC(GetLastState, stateBlob);
+        CALL_STATUS_FNC(GetState, stateBlob);
         return stateBlob;
     }
 
+    INFERENCE_ENGINE_DEPRECATED("Use GetState function instead")
+    Blob::CPtr GetLastState() const {
+        return GetState();
+    }
+
     /**
-     * @copybrief IMemoryState::SetState
+     * @copybrief IVariableState::SetState
      *
-     * Wraps IMemoryState::SetState
+     * Wraps IVariableState::SetState
      * @param state The current state to set
      */
     void SetState(Blob::Ptr state) {
@@ -75,4 +83,8 @@ public:
     }
 };
 
-}  // namespace InferenceEngine
\ No newline at end of file
+/*
+ * @brief For compatibility reasons.
+ */
+using MemoryState = VariableState;
+}  // namespace InferenceEngine
index 491c24e..8e7c5fa 100644 (file)
@@ -118,7 +118,7 @@ public:
      * @return Status code of the operation: InferenceEngine::OK (0) for success, OUT_OF_BOUNDS (-6) no memory state for
      * given index
      */
-    virtual StatusCode QueryState(IMemoryState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0;
+    virtual StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0;
 
     /**
      * @brief Sets configuration for current executable network
index a83b613..e5674fe 100644 (file)
@@ -17,6 +17,7 @@
 #include "ie_blob.h"
 #include "ie_common.h"
 #include "ie_preprocess.hpp"
+#include "ie_imemory_state.hpp"
 #include "details/ie_irelease.hpp"
 
 namespace InferenceEngine {
@@ -177,6 +178,18 @@ public:
      * @return Enumeration of the resulted action: InferenceEngine::OK (0) for success
      */
     virtual InferenceEngine::StatusCode SetBatch(int batch_size, ResponseDesc* resp) noexcept = 0;
-};
 
-}  // namespace InferenceEngine
+    /**
+    * @brief Gets state control interface for given infer request.
+    *
+    * State control essential for recurrent networks
+    *
+    * @param pState reference to a pointer that receives internal states
+    * @param idx requested index for receiving memory state
+    * @param resp Optional: pointer to an already allocated object to contain information in case of failure
+    * @return Status code of the operation: InferenceEngine::OK (0) for success, OUT_OF_BOUNDS (-6) no memory state for
+    * given index
+    */
+    virtual StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept = 0;
+};
+}  // namespace InferenceEngine
\ No newline at end of file
index 98c9f34..2e44350 100644 (file)
@@ -3,7 +3,7 @@
 //
 
 /**
- * @brief a header file for IMemoryState interface
+ * @brief a header file for IVariableState interface
  *
  * @file ie_imemory_state.hpp
  */
 namespace InferenceEngine {
 
 /**
- * @interface IMemoryState
+ * @interface IVariableState
  * @brief manages data for reset operations
  */
-class IMemoryState : public details::no_copy {
+class IVariableState : public details::no_copy {
 public:
     /**
-     * @brief A shared pointer to the IMemoryState interface
+     * @brief A shared pointer to the IVariableState interface
      */
-    using Ptr = std::shared_ptr<IMemoryState>;
+    using Ptr = std::shared_ptr<IVariableState>;
 
     /**
      * @brief Gets name of current memory state, if length of array is not enough name is truncated by len, null
-     * terminator is inserted as well.
+     * terminator is inserted as well. As memory state name variable_id from according ReadValue used. 
      *
      * @param name preallocated buffer for receiving name
      * @param len Length of the buffer
@@ -41,7 +41,7 @@ public:
     virtual StatusCode GetName(char* name, size_t len, ResponseDesc* resp) const noexcept = 0;
 
     /**
-     * @brief reset internal memory state for relevant iexecutable network, to a value specified in SetState
+     * @brief Reset internal memory state for relevant infer request, to a value specified as default for according ReadValue node
      *
      * @param  resp Optional: pointer to an already allocated object to contain information in case of failure
      * @return Status code of the operation: InferenceEngine::OK (0) for success*
@@ -49,25 +49,30 @@ public:
     virtual StatusCode Reset(ResponseDesc* resp) noexcept = 0;
 
     /**
-     * @brief  Sets the new state that is used for all future Reset() operations as a base.
+     * @brief  Sets the new state for the next inference.
      *
      * This method can fail if Blob size does not match the internal state size or precision
      *
-     * @param  newState is the data to use as base state
+     * @param  newState is the data to use as new state
      * @param  resp Optional: pointer to an already allocated object to contain information in case of failure
      * @return Status code of the operation: InferenceEngine::OK (0) for success
      */
     virtual StatusCode SetState(Blob::Ptr newState, ResponseDesc* resp) noexcept = 0;
 
     /**
-     * @brief returns the value of the last memory state.
+     * @brief Returns the value of the memory state.
      *
-     * @details Since we roll memory after each infer, we can query the input state always and still get the last state.
      * @param lastState
      * @param  resp Optional: pointer to an already allocated object to contain information in case of failure
      * @return Status code of the operation: InferenceEngine::OK (0) for success
      * */
-    virtual StatusCode GetLastState(Blob::CPtr& lastState, ResponseDesc* resp) const noexcept = 0;
+    INFERENCE_ENGINE_DEPRECATED("Use GetState function instead")
+    virtual StatusCode GetLastState(Blob::CPtr& state, ResponseDesc* resp) const noexcept {return GetState(state, resp);}
+    virtual StatusCode GetState(Blob::CPtr& state, ResponseDesc* resp) const noexcept = 0;
 };
 
+/*
+ * @brief For compatibility reasons.
+ */
+using IMemoryState = IVariableState;
 }  // namespace InferenceEngine
\ No newline at end of file
index c917473..c7db028 100644 (file)
@@ -845,7 +845,7 @@ int main(int argc, char *argv[]) {
             ptrUtterances.resize(inputArkFiles.size());
 
             // initialize memory state before starting
-            for (auto &&state : executableNet.QueryState()) {
+            for (auto &&state : inferRequests.begin()->inferRequest.QueryState()) {
                 state.Reset();
             }
 
@@ -1080,7 +1080,7 @@ int main(int argc, char *argv[]) {
                 totalTime += d.count();
 
                 // resetting state between utterances
-                for (auto &&state : executableNet.QueryState()) {
+                for (auto &&state : inferRequests.begin()->inferRequest.QueryState()) {
                     state.Reset();
                 }
 
index b7a1088..d240c78 100644 (file)
@@ -59,12 +59,13 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe
         return std::make_shared<GNAInferRequest>(plg, networkInputs, networkOutputs);
     }
 
-
-
-    std::vector<InferenceEngine::IMemoryStateInternal::Ptr>  QueryState() override {
+    INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
+    std::vector<InferenceEngine::IVariableStateInternal::Ptr>  QueryState() override {
+        IE_SUPPRESS_DEPRECATED_START
         auto pluginStates = plg->QueryState();
-        std::vector<InferenceEngine::IMemoryStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
+        std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
         return plg->QueryState();
+        IE_SUPPRESS_DEPRECATED_END
     }
 
     void Export(const std::string &modelFileName) override {
index fd2cc69..fcdc92b 100644 (file)
@@ -111,5 +111,13 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
         }
         return InferenceEngine::OK;
     }
+
+    IE_SUPPRESS_DEPRECATED_START
+    std::vector<InferenceEngine::IVariableStateInternal::Ptr>  QueryState() override {
+        auto pluginStates = plg->QueryState();
+        std::vector<InferenceEngine::IVariableStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
+        return plg->QueryState();
+    }
+    IE_SUPPRESS_DEPRECATED_END
 };
 }  // namespace GNAPluginNS
index 5f0a04f..7d6e676 100644 (file)
@@ -1186,11 +1186,11 @@ Blob::Ptr GNAPlugin::GetInputBlob(const std::string& name, InferenceEngine::Prec
     return inputBlob;
 }
 
-std::vector<InferenceEngine::MemoryStateInternal::Ptr>  GNAPlugin::QueryState() {
+std::vector<InferenceEngine::VariableStateInternal::Ptr>  GNAPlugin::QueryState() {
     if (memoryStates.size() != graphCompiler.memory_connection.size()) {
         memoryStates.clear();
         for (auto& connection : graphCompiler.memory_connection) {
-            auto state = std::make_shared<memory::GNAMemoryState>(connection.first, std::make_shared <GNAMemoryLayer>(connection.second));
+            auto state = std::make_shared<memory::GNAVariableState>(connection.first, std::make_shared <GNAMemoryLayer>(connection.second));
             memoryStates.emplace_back(state);
         }
     }
index 1e4c4fd..dbe98fd 100644 (file)
@@ -84,7 +84,7 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
 
     InferenceEngine::InputsDataMap inputsDataMap;
     InferenceEngine::OutputsDataMap outputsDataMap;
-    std::vector<InferenceEngine::MemoryStateInternal::Ptr> memoryStates;
+    std::vector<InferenceEngine::VariableStateInternal::Ptr> memoryStates;
 
  public:
     explicit GNAPlugin(const std::map<std::string, std::string>& configMap);
@@ -159,7 +159,8 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
      * QueryState API
      * @return
      */
-     std::vector<InferenceEngine::IMemoryStateInternal::Ptr>  QueryState();
+    INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
+    std::vector<InferenceEngine::IVariableStateInternal::Ptr>  QueryState();
 
      /**
       * test-wise API
index bb25cd9..27e9384 100644 (file)
@@ -12,15 +12,15 @@ namespace  GNAPluginNS {
 
 namespace memory {
 
-    std::string GNAMemoryState::GetName() const {
+    std::string GNAVariableState::GetName() const {
         return name;
     }
 
-    void GNAMemoryState::Reset() {
+    void GNAVariableState::Reset() {
         state->Reset();
     }
 
-    InferenceEngine::Precision GNAMemoryState::getPrecision() const {
+    InferenceEngine::Precision GNAVariableState::getPrecision() const {
         InferenceEngine::Precision state_precision;
 
         if (state->getInput()) {
@@ -36,14 +36,14 @@ namespace memory {
                 break;
             default:
                 THROW_GNA_EXCEPTION << "Incorrect state element size " << element_size <<
-                    " to determine precision for MemoryState " << name;
+                    " to determine precision for VariableState " << name;
             }
         }
 
         return state_precision;
     }
 
-    void GNAMemoryState::SetState(InferenceEngine::Blob::Ptr newState) {
+    void GNAVariableState::SetState(InferenceEngine::Blob::Ptr newState) {
         IE_ASSERT(newState != nullptr);
 
         auto data_ptr = newState->cbuffer().as<void*>();
@@ -78,20 +78,20 @@ namespace memory {
                     data_elements,
                     scale_factor);
             } else {
-                THROW_GNA_EXCEPTION << "Failed to SetState for MemoryState " << name
+                THROW_GNA_EXCEPTION << "Failed to SetState for VariableState " << name
                     << ". If old state precision is I16 only I16 and FP32 are allowed as new state precisions."
                     << " Old state: " << state_precision << " New state: " << new_state_precision;
             }
             break;
         }
         default:
-            THROW_GNA_EXCEPTION << "Failed to SetState for MemoryState " << name
+            THROW_GNA_EXCEPTION << "Failed to SetState for VariableState " << name
                 << ". Incorrect new/old precision pair"
                 << " Old state: " << state_precision << " New state: " << new_state_precision;
         }
     }
 
-    InferenceEngine::Blob::CPtr GNAMemoryState::GetLastState() const {
+    InferenceEngine::Blob::CPtr GNAVariableState::GetState() const {
         auto elements = state->reserved_size / state->elementSizeBytes();
         InferenceEngine::Precision state_precision = getPrecision();
 
index 499c4c9..2a7c83d 100644 (file)
 
 namespace  GNAPluginNS {
 namespace memory {
-class GNAMemoryState : public InferenceEngine::IMemoryStateInternal {
+class GNAVariableState : public InferenceEngine::IVariableStateInternal {
  public:
-    GNAMemoryState(std::string name, std::shared_ptr<GNAMemoryLayer> state)
+    GNAVariableState(std::string name, std::shared_ptr<GNAMemoryLayer> state)
         : name(name), state(state) { IE_ASSERT(state != nullptr); }
 
     void Reset() override;
     void SetState(InferenceEngine::Blob::Ptr newState) override;
-    InferenceEngine::Blob::CPtr GetLastState() const override;
+    InferenceEngine::Blob::CPtr GetState() const override;
     std::string GetName() const override;
 
 private:
index e6bd3b2..94919a1 100644 (file)
@@ -183,14 +183,14 @@ MKLDNNExecNetwork::MKLDNNExecNetwork(const InferenceEngine::ICNNNetwork &network
             if (node->getType() == MemoryInput) {
                 auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
                 auto state_store = memoryNode->getStore();
-                auto state_name = node->getName();
+                auto state_name = memoryNode->getId();
 
                 // Remove suffix with pair ID. Internal information.
                 auto suffix_idx = state_name.find("/id=");
                 if (suffix_idx != std::string::npos)
                     state_name = state_name.substr(0, suffix_idx);
 
-                memoryStates.emplace_back(new MKLDNNMemoryState(state_name, state_store));
+                memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store));
             }
         }
     }
@@ -314,6 +314,8 @@ bool MKLDNNExecNetwork::CanProcessDynBatch(const InferenceEngine::ICNNNetwork &n
     return check_result;
 }
 
-std::vector<IMemoryStateInternal::Ptr> MKLDNNExecNetwork::QueryState() {
+IE_SUPPRESS_DEPRECATED_START
+std::vector<IVariableStateInternal::Ptr> MKLDNNExecNetwork::QueryState() {
     return memoryStates;
 }
+IE_SUPPRESS_DEPRECATED_END
index 8ea85bb..4247503 100644 (file)
@@ -42,14 +42,15 @@ public:
 
     InferenceEngine::CNNNetwork GetExecGraphInfo() override;
 
-    std::vector<InferenceEngine::IMemoryStateInternal::Ptr> QueryState() override;
+    INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
+    std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
 
     InferenceEngine::ThreadLocal<MKLDNNGraph::Ptr>  _graphs;
 
 protected:
     friend class MKLDNNInferRequest;
     MKLDNNExtensionManager::Ptr extensionManager;
-    std::vector<InferenceEngine::IMemoryStateInternal::Ptr> memoryStates;
+    std::vector<InferenceEngine::IVariableStateInternal::Ptr> memoryStates;
     InferenceEngine::details::CNNNetworkImplPtr _clonedNetwork;
     std::mutex                                  _cfgMutex;
     Config                                      _cfg;
index e07dfca..ae7db88 100644 (file)
@@ -14,6 +14,8 @@
 #include "mkldnn_exec_network.h"
 #include "mkldnn_itt.h"
 #include "nodes/common/cpu_convert.h"
+#include "mkldnn_memory_state.h"
+#include "nodes/mkldnn_memory_node.hpp"
 
 MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsDataMap     networkInputs,
                                                      InferenceEngine::OutputsDataMap    networkOutputs,
@@ -35,6 +37,30 @@ MKLDNNPlugin::MKLDNNInferRequest::MKLDNNInferRequest(InferenceEngine::InputsData
         InferenceEngine::Blob::Ptr blob;
         MKLDNNInferRequest::GetBlob(it.first.c_str(), blob);
     }
+
+    // Save all MemoryLayer data tensors. Will use insight about mechanics
+    // of MemoryLayer implementation. It uses output edge of MemoryLayer
+    // producer as storage for tensor to keep it between infer calls.
+    IE_SUPPRESS_DEPRECATED_START
+    if (execNetwork->QueryState().size() == 0) {
+        for (auto &node : graph->GetNodes()) {
+            if (node->getType() == MemoryInput) {
+                auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
+                auto state_store = memoryNode->getStore();
+                auto state_name = memoryNode->getId();
+
+                // Remove suffix with pair ID. Internal information.
+                auto suffix_idx = state_name.find("/id=");
+                if (suffix_idx != std::string::npos)
+                    state_name = state_name.substr(0, suffix_idx);
+
+                memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store));
+           }
+        }
+    } else {
+        memoryStates = execNetwork->QueryState();
+    }
+    IE_SUPPRESS_DEPRECATED_END
 }
 
 MKLDNNPlugin::MKLDNNInferRequest::~MKLDNNInferRequest() {
@@ -390,3 +416,7 @@ void MKLDNNPlugin::MKLDNNInferRequest::SetBatch(int new_batch) {
 
     m_curBatch = new_batch;
 }
+
+std::vector<InferenceEngine::IVariableStateInternal::Ptr> MKLDNNPlugin::MKLDNNInferRequest::QueryState() {
+    return memoryStates;
+}
index 4c058a4..e9863be 100644 (file)
@@ -43,6 +43,8 @@ public:
 
     void SetBatch(int batch = -1) override;
 
+    std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
+
 private:
     void PushInputData();
 
@@ -53,5 +55,6 @@ private:
     MKLDNNGraph*                        graph = nullptr;
     std::map<std::string, void*>        externalPtr;
     openvino::itt::handle_t             profilingTask;
+    std::vector<InferenceEngine::IVariableStateInternal::Ptr> memoryStates;
 };
 }  // namespace MKLDNNPlugin
index 56d74d2..4af7550 100644 (file)
@@ -4,20 +4,21 @@
 
 #include "mkldnn_memory_state.h"
 #include "mkldnn_extension_utils.h"
+#include "blob_factory.hpp"
 
 using namespace InferenceEngine;
 
 namespace MKLDNNPlugin {
 
-std::string  MKLDNNMemoryState::GetName() const {
+std::string  MKLDNNVariableState::GetName() const {
     return name;
 }
 
-void  MKLDNNMemoryState::Reset() {
+void  MKLDNNVariableState::Reset() {
     storage->FillZero();
 }
 
-void  MKLDNNMemoryState::SetState(Blob::Ptr newState) {
+void  MKLDNNVariableState::SetState(Blob::Ptr newState) {
     auto prec = newState->getTensorDesc().getPrecision();
     auto data_type = MKLDNNExtensionUtils::IEPrecisionToDataType(prec);
     auto data_layout = MKLDNNMemory::Convert(newState->getTensorDesc().getLayout());
@@ -27,9 +28,11 @@ void  MKLDNNMemoryState::SetState(Blob::Ptr newState) {
     storage->SetData(data_type, data_layout, data_ptr, data_size);
 }
 
-InferenceEngine::Blob::CPtr MKLDNNMemoryState::GetLastState() const {
-    THROW_IE_EXCEPTION << "GetLastState method is not implemented for MemoryState";
-    return nullptr;
+InferenceEngine::Blob::CPtr MKLDNNVariableState::GetState() const {
+    auto result_blob = make_blob_with_precision(MKLDNNMemoryDesc(storage->GetDescriptor()));
+    result_blob->allocate();
+    std::memcpy(result_blob->buffer(), storage->GetData(), storage->GetSize());
+    return result_blob;
 }
 
-}  // namespace MKLDNNPlugin
\ No newline at end of file
+}  // namespace MKLDNNPlugin
index cb024df..751635b 100644 (file)
 
 namespace MKLDNNPlugin {
 
-class MKLDNNMemoryState : public InferenceEngine::IMemoryStateInternal {
+class MKLDNNVariableState : public InferenceEngine::IVariableStateInternal {
 public:
-    MKLDNNMemoryState(std::string name, MKLDNNMemoryPtr storage) :
+    MKLDNNVariableState(std::string name, MKLDNNMemoryPtr storage) :
             name(name), storage(storage) {}
 
     std::string GetName() const override;
     void Reset() override;
     void SetState(InferenceEngine::Blob::Ptr newState) override;
-    InferenceEngine::Blob::CPtr GetLastState() const override;
+    InferenceEngine::Blob::CPtr GetState() const override;
 
 private:
     std::string name;
index fd86780..b9d7833 100644 (file)
@@ -66,19 +66,22 @@ public:
         TO_STATUS(graphPtr = _impl->GetExecGraphInfo());
     }
 
-    StatusCode QueryState(IMemoryState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override {
+    INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
+    StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override {
+        IE_SUPPRESS_DEPRECATED_START
         try {
             auto v = _impl->QueryState();
             if (idx >= v.size()) {
                 return OUT_OF_BOUNDS;
             }
-            pState = std::make_shared<MemoryStateBase<IMemoryStateInternal>>(v[idx]);
+            pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
             return OK;
         } catch (const std::exception& ex) {
             return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
         } catch (...) {
             return InferenceEngine::DescriptionBuffer(UNEXPECTED);
         }
+        IE_SUPPRESS_DEPRECATED_END
     }
 
     void Release() noexcept override {
index e350e6e..5892ef0 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "cpp_interfaces/exception2status.hpp"
 #include "cpp_interfaces/plugin_itt.hpp"
+#include <cpp_interfaces/base/ie_memory_state_base.hpp>
 #include "ie_iinfer_request.hpp"
 #include "ie_preprocess.hpp"
 #include "ie_profiling.hpp"
@@ -88,6 +89,21 @@ public:
         TO_STATUS(_impl->SetBatch(batch_size));
     }
 
+    StatusCode QueryState(IVariableState::Ptr& pState, size_t idx, ResponseDesc* resp) noexcept override {
+        try {
+            auto v = _impl->QueryState();
+            if (idx >= v.size()) {
+                return OUT_OF_BOUNDS;
+            }
+            pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
+            return OK;
+        } catch (const std::exception& ex) {
+            return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
+        } catch (...) {
+            return InferenceEngine::DescriptionBuffer(UNEXPECTED);
+        }
+    }
+
 private:
     ~InferRequestBase() = default;
 };
index 2b88ee5..fe191a8 100644 (file)
@@ -7,23 +7,24 @@
 #include <memory>
 
 #include "cpp_interfaces/exception2status.hpp"
+#include "cpp_interfaces/impl/ie_memory_state_internal.hpp"
 #include "ie_imemory_state.hpp"
 
 namespace InferenceEngine {
 
 /**
- * @brief default implementation for IMemoryState
+ * @brief default implementation for IVariableState
  * @ingroup ie_dev_api_mem_state_api
  */
 template <class T>
-class MemoryStateBase : public IMemoryState {
+class VariableStateBase : public IVariableState {
 protected:
     std::shared_ptr<T> impl;
 
 public:
-    explicit MemoryStateBase(std::shared_ptr<T> impl): impl(impl) {
+    explicit VariableStateBase(std::shared_ptr<T> impl): impl(impl) {
         if (impl == nullptr) {
-            THROW_IE_EXCEPTION << "MemoryStateBase implementation not defined";
+            THROW_IE_EXCEPTION << "VariableStateBase implementation not defined";
         }
     }
 
@@ -44,9 +45,9 @@ public:
         TO_STATUS(impl->SetState(newState));
     }
 
-    StatusCode GetLastState(Blob::CPtr& lastState, ResponseDesc* resp) const noexcept override {
-        TO_STATUS(lastState = impl->GetLastState());
+    StatusCode GetState(Blob::CPtr& state, ResponseDesc* resp) const noexcept override {
+        TO_STATUS(state = impl->GetState());
     }
 };
 
-}  // namespace InferenceEngine
\ No newline at end of file
+}  // namespace InferenceEngine
index 41f5d16..c2e70b5 100644 (file)
@@ -88,7 +88,7 @@ public:
         _plugin = plugin;
     }
 
-    std::vector<IMemoryStateInternal::Ptr> QueryState() override {
+    std::vector<IVariableStateInternal::Ptr> QueryState() override {
         THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
     }
 
index d7b2da1..71d2f5a 100644 (file)
@@ -152,6 +152,10 @@ public:
         _publicInterface = std::shared_ptr<IInferRequest>(ptr.get(), [](IInferRequest*) {});
     }
 
+    std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
+        return _syncRequest->QueryState();
+    }
+
 protected:
     /**
      * @brief Each pipeline stage is a @ref Task that is executed by specified ITaskExecutor implementation
index 7fe1c30..50671b4 100644 (file)
@@ -223,6 +223,12 @@ public:
         }
     }
 
+    std::vector<IVariableStateInternal::Ptr> QueryState() override {
+        // meaning base plugin reports as no state available - plugin owners need to create proper override of this
+        THROW_IE_EXCEPTION << "Plugin doesn't override QueryState";
+        return {};
+    }
+
 protected:
     InferenceEngine::InputsDataMap _networkInputs;  //!< Holds information about network inputs info
     InferenceEngine::OutputsDataMap _networkOutputs;  //!< Holds information about network outputs data
index 5da62e3..05f96d5 100644 (file)
@@ -13,21 +13,25 @@ namespace InferenceEngine {
  * @brief minimal interface for memory state implementation
  * @ingroup ie_dev_api_mem_state_api
  */
-class MemoryStateInternal : public IMemoryStateInternal {
+class VariableStateInternal : public IVariableStateInternal {
     std::string name;
     Blob::Ptr state;
 
 public:
-    explicit MemoryStateInternal(std::string name): name(name) {}
+    explicit VariableStateInternal(std::string name): name(name) {}
     std::string GetName() const override {
         return name;
     }
     void SetState(Blob::Ptr newState) override {
         state = newState;
     }
-    Blob::CPtr GetLastState() const override {
+    Blob::CPtr GetState() const override {
         return state;
     }
 };
 
-}  // namespace InferenceEngine
\ No newline at end of file
+/*
+ * @brief For compatibility reasons.
+ */
+using MemoryStateInternal = VariableStateInternal;
+}  // namespace InferenceEngine
index 9efdb66..17cc927 100644 (file)
@@ -79,7 +79,7 @@ public:
      * @brief Queries memory states.
      * @return Returns memory states
      */
-    virtual std::vector<IMemoryStateInternal::Ptr> QueryState() = 0;
+    virtual std::vector<IVariableStateInternal::Ptr> QueryState() = 0;
 
     /**
      * @brief Sets configuration for current executable network
index d62cd42..c09a15a 100644 (file)
@@ -4,6 +4,7 @@
 
 #pragma once
 
+#include <cpp_interfaces/interface/ie_imemory_state_internal.hpp>
 #include <ie_blob.h>
 #include <ie_common.h>
 #include <ie_preprocess.hpp>
@@ -83,6 +84,12 @@ public:
      * @param batch - new batch size to be used by all the following inference calls for this request.
      */
     virtual void SetBatch(int batch) = 0;
+
+    /**
+     * @brief Queries memory states.
+     * @return Returns memory states
+     */
+    virtual std::vector<IVariableStateInternal::Ptr> QueryState() = 0;
 };
 
 }  // namespace InferenceEngine
index aa81e69..ef37d8b 100644 (file)
 
 namespace InferenceEngine {
 /**
- * @interface IMemoryStateInternal
+ * @interface IVariableStateInternal
  * @brief minimal interface for memory state implementation
  * @ingroup ie_dev_api_mem_state_api
  */
-class IMemoryStateInternal {
+class IVariableStateInternal {
 public:
-    using Ptr = std::shared_ptr<IMemoryStateInternal>;
+    using Ptr = std::shared_ptr<IVariableStateInternal>;
 
-    virtual ~IMemoryStateInternal() = default;
+    virtual ~IVariableStateInternal() = default;
     virtual std::string GetName() const = 0;
     virtual void Reset() = 0;
     virtual void SetState(Blob::Ptr newState) = 0;
-    virtual Blob::CPtr GetLastState() const = 0;
+    virtual Blob::CPtr GetState() const = 0;
+    INFERENCE_ENGINE_DEPRECATED("Use GetState function instead")
+    virtual Blob::CPtr GetLastState() const {return GetState();}
 };
 
+/*
+ * @brief For compatibility reasons.
+ */
+using IMemoryStateInternal = IVariableStateInternal;
 }  // namespace InferenceEngine
index 861ceee..1ce30df 100644 (file)
@@ -83,3 +83,8 @@ TEST(InferRequestCPPTests, throwsOnUninitializedCast) {
     InferRequest req;
     ASSERT_THROW(auto &ireq = static_cast<IInferRequest::Ptr &>(req), InferenceEngine::details::InferenceEngineException);
 }
+
+TEST(InferRequestCPPTests, throwsOnUninitializedQueryState) {
+    InferRequest req;
+    ASSERT_THROW(req.QueryState(), InferenceEngine::details::InferenceEngineException);
+}
index f449c4c..89f3b75 100644 (file)
@@ -46,8 +46,10 @@ TEST(ExecutableNetworkTests, throwsOnUninitializedGetExecGraphInfo) {
 }
 
 TEST(ExecutableNetworkTests, throwsOnUninitializedQueryState) {
+    IE_SUPPRESS_DEPRECATED_START
     ExecutableNetwork exec;
     ASSERT_THROW(exec.QueryState(), InferenceEngine::details::InferenceEngineException);
+    IE_SUPPRESS_DEPRECATED_END
 }
 
 TEST(ExecutableNetworkTests, throwsOnUninitializedSetConfig) {
index 17de549..f01442e 100644 (file)
@@ -10,12 +10,15 @@ namespace {
             // 0 - plugin
             // 1 - executable_network
             // 2 - infer_request
-            {0, 1, 2},
-            {0, 2, 1},
-            {1, 0, 2},
-            {1, 2, 0},
-            {2, 0, 1},
-            {2, 1, 0}
+            // 3 - variable state
+            {3, 0, 1, 2},
+            {3, 0, 2, 1},
+            {3, 1, 0, 2},
+            {3, 1, 2, 0},
+            {3, 2, 0, 1},
+            {3, 2, 1, 0},
+            {0, 3, 1, 2},
+            {0, 1, 3, 2}
     };
 
     INSTANTIATE_TEST_CASE_P(smoke_BehaviorTests, HoldersTest,
@@ -24,4 +27,4 @@ namespace {
             ::testing::ValuesIn(orders)),
             HoldersTest::getTestCaseName);
 
-}  // namespace
\ No newline at end of file
+}  // namespace
diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/behavior/memory_states.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/behavior/memory_states.cpp
new file mode 100644 (file)
index 0000000..0a7bc37
--- /dev/null
@@ -0,0 +1,22 @@
+// Copyright (C) 2020 Intel Corporation
+//
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <common_test_utils/test_constants.hpp>
+#include "behavior/memory_states.hpp"
+#include "functional_test_utils/test_model/test_model.hpp"
+#include "functional_test_utils/plugin_cache.hpp"
+
+InferenceEngine::CNNNetwork getNetwork() {
+    auto model = FuncTestUtils::TestModel::getModelWithMultipleMemoryConnections(InferenceEngine::Precision::FP32);
+    auto ie = PluginCache::get().ie();
+    return ie->ReadNetwork(model.model_xml_str, model.weights_blob);
+}
+std::vector<memoryStateParams> memoryStateTestCases = {
+        memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_CPU)
+};
+
+INSTANTIATE_TEST_CASE_P(smoke_VariableStateBasic, VariableStateTest,
+        ::testing::ValuesIn(memoryStateTestCases),
+        VariableStateTest::getTestCaseName);
index 2c27ccc..c14c4dc 100644 (file)
@@ -10,12 +10,15 @@ namespace {
             // 0 - plugin
             // 1 - executable_network
             // 2 - infer_request
-            {0, 1, 2},
-            {0, 2, 1},
-            {1, 0, 2},
-            {1, 2, 0},
-            {2, 0, 1},
-            {2, 1, 0}
+            // 3 - variable state
+            {3, 0, 1, 2},
+            {3, 0, 2, 1},
+            {3, 1, 0, 2},
+            {3, 1, 2, 0},
+            {3, 2, 0, 1},
+            {3, 2, 1, 0},
+            {0, 3, 1, 2},
+            {0, 1, 3, 2}
     };
 
     INSTANTIATE_TEST_CASE_P(smoke_BehaviorTests, HoldersTest,
index 62ab38b..6a7b4cc 100644 (file)
@@ -17,6 +17,6 @@ std::vector<memoryStateParams> memoryStateTestCases = {
         memoryStateParams(getNetwork(), {"c_1-3", "r_1-3"}, CommonTestUtils::DEVICE_GNA)
 };
 
-INSTANTIATE_TEST_CASE_P(smoke_MemoryStateBasic, MemoryStateTest,
+INSTANTIATE_TEST_CASE_P(smoke_VariableStateBasic, VariableStateTest,
         ::testing::ValuesIn(memoryStateTestCases),
-        MemoryStateTest::getTestCaseName);
+        VariableStateTest::getTestCaseName);
index bac01c1..a9718b1 100644 (file)
@@ -14,7 +14,7 @@ typedef std::tuple<
         std::string>                 // Target device name
         memoryStateParams;
 
-class MemoryStateTest : public CommonTestUtils::TestsCommon,
+class VariableStateTest : public CommonTestUtils::TestsCommon,
                         public testing::WithParamInterface<memoryStateParams> {
 protected:
     InferenceEngine::CNNNetwork net;
index 61f9eb2..62db86b 100644 (file)
@@ -25,7 +25,11 @@ namespace BehaviorTestsDefinitions {
         if (deathTestStyle == "fast") {
             ::testing::GTEST_FLAG(death_test_style) = "threadsafe";
         }
-        function = ngraph::builder::subgraph::makeConvPoolRelu();
+        if (targetDevice == CommonTestUtils::DEVICE_CPU) {
+            function = ngraph::builder::subgraph::makeReadConcatSplitAssign();
+        } else {
+            function = ngraph::builder::subgraph::makeConvPoolRelu();
+        }
     }
 
     void HoldersTest::TearDown() {
@@ -42,6 +46,12 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "")
         InferenceEngine::Core core;
         auto exe_net = core.LoadNetwork(cnnNet, deviceName);
         auto request = exe_net.CreateInferRequest();
+        std::vector<InferenceEngine::VariableState> states = {};
+        try {
+            states = request.QueryState();
+        } catch(...) {
+            // do nothing
+        }
 
         auto release = [&](int i) {
             switch (i) {
@@ -54,6 +64,9 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "")
                 case 2:
                     request = {};
                     break;
+                case 3:
+                    states = {};
+                    break;
                 default:
                     break;
             }
@@ -67,4 +80,4 @@ EXPECT_EXIT(_statement; exit(0), testing::ExitedWithCode(0), "")
         // Test failed if crash happens
         EXPECT_NO_CRASH(release_order_test(order, targetDevice, function));
     }
-}  // namespace BehaviorTestsDefinitions
\ No newline at end of file
+}  // namespace BehaviorTestsDefinitions
index 2aa6694..4ef378e 100644 (file)
@@ -7,7 +7,7 @@
 #include "behavior/memory_states.hpp"
 #include "functional_test_utils/plugin_cache.hpp"
 
-std::string MemoryStateTest::getTestCaseName(const testing::TestParamInfo<memoryStateParams> &obj) {
+std::string VariableStateTest::getTestCaseName(const testing::TestParamInfo<memoryStateParams> &obj) {
     std::ostringstream result;
     InferenceEngine::CNNNetwork net;
     std::string targetDevice;
@@ -17,47 +17,51 @@ std::string MemoryStateTest::getTestCaseName(const testing::TestParamInfo<memory
     return result.str();
 }
 
-void MemoryStateTest::SetUp() {
+void VariableStateTest::SetUp() {
     std::tie(net, statesToQuery, deviceName) = GetParam();
 }
 
-InferenceEngine::ExecutableNetwork MemoryStateTest::PrepareNetwork() {
+InferenceEngine::ExecutableNetwork VariableStateTest::PrepareNetwork() {
     net.addOutput("Memory_1");
     net.addOutput("Memory_2");
     auto ie = PluginCache::get().ie(deviceName);
     return ie->LoadNetwork(net, deviceName);
 }
 
-TEST_P(MemoryStateTest, smoke_MemoryState_QueryState) {
+TEST_P(VariableStateTest, smoke_VariableState_QueryState) {
+    IE_SUPPRESS_DEPRECATED_START
     auto executableNet = PrepareNetwork();
 
     auto states = executableNet.QueryState();
-    ASSERT_TRUE(states.size() == 2) << "Incorrect number of MemoryStates";
+    ASSERT_TRUE(states.size() == 2) << "Incorrect number of VariableStates";
 
     for (auto&& state : states) {
         auto name = state.GetName();
         ASSERT_TRUE(std::find(statesToQuery.begin(), statesToQuery.end(), name) != statesToQuery.end())
             << "State " << name << "expected to be in memory states but it is not!";
     }
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-TEST_P(MemoryStateTest, smoke_MemoryState_SetState) {
+TEST_P(VariableStateTest, smoke_VariableState_SetState) {
+    IE_SUPPRESS_DEPRECATED_START
     auto executableNet = PrepareNetwork();
     const float new_state_val = 13.0f;
     for (auto&& state : executableNet.QueryState()) {
         state.Reset();
-        auto element_count = state.GetLastState()->size();
+        auto state_val = state.GetState();
+        auto element_count = state_val->size();
 
         std::vector<float> new_state_data(element_count, new_state_val);
         auto stateBlob = InferenceEngine::make_shared_blob<float>(
-            { InferenceEngine::Precision::FP32, {element_count}, InferenceEngine::C },
+            { state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
             new_state_data.data(), new_state_data.size());
 
         state.SetState(stateBlob);
     }
 
     for (auto&& state : executableNet.QueryState()) {
-        auto lastState = state.GetLastState();
+        auto lastState = state.GetState();
         auto last_state_size = lastState->size();
         auto last_state_data = lastState->cbuffer().as<float*>();
         ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
@@ -66,18 +70,21 @@ TEST_P(MemoryStateTest, smoke_MemoryState_SetState) {
             EXPECT_NEAR(new_state_val, last_state_data[i], 1e-5);
         }
     }
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-TEST_P(MemoryStateTest, smoke_MemoryState_Reset) {
+TEST_P(VariableStateTest, smoke_VariableState_Reset) {
+    IE_SUPPRESS_DEPRECATED_START
     auto executableNet = PrepareNetwork();
     const float new_state_val = 13.0f;
     for (auto&& state : executableNet.QueryState()) {
         state.Reset();
-        auto element_count = state.GetLastState()->size();
+        auto state_val = state.GetState();
+        auto element_count = state_val->size();
 
         std::vector<float> new_state_data(element_count, new_state_val);
         auto stateBlob = InferenceEngine::make_shared_blob<float>(
-            { InferenceEngine::Precision::FP32, {element_count}, InferenceEngine::C },
+            { state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
             new_state_data.data(), new_state_data.size());
 
         state.SetState(stateBlob);
@@ -87,7 +94,92 @@ TEST_P(MemoryStateTest, smoke_MemoryState_Reset) {
 
     auto states = executableNet.QueryState();
     for (int i = 0; i < states.size(); ++i) {
-        auto lastState = states[i].GetLastState();
+        auto lastState = states[i].GetState();
+        auto last_state_size = lastState->size();
+        auto last_state_data = lastState->cbuffer().as<float*>();
+
+        ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
+
+        if (i == 0) {
+            for (int j = 0; j < last_state_size; ++j) {
+                EXPECT_NEAR(0, last_state_data[j], 1e-5);
+            }
+        } else {
+            for (int j = 0; j < last_state_size; ++j) {
+                EXPECT_NEAR(13.0f, last_state_data[j], 1e-5);
+            }
+        }
+    }
+    IE_SUPPRESS_DEPRECATED_END
+}
+
+TEST_P(VariableStateTest, inferreq_smoke_VariableState_QueryState) {
+    auto executableNet = PrepareNetwork();
+    auto inferReq = executableNet.CreateInferRequest();
+
+    auto states = inferReq.QueryState();
+    ASSERT_TRUE(states.size() == 2) << "Incorrect number of VariableStates";
+
+    for (auto&& state : states) {
+        auto name = state.GetName();
+        ASSERT_TRUE(std::find(statesToQuery.begin(), statesToQuery.end(), name) != statesToQuery.end())
+            << "State " << name << "expected to be in memory states but it is not!";
+    }
+}
+
+TEST_P(VariableStateTest, inferreq_smoke_VariableState_SetState) {
+    auto executableNet = PrepareNetwork();
+    auto inferReq = executableNet.CreateInferRequest();
+
+    const float new_state_val = 13.0f;
+    for (auto&& state : inferReq.QueryState()) {
+        state.Reset();
+        auto state_val = state.GetState();
+        auto element_count = state_val->size();
+
+        std::vector<float> new_state_data(element_count, new_state_val);
+        auto stateBlob = InferenceEngine::make_shared_blob<float>(
+            { state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
+            new_state_data.data(), new_state_data.size());
+
+        state.SetState(stateBlob);
+    }
+
+    for (auto&& state : inferReq.QueryState()) {
+        auto lastState = state.GetState();
+        auto last_state_size = lastState->size();
+        auto last_state_data = lastState->cbuffer().as<float*>();
+        ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
+
+        for (int i = 0; i < last_state_size; i++) {
+            EXPECT_NEAR(new_state_val, last_state_data[i], 1e-5);
+        }
+    }
+}
+
+TEST_P(VariableStateTest, inferreq_smoke_VariableState_Reset) {
+    auto executableNet = PrepareNetwork();
+    auto inferReq = executableNet.CreateInferRequest();
+
+    const float new_state_val = 13.0f;
+    for (auto&& state : inferReq.QueryState()) {
+        state.Reset();
+        auto state_val = state.GetState();
+        auto element_count = state_val->size();
+
+        std::vector<float> new_state_data(element_count, new_state_val);
+        auto stateBlob = InferenceEngine::make_shared_blob<float>(
+            { state_val->getTensorDesc().getPrecision(), {1, element_count}, state_val->getTensorDesc().getLayout() },
+            new_state_data.data(), new_state_data.size());
+
+        state.SetState(stateBlob);
+    }
+
+    inferReq.QueryState().front().Reset();
+
+    auto states = inferReq.QueryState();
+    for (int i = 0; i < states.size(); ++i) {
+        auto lastState = states[i].GetState();
         auto last_state_size = lastState->size();
         auto last_state_data = lastState->cbuffer().as<float*>();
 
index 1c60681..d544e6b 100644 (file)
@@ -11,6 +11,7 @@
 #include <vector>
 
 #include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
+#include <cpp_interfaces/interface/ie_imemory_state_internal.hpp>
 
 class MockIAsyncInferRequestInternal : public InferenceEngine::IAsyncInferRequestInternal {
 public:
@@ -26,4 +27,5 @@ public:
     MOCK_CONST_METHOD2(GetPreProcess, void(const char* name, const InferenceEngine::PreProcessInfo**));
     MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
     MOCK_METHOD1(SetBatch, void(int));
+    MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
 };
index 9cec0ff..54a26b9 100644 (file)
@@ -28,7 +28,7 @@ public:
     MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr(void));
     MOCK_METHOD1(Export, void(const std::string &));
     void Export(std::ostream &) override {};
-    MOCK_METHOD0(QueryState, std::vector<IMemoryStateInternal::Ptr>(void));
+    MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>(void));
     MOCK_METHOD0(GetExecGraphInfo, CNNNetwork(void));
 
     MOCK_METHOD1(SetConfig, void(const std::map<std::string, Parameter> &config));
index 30343f8..0cc1e7f 100644 (file)
@@ -11,6 +11,7 @@
 #include <vector>
 
 #include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
+#include <cpp_interfaces/impl/ie_memory_state_internal.hpp>
 
 class MockIInferRequestInternal : public InferenceEngine::IInferRequestInternal {
 public:
@@ -20,4 +21,5 @@ public:
     MOCK_METHOD2(GetBlob, void(const char *name, InferenceEngine::Blob::Ptr &));
     MOCK_METHOD3(SetBlob, void(const char*, const InferenceEngine::Blob::Ptr&, const InferenceEngine::PreProcessInfo&));
     MOCK_METHOD2(GetPreProcess, void(const char*, const InferenceEngine::PreProcessInfo**));
+    MOCK_METHOD0(QueryState, std::vector<InferenceEngine::IVariableStateInternal::Ptr>());
 };
index c57cae8..13cd110 100644 (file)
 
 #include <cpp_interfaces/interface/ie_imemory_state_internal.hpp>
 
-class MockIMemoryStateInternal : public InferenceEngine::IMemoryStateInternal {
+class MockIVariableStateInternal : public InferenceEngine::IVariableStateInternal {
  public:
     MOCK_CONST_METHOD0(GetName, std::string());
     MOCK_METHOD0(Reset, void());
     MOCK_METHOD1(SetState, void(InferenceEngine::Blob::Ptr));
-    MOCK_CONST_METHOD0(GetLastState, InferenceEngine::Blob::CPtr());
+    MOCK_CONST_METHOD0(GetState, InferenceEngine::Blob::CPtr());
 };
index 62dc9de..32135cc 100644 (file)
 
 using namespace InferenceEngine;
 
-class MockIMemoryState : public InferenceEngine::IMemoryState {
+class MockIVariableState : public InferenceEngine::IVariableState {
 public:
     MOCK_QUALIFIED_METHOD3(GetName, const noexcept, StatusCode(char * , size_t, ResponseDesc *));
     MOCK_QUALIFIED_METHOD1(Reset, noexcept, StatusCode(ResponseDesc *));
     MOCK_QUALIFIED_METHOD2(SetState, noexcept, StatusCode(Blob::Ptr, ResponseDesc *));
-    MOCK_QUALIFIED_METHOD2(GetLastState, const noexcept, StatusCode(Blob::CPtr &, ResponseDesc *));
+    MOCK_QUALIFIED_METHOD2(GetState, const noexcept, StatusCode(Blob::CPtr &, ResponseDesc *));
 };
index 2af9e83..903cb04 100644 (file)
@@ -30,6 +30,6 @@ public:
     MOCK_QUALIFIED_METHOD3(GetConfig, const noexcept, StatusCode(const std::string &name, Parameter &result, ResponseDesc *resp));
     MOCK_QUALIFIED_METHOD3(GetMetric, const noexcept, StatusCode(const std::string &name, Parameter &result, ResponseDesc *resp));
     MOCK_QUALIFIED_METHOD2(GetContext, const noexcept, StatusCode(RemoteContext::Ptr &pContext, ResponseDesc *resp));
-    MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IMemoryState::Ptr &, size_t, ResponseDesc *));
+    MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IVariableState::Ptr &, size_t, ResponseDesc *));
     MOCK_QUALIFIED_METHOD0(Release, noexcept, void());
 };
index 3489c0e..613898a 100644 (file)
@@ -34,4 +34,5 @@ public:
     MOCK_QUALIFIED_METHOD3(SetBlob, noexcept, StatusCode(const char*, const Blob::Ptr&, ResponseDesc*));
     MOCK_QUALIFIED_METHOD4(SetBlob, noexcept, StatusCode(const char*, const Blob::Ptr&, const PreProcessInfo&, ResponseDesc*));
     MOCK_QUALIFIED_METHOD2(SetBatch, noexcept, StatusCode(int batch, ResponseDesc*));
+    MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IVariableState::Ptr &, size_t, ResponseDesc *));
 };
index 8064ffb..91d43e2 100644 (file)
@@ -484,6 +484,33 @@ static std::shared_ptr<ngraph::Function> makeConvBias(std::vector<size_t> inputS
     fn_ptr->set_friendly_name("ConvBias");
     return fn_ptr;
 }
+
+static std::shared_ptr<ngraph::Function> makeReadConcatSplitAssign(std::vector<size_t> inputShape = {1, 1, 2, 4},
+                                                                   InferenceEngine::Precision prc = InferenceEngine::Precision::FP32) {
+    ngraph::element::Type type = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(prc);
+    auto parameter =  ngraph::builder::makeParams(type, {inputShape});
+    parameter[0]->set_friendly_name("parameter");
+    auto init_const = ngraph::op::Constant::create(element::f32, Shape{1, 1, 2, 2}, {0, 0, 0, 0});
+    auto read = std::make_shared<ngraph::op::ReadValue>(init_const, "v0");
+    read->set_friendly_name("read");
+    std::vector<std::shared_ptr<ngraph::Node>> args = {parameter[0], read};
+    auto conc = std::make_shared<ngraph::op::Concat>(args, 3);
+    conc->set_friendly_name("concat");
+    auto res = std::make_shared<ngraph::op::Result>(conc);
+    res->set_friendly_name("result");
+    const auto axis = ngraph::op::Constant::create(element::i64, Shape{}, {3});
+    axis->set_friendly_name("axis");
+    auto crop = std::make_shared<ngraph::op::v1::Split>(conc, axis, 3);
+    crop->set_friendly_name("crop");
+    auto assign = std::make_shared<ngraph::op::Assign>(crop, "v0");
+    assign->set_friendly_name("assign");
+
+    std::shared_ptr<ngraph::Function> fn_ptr = std::make_shared<ngraph::Function>(ngraph::ResultVector({res}),
+                                                                                  ngraph::SinkVector({assign}),
+                                                                                  ngraph::ParameterVector{parameter});
+    fn_ptr->set_friendly_name("ReadConcatSplitAssign");
+    return fn_ptr;
+}
 }  // namespace subgraph
 }  // namespace builder
 }  // namespace ngraph
index 64499f2..ec2cd9c 100644 (file)
 #include <cpp/ie_executable_network.hpp>
 
 #include <cpp_interfaces/base/ie_executable_network_base.hpp>
-#include <cpp_interfaces/impl/ie_memory_state_internal.hpp>
+#include <cpp_interfaces/base/ie_infer_async_request_base.hpp>
 
 #include "unit_test_utils/mocks/cpp_interfaces/interface/mock_imemory_state_internal.hpp"
 #include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iexecutable_network_internal.hpp"
+#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iasync_infer_request_internal.hpp"
 
 using namespace ::testing;
 using namespace std;
 using namespace InferenceEngine;
 using namespace InferenceEngine::details;
 
-class MemoryStateTests : public ::testing::Test {
+template <class T>
+inline typename InferenceEngine::InferRequest make_infer_request(std::shared_ptr<T> impl) {
+    typename InferRequestBase<T>::Ptr req(new InferRequestBase<T>(impl), [](IInferRequest* p) {
+        p->Release();
+    });
+    return InferenceEngine::InferRequest(req);
+}
+
+
+class VariableStateTests : public ::testing::Test {
  protected:
     shared_ptr<MockIExecutableNetworkInternal> mockExeNetworkInternal;
-    shared_ptr<MockIMemoryStateInternal> mockMemoryStateInternal;
+    shared_ptr<MockIAsyncInferRequestInternal> mockInferRequestInternal;
+    shared_ptr<MockIVariableStateInternal> mockVariableStateInternal;
 
     virtual void SetUp() {
         mockExeNetworkInternal = make_shared<MockIExecutableNetworkInternal>();
-        mockMemoryStateInternal = make_shared<MockIMemoryStateInternal>();
+        mockInferRequestInternal = make_shared<MockIAsyncInferRequestInternal>();
+        mockVariableStateInternal = make_shared<MockIVariableStateInternal>();
     }
 };
 
-TEST_F(MemoryStateTests, ExecutableNetworkCanConvertOneMemoryStateFromCppToAPI) {
+TEST_F(VariableStateTests, ExecutableNetworkCanConvertOneVariableStateFromCppToAPI) {
+    IE_SUPPRESS_DEPRECATED_START
     auto net = make_executable_network(mockExeNetworkInternal);
-    std::vector<IMemoryStateInternal::Ptr> toReturn(1);
-    toReturn[0] = mockMemoryStateInternal;
+    std::vector<IVariableStateInternal::Ptr> toReturn(1);
+    toReturn[0] = mockVariableStateInternal;
 
     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
 
     auto state = net.QueryState();
     ASSERT_EQ(state.size(), 1);
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-TEST_F(MemoryStateTests, ExecutableNetworkCanConvertZeroMemoryStateFromCppToAPI) {
+TEST_F(VariableStateTests, ExecutableNetworkCanConvertZeroVariableStateFromCppToAPI) {
+    IE_SUPPRESS_DEPRECATED_START
     auto net = make_executable_network(mockExeNetworkInternal);
-    std::vector<IMemoryStateInternal::Ptr> toReturn;
+    std::vector<IVariableStateInternal::Ptr> toReturn;
 
     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
 
     auto state = net.QueryState();
     ASSERT_EQ(state.size(), 0);
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-TEST_F(MemoryStateTests, ExecutableNetworkCanConvert2MemoryStatesFromCPPtoAPI) {
+TEST_F(VariableStateTests, ExecutableNetworkCanConvert2VariableStatesFromCPPtoAPI) {
+    IE_SUPPRESS_DEPRECATED_START
     auto net = make_executable_network(mockExeNetworkInternal);
-    std::vector<IMemoryStateInternal::Ptr> toReturn;
-    toReturn.push_back(mockMemoryStateInternal);
-    toReturn.push_back(mockMemoryStateInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
+    toReturn.push_back(mockVariableStateInternal);
 
     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
 
     auto state = net.QueryState();
     ASSERT_EQ(state.size(), 2);
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-TEST_F(MemoryStateTests, MemoryStatePropagatesReset) {
+TEST_F(VariableStateTests, VariableStatePropagatesReset) {
+    IE_SUPPRESS_DEPRECATED_START
     auto net = make_executable_network(mockExeNetworkInternal);
-    std::vector<IMemoryStateInternal::Ptr> toReturn;
-    toReturn.push_back(mockMemoryStateInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
 
     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
-    EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).Times(1);
+    EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1);
 
     auto state = net.QueryState();
     state.front().Reset();
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-TEST_F(MemoryStateTests, MemoryStatePropagatesExceptionsFromReset) {
+TEST_F(VariableStateTests, VariableStatePropagatesExceptionsFromReset) {
+    IE_SUPPRESS_DEPRECATED_START
     auto net = make_executable_network(mockExeNetworkInternal);
-    std::vector<IMemoryStateInternal::Ptr> toReturn;
-    toReturn.push_back(mockMemoryStateInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
 
     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
-    EXPECT_CALL(*mockMemoryStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
+    EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
 
     auto state = net.QueryState();
     EXPECT_ANY_THROW(state.front().Reset());
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-TEST_F(MemoryStateTests, MemoryStatePropagatesGetName) {
+TEST_F(VariableStateTests, VariableStatePropagatesGetName) {
+    IE_SUPPRESS_DEPRECATED_START
     auto net = make_executable_network(mockExeNetworkInternal);
-    std::vector<IMemoryStateInternal::Ptr> toReturn;
-    toReturn.push_back(mockMemoryStateInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
 
     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
-    EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
+    EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
 
     auto state = net.QueryState();
     EXPECT_STREQ(state.front().GetName().c_str(), "someName");
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithZeroLen) {
+TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithZeroLen) {
+    IE_SUPPRESS_DEPRECATED_START
     auto net = make_executable_network(mockExeNetworkInternal);
-    std::vector<IMemoryStateInternal::Ptr> toReturn;
-    toReturn.push_back(mockMemoryStateInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
 
     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
-    EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
+    EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
 
-    IMemoryState::Ptr pState;
+    IVariableState::Ptr pState;
 
     static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
     char *name = reinterpret_cast<char *>(1);
     EXPECT_NO_THROW(pState->GetName(name, 0, nullptr));
+    IE_SUPPRESS_DEPRECATED_END
 }
 
 
-TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfOne) {
+TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfOne) {
+    IE_SUPPRESS_DEPRECATED_START
     auto net = make_executable_network(mockExeNetworkInternal);
-    std::vector<IMemoryStateInternal::Ptr> toReturn;
-    toReturn.push_back(mockMemoryStateInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
 
     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
-    EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
+    EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
 
-    IMemoryState::Ptr pState;
+    IVariableState::Ptr pState;
 
     static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
     char name[1];
     EXPECT_NO_THROW(pState->GetName(name, 1, nullptr));
     EXPECT_STREQ(name, "");
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-TEST_F(MemoryStateTests, MemoryStatePropagatesGetNameWithLenOfTwo) {
+TEST_F(VariableStateTests, VariableStatePropagatesGetNameWithLenOfTwo) {
+    IE_SUPPRESS_DEPRECATED_START
     auto net = make_executable_network(mockExeNetworkInternal);
-    std::vector<IMemoryStateInternal::Ptr> toReturn;
-    toReturn.push_back(mockMemoryStateInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
 
     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
-    EXPECT_CALL(*mockMemoryStateInternal.get(), GetName()).WillOnce(Return("someName"));
+    EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
 
-    IMemoryState::Ptr pState;
+    IVariableState::Ptr pState;
 
     static_cast<IExecutableNetwork::Ptr>(net)->QueryState(pState, 0, nullptr);
     char name[2];
     EXPECT_NO_THROW(pState->GetName(name, 2, nullptr));
     EXPECT_STREQ(name, "s");
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) {
+TEST_F(VariableStateTests, VariableStateCanPropagateSetState) {
+    IE_SUPPRESS_DEPRECATED_START
     auto net = make_executable_network(mockExeNetworkInternal);
-    std::vector<IMemoryStateInternal::Ptr> toReturn;
+    std::vector<IVariableStateInternal::Ptr> toReturn;
     Blob::Ptr saver;
-    toReturn.push_back(mockMemoryStateInternal);
+    toReturn.push_back(mockVariableStateInternal);
 
     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
-    EXPECT_CALL(*mockMemoryStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
+    EXPECT_CALL(*mockVariableStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
 
     float data[] = {123, 124, 125};
     auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
@@ -161,47 +192,50 @@ TEST_F(MemoryStateTests, MemoryStateCanPropagateSetState) {
     ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
     ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
     ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-TEST_F(MemoryStateTests, MemoryStateCanPropagateGetLastState) {
+TEST_F(VariableStateTests, VariableStateCanPropagateGetLastState) {
+    IE_SUPPRESS_DEPRECATED_START
     auto net = make_executable_network(mockExeNetworkInternal);
-    std::vector<IMemoryStateInternal::Ptr> toReturn;
+    std::vector<IVariableStateInternal::Ptr> toReturn;
 
     float data[] = {123, 124, 125};
     auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
 
 
-    toReturn.push_back(mockMemoryStateInternal);
+    toReturn.push_back(mockVariableStateInternal);
 
     EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
-    EXPECT_CALL(*mockMemoryStateInternal.get(), GetLastState()).WillOnce(Return(stateBlob));
+    EXPECT_CALL(*mockVariableStateInternal.get(), GetState()).WillOnce(Return(stateBlob));
 
 
-    auto saver = net.QueryState().front().GetLastState();
+    auto saver = net.QueryState().front().GetState();
     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
+    IE_SUPPRESS_DEPRECATED_END
 }
 
-class MemoryStateInternalMockImpl : public MemoryStateInternal {
+class VariableStateInternalMockImpl : public VariableStateInternal {
  public:
-    using MemoryStateInternal::MemoryStateInternal;
+    using VariableStateInternal::VariableStateInternal;
     MOCK_METHOD0(Reset, void());
 };
 
-TEST_F(MemoryStateTests, MemoryStateInternalCanSaveName) {
-    IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
+TEST_F(VariableStateTests, VariableStateInternalCanSaveName) {
+    IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name"));
     ASSERT_STREQ(pState->GetName().c_str(), "name");
 }
 
 
-TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) {
-    IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
+TEST_F(VariableStateTests, VariableStateInternalCanSaveState) {
+    IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name"));
     float data[] = {123, 124, 125};
     auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
 
     pState->SetState(stateBlob);
-    auto saver = pState->GetLastState();
+    auto saver = pState->GetState();
 
     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 123);
     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 124);
@@ -209,8 +243,8 @@ TEST_F(MemoryStateTests, MemoryStateInternalCanSaveState) {
 }
 
 
-TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) {
-    IMemoryStateInternal::Ptr pState(new MemoryStateInternalMockImpl("name"));
+TEST_F(VariableStateTests, VariableStateInternalCanSaveStateByReference) {
+    IVariableStateInternal::Ptr pState(new VariableStateInternalMockImpl("name"));
     float data[] = {123, 124, 125};
     auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
 
@@ -219,9 +253,162 @@ TEST_F(MemoryStateTests, MemoryStateInternalCanSaveStateByReference) {
     data[0] = 121;
     data[1] = 122;
     data[2] = 123;
-    auto saver = pState->GetLastState();
+    auto saver = pState->GetState();
 
     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[0], 121);
     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[1], 122);
     ASSERT_FLOAT_EQ(saver->cbuffer().as<const float *>()[2], 123);
 }
+
+// Tests for InferRequest::QueryState
+TEST_F(VariableStateTests, InferRequestCanConvertOneVariableStateFromCppToAPI) {
+    auto req = make_infer_request(mockInferRequestInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn(1);
+    toReturn[0] = mockVariableStateInternal;
+
+    EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
+
+    auto state = req.QueryState();
+    ASSERT_EQ(state.size(), 1);
+}
+
+TEST_F(VariableStateTests, InferRequestCanConvertZeroVariableStateFromCppToAPI) {
+    auto req = make_infer_request(mockInferRequestInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+
+    EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillOnce(Return(toReturn));
+
+    auto state = req.QueryState();
+    ASSERT_EQ(state.size(), 0);
+}
+
+TEST_F(VariableStateTests, InferRequestCanConvert2VariableStatesFromCPPtoAPI) {
+    auto req = make_infer_request(mockInferRequestInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
+    toReturn.push_back(mockVariableStateInternal);
+
+    EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(3).WillRepeatedly(Return(toReturn));
+
+    auto state = req.QueryState();
+    ASSERT_EQ(state.size(), 2);
+}
+
+TEST_F(VariableStateTests, InfReqVariableStatePropagatesReset) {
+    auto req = make_infer_request(mockInferRequestInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
+
+    EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
+    EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1);
+
+    auto state = req.QueryState();
+    state.front().Reset();
+}
+
+TEST_F(VariableStateTests, InfReqVariableStatePropagatesExceptionsFromReset) {
+    auto req = make_infer_request(mockInferRequestInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
+
+    EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
+    EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
+
+    auto state = req.QueryState();
+    EXPECT_ANY_THROW(state.front().Reset());
+}
+
+TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetName) {
+auto req = make_infer_request(mockInferRequestInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
+
+    EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(2).WillRepeatedly(Return(toReturn));
+    EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
+
+    auto state = req.QueryState();
+    EXPECT_STREQ(state.front().GetName().c_str(), "someName");
+}
+
+TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithZeroLen) {
+    auto req = make_infer_request(mockInferRequestInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
+
+    EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
+    EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
+
+    IVariableState::Ptr pState;
+
+    static_cast<IInferRequest::Ptr>(req)->QueryState(pState, 0, nullptr);
+    char *name = reinterpret_cast<char *>(1);
+    EXPECT_NO_THROW(pState->GetName(name, 0, nullptr));
+}
+
+TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfOne) {
+    auto req = make_infer_request(mockInferRequestInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
+
+    EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
+    EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
+
+    IVariableState::Ptr pState;
+
+    static_cast<IInferRequest::Ptr>(req)->QueryState(pState, 0, nullptr);
+    char name[1];
+    EXPECT_NO_THROW(pState->GetName(name, 1, nullptr));
+    EXPECT_STREQ(name, "");
+}
+
+TEST_F(VariableStateTests, InfReqVariableStatePropagatesGetNameWithLenOfTwo) {
+    auto req = make_infer_request(mockInferRequestInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    toReturn.push_back(mockVariableStateInternal);
+
+    EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
+    EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
+
+    IVariableState::Ptr pState;
+
+    static_cast<IInferRequest::Ptr>(req)->QueryState(pState, 0, nullptr);
+    char name[2];
+    EXPECT_NO_THROW(pState->GetName(name, 2, nullptr));
+    EXPECT_STREQ(name, "s");
+}
+
+TEST_F(VariableStateTests, InfReqVariableStateCanPropagateSetState) {
+    auto req = make_infer_request(mockInferRequestInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+    Blob::Ptr saver;
+    toReturn.push_back(mockVariableStateInternal);
+
+    EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
+    EXPECT_CALL(*mockVariableStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
+
+    float data[] = {123, 124, 125};
+    auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
+
+    EXPECT_NO_THROW(req.QueryState().front().SetState(stateBlob));
+    ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
+    ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
+    ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
+}
+
+TEST_F(VariableStateTests, InfReqVariableStateCanPropagateGetLastState) {
+    auto req = make_infer_request(mockInferRequestInternal);
+    std::vector<IVariableStateInternal::Ptr> toReturn;
+
+    float data[] = {123, 124, 125};
+    auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
+
+    toReturn.push_back(mockVariableStateInternal);
+
+    EXPECT_CALL(*mockInferRequestInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
+    EXPECT_CALL(*mockVariableStateInternal.get(), GetState()).WillOnce(Return(stateBlob));
+
+    auto saver = req.QueryState().front().GetState();
+    ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
+    ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
+    ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
+}
index 0632a65..810341f 100644 (file)
@@ -127,6 +127,7 @@ TEST_F(ExecutableNetworkTests, OperatorAmpersand) {
     ASSERT_EQ(exeNet_p, mockIExeNet_p);
 }
 
+IE_SUPPRESS_DEPRECATED_START
 TEST_F(ExecutableNetworkTests, QueryStateThrowsIfReturnErr) {
     EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _))
             .Times(1)
@@ -138,21 +139,22 @@ TEST_F(ExecutableNetworkTests, QueryStateIfReturnOutOfBounds) {
     EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _))
             .Times(1)
             .WillOnce(Return(InferenceEngine::OUT_OF_BOUNDS));
-    std::vector<InferenceEngine::MemoryState> MemState_;
+    std::vector<InferenceEngine::VariableState> MemState_;
     EXPECT_NO_THROW(MemState_ = exeNetwork->QueryState());
     EXPECT_EQ(MemState_.size(), 0);
 }
 
 TEST_F(ExecutableNetworkTests, QueryState) {
-    std::shared_ptr<MockIMemoryState> mockIMemState_p = std::make_shared<MockIMemoryState>();
+    std::shared_ptr<MockIVariableState> mockIMemState_p = std::make_shared<MockIVariableState>();
     EXPECT_CALL(*mockIExeNet_p.get(), QueryState(_, _, _))
             .Times(2)
             .WillOnce(DoAll(SetArgReferee<0>(mockIMemState_p), Return(InferenceEngine::OK)))
             .WillOnce(Return(InferenceEngine::OUT_OF_BOUNDS));
-    std::vector<InferenceEngine::MemoryState> MemState_v;
+    std::vector<InferenceEngine::VariableState> MemState_v;
     EXPECT_NO_THROW(MemState_v = exeNetwork->QueryState());
     EXPECT_EQ(MemState_v.size(), 1);
 }
+IE_SUPPRESS_DEPRECATED_END
 
 class ExecutableNetworkWithIInferReqTests : public ExecutableNetworkTests {
 protected:
index 43f7f9a..658c161 100644 (file)
@@ -207,6 +207,7 @@ void Regression::Matchers::CustomMatcher::matchCustom() {
                     }
                 }
 
+                IE_SUPPRESS_DEPRECATED_START
                 if (fetchResult.reset) {
                     auto states = executableApi.QueryState();
                     ASSERT_FALSE(states.empty());
@@ -218,6 +219,7 @@ void Regression::Matchers::CustomMatcher::matchCustom() {
                     outputs["reset"] = nullptr;
                     //continue;
                 }
+                IE_SUPPRESS_DEPRECATED_END
 
                 //FAIL()<<"stop after one frame";
 
index 0e8443d..947e248 100644 (file)
@@ -808,6 +808,7 @@ void GNAQueryStateMatcher :: match() {
 
     EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
 #endif
+    IE_SUPPRESS_DEPRECATED_START
     try {
         loadNetwork();
         if (GnaPluginTestEnvironment::kAnyNotNull == _env.numberOfStates) {
@@ -830,6 +831,7 @@ void GNAQueryStateMatcher :: match() {
     catch(...) {
         FAIL() << "unknown exception thrown";
     }
+    IE_SUPPRESS_DEPRECATED_END
 }