Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / type_helpers.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef TYPE_HELPERS_HPP
18 #define TYPE_HELPERS_HPP
19
20 #include <assert.h>
21 #include <math.h>
22
23 #include "mkldnn.h"
24
25 #include "c_types_map.hpp"
26 #include "mkldnn_traits.hpp"
27 #include "nstl.hpp"
28 #include "utils.hpp"
29 #include "math_utils.hpp"
30
31 namespace mkldnn {
32 namespace impl {
33
34 template <typename T>
35 status_t safe_ptr_assign(T * &lhs, T* rhs) {
36     if (rhs == nullptr) return status::out_of_memory;
37     lhs = rhs;
38     return status::success;
39 }
40
41 template <typename T, typename U> struct is_subset
42 { static constexpr bool value = false; };
43 template <typename T> struct is_subset<T, T>
44 { static constexpr bool value = true; };
45 template <typename T> struct is_subset<T,
46          typename utils::enable_if<nstl::is_integral<T>::value, float>::type>
47 { static constexpr bool value = true; };
48 #define ISSPEC(t1, t2) template <> \
49     struct is_subset<t1, t2> { static constexpr bool value = true; }
50 ISSPEC(int16_t, int32_t);
51 ISSPEC(int8_t, int32_t);
52 ISSPEC(uint8_t, int32_t);
53 ISSPEC(int8_t, int16_t);
54 ISSPEC(uint8_t, int16_t);
55 #undef ISSPEC
56
57 namespace types {
58
59 inline size_t data_type_size(data_type_t data_type) {
60     using namespace data_type;
61     switch (data_type) {
62     case f32: return sizeof(prec_traits<f32>::type);
63     case s32: return sizeof(prec_traits<s32>::type);
64     case s16: return sizeof(prec_traits<s16>::type);
65     case s8: return sizeof(prec_traits<s8>::type);
66     case u8: return sizeof(prec_traits<u8>::type);
67     case bin: return sizeof(prec_traits<u8>::type);
68     case data_type::undef:
69     default: assert(!"unknown data_type");
70     }
71     return 0; /* not supposed to be reachable */
72 }
73
74 inline memory_format_t flat_memory_format(int ndims) {
75     switch (ndims) {
76     case 1: return memory_format::x;
77     case 2: return memory_format::nc;
78     case 3: return memory_format::ncw;
79     case 4: return memory_format::nchw;
80     case 5: return memory_format::ncdhw;
81     default: return memory_format::undef;
82     }
83     return memory_format::undef;
84 }
85
86 inline memory_format_t format_normalize(const memory_format_t fmt) {
87     using namespace memory_format;
88     /* FIXME: double blocked formats are special cases -- the blocking
89      *        structure doesn't correctly describe memory layout (wrt
90      *        the strides within blocks). Though as long as the code
91      *        uses memory_desc_wrapper::off() or explicit offset
92      *        calculations everything should be fine. */
93     const bool is_blocked = utils::one_of(fmt, blocked,
94             x,
95             nc,
96             ncw,
97             nwc,
98             nCw4c,
99             nCw8c,
100             nCw16c,
101             nchw,
102             nhwc,
103             chwn,
104             nChw4c,
105             nChw8c,
106             nChw16c,
107             ncdhw,
108             ndhwc,
109             nCdhw4c,
110             nCdhw8c,
111             nCdhw16c,
112             oi,
113             io,
114             oiw,
115             wio,
116             Owi4o,
117             OIw4i4o,
118             Owi8o,
119             OIw8i8o,
120             OIw8o8i,
121             OIw16i16o,
122             OIw16o16i,
123             Oiw4o,
124             Oiw16o,
125             Owi16o,
126             OIw8i16o2i,
127             OIw8o16i2o,
128             IOw16o16i,
129             oihw,
130             ihwo,
131             hwio,
132             iohw,
133             hwio_s8s8,
134             dhwio,
135             oidhw,
136             OIdhw4i4o,
137             Odhwi4o,
138             OIdhw8i8o,
139             OIdhw8o8i,
140             Odhwi8o,
141             OIdhw16i16o,
142             OIdhw16o16i,
143             Oidhw4o,
144             Oidhw16o,
145             Odhwi16o,
146             oIhw8i,
147             oIhw16i,
148             oIdhw8i,
149             oIdhw16i,
150             OIhw4i4o,
151             OIhw8i8o,
152             OIhw16i16o,
153             OIhw4i16o4i,
154             OIhw4i16o4i_s8s8,
155             OIhw8i16o2i,
156             OIdhw8i16o2i,
157             OIhw8o16i2o,
158             OIhw8o8i,
159             OhIw8o4i,
160             OhIw8o32i,
161             OhIw16o32i,
162             OhIw8o4i_s8s8,
163             OIhw16o16i,
164             IOhw16o16i,
165             Oihw4o,
166             Oihw16o,
167             Ohwi8o,
168             Ohwi4o,
169             Ohwi16o,
170             goiw,
171             gOwi4o,
172             gOIw4i4o,
173             gOwi8o,
174             gOIw8i8o,
175             gOIw8o8i,
176             gOIw16i16o,
177             gOIw16o16i,
178             gOiw4o,
179             gOiw16o,
180             gOwi16o,
181             gOIw8i16o2i,
182             gOIw8o16i2o,
183             gIOw16o16i,
184             goihw,
185             hwigo,
186             giohw,
187             hwigo_s8s8,
188             gOIhw4i4o,
189             gOIhw8i8o,
190             gOIhw16i16o,
191             gOIhw4i16o4i,
192             gOIhw4i16o4i_s8s8,
193             gOIhw2i8o4i,
194             gOIhw2i8o4i_s8s8,
195             gOIhw8i16o2i,
196             gOIdhw8i16o2i,
197             gOIhw8o16i2o,
198             gOIhw4o4i,
199             gOIhw4o4i_s8s8,
200             gOIhw8o8i,
201             gOhIw8o4i,
202             gOhIw8o4i_s8s8,
203             gOIhw16o16i,
204             gIOhw16o16i,
205             gOihw4o,
206             gOihw16o,
207             gOhwi8o,
208             gOhwi4o,
209             gOhwi16o,
210             Goihw8g,
211             Goihw16g,
212             Goihw16g_s8s8,
213             goidhw,
214             gOIdhw4i4o,
215             gOdhwi4o,
216             gOIdhw8i8o,
217             gOIdhw8o8i,
218             gOdhwi8o,
219             gOIdhw16i16o,
220             gOIdhw16o16i,
221             gOidhw16o,
222             gOidhw4o,
223             gOdhwi16o,
224             ntc,
225             tnc,
226             ldsnc,
227             ldigo,
228             ldgoi,
229             ldgo);
230     return is_blocked ? blocked : fmt;
231 }
232
233 inline bool is_format_double_blocked(memory_format_t fmt) {
234     using namespace memory_format;
235     return utils::one_of(OIw8o16i2o, OIw8i16o2i, OIhw8i16o2i, OIdhw8i16o2i,
236             OIhw8o16i2o, OIhw4i16o4i, OIhw4i16o4i_s8s8,
237             gOIw8o16i2o, gOIw8i16o2i, gOIhw8i16o2i, gOIdhw8i16o2i, gOIhw8o16i2o,
238             gOIhw4i16o4i, gOIhw4i16o4i_s8s8, gOIhw2i8o4i, gOIhw2i8o4i_s8s8);
239 }
240
241 inline bool blocking_desc_is_equal(const blocking_desc_t &lhs,
242         const blocking_desc_t &rhs, int ndims = TENSOR_MAX_DIMS) {
243     using mkldnn::impl::utils::array_cmp;
244     return lhs.offset_padding == rhs.offset_padding
245         && array_cmp(lhs.block_dims, rhs.block_dims, ndims)
246         && array_cmp(lhs.strides[0], rhs.strides[0], ndims)
247         && array_cmp(lhs.strides[1], rhs.strides[1], ndims)
248         && array_cmp(lhs.padding_dims, rhs.padding_dims, ndims)
249         && array_cmp(lhs.offset_padding_to_data, rhs.offset_padding_to_data,
250                 ndims);
251 }
252
253 inline bool wino_desc_is_equal(const wino_data_t &lhs,
254     const wino_data_t &rhs) {
255     return lhs.wino_format == rhs.wino_format
256         && lhs.alpha == rhs.alpha
257         && lhs.ic == rhs.ic
258         && lhs.oc == rhs.oc
259         && lhs.ic_block == rhs.ic_block
260         && lhs.oc_block == rhs.oc_block
261         && lhs.ic2_block == rhs.ic2_block
262         && lhs.oc2_block == rhs.oc2_block
263         && lhs.r == rhs.r;
264 }
265
266 inline bool rnn_packed_desc_is_equal(
267         const rnn_packed_data_t &lhs, const rnn_packed_data_t &rhs) {
268     bool ok = lhs.format == rhs.format && lhs.n_parts == rhs.n_parts
269             && lhs.offset_compensation == rhs.offset_compensation
270             && lhs.size == rhs.size
271             && lhs.n == rhs.n;
272     if (!ok)
273         return false;
274
275     for (int i = 0; i < rhs.n_parts; i++)
276         ok = ok && lhs.parts[i] == rhs.parts[i];
277     for (int i = 0; i < rhs.n_parts; i++)
278         ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i];
279     return ok;
280 }
281
282 inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) {
283     assert(lhs.primitive_kind == mkldnn::impl::primitive_kind::memory);
284     assert(rhs.primitive_kind == mkldnn::impl::primitive_kind::memory);
285     bool base_equal = true
286         && lhs.ndims == rhs.ndims
287         && mkldnn::impl::utils::array_cmp(lhs.dims, rhs.dims, lhs.ndims)
288         && lhs.data_type == rhs.data_type
289         && lhs.format == rhs.format; /* FIXME: normalize format? */
290     if (!base_equal) return false;
291     if (lhs.format == memory_format::blocked)
292         return blocking_desc_is_equal(lhs.layout_desc.blocking,
293                 rhs.layout_desc.blocking, lhs.ndims);
294     else if (lhs.format == memory_format::wino_fmt)
295         return wino_desc_is_equal(lhs.layout_desc.wino_desc,
296             rhs.layout_desc.wino_desc);
297     else if (lhs.format == memory_format::rnn_packed)
298         return rnn_packed_desc_is_equal(lhs.layout_desc.rnn_packed_desc,
299                 rhs.layout_desc.rnn_packed_desc);
300     return true;
301 }
302
303 inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) {
304     return !operator==(lhs, rhs);
305 }
306
307 inline memory_desc_t zero_md() {
308     auto zero = memory_desc_t();
309     zero.primitive_kind = primitive_kind::memory;
310     return zero;
311 }
312
313 inline bool is_zero_md(const memory_desc_t *md) {
314     return md == nullptr || *md == zero_md();
315 }
316
317 inline status_t set_default_format(memory_desc_t &md, memory_format_t fmt) {
318     return mkldnn_memory_desc_init(&md, md.ndims, md.dims, md.data_type, fmt);
319 }
320
321 inline data_type_t default_accum_data_type(data_type_t src_dt,
322         data_type_t dst_dt) {
323     using namespace utils;
324     using namespace data_type;
325
326     if (one_of(f32, src_dt, dst_dt)) return f32;
327     if (one_of(s32, src_dt, dst_dt)) return s32;
328     if (one_of(s16, src_dt, dst_dt)) return s32;
329     if (one_of(bin, src_dt, dst_dt)) return s32;
330
331     if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32;
332
333     assert(!"unimplemented use-case: no default parameters available");
334     return dst_dt;
335 }
336
337 inline data_type_t default_accum_data_type(data_type_t src_dt,
338         data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) {
339     using namespace utils;
340     using namespace data_type;
341     using namespace prop_kind;
342
343     /* prop_kind doesn't matter */
344     if (everyone_is(f32, src_dt, wei_dt, dst_dt)) return f32;
345
346     if (one_of(prop_kind, forward_training, forward_inference)) {
347         if (src_dt == s16 && wei_dt == s16 && dst_dt == s32)
348             return s32;
349         if ((src_dt == u8 || src_dt == s8)
350             && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8))
351             return s32;
352         if (src_dt == bin && wei_dt == bin && (dst_dt == f32 || dst_dt == bin))
353             return s32;
354     } else if (prop_kind == backward_data) {
355         if (src_dt == s32 && wei_dt == s16 && dst_dt == s16)
356             return s32;
357         if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 &&
358                 one_of(dst_dt, s8, u8))
359             return s32;
360     } else if (prop_kind == backward_weights) {
361         if (src_dt == s16 && wei_dt == s32 && dst_dt == s16)
362             return s32;
363     }
364
365     assert(!"unimplemented use-case: no default parameters available");
366     return dst_dt;
367 }
368
369 }
370 }
371 }
372
373 #include "memory_desc_wrapper.hpp"
374
375 #endif
376
377 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s