memory_format_t format() const { return _md->format; }
bool is_blocking_desc() const {
return (format() != memory_format::wino_fmt
+ && format() != memory_format::rnn_packed
&& format() != memory_format::any
&& format() != memory_format::undef);
}
bool is_wino_desc() const {
return (format() == memory_format::wino_fmt);
}
+ bool is_rnn_packed_desc() const {
+ return (format() == memory_format::rnn_packed);
+ }
const blocking_desc_t &blocking_desc() const {
assert(is_blocking_desc());
return _md->layout_desc.blocking;
assert(is_wino_desc());
return _md->layout_desc.wino_desc;
}
+ const rnn_packed_data_t &rnn_packed_desc() const {
+ assert(is_rnn_packed_desc());
+ return _md->layout_desc.rnn_packed_desc;
+ }
/* some useful function */
* is true, and the number of data elements otherwise */
size_t nelems(bool with_padding = false) const {
if (is_zero()) return 0;
- return (utils::array_product<int, size_t>(with_padding
+ return (utils::array_product<ptrdiff_t, size_t>(with_padding
? blocking_desc().padding_dims : dims(), ndims()));
}
size_t additional_buffer_data_size() const {
using namespace mkldnn::impl::memory_format;
return (utils::one_of(format(), hwio_s8s8, hwigo_s8s8,
- gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8, OhIw8o4i_s8s8, gOhIw8o4i_s8s8))
+ gOIhw4o4i_s8s8,
+ gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8,
+ gOIhw2i8o4i_s8s8,
+ gOhIw8o4i_s8s8, OhIw8o4i_s8s8,
+ Goihw16g_s8s8))
? sizeof(int32_t) : 0;
}
bool is_additional_buffer() const {
using namespace mkldnn::impl::memory_format;
return (utils::one_of(format(), hwio_s8s8, hwigo_s8s8,
- gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8, OhIw8o4i_s8s8, gOhIw8o4i_s8s8))
+ gOIhw4o4i_s8s8,
+ gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8,
+ gOIhw2i8o4i_s8s8,
+ gOhIw8o4i_s8s8, OhIw8o4i_s8s8,
+ Goihw16g_s8s8))
? true : false;
}
const auto &padding_dims = blocking_desc().padding_dims;
switch(format()) {
case hwigo_s8s8:
+ case gOIhw4o4i_s8s8:
+ case gOIhw2i8o4i_s8s8:
case gOIhw4i16o4i_s8s8:
case gOhIw8o4i_s8s8:
return size_t(padding_dims[0]) * size_t(padding_dims[1])
* additional_buffer_data_size();
+ case Goihw16g_s8s8:
case hwio_s8s8:
case OIhw4i16o4i_s8s8:
case OhIw8o4i_s8s8:
assert((false
|| types::format_normalize(format()) == blocked
|| types::is_format_double_blocked(format())
- || format() == wino_fmt)
+ || format() == wino_fmt
+ || format() == rnn_packed)
&& "unknown format");
if (format() == wino_fmt) {
return wino_desc().size;
+ } else if (format() == rnn_packed) {
+ return rnn_packed_desc().size;
} else {
if (blocking_desc().offset_padding != 0) return 0;
max_size = nstl::max(max_size,
size_t(block * strides[1][d]));
}
- return max_size * data_type_size() + additional_buffer_size();;
+
+ return max_size * data_type_size() + additional_buffer_size();
}
}
const int ic_4 = pos[with_groups + 1] % 4;
phys_offset += 4 * oc_16 + ic_4 - (oc_16 + 16 * ic_4);
}
+ if (utils::one_of(format(), gOIhw2i8o4i, gOIhw2i8o4i_s8s8)) {
+ // TODO: Fix temporary workaround for formats with double blocking
+ const bool with_groups = true;
+ const int oc_8 = pos[with_groups + 0] % 8;
+ const int ic_4 = pos[with_groups + 1] % 4;
+ phys_offset += 4 * oc_8 + ic_4 - (oc_8 + 8 * ic_4);
+ }
if (format() == gOIw8i16o2i || format() == OIw8i16o2i) {
// TODO: Fix temporary workaround for formats with double blocking
const bool with_groups = format() == gOIw8i16o2i;
&& utils::array_cmp(dims(), rhs.dims(), ndims())
&& data_type() == rhs.data_type()
&& ((is_blocking_desc() && rhs.is_blocking_desc())
- || (is_wino_desc() && rhs.is_wino_desc()))
+ || (is_wino_desc() && rhs.is_wino_desc())
+ || (is_rnn_packed_desc() && rhs.is_rnn_packed_desc()))
&& (is_blocking_desc() ? blocking_desc_is_equal(blocking_desc(),
rhs.blocking_desc(), ndims()) :
true)
&& (is_wino_desc() ? wino_desc_is_equal(
wino_desc(), rhs.wino_desc()) :
- true);
+ true)
+ && (is_rnn_packed_desc() ?
+ rnn_packed_desc_is_equal(rnn_packed_desc(),
+ rhs.rnn_packed_desc()) :
+ true);
}
inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
using namespace utils;
if (utils::one_of(format(), memory_format::undef, memory_format::any))
return false;
- if (is_wino_desc() || rhs.is_wino_desc())
+ if (is_wino_desc() || rhs.is_wino_desc() || is_rnn_packed_desc()
+ || rhs.is_rnn_packed_desc())
return false;
const int ds = dim_start;