Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / gna_plugin / gna_executable_network.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6 #include <memory>
7 #include <string>
8 #include <map>
9 #include <vector>
10
11 #include <cpp_interfaces/impl/ie_executable_network_thread_safe_default.hpp>
12 #include "gna_infer_request.hpp"
13 #include "gna_plugin.hpp"
14 #include <cpp_interfaces/ie_executor_manager.hpp>
15 #include <cpp_interfaces/impl/ie_executable_network_thread_safe_async_only.hpp>
16
17 namespace GNAPluginNS {
18
19 class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafeAsyncOnly {
20     std::shared_ptr<GNAPlugin> plg;
21
22  public:
23     GNAExecutableNetwork(const std::string &aotFileName, const std::map<std::string, std::string> &config) :
24         plg(std::make_shared<GNAPlugin>(config)) {
25         plg->ImportNetwork(aotFileName);
26         _networkInputs  = plg->GetInputs();
27         _networkOutputs = plg->GetOutputs();
28     }
29
30     GNAExecutableNetwork(InferenceEngine::ICNNNetwork &network, const std::map<std::string, std::string> &config)
31         : plg(std::make_shared<GNAPlugin>(config)) {
32         plg->LoadNetwork(network);
33     }
34
35     InferenceEngine::AsyncInferRequestInternal::Ptr
36         CreateAsyncInferRequestImpl(InferenceEngine::InputsDataMap networkInputs,
37                                     InferenceEngine::OutputsDataMap networkOutputs) override {
38         return std::make_shared<GNAInferRequest>(plg, networkInputs, networkOutputs);
39     }
40
41
42
43     std::vector<InferenceEngine::IMemoryStateInternal::Ptr>  QueryState() override {
44         auto pluginStates = plg->QueryState();
45         std::vector<InferenceEngine::IMemoryStateInternal::Ptr> state(pluginStates.begin(), pluginStates.end());
46         return plg->QueryState();
47     }
48
49     void Export(const std::string &modelFileName) override {
50         plg->Export(modelFileName);
51     }
52 };
53 }  // namespace GNAPluginNS