case s16: return sizeof(prec_traits<s16>::type);
case s8: return sizeof(prec_traits<s8>::type);
case u8: return sizeof(prec_traits<u8>::type);
+ case bin: return sizeof(prec_traits<u8>::type);
case data_type::undef:
default: assert(!"unknown data_type");
}
nc,
ncw,
nwc,
+ nCw4c,
nCw8c,
nCw16c,
nchw,
nhwc,
chwn,
+ nChw4c,
nChw8c,
nChw16c,
ncdhw,
ndhwc,
+ nCdhw4c,
nCdhw8c,
nCdhw16c,
oi,
io,
oiw,
wio,
+ Owi4o,
+ OIw4i4o,
Owi8o,
OIw8i8o,
OIw8o8i,
OIw16i16o,
OIw16o16i,
+ Oiw4o,
Oiw16o,
Owi16o,
OIw8i16o2i,
oihw,
ihwo,
hwio,
+ iohw,
hwio_s8s8,
dhwio,
oidhw,
+ OIdhw4i4o,
+ Odhwi4o,
OIdhw8i8o,
OIdhw8o8i,
Odhwi8o,
OIdhw16i16o,
OIdhw16o16i,
+ Oidhw4o,
Oidhw16o,
Odhwi16o,
oIhw8i,
oIhw16i,
oIdhw8i,
oIdhw16i,
+ OIhw4i4o,
OIhw8i8o,
OIhw16i16o,
OIhw4i16o4i,
OIhw8o16i2o,
OIhw8o8i,
OhIw8o4i,
+ OhIw8o32i,
+ OhIw16o32i,
OhIw8o4i_s8s8,
OIhw16o16i,
IOhw16o16i,
+ Oihw4o,
Oihw16o,
Ohwi8o,
+ Ohwi4o,
Ohwi16o,
goiw,
+ gOwi4o,
+ gOIw4i4o,
gOwi8o,
gOIw8i8o,
gOIw8o8i,
gOIw16i16o,
gOIw16o16i,
+ gOiw4o,
gOiw16o,
gOwi16o,
gOIw8i16o2i,
gIOw16o16i,
goihw,
hwigo,
+ giohw,
hwigo_s8s8,
+ gOIhw4i4o,
gOIhw8i8o,
gOIhw16i16o,
gOIhw4i16o4i,
gOIhw4i16o4i_s8s8,
+ gOIhw2i8o4i,
+ gOIhw2i8o4i_s8s8,
gOIhw8i16o2i,
gOIdhw8i16o2i,
gOIhw8o16i2o,
+ gOIhw4o4i,
+ gOIhw4o4i_s8s8,
gOIhw8o8i,
gOhIw8o4i,
gOhIw8o4i_s8s8,
gOIhw16o16i,
gIOhw16o16i,
+ gOihw4o,
gOihw16o,
gOhwi8o,
+ gOhwi4o,
gOhwi16o,
Goihw8g,
Goihw16g,
+ Goihw16g_s8s8,
goidhw,
+ gOIdhw4i4o,
+ gOdhwi4o,
gOIdhw8i8o,
gOIdhw8o8i,
gOdhwi8o,
gOIdhw16i16o,
gOIdhw16o16i,
gOidhw16o,
+ gOidhw4o,
gOdhwi16o,
ntc,
tnc,
inline bool is_format_double_blocked(memory_format_t fmt) {
using namespace memory_format;
return utils::one_of(OIw8o16i2o, OIw8i16o2i, OIhw8i16o2i, OIdhw8i16o2i,
- OIhw8o16i2o, OIhw4i16o4i, OIhw4i16o4i_s8s8, gOIw8o16i2o, gOIw8i16o2i,
- gOIhw8i16o2i, gOIdhw8i16o2i, gOIhw8o16i2o, gOIhw4i16o4i,
- gOIhw4i16o4i_s8s8);
+ OIhw8o16i2o, OIhw4i16o4i, OIhw4i16o4i_s8s8,
+ gOIw8o16i2o, gOIw8i16o2i, gOIhw8i16o2i, gOIdhw8i16o2i, gOIhw8o16i2o,
+ gOIhw4i16o4i, gOIhw4i16o4i_s8s8, gOIhw2i8o4i, gOIhw2i8o4i_s8s8);
}
inline bool blocking_desc_is_equal(const blocking_desc_t &lhs,
&& lhs.r == rhs.r;
}
+inline bool rnn_packed_desc_is_equal(
+ const rnn_packed_data_t &lhs, const rnn_packed_data_t &rhs) {
+ bool ok = lhs.format == rhs.format && lhs.n_parts == rhs.n_parts
+ && lhs.offset_compensation == rhs.offset_compensation
+ && lhs.size == rhs.size
+ && lhs.n == rhs.n;
+ if (!ok)
+ return false;
+
+ for (int i = 0; i < rhs.n_parts; i++)
+ ok = ok && lhs.parts[i] == rhs.parts[i];
+ for (int i = 0; i < rhs.n_parts; i++)
+ ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i];
+ return ok;
+}
+
inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) {
assert(lhs.primitive_kind == mkldnn::impl::primitive_kind::memory);
assert(rhs.primitive_kind == mkldnn::impl::primitive_kind::memory);
else if (lhs.format == memory_format::wino_fmt)
return wino_desc_is_equal(lhs.layout_desc.wino_desc,
rhs.layout_desc.wino_desc);
+ else if (lhs.format == memory_format::rnn_packed)
+ return rnn_packed_desc_is_equal(lhs.layout_desc.rnn_packed_desc,
+ rhs.layout_desc.rnn_packed_desc);
return true;
}
if (one_of(f32, src_dt, dst_dt)) return f32;
if (one_of(s32, src_dt, dst_dt)) return s32;
if (one_of(s16, src_dt, dst_dt)) return s32;
+ if (one_of(bin, src_dt, dst_dt)) return s32;
if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32;
if ((src_dt == u8 || src_dt == s8)
&& wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8))
return s32;
+ if (src_dt == bin && wei_dt == bin && (dst_dt == f32 || dst_dt == bin))
+ return s32;
} else if (prop_kind == backward_data) {
if (src_dt == s32 && wei_dt == s16 && dst_dt == s16)
return s32;
- if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 && dst_dt == u8)
+ if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 &&
+ one_of(dst_dt, s8, u8))
return s32;
} else if (prop_kind == backward_weights) {
if (src_dt == s16 && wei_dt == s32 && dst_dt == s16)