Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / mkldnn_memory.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef _MKLDNN_MEMORY_HPP
18 #define _MKLDNN_MEMORY_HPP
19
20 #include "mkldnn_common.hpp"
21
22 struct dnn_mem_t {
23     dnn_mem_t(): active_(false) {}
24
25     dnn_mem_t(const mkldnn_memory_desc_t &md, void *data = NULL)
26         : active_(initialize(md, data) == OK) {}
27
28     dnn_mem_t(int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t dt,
29             mkldnn_memory_format_t fmt, void *data = NULL)
30         : active_(initialize(ndims, dims, dt, fmt, data) == OK) {}
31
32     dnn_mem_t(const mkldnn_memory_desc_t &md, mkldnn_data_type_t dt,
33             mkldnn_memory_format_t fmt = mkldnn_format_undef,
34             void *data = NULL)
35         : active_(initialize(md, dt, fmt, data) == OK) {}
36
37     dnn_mem_t(const dnn_mem_t &rhs, mkldnn_data_type_t dt,
38             mkldnn_memory_format_t fmt = mkldnn_format_undef,
39             void *data = NULL): dnn_mem_t(rhs.md_, dt, fmt, data)
40     { if (active_) reorder(rhs); }
41
42     /* FIXME: ugly RT assert... need better mkldnn memory handling */
43     dnn_mem_t &operator=(const dnn_mem_t &rhs)
44     { []() { SAFE(FAIL, CRIT); return 0; }(); return *this; }
45     dnn_mem_t(const dnn_mem_t &rhs)
46     { []() { SAFE(FAIL, CRIT); return 0; }(); }
47
48     ~dnn_mem_t() { cleanup(); }
49
50     int reorder(const dnn_mem_t &rhs) { return reorder(rhs, NULL); }
51     int reorder(const dnn_mem_t &rhs, const mkldnn_primitive_attr_t &attr) {
52         if (this == &rhs) return OK;
53
54         mkldnn_primitive_desc_t rpd;
55         mkldnn_primitive_t r;
56         DNN_SAFE(mkldnn_reorder_primitive_desc_create_v2(&rpd, rhs.mpd_,
57                     mpd_, attr), WARN);
58         mkldnn_primitive_at_t i = {rhs.p_, 0};
59         const_mkldnn_primitive_t o = p_;
60         DNN_SAFE(mkldnn_primitive_create(&r, rpd, &i, &o), WARN);
61         SAFE(execute(r), WARN);
62         DNN_SAFE(mkldnn_primitive_desc_destroy(rpd), CRIT);
63         DNN_SAFE(mkldnn_primitive_destroy(r), CRIT);
64
65         return OK;
66     }
67
68     int N() { return md_.dims[0]; }
69     int with_G() { return md_.ndims == 5; }
70     int G() { return md_.ndims == 5 ? md_.dims[0] : 1; }
71
72     int C() { return md_.ndims == 1 ? md_.dims[0] : md_.dims[1]; }
73     int OC() { return md_.dims[with_G() + 0]; }
74     int IC() { return md_.dims[with_G() + 1]; }
75     int H() { return md_.dims[with_G() + 2]; } // works for both IH and KH
76     int W() { return md_.dims[with_G() + 3]; } // works for both IW and KW
77
78     size_t size() const { return mkldnn_memory_primitive_desc_get_size(mpd_); }
79
80     size_t nelems(bool with_padding_dims = false) const {
81         auto dims = with_padding_dims
82             ? md_.layout_desc.blocking.padding_dims
83             : md_.dims;
84         size_t n = 1;
85         for (int i = 0; i < md_.ndims; ++i)
86             n *= dims[i];
87         return n;
88     }
89
90     mkldnn_data_type_t dt() const { return md_.data_type; }
91     size_t sizeof_dt() const { return ::sizeof_dt(dt()); }
92
93     template <typename T>
94     explicit operator T*() const { return static_cast<T*>(data_); }
95
96     float get_elem(size_t idx) const {
97         float elem = 0.0;
98         switch (dt()) {
99             case mkldnn_s8: elem = static_cast<int8_t *>(data_)[idx]; break;
100             case mkldnn_u8: elem = static_cast<uint8_t *>(data_)[idx]; break;
101             case mkldnn_s16: elem = static_cast<int16_t *>(data_)[idx]; break;
102             case mkldnn_s32: elem = static_cast<int32_t *>(data_)[idx]; break;
103             case mkldnn_f32: elem = static_cast<float *>(data_)[idx]; break;
104             default: assert(!"bad data type");
105         }
106         return elem;
107     }
108
109     void set_elem(size_t idx, float value) {
110         switch (dt()) {
111             case mkldnn_s8: ((int8_t *)data_)[idx] = value; break;
112             case mkldnn_u8: ((uint8_t *)data_)[idx] = value; break;
113             case mkldnn_s16: ((int16_t *)data_)[idx] = value; break;
114             case mkldnn_s32: ((int32_t *)data_)[idx] = value; break;
115             case mkldnn_f32: ((float *)data_)[idx] = value; break;
116             default: assert(!"bad data type");
117         }
118     }
119
120     size_t get_scale_idx(size_t data_idx, int scale_mask) const {
121         const int ndims = md_.ndims;
122         const auto &dims = md_.dims;
123         size_t stride = 1;
124         size_t offset = 0;
125
126         if (scale_mask != 0) {
127             for (int i = 0; i < ndims; ++i) {
128                 size_t d = md_.ndims - 1 - i;
129                 auto pos = data_idx % dims[d];
130                 data_idx /= dims[d];
131                 if (scale_mask & (1 << d)) {
132                     offset += pos * stride;
133                     stride *= dims[d];
134                 }
135             }
136         }
137
138         return offset;
139     }
140
141     /* fields */
142
143     mkldnn_memory_desc_t md_;
144     mkldnn_primitive_desc_t mpd_;
145     mkldnn_primitive_t p_;
146     void *data_;
147     bool is_data_owner_, active_;
148
149 private:
150     int initialize(const mkldnn_memory_desc_t &md, mkldnn_data_type_t dt,
151             mkldnn_memory_format_t fmt, void *data) {
152         if (fmt == mkldnn_format_undef || fmt == mkldnn_blocked) {
153             md_ = md;
154             md_.data_type = dt;
155         } else {
156             DNN_SAFE(mkldnn_memory_desc_init(&md_, md.ndims, md.dims, dt, fmt),
157                     CRIT);
158         }
159         DNN_SAFE(mkldnn_memory_primitive_desc_create(&mpd_, &md_, engine),
160                 CRIT);
161         DNN_SAFE(mkldnn_primitive_create(&p_, mpd_, NULL, NULL), CRIT);
162         is_data_owner_ = data == NULL;
163         if (data == NULL) {
164             const size_t alignment = 1024 * 1024 * 2;
165             size_t sz = mkldnn_memory_primitive_desc_get_size(mpd_);
166             data_ = zmalloc(sz, alignment);
167             DNN_SAFE(data_ == NULL ? mkldnn_out_of_memory : mkldnn_success,
168                     WARN);
169         } else {
170             data_ = data;
171         }
172         DNN_SAFE(mkldnn_memory_set_data_handle(p_, data_), CRIT);
173
174         return OK;
175     }
176
177     int initialize(const mkldnn_memory_desc_t &md, void *data) {
178         return initialize(md, md.data_type, mkldnn_format_undef, data);
179     }
180
181     int initialize(int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t dt,
182                     mkldnn_memory_format_t fmt, void* data) {
183         mkldnn_memory_desc_t xmd;
184         DNN_SAFE(mkldnn_memory_desc_init(&xmd, ndims, dims, dt, fmt), CRIT);
185         SAFE(initialize(xmd, data), CRIT);
186         return OK;
187     }
188
189     int cleanup() {
190         if (!active_) return OK;
191         DNN_SAFE(mkldnn_primitive_desc_destroy(mpd_), CRIT);
192         DNN_SAFE(mkldnn_primitive_destroy(p_), CRIT);
193         if (is_data_owner_) zfree(data_);
194         return OK;
195     }
196 };
197
198 #endif