* limitations under the License.
*******************************************************************************/
-#include "mkldnn_types.h"
#include "c_types_map.hpp"
#include "mkldnn_thread.hpp"
#include "type_helpers.hpp"
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
using namespace nstl;
using jit_conv_ker_t = void (*)(jit_conv_call_s *);
#define wht_blk_off(d, g, ...) \
- (conf_.with_groups() \
+ (pd()->with_groups() \
? (d).blk_off((g), __VA_ARGS__) \
: (d).blk_off(__VA_ARGS__))
-template <bool with_relu, data_type_t src_type, data_type_t dst_type>
-void _jit_avx512_core_x8s8s32x_convolution_fwd_t<with_relu, src_type, dst_type>::
-execute_forward()
+template <data_type_t src_type, data_type_t dst_type>
+void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
+execute_forward() const
{
auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const char *>(this->input_memory(2));
auto dst = reinterpret_cast<dst_data_t *>(this->memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper bias_d(pd()->weights_pd(1));
- const size_t bia_dt_size = conf_.with_bias()
- ? types::data_type_size(conf_.cdesc()->bias_desc.data_type) : 0;
+ const size_t bia_dt_size = pd()->with_bias()
+ ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
- const auto &jcp = kernel_->jcp;
+ const auto &jcp = pd()->jcp_;
assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
+ assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
+
+ const float *oscales = pd()->attr()->output_scales_.scales_;
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ auto local_scales = scratchpad().template get<float>(
+ key_conv_adjusted_scales);
+ size_t count = pd()->attr()->output_scales_.count_;
+ float factor = 1.f / pd()->jcp_.wei_adj_scale;
+ if (count == 1) {
+ utils::array_set(local_scales, oscales[0] * factor, 16);
+ } else {
+ for (size_t c = 0; c < count; c++)
+ local_scales[c] = oscales[c] * factor;
+ }
+ oscales = local_scales;
+ }
- size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
+ size_t offset = weights_d.size() - weights_d.additional_buffer_size();
auto w = const_cast<wei_data_t *>(weights);
int32_t* compensation = (jcp.signed_input)
? reinterpret_cast<int32_t *>(&w[offset]) : 0;
- const auto &oscales = conf_.attr()->output_scales_;
int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
- int nb_groups = jcp.nb_ch;
+ int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
int group_block = jcp.ch_block;
int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
- int n{ 0 }, gb{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 };
+ int n{ 0 }, gg{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 };
if (jcp.loop_order == loop_cwgn)
- nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gb,
+ nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
nb_groups, n, jcp.mb, oh_s, jcp.oh);
else if (jcp.loop_order == loop_gncw)
- nd_iterator_init(start, gb, nb_groups, n, jcp.mb, occ, oc_chunks,
+ nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks,
owb, jcp.nb_ow, oh_s, jcp.oh);
else if (jcp.loop_order == loop_ngcw)
- nd_iterator_init(start, n, jcp.mb, gb, nb_groups, occ, oc_chunks,
+ nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks,
owb, jcp.nb_ow, oh_s, jcp.oh);
+ else if (jcp.loop_order == loop_nhwcg)
+ nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
+ occ, oc_chunks, gg, nb_groups);
else
assert(!"unsupported loop order");
while (start < end) {
int ocb = occ * jcp.nb_oc_blocking;
+ int gb = gg * jcp.nb_ch_blocking;
int g = gb * group_block;
int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
int work_rem = end - start;
int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
+ if (jcp.loop_order == loop_nhwcg) oh_e = oh_s + 1; // step instead
int ow_s = owb * jcp.ow_block;
int iw_s = ow_s * jcp.stride_w;
auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s);
auto wht_w = weights + wht_blk_off(weights_d, gb, ocb, 0);
- auto scales = (jcp.signed_input && jcp.ver != ver_vnni)
- ? &local_scales_[jcp.is_oc_scale * g_oc]
- : &oscales.scales_[jcp.is_oc_scale * g_oc];
+ auto scales = &oscales[jcp.is_oc_scale * g_oc];
for (int oj = oh_s, ij = ih_s; oj < oh_e;
++oj, ij += jcp.stride_h) {
p.b_overflow = i_b_overflow;
p.owb = owb;
+ p.oc_off = g_oc * sizeof(float);
+
kernel_->jit_ker(&p);
src_w += src_h_stride * jcp.stride_h;
dst_w += dst_h_stride;
}
if (jcp.loop_order == loop_cwgn)
- nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gb,
+ nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg,
nb_groups, n, jcp.mb, oh_s, jcp.oh);
else if (jcp.loop_order == loop_gncw)
- nd_iterator_jump(start, end, gb, nb_groups, n, jcp.mb, occ,
+ nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ,
oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
else if (jcp.loop_order == loop_ngcw)
- nd_iterator_jump(start, end, n, jcp.mb, gb, nb_groups, occ,
+ nd_iterator_jump(start, end, n, jcp.mb, gg, nb_groups, occ,
oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
+ else if (jcp.loop_order == loop_nhwcg) {
+ ++start;
+ nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ,
+ oc_chunks, gg, nb_groups);
+ }
else
assert(!"unsupported loop order");
}
});
}
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
data_type::s8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
- data_type::s8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
- data_type::u8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
data_type::u8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
- data_type::s8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
data_type::s8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
data_type::u8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
- data_type::u8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
- data_type::s8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
data_type::s8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
- data_type::u8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
data_type::u8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
data_type::s8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
- data_type::s8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
- data_type::u8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
data_type::u8, data_type::f32>;
}
}