1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
10 #include "inference_engine.hpp"
11 #include "mkldnn_dims.h"
14 #include <mkldnn_types.h>
17 namespace MKLDNNPlugin {
19 class MKLDNNMemoryDesc {
21 MKLDNNMemoryDesc(): desc({}, mkldnn::memory::data_type::f32, mkldnn::memory::format::format_undef) {}
22 explicit MKLDNNMemoryDesc(const InferenceEngine::TensorDesc& tDesc);
23 explicit MKLDNNMemoryDesc(const mkldnn::memory::desc& desc): desc(desc) {}
24 MKLDNNMemoryDesc(mkldnn::memory::dims dims, mkldnn::memory::data_type dataType, mkldnn::memory::format format);
26 mkldnn::memory::format getFormat() const {
27 return static_cast<mkldnn::memory::format>(desc.data.format);
30 mkldnn::memory::data_type getDataType() const {
31 return static_cast<mkldnn::memory::data_type>(desc.data.data_type);
34 MKLDNNDims getDims() const {
35 return MKLDNNDims(desc.data.dims, desc.data.ndims);
38 bool blocksExtended() const;
39 operator bool() const {
40 return getFormat() != mkldnn::memory::format::any && getFormat() != mkldnn::memory::format::format_undef;
43 bool operator == (const MKLDNNMemoryDesc& rhs) const;
44 bool operator != (const MKLDNNMemoryDesc& rhs) const;
46 operator mkldnn::memory::desc() const;
47 operator InferenceEngine::TensorDesc() const;
50 mkldnn::memory::desc desc;
56 using MKLDNNMemoryPtr = std::shared_ptr<MKLDNNMemory>;
60 explicit MKLDNNMemory(const mkldnn::engine& eng);
62 const mkldnn::memory& GetPrimitive() const {
66 const std::shared_ptr<mkldnn::memory>& GetPrimitivePtr() const {
70 mkldnn::memory::desc GetDescriptor() const {
71 return prim->get_primitive_desc().desc();
74 mkldnn::memory::primitive_desc GetPrimitiveDescriptor() const {
75 return prim->get_primitive_desc();
78 void* GetData() const {
79 void* data = prim->get_data_handle();
81 THROW_IE_EXCEPTION << "Cannot get memory!";
85 mkldnn::memory::data_type GetDataType() const {
86 return static_cast<mkldnn::memory::data_type>(GetDescriptor().data.data_type);
89 size_t GetSize() const;
91 mkldnn::memory::format GetFormat() const {
92 return static_cast<mkldnn::memory::format>(prim->get_primitive_desc().desc().data.format);
95 mkldnn::memory::dims GetDims() const {
96 auto data = GetDescriptor().data;
98 return std::vector<ptrdiff_t>(data.dims, data.dims + data.ndims);
101 void Create(mkldnn::memory::dims dims, mkldnn::memory::data_type data_type, mkldnn::memory::format format,
102 const void* data = nullptr);
104 void Create(const mkldnn::memory::desc& desc, const void* data = nullptr);
106 void SetData(mkldnn::memory::data_type dataType, mkldnn::memory::format format, const void* data, size_t size, bool ftz = true) const;
107 void SetData(const MKLDNNMemory& memory, bool ftz = true) const;
111 static bool IsPlainFormat(mkldnn::memory::format format);
112 static mkldnn::memory::format GetPlainFormat(mkldnn::memory::dims dims);
113 static InferenceEngine::Layout GetPlainLayout(mkldnn::memory::dims dims);
114 static bool isConsistant(mkldnn::memory::dims dims, mkldnn::memory::format format);
115 static mkldnn::memory::format Convert(const InferenceEngine::Layout layout);
117 static std::string formatToString(mkldnn::memory::format fmt);
119 static void CreateBlockingDesc(mkldnn::memory::desc& desc);
122 std::shared_ptr<mkldnn::memory> prim;
127 } // namespace MKLDNNPlugin