Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / gna_plugin / gna_plugin_config.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6 #include <vector>
7 #include <memory>
8 #include <utility>
9 #include <ie_icnn_network.hpp>
10 #include "ie_common.h"
11 #include "gna_plugin_log.hpp"
12
13 namespace GNAPluginNS {
14
15 using CNNNetworkPtr = std::shared_ptr<InferenceEngine::ICNNNetwork>;
16
17 struct Endpoint {
18     InferenceEngine::TargetDevice device;
19     InferenceEngine::Precision networkPrec;
20     std::function<CNNNetworkPtr(InferenceEngine::ICNNNetwork &network)> convert;
21
22     Endpoint(InferenceEngine::TargetDevice device,
23              InferenceEngine::Precision networkPrec,
24              std::function<CNNNetworkPtr(InferenceEngine::ICNNNetwork &network)> converter = [](InferenceEngine::ICNNNetwork &network) {
25                  return CNNNetworkPtr(&network, [](InferenceEngine::ICNNNetwork *nodelete) {});
26              }) : device(device), networkPrec(networkPrec), convert(converter) {
27     }
28 };
29
30 class Config {
31  public:
32     using Desc = std::vector<Endpoint>;
33     Desc supported;
34     InferenceEngine::TargetDevice _defaultDevice = InferenceEngine::TargetDevice::eDefault;
35
36  public:
37     explicit Config(std::vector<Endpoint> &&config)
38         : supported(std::move(config)) {
39     }
40
41     /**
42      * @brief default device value is plugin dependent, so it should be also set, to allow fallback
43      */
44     void setDefaultDevice(InferenceEngine::TargetDevice d) {
45         _defaultDevice = d;
46     }
47
48     inline Endpoint find_configuration(InferenceEngine::ICNNNetwork &network) {
49         auto device = network.getTargetDevice();
50         auto targetDevice = device == InferenceEngine::TargetDevice::eDefault ? _defaultDevice : device;
51
52         auto res = std::find_if(std::begin(supported), std::end(supported), [&](Endpoint &e) {
53             return e.networkPrec == network.getPrecision() && (
54                 e.device == device ||
55                     e.device == targetDevice);
56         });
57
58         if (res == std::end(supported)) {
59             THROW_GNA_EXCEPTION << "\"The plugin doesn't support target device: "
60                                << InferenceEngine::TargetDeviceInfo::name(network.getTargetDevice())
61                                << ".\nSupported target device: " << InferenceEngine::TargetDeviceInfo::name(InferenceEngine::TargetDevice::eGNA);
62         }
63
64         return *res;
65     }
66 };
67 }  // namespace GNAPluginNS