Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_memory.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <memory>
8 #include <vector>
9
10 #include "inference_engine.hpp"
11 #include "mkldnn_dims.h"
12 #include <mkldnn.hpp>
13 #include <string>
14 #include <mkldnn_types.h>
15 #include <functional>
16
17 namespace MKLDNNPlugin {
18
19 class MKLDNNMemoryDesc {
20 public:
21     MKLDNNMemoryDesc(): desc({}, mkldnn::memory::data_type::f32, mkldnn::memory::format::format_undef) {}
22     explicit MKLDNNMemoryDesc(const InferenceEngine::TensorDesc& tDesc);
23     explicit MKLDNNMemoryDesc(const mkldnn::memory::desc& desc): desc(desc) {}
24     MKLDNNMemoryDesc(mkldnn::memory::dims dims, mkldnn::memory::data_type dataType, mkldnn::memory::format format);
25
26     mkldnn::memory::format getFormat() const {
27         return static_cast<mkldnn::memory::format>(desc.data.format);
28     }
29
30     mkldnn::memory::data_type getDataType() const {
31         return static_cast<mkldnn::memory::data_type>(desc.data.data_type);
32     }
33
34     MKLDNNDims getDims() const {
35         return MKLDNNDims(desc.data.dims, desc.data.ndims);
36     }
37
38     bool blocksExtended() const;
39     operator bool() const {
40         return getFormat() != mkldnn::memory::format::any && getFormat() != mkldnn::memory::format::format_undef;
41     }
42
43     bool operator == (const MKLDNNMemoryDesc& rhs) const;
44     bool operator != (const MKLDNNMemoryDesc& rhs) const;
45
46     operator mkldnn::memory::desc() const;
47     operator InferenceEngine::TensorDesc() const;
48
49 private:
50     mkldnn::memory::desc desc;
51 };
52
53
54 class MKLDNNMemory;
55
56 using MKLDNNMemoryPtr = std::shared_ptr<MKLDNNMemory>;
57
58 class MKLDNNMemory {
59 public:
60     explicit MKLDNNMemory(const mkldnn::engine& eng);
61
62     const mkldnn::memory& GetPrimitive() const {
63         return *prim;
64     }
65
66     const std::shared_ptr<mkldnn::memory>& GetPrimitivePtr() const {
67         return prim;
68     }
69
70     mkldnn::memory::desc GetDescriptor() const {
71         return prim->get_primitive_desc().desc();
72     }
73
74     mkldnn::memory::primitive_desc GetPrimitiveDescriptor() const {
75         return prim->get_primitive_desc();
76     }
77
78     void* GetData() const {
79         void* data = prim->get_data_handle();
80         if (data == nullptr)
81             THROW_IE_EXCEPTION << "Cannot get memory!";
82         return data;
83     }
84
85     mkldnn::memory::data_type GetDataType() const {
86         return static_cast<mkldnn::memory::data_type>(GetDescriptor().data.data_type);
87     }
88
89     size_t GetSize() const;
90
91     mkldnn::memory::format GetFormat() const {
92         return static_cast<mkldnn::memory::format>(prim->get_primitive_desc().desc().data.format);
93     }
94
95     mkldnn::memory::dims GetDims() const {
96         auto data = GetDescriptor().data;
97
98         return std::vector<ptrdiff_t>(data.dims, data.dims + data.ndims);
99     }
100
101     void Create(mkldnn::memory::dims dims, mkldnn::memory::data_type data_type, mkldnn::memory::format format,
102                 const void* data = nullptr);
103
104     void Create(const mkldnn::memory::desc& desc, const void* data = nullptr);
105
106     void SetData(mkldnn::memory::data_type dataType, mkldnn::memory::format format, const void* data, size_t size, bool ftz = true) const;
107     void SetData(const MKLDNNMemory& memory, bool ftz = true) const;
108
109     void FillZero();
110
111     static bool IsPlainFormat(mkldnn::memory::format format);
112     static mkldnn::memory::format GetPlainFormat(mkldnn::memory::dims dims);
113     static InferenceEngine::Layout GetPlainLayout(mkldnn::memory::dims dims);
114     static bool isConsistant(mkldnn::memory::dims dims, mkldnn::memory::format format);
115     static mkldnn::memory::format Convert(const InferenceEngine::Layout layout);
116
117     static std::string formatToString(mkldnn::memory::format fmt);
118
119     static void CreateBlockingDesc(mkldnn::memory::desc& desc);
120
121 private:
122     std::shared_ptr<mkldnn::memory> prim;
123     mkldnn::engine eng;
124 };
125
126
127 }  // namespace MKLDNNPlugin