Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_plugin.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include "mkldnn_graph.h"
8 #include <string>
9 #include <map>
10 #include <unordered_map>
11 #include <memory>
12 #include <functional>
13 #include <cpp_interfaces/impl/ie_plugin_internal.hpp>
14
15 namespace MKLDNNPlugin {
16
17 class SimpleDataHash {
18 public:
19     SimpleDataHash() {
20         for (int i = 0; i < kTableSize; i++) {
21             uint64_t c = i;
22             for (int j = 0; j < 8; j++)
23                 c = ((c & 1) ? 0xc96c5795d7870f42 : 0) ^ (c >> 1);
24             table[i] = c;
25         }
26     }
27     // Computes 64-bit "cyclic redundancy check" sum, as specified in ECMA-182
28     uint64_t hash(const unsigned char* data, size_t size) const {
29         uint64_t crc = 0;
30         for (size_t idx = 0; idx < size; idx++)
31             crc = table[(unsigned char)crc ^ data[idx]] ^ (crc >> 8);
32
33         return ~crc;
34     }
35
36 protected:
37     static const int kTableSize = 256;
38     uint64_t table[kTableSize];
39 };
40
41 class MKLDNNWeightsSharing {
42 public:
43     MKLDNNMemoryPtr findOrCreate(const std::string& name_hash,
44                              std::function<MKLDNNMemoryPtr(void)> create) {
45         std::unique_lock<std::mutex> lock(guard);
46         auto found = sharedWeights.find(name_hash);
47
48         MKLDNNMemoryPtr ptr;
49         if (found == sharedWeights.end() || !(ptr = found->second.lock())) {
50             ptr = create();
51             sharedWeights[name_hash] = ptr;
52         }
53         return ptr;
54     }
55     static const SimpleDataHash& GetHashFunc () { return simpleCRC; }
56
57 protected:
58     std::unordered_map<std::string, std::weak_ptr<MKLDNNMemory>> sharedWeights;
59     std::mutex guard;
60     static const SimpleDataHash simpleCRC;
61 };
62
63 class Engine : public InferenceEngine::InferencePluginInternal {
64 public:
65     Engine() = default;
66     ~Engine() override = default;
67
68     InferenceEngine::ExecutableNetworkInternal::Ptr
69     LoadExeNetworkImpl(InferenceEngine::ICNNNetwork &network,
70                        const std::map<std::string, std::string> &config) override;
71
72     void AddExtension(InferenceEngine::IExtensionPtr extension) override;
73     /**
74      * @deprecated
75      * @param config
76      */
77     void SetConfig(const std::map<std::string, std::string> &config) override;
78
79     /**
80      * @deprecated Use the version with config parameter
81      */
82     void QueryNetwork(const InferenceEngine::ICNNNetwork& network, InferenceEngine::QueryNetworkResult& res) const override;
83     void QueryNetwork(const InferenceEngine::ICNNNetwork& network,
84                       const std::map<std::string, std::string>& config, InferenceEngine::QueryNetworkResult& res) const override;
85
86     static MKLDNNWeightsSharing& GetWeightsSharing() { return weightsSharing; }
87
88 private:
89     Config engConfig;
90     MKLDNNExtensionManager::Ptr extensionManager = std::make_shared<MKLDNNExtensionManager>();
91
92 protected:
93     static MKLDNNWeightsSharing weightsSharing;
94 };
95
96 }  // namespace MKLDNNPlugin