Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / memory_desc_wrapper.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef MEMORY_DESC_WRAPPER_HPP
18 #define MEMORY_DESC_WRAPPER_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "nstl.hpp"
24 #include "utils.hpp"
25
26 #include "type_helpers.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30
31 /** thin wrapper class over \struct memory_desc_t which allows easy
32  * manipulatings with underlying C structure, which is taken by refernce */
33 struct memory_desc_wrapper: public c_compatible {
34     const memory_desc_t *_md;
35
36     /** constructor which takes a reference to a constant underlying C memory
37      * descriptor \param md */
38     memory_desc_wrapper(const memory_desc_t &md) : _md(&md) {}
39     memory_desc_wrapper(const memory_desc_t *md) : _md(md) {}
40     memory_desc_wrapper(const memory_pd_t *m_pd);
41
42     /* implementing attrubutes */
43     inline int ndims() const { return _md->ndims; }
44     const dims_t &dims() const { return _md->dims; }
45     data_type_t data_type() const { return _md->data_type; }
46     memory_format_t format() const { return _md->format; }
47     bool is_blocking_desc() const {
48         return (format() != memory_format::wino_fmt
49                 && format() != memory_format::rnn_packed
50                 && format() != memory_format::any
51                 && format() != memory_format::undef);
52     }
53     bool is_wino_desc() const {
54         return (format() == memory_format::wino_fmt);
55     }
56     bool is_rnn_packed_desc() const {
57         return (format() == memory_format::rnn_packed);
58     }
59     const blocking_desc_t &blocking_desc() const {
60         assert(is_blocking_desc());
61         return _md->layout_desc.blocking;
62     }
63     const wino_data_t &wino_desc() const {
64         assert(is_wino_desc());
65         return _md->layout_desc.wino_desc;
66     }
67     const rnn_packed_data_t &rnn_packed_desc() const {
68         assert(is_rnn_packed_desc());
69         return _md->layout_desc.rnn_packed_desc;
70     }
71
72     /* some useful function */
73
74     /** returns the number of elements including padding if \param with_padding
75      * is true, and the number of data elements otherwise */
76     size_t nelems(bool with_padding = false) const {
77         if (is_zero()) return 0;
78         return (utils::array_product<ptrdiff_t, size_t>(with_padding
79                 ? blocking_desc().padding_dims : dims(), ndims()));
80     }
81
82     /** returns true if memory descriptor is zero */
83     bool is_zero() const { return ndims() == 0; }
84
85     /** returns true if memory descriptor contains zero as one of its dim */
86     bool has_zero_dim() const { return nelems() == 0; }
87
88     /** return the size of data type (a shortcut) */
89     size_t data_type_size() const
90     { return types::data_type_size(data_type()); }
91
92     /** return the size of data type of additional buffer */
93     size_t additional_buffer_data_size() const {
94         using namespace mkldnn::impl::memory_format;
95         return (utils::one_of(format(), hwio_s8s8, hwigo_s8s8,
96                     gOIhw4o4i_s8s8,
97                     gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8,
98                     gOIhw2i8o4i_s8s8,
99                     gOhIw8o4i_s8s8, OhIw8o4i_s8s8,
100                     Goihw16g_s8s8))
101             ? sizeof(int32_t) : 0;
102     }
103
104     /** return true if memory format has additional buffer */
105     bool is_additional_buffer() const {
106         using namespace mkldnn::impl::memory_format;
107         return (utils::one_of(format(), hwio_s8s8, hwigo_s8s8,
108                     gOIhw4o4i_s8s8,
109                     gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8,
110                     gOIhw2i8o4i_s8s8,
111                     gOhIw8o4i_s8s8, OhIw8o4i_s8s8,
112                     Goihw16g_s8s8))
113             ? true : false;
114     }
115
116     /** returns the size of additional buffer */
117     size_t additional_buffer_size() const {
118         using namespace mkldnn::impl::memory_format;
119         const auto &padding_dims = blocking_desc().padding_dims;
120         switch(format()) {
121             case hwigo_s8s8:
122             case gOIhw4o4i_s8s8:
123             case gOIhw2i8o4i_s8s8:
124             case gOIhw4i16o4i_s8s8:
125             case gOhIw8o4i_s8s8:
126                 return size_t(padding_dims[0]) * size_t(padding_dims[1])
127                     * additional_buffer_data_size();
128             case Goihw16g_s8s8:
129             case hwio_s8s8:
130             case OIhw4i16o4i_s8s8:
131             case OhIw8o4i_s8s8:
132                 return size_t(padding_dims[0]) * additional_buffer_data_size();
133             default:
134                 return 0;
135         }
136     }
137
138     /** returns the size required to store described memory
139      * note: if offset_padding != 0 returns 0 (need to specify the behavior) */
140     size_t size() const {
141         using namespace mkldnn::impl::memory_format;
142         if (is_zero() || has_zero_dim() || format() == memory_format::any)
143             return 0;
144
145         assert((false
146                     || types::format_normalize(format()) == blocked
147                     || types::is_format_double_blocked(format())
148                     || format() == wino_fmt
149                     || format() == rnn_packed)
150                 && "unknown format");
151
152         if (format() == wino_fmt) {
153             return wino_desc().size;
154         } else if (format() == rnn_packed) {
155             return rnn_packed_desc().size;
156         } else {
157             if (blocking_desc().offset_padding != 0) return 0;
158
159             const auto &block_dims = blocking_desc().block_dims;
160             const auto &strides = blocking_desc().strides;
161             const auto &padding_dims = blocking_desc().padding_dims;
162
163             size_t max_size = 0;
164             for (int d = 0; d < ndims(); ++d) {
165                 auto block = block_dims[d];
166                 max_size = nstl::max(max_size,
167                     size_t(padding_dims[d] / block) * strides[0][d]);
168                 if (block > 1)
169                     max_size = nstl::max(max_size,
170                             size_t(block * strides[1][d]));
171             }
172
173             return max_size * data_type_size() + additional_buffer_size();
174         }
175     }
176
177     /** returns true if data is dense in memory */
178     bool is_dense(bool with_padding = false) const;
179
180     /** returns true if memory desc is fully defined */
181     bool is_defined() const { return format() != memory_format::any; }
182
183     /** returns true if the only (potentially) padded dim is \param dim */
184     bool only_padded_dim(int dim) const {
185         assert(is_blocking_desc());
186         const auto pdims = blocking_desc().padding_dims;
187         for (int d = 0; d < ndims(); ++d)
188             if (d != dim && dims()[d] != pdims[d])
189                 return false;
190         return true;
191     }
192
193     /** returns true if memory desc has blocked layout and block dims are 1s */
194     bool is_plain() const {
195         if (!is_blocking_desc()) return false;
196         return
197             utils::array_product(blocking_desc().block_dims, ndims()) == 1;
198     }
199
200     /* comparison section */
201
202     inline bool operator==(const memory_desc_wrapper &rhs) const;
203     inline bool operator!=(const memory_desc_wrapper &rhs) const
204     { return !operator==(rhs); }
205     inline bool operator==(const memory_desc_t &rhs) const
206     { return operator==(memory_desc_wrapper(rhs)); }
207     inline bool operator!=(const memory_desc_t &rhs) const
208     { return !operator==(rhs); }
209
210     /** returns true if data (w/o padding if with_padding == false and w/
211      * padding otherwise) have the same physical structure, i.e. dimensions,
212      * strides, and blocked structure. depending on with_data_type flag
213      * data_type is taken or not taken into account. dim_start allows to chech
214      * similarity for the logical part of data [dim_start .. ndims()].
215      * CAUTION: format any and undef are not similiar to whatever, hence the
216      * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
217     /* TODO: revise */
218     inline bool similar_to(const memory_desc_wrapper &rhs,
219             bool with_padding = true, bool with_data_type = true,
220             int dim_start = 0) const;
221
222     /** returns true if one memory can be reordered to another */
223     inline bool consistent_with(const memory_desc_wrapper &rhs) const;
224
225     /* offset section */
226
227     /** returns physical offset by logical one. logical offset is represented by
228      * an array \param pos. if \param is_pos_padded is true \param pos
229      * represents the position in already padded area */
230     inline size_t off_v(const dims_t pos, bool is_pos_padded = false) const {
231         using namespace mkldnn::impl::memory_format;
232         assert(format() != memory_format::any);
233         assert(is_blocking_desc());
234         const blocking_desc_t &blk = blocking_desc();
235         const dims_t &optd = blk.offset_padding_to_data;
236
237         size_t phys_offset = blk.offset_padding;
238         for (int d = 0; d < ndims(); ++d) {
239             const int block = blk.block_dims[d];
240
241             const int p = pos[d] + (is_pos_padded ? 0 : optd[d]);
242             const int pos_within_block = p % block;
243             const int pos_block = p / block;
244
245             phys_offset += pos_block * blk.strides[0][d];
246             phys_offset += pos_within_block * blk.strides[1][d];
247         }
248         if (utils::one_of(format(), gOIhw4i16o4i, OIhw4i16o4i, gOIhw4i16o4i_s8s8,
249                             OIhw4i16o4i_s8s8)) {
250             // TODO: Fix temporary workaround for formats with double blocking
251             const bool with_groups = (format() == gOIhw4i16o4i
252                                       || format() == gOIhw4i16o4i_s8s8);
253             const int oc_16 = pos[with_groups + 0] % 16;
254             const int ic_4  = pos[with_groups + 1] % 4;
255             phys_offset += 4 * oc_16 + ic_4 - (oc_16 + 16 * ic_4);
256         }
257         if (utils::one_of(format(), gOIhw2i8o4i,  gOIhw2i8o4i_s8s8)) {
258             // TODO: Fix temporary workaround for formats with double blocking
259             const bool with_groups = true;
260             const int oc_8 = pos[with_groups + 0] % 8;
261             const int ic_4 = pos[with_groups + 1] % 4;
262             phys_offset += 4 * oc_8 + ic_4 - (oc_8 + 8 * ic_4);
263         }
264         if (format() == gOIw8i16o2i || format() == OIw8i16o2i) {
265             // TODO: Fix temporary workaround for formats with double blocking
266             const bool with_groups = format() == gOIw8i16o2i;
267             const int oc_16 = pos[with_groups + 0] % 16;
268             const int ic_2  = pos[with_groups + 1] % 2;
269             phys_offset += -16 * ic_2 + oc_16 + ic_2;
270         }
271         if (format() == gOIhw8i16o2i || format() == OIhw8i16o2i) {
272             // TODO: Fix temporary workaround for formats with double blocking
273             const bool with_groups = format() == gOIhw8i16o2i;
274             const int oc_16 = pos[with_groups + 0] % 16;
275             const int ic_2  = pos[with_groups + 1] % 2;
276             phys_offset += -16 * ic_2 + oc_16 + ic_2;
277         }
278         if (format() == gOIdhw8i16o2i || format() == OIdhw8i16o2i) {
279             // TODO: Fix temporary workaround for formats with double blocking
280             const bool with_groups = format() == gOIdhw8i16o2i;
281             const int oc_16 = pos[with_groups + 0] % 16;
282             const int ic_2  = pos[with_groups + 1] % 2;
283             phys_offset += -16 * ic_2 + oc_16 + ic_2;
284         }
285         if (format() == gOIhw8o16i2o || format() == OIhw8o16i2o) {
286             // TODO: Fix temporary workaround for formats with double blocking
287             const bool with_groups = format() == gOIhw8o16i2o;
288             const int ic_16 = pos[with_groups + 1] % 16;
289             const int oc_2  = pos[with_groups + 0] % 2;
290             phys_offset += -16 * oc_2 + ic_16 + oc_2;
291         }
292         if (format() == gOIw8o16i2o || format() == OIw8o16i2o) {
293             // TODO: Fix temporary workaround for formats with double blocking
294             const bool with_groups = format() == gOIw8o16i2o;
295             const int ic_16 = pos[with_groups + 1] % 16;
296             const int oc_2  = pos[with_groups + 0] % 2;
297             phys_offset += -16 * oc_2 + ic_16 + oc_2;
298         }
299         return phys_offset;
300     }
301
302     /** returns physical offset by logical one. logical offset is represented by
303      * a scalar \param l_offset. if \param is_pos_padded is true, \param
304      * l_offset represents logical offset in already padded area */
305     inline size_t off_l(size_t l_offset, bool is_pos_padded = false) const {
306         assert(is_blocking_desc());
307         const dims_t &padding_dims = blocking_desc().padding_dims;
308         dims_t pos;
309         for (int rd = 0; rd < ndims(); ++rd) {
310             const int d = ndims() - 1 - rd;
311             const int cur_dim = is_pos_padded ? padding_dims[d] : dims()[d];
312             pos[d] = l_offset % cur_dim;
313             l_offset /= cur_dim;
314         }
315         return off_v(pos, is_pos_padded);
316     }
317
318     /** returns physical offset by logical one. logical offset is represented by
319      * a tuple of indeces (\param xn, ..., \param x1, \param x0) */
320     template<typename... Args> inline size_t off(Args... args) const {
321         assert(sizeof...(args) == ndims());
322         dims_t pos = { args... };
323         return off_v(pos, false);
324     }
325
326     /** returns physical offset by logical one. logical offset is represented by
327      * a tuple of indeces (\param xn, ..., \param x1, \param x0) in already
328      * padded area */
329     template<typename... Args> inline size_t off_padding(Args... args) const {
330         assert(sizeof...(args) == ndims());
331         dims_t pos = { args... };
332         return off_v(pos, true);
333     }
334
335     /** returns physical offset by logical one. Logical offset is represented by
336      * a tuple of block indeces (\param bn, ..., \param b1, \param b0). It is a
337      * user responsibility to adjust the result to get offset within blocks */
338     template<typename ...Args> inline size_t blk_off(Args... args) const {
339         return _blk_off<sizeof...(args), Args...>(args...);
340     }
341
342     template<bool skip_first, typename T, typename ...Args>
343     inline size_t blk_off(T xn, Args... args) const {
344         return skip_first
345             ? blk_off<Args...>(args...)
346             : blk_off<T, Args...>(xn, args...);
347     }
348
349     /* static functions section */
350     /* TODO: replace with non-static, once _md becomes non-const ref */
351
352     static status_t compute_blocking(memory_desc_t &memory_desc);
353
354 private:
355     /* TODO: put logical_offset in utils */
356     template<typename T>
357     inline size_t logical_offset(T x0) const { return size_t(x0); }
358
359     template<typename T, typename... Args>
360     inline size_t logical_offset(T xn, Args... args) const {
361         const size_t n_args = sizeof...(args);
362         return size_t(xn)*utils::array_product<n_args>(
363                 &dims()[ndims() - n_args]) + logical_offset(args...);
364     }
365
366     template<int ORIG_LEN, typename ...Void>
367     inline size_t _blk_off() const {
368         assert(is_blocking_desc());
369         return blocking_desc().offset_padding;
370     }
371
372     template<int ORIG_LEN, typename T, typename ...Args>
373     inline size_t _blk_off(T xc, Args ...args) const {
374         assert(is_blocking_desc());
375         constexpr int dc = ORIG_LEN - sizeof...(args) - 1;
376         return size_t(xc) * blocking_desc().strides[0][dc]
377             + _blk_off<ORIG_LEN, Args...>(args...);
378     }
379 };
380
381 inline bool memory_desc_wrapper::is_dense(bool with_padding) const {
382     if (utils::one_of(format(), memory_format::undef, memory_format::any))
383         return false;
384     return nelems(with_padding) * data_type_size() == size();
385 }
386
387 inline bool memory_desc_wrapper::operator==(const memory_desc_wrapper &rhs)
388     const
389 {
390     using namespace impl::types;
391     return ndims() == rhs.ndims()
392             && utils::array_cmp(dims(), rhs.dims(), ndims())
393             && data_type() == rhs.data_type()
394             && ((is_blocking_desc() && rhs.is_blocking_desc())
395                        || (is_wino_desc() && rhs.is_wino_desc())
396                        || (is_rnn_packed_desc() && rhs.is_rnn_packed_desc()))
397             && (is_blocking_desc() ? blocking_desc_is_equal(blocking_desc(),
398                                              rhs.blocking_desc(), ndims()) :
399                                      true)
400             && (is_wino_desc() ? wino_desc_is_equal(
401                                          wino_desc(), rhs.wino_desc()) :
402                                  true)
403             && (is_rnn_packed_desc() ?
404                                rnn_packed_desc_is_equal(rnn_packed_desc(),
405                                        rhs.rnn_packed_desc()) :
406                                true);
407 }
408
409 inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
410         bool with_padding, bool with_data_type, int dim_start) const {
411     using namespace impl::types;
412     using namespace utils;
413     if (utils::one_of(format(), memory_format::undef, memory_format::any))
414         return false;
415     if (is_wino_desc() || rhs.is_wino_desc() || is_rnn_packed_desc()
416             || rhs.is_rnn_packed_desc())
417         return false;
418
419     const int ds = dim_start;
420     const auto &blk = blocking_desc();
421     const auto &r_blk = rhs.blocking_desc();
422
423     return ndims() == rhs.ndims()
424         && dim_start <= ndims() /* guard */
425         && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds)
426         && format_normalize(format()) == format_normalize(rhs.format())
427         && IMPLICATION(with_data_type, data_type() == rhs.data_type())
428         && array_cmp(blk.block_dims + ds, r_blk.block_dims + ds, ndims() - ds)
429         && array_cmp(blk.strides[0] + ds, r_blk.strides[0] + ds, ndims() - ds)
430         && array_cmp(blk.strides[1] + ds, r_blk.strides[1] + ds, ndims() - ds)
431         && IMPLICATION(with_padding,
432                 array_cmp(blk.padding_dims + ds, r_blk.padding_dims + ds,
433                     ndims() - ds)
434                 && array_cmp(blk.offset_padding_to_data + ds,
435                     r_blk.offset_padding_to_data + ds, ndims() - ds));
436 }
437
438 inline bool memory_desc_wrapper::consistent_with(
439         const memory_desc_wrapper &rhs) const {
440     if (ndims() == rhs.ndims()) {
441         for (int d = 0; d < ndims(); ++d) {
442             if (dims()[d] != rhs.dims()[d]) return false;
443         }
444         return true;
445     } else {
446         /* TODO: revise.
447          * is the following possible?
448          * [1, a, b] <--reorder--> [a, b]
449          * [a, 1, b] <--reorder--> [a, b]
450          * not, at least for now */
451         return false;
452     }
453 }
454
455 }
456 }
457
458 #endif
459
460 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
461