1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef _MKLDNN_MEMORY_HPP
18 #define _MKLDNN_MEMORY_HPP
20 #include "mkldnn_common.hpp"
23 dnn_mem_t(): active_(false) {}
25 dnn_mem_t(const mkldnn_memory_desc_t &md, void *data = NULL)
26 : active_(initialize(md, data) == OK) {}
28 dnn_mem_t(int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t dt,
29 mkldnn_memory_format_t fmt, void *data = NULL)
30 : active_(initialize(ndims, dims, dt, fmt, data) == OK) {}
32 dnn_mem_t(const mkldnn_memory_desc_t &md, mkldnn_data_type_t dt,
33 mkldnn_memory_format_t fmt = mkldnn_format_undef,
35 : active_(initialize(md, dt, fmt, data) == OK) {}
37 dnn_mem_t(const dnn_mem_t &rhs, mkldnn_data_type_t dt,
38 mkldnn_memory_format_t fmt = mkldnn_format_undef,
39 void *data = NULL): dnn_mem_t(rhs.md_, dt, fmt, data)
40 { if (active_) reorder(rhs); }
42 /* FIXME: ugly RT assert... need better mkldnn memory handling */
43 dnn_mem_t &operator=(const dnn_mem_t &rhs)
44 { []() { SAFE(FAIL, CRIT); return 0; }(); return *this; }
45 dnn_mem_t(const dnn_mem_t &rhs)
46 { []() { SAFE(FAIL, CRIT); return 0; }(); }
48 ~dnn_mem_t() { cleanup(); }
50 int reorder(const dnn_mem_t &rhs) { return reorder(rhs, NULL); }
51 int reorder(const dnn_mem_t &rhs, const mkldnn_primitive_attr_t &attr) {
52 if (this == &rhs) return OK;
54 mkldnn_primitive_desc_t rpd;
56 DNN_SAFE(mkldnn_reorder_primitive_desc_create_v2(&rpd, rhs.mpd_,
58 mkldnn_primitive_at_t i = {rhs.p_, 0};
59 const_mkldnn_primitive_t o = p_;
60 DNN_SAFE(mkldnn_primitive_create(&r, rpd, &i, &o), WARN);
61 SAFE(execute(r), WARN);
62 DNN_SAFE(mkldnn_primitive_desc_destroy(rpd), CRIT);
63 DNN_SAFE(mkldnn_primitive_destroy(r), CRIT);
68 int N() { return md_.dims[0]; }
69 int with_G() { return md_.ndims == 5; }
70 int G() { return md_.ndims == 5 ? md_.dims[0] : 1; }
72 int C() { return md_.ndims == 1 ? md_.dims[0] : md_.dims[1]; }
73 int OC() { return md_.dims[with_G() + 0]; }
74 int IC() { return md_.dims[with_G() + 1]; }
75 int H() { return md_.dims[with_G() + 2]; } // works for both IH and KH
76 int W() { return md_.dims[with_G() + 3]; } // works for both IW and KW
78 size_t size() const { return mkldnn_memory_primitive_desc_get_size(mpd_); }
80 size_t nelems(bool with_padding_dims = false) const {
81 auto dims = with_padding_dims
82 ? md_.layout_desc.blocking.padding_dims
85 for (int i = 0; i < md_.ndims; ++i)
90 mkldnn_data_type_t dt() const { return md_.data_type; }
91 size_t sizeof_dt() const { return ::sizeof_dt(dt()); }
94 explicit operator T*() const { return static_cast<T*>(data_); }
96 float get_elem(size_t idx) const {
99 case mkldnn_s8: elem = static_cast<int8_t *>(data_)[idx]; break;
100 case mkldnn_u8: elem = static_cast<uint8_t *>(data_)[idx]; break;
101 case mkldnn_s16: elem = static_cast<int16_t *>(data_)[idx]; break;
102 case mkldnn_s32: elem = static_cast<int32_t *>(data_)[idx]; break;
103 case mkldnn_f32: elem = static_cast<float *>(data_)[idx]; break;
104 default: assert(!"bad data type");
109 void set_elem(size_t idx, float value) {
111 case mkldnn_s8: ((int8_t *)data_)[idx] = value; break;
112 case mkldnn_u8: ((uint8_t *)data_)[idx] = value; break;
113 case mkldnn_s16: ((int16_t *)data_)[idx] = value; break;
114 case mkldnn_s32: ((int32_t *)data_)[idx] = value; break;
115 case mkldnn_f32: ((float *)data_)[idx] = value; break;
116 default: assert(!"bad data type");
120 size_t get_scale_idx(size_t data_idx, int scale_mask) const {
121 const int ndims = md_.ndims;
122 const auto &dims = md_.dims;
126 if (scale_mask != 0) {
127 for (int i = 0; i < ndims; ++i) {
128 size_t d = md_.ndims - 1 - i;
129 auto pos = data_idx % dims[d];
131 if (scale_mask & (1 << d)) {
132 offset += pos * stride;
143 mkldnn_memory_desc_t md_;
144 mkldnn_primitive_desc_t mpd_;
145 mkldnn_primitive_t p_;
147 bool is_data_owner_, active_;
150 int initialize(const mkldnn_memory_desc_t &md, mkldnn_data_type_t dt,
151 mkldnn_memory_format_t fmt, void *data) {
152 if (fmt == mkldnn_format_undef || fmt == mkldnn_blocked) {
156 DNN_SAFE(mkldnn_memory_desc_init(&md_, md.ndims, md.dims, dt, fmt),
159 DNN_SAFE(mkldnn_memory_primitive_desc_create(&mpd_, &md_, engine),
161 DNN_SAFE(mkldnn_primitive_create(&p_, mpd_, NULL, NULL), CRIT);
162 is_data_owner_ = data == NULL;
164 const size_t alignment = 1024 * 1024 * 2;
165 size_t sz = mkldnn_memory_primitive_desc_get_size(mpd_);
166 data_ = zmalloc(sz, alignment);
167 DNN_SAFE(data_ == NULL ? mkldnn_out_of_memory : mkldnn_success,
172 DNN_SAFE(mkldnn_memory_set_data_handle(p_, data_), CRIT);
177 int initialize(const mkldnn_memory_desc_t &md, void *data) {
178 return initialize(md, md.data_type, mkldnn_format_undef, data);
181 int initialize(int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t dt,
182 mkldnn_memory_format_t fmt, void* data) {
183 mkldnn_memory_desc_t xmd;
184 DNN_SAFE(mkldnn_memory_desc_init(&xmd, ndims, dims, dt, fmt), CRIT);
185 SAFE(initialize(xmd, data), CRIT);
190 if (!active_) return OK;
191 DNN_SAFE(mkldnn_primitive_desc_destroy(mpd_), CRIT);
192 DNN_SAFE(mkldnn_primitive_destroy(p_), CRIT);
193 if (is_data_owner_) zfree(data_);