Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / gna_plugin / gna_infer_request.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <memory>
8 #include <string>
9 #include <map>
10
11 #include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
12 #include "gna_plugin.hpp"
13
14 namespace GNAPluginNS {
15
16 class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
17     std::shared_ptr<GNAPlugin> plg;
18     uint32_t inferRequestIdx = -1;
19
20  public:
21     GNAInferRequest(const std::shared_ptr<GNAPlugin>& plg,
22                     InferenceEngine::InputsDataMap networkInputs,
23                     InferenceEngine::OutputsDataMap networkOutputs)
24         : InferenceEngine::AsyncInferRequestInternal(networkInputs, networkOutputs), plg(plg) {
25         // TODO: internal connection API - better to generalize
26         if (networkOutputs.empty()) {
27             THROW_GNA_EXCEPTION << "GNAInferRequest :: network has zero outputs";
28         }
29         if (networkInputs.empty()) {
30             THROW_GNA_EXCEPTION << "GNAInferRequest :: network has zero inputs";
31         }
32
33         // copy inputs blobs since we need to have them in separate address space to allow simultaneous infer requests
34         _outputs[_networkOutputs.begin()->first] = plg->GetOutputBlob(networkOutputs.begin()->second->getPrecision());
35         for (auto input : _networkInputs) {
36             _inputs[input.first] =
37                 plg->GetInputBlob(input.first, networkInputs.begin()->second->getInputPrecision());
38         }
39     }
40     /**
41      * @brief Infers specified input(s) in synchronous mode
42      * @note blocks all method of IInferRequest while request is ongoing (running or waiting in queue)
43      */
44     void InferImpl() override {
45         // execute input pre-processing.
46         execDataPreprocessing(_inputs);
47         plg->Infer(_inputs, _outputs);
48     }
49
50     /**
51      * @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
52      *  Note: not all plugins may provide meaningful data
53      *  @param perfMap - a map of layer names to profiling information for that layer.
54      */
55     void GetPerformanceCounts(std::map<std::string,
56                                                InferenceEngine::InferenceEngineProfileInfo> &perfMap) const override {
57         plg->GetPerformanceCounts(perfMap);
58     }
59
60     /**
61         * @brief methods with _ThreadUnsafe prefix are to implement in plugins
62         * or in default wrapper (e.g. AsyncInferRequestThreadSafeDefault)
63         */
64     void StartAsyncImpl() override {
65         // execute input pre-processing.
66         execDataPreprocessing(_inputs);
67         inferRequestIdx = plg->QueueInference(_inputs, _outputs);
68     }
69
70     InferenceEngine::StatusCode Wait(int64_t millis_timeout) override {
71         if (inferRequestIdx == -1) return InferenceEngine::INFER_NOT_STARTED;
72         plg->Wait(inferRequestIdx);
73         return InferenceEngine::OK;
74     }
75 };
76 }  // namespace GNAPluginNS