1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef TYPE_HELPERS_HPP
18 #define TYPE_HELPERS_HPP
25 #include "c_types_map.hpp"
26 #include "mkldnn_traits.hpp"
29 #include "math_utils.hpp"
35 status_t safe_ptr_assign(T * &lhs, T* rhs) {
36 if (rhs == nullptr) return status::out_of_memory;
38 return status::success;
41 template <typename T, typename U> struct is_subset
42 { static constexpr bool value = false; };
43 template <typename T> struct is_subset<T, T>
44 { static constexpr bool value = true; };
45 template <typename T> struct is_subset<T,
46 typename utils::enable_if<nstl::is_integral<T>::value, float>::type>
47 { static constexpr bool value = true; };
48 #define ISSPEC(t1, t2) template <> \
49 struct is_subset<t1, t2> { static constexpr bool value = true; }
50 ISSPEC(int16_t, int32_t);
51 ISSPEC(int8_t, int32_t);
52 ISSPEC(uint8_t, int32_t);
53 ISSPEC(int8_t, int16_t);
54 ISSPEC(uint8_t, int16_t);
59 inline size_t data_type_size(data_type_t data_type) {
60 using namespace data_type;
62 case f32: return sizeof(prec_traits<f32>::type);
63 case s32: return sizeof(prec_traits<s32>::type);
64 case s16: return sizeof(prec_traits<s16>::type);
65 case s8: return sizeof(prec_traits<s8>::type);
66 case u8: return sizeof(prec_traits<u8>::type);
67 case bin: return sizeof(prec_traits<u8>::type);
68 case data_type::undef:
69 default: assert(!"unknown data_type");
71 return 0; /* not supposed to be reachable */
74 inline memory_format_t flat_memory_format(int ndims) {
76 case 1: return memory_format::x;
77 case 2: return memory_format::nc;
78 case 3: return memory_format::ncw;
79 case 4: return memory_format::nchw;
80 case 5: return memory_format::ncdhw;
81 default: return memory_format::undef;
83 return memory_format::undef;
86 inline memory_format_t format_normalize(const memory_format_t fmt) {
87 using namespace memory_format;
88 /* FIXME: double blocked formats are special cases -- the blocking
89 * structure doesn't correctly describe memory layout (wrt
90 * the strides within blocks). Though as long as the code
91 * uses memory_desc_wrapper::off() or explicit offset
92 * calculations everything should be fine. */
93 const bool is_blocked = utils::one_of(fmt, blocked,
230 return is_blocked ? blocked : fmt;
233 inline bool is_format_double_blocked(memory_format_t fmt) {
234 using namespace memory_format;
235 return utils::one_of(OIw8o16i2o, OIw8i16o2i, OIhw8i16o2i, OIdhw8i16o2i,
236 OIhw8o16i2o, OIhw4i16o4i, OIhw4i16o4i_s8s8,
237 gOIw8o16i2o, gOIw8i16o2i, gOIhw8i16o2i, gOIdhw8i16o2i, gOIhw8o16i2o,
238 gOIhw4i16o4i, gOIhw4i16o4i_s8s8, gOIhw2i8o4i, gOIhw2i8o4i_s8s8);
241 inline bool blocking_desc_is_equal(const blocking_desc_t &lhs,
242 const blocking_desc_t &rhs, int ndims = TENSOR_MAX_DIMS) {
243 using mkldnn::impl::utils::array_cmp;
244 return lhs.offset_padding == rhs.offset_padding
245 && array_cmp(lhs.block_dims, rhs.block_dims, ndims)
246 && array_cmp(lhs.strides[0], rhs.strides[0], ndims)
247 && array_cmp(lhs.strides[1], rhs.strides[1], ndims)
248 && array_cmp(lhs.padding_dims, rhs.padding_dims, ndims)
249 && array_cmp(lhs.offset_padding_to_data, rhs.offset_padding_to_data,
253 inline bool wino_desc_is_equal(const wino_data_t &lhs,
254 const wino_data_t &rhs) {
255 return lhs.wino_format == rhs.wino_format
256 && lhs.alpha == rhs.alpha
259 && lhs.ic_block == rhs.ic_block
260 && lhs.oc_block == rhs.oc_block
261 && lhs.ic2_block == rhs.ic2_block
262 && lhs.oc2_block == rhs.oc2_block
266 inline bool rnn_packed_desc_is_equal(
267 const rnn_packed_data_t &lhs, const rnn_packed_data_t &rhs) {
268 bool ok = lhs.format == rhs.format && lhs.n_parts == rhs.n_parts
269 && lhs.offset_compensation == rhs.offset_compensation
270 && lhs.size == rhs.size
275 for (int i = 0; i < rhs.n_parts; i++)
276 ok = ok && lhs.parts[i] == rhs.parts[i];
277 for (int i = 0; i < rhs.n_parts; i++)
278 ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i];
282 inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) {
283 assert(lhs.primitive_kind == mkldnn::impl::primitive_kind::memory);
284 assert(rhs.primitive_kind == mkldnn::impl::primitive_kind::memory);
285 bool base_equal = true
286 && lhs.ndims == rhs.ndims
287 && mkldnn::impl::utils::array_cmp(lhs.dims, rhs.dims, lhs.ndims)
288 && lhs.data_type == rhs.data_type
289 && lhs.format == rhs.format; /* FIXME: normalize format? */
290 if (!base_equal) return false;
291 if (lhs.format == memory_format::blocked)
292 return blocking_desc_is_equal(lhs.layout_desc.blocking,
293 rhs.layout_desc.blocking, lhs.ndims);
294 else if (lhs.format == memory_format::wino_fmt)
295 return wino_desc_is_equal(lhs.layout_desc.wino_desc,
296 rhs.layout_desc.wino_desc);
297 else if (lhs.format == memory_format::rnn_packed)
298 return rnn_packed_desc_is_equal(lhs.layout_desc.rnn_packed_desc,
299 rhs.layout_desc.rnn_packed_desc);
303 inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) {
304 return !operator==(lhs, rhs);
307 inline memory_desc_t zero_md() {
308 auto zero = memory_desc_t();
309 zero.primitive_kind = primitive_kind::memory;
313 inline bool is_zero_md(const memory_desc_t *md) {
314 return md == nullptr || *md == zero_md();
317 inline status_t set_default_format(memory_desc_t &md, memory_format_t fmt) {
318 return mkldnn_memory_desc_init(&md, md.ndims, md.dims, md.data_type, fmt);
321 inline data_type_t default_accum_data_type(data_type_t src_dt,
322 data_type_t dst_dt) {
323 using namespace utils;
324 using namespace data_type;
326 if (one_of(f32, src_dt, dst_dt)) return f32;
327 if (one_of(s32, src_dt, dst_dt)) return s32;
328 if (one_of(s16, src_dt, dst_dt)) return s32;
329 if (one_of(bin, src_dt, dst_dt)) return s32;
331 if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32;
333 assert(!"unimplemented use-case: no default parameters available");
337 inline data_type_t default_accum_data_type(data_type_t src_dt,
338 data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) {
339 using namespace utils;
340 using namespace data_type;
341 using namespace prop_kind;
343 /* prop_kind doesn't matter */
344 if (everyone_is(f32, src_dt, wei_dt, dst_dt)) return f32;
346 if (one_of(prop_kind, forward_training, forward_inference)) {
347 if (src_dt == s16 && wei_dt == s16 && dst_dt == s32)
349 if ((src_dt == u8 || src_dt == s8)
350 && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8))
352 if (src_dt == bin && wei_dt == bin && (dst_dt == f32 || dst_dt == bin))
354 } else if (prop_kind == backward_data) {
355 if (src_dt == s32 && wei_dt == s16 && dst_dt == s16)
357 if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 &&
358 one_of(dst_dt, s8, u8))
360 } else if (prop_kind == backward_weights) {
361 if (src_dt == s16 && wei_dt == s32 && dst_dt == s16)
365 assert(!"unimplemented use-case: no default parameters available");
373 #include "memory_desc_wrapper.hpp"
377 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s