#include "nstl.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
+
+#include "cpu_barrier.hpp"
#include "cpu_memory.hpp"
#include "jit_avx512_common_conv_kernel.hpp"
namespace cpu {
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
using namespace Xbyak;
inline bool is_1stconv(const jit_conv_conf_t &jcp) {
if (mayiuse(avx512_core) && !mayiuse(avx512_core_vnni))
- return jcp.ic < 16;
+ return (jcp.ic < 16 && jcp.ngroups == 1);
else
return one_of(jcp.ic, 1, 3);
}
-inline bool is_1D_conv(const jit_conv_conf_t &jcp) {
- return (jcp.ih == 1 && jcp.kh == 1);
-}
-inline bool is_ow_threading_available(const jit_conv_conf_t &jcp) {
- return (is_1D_conv(jcp) && one_of(jcp.ndims, 3, 4)
- && !(jcp.ver == ver_fma && mayiuse(avx512_mic)));
-}
+
inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
return (jcp.nb_ow > 1);
}
-inline bool is_1D_prefetching(const jit_conv_conf_t &jcp) {
- return (jcp.ver == ver_4fma && is_1D_conv(jcp) && is_ow_threading_on(jcp));
+
+inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) {
+ return (jcp.ver == ver_4fma && is_ow_threading_on(jcp));
}
+
}
-void jit_avx512_common_conv_fwd_kernel::prepare_output(int ur_w)
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::prepare_output(int ur_w)
{
for (int k = 0; k < jcp.nb_oc_blocking; k++)
for (int j = 0; j < ur_w; j++) {
- Zmm zmm = zmm_out(j, k);
- vpxord(zmm, zmm, zmm);
- if (!is_1D_prefetching(jcp)) {
+ Vmm vmm = vmm_out(j, k);
+ vpxord(vmm, vmm, vmm);
+ if (!is_owb_prefetching(jcp)) {
size_t aux_output_offset = get_output_offset(j, k);
mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf,
aux_output_offset, reg_out_long_offt));
}
}
-void jit_avx512_common_conv_fwd_kernel::store_output(int ur_w)
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::store_output(int ur_w)
{
Label no_update_label, store_label, postproc_label;
for (int k = 0; k < jcp.nb_oc_blocking; k++)
for (int j = 0; j < ur_w; j++) {
- Zmm zmm = zmm_out(j, k);
+ Vmm vmm = vmm_out(j, k);
size_t aux_output_offset = get_output_offset(j, k);
- vadd(zmm,
+ vadd(vmm,
make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
}
for (int k = 0; k < jcp.nb_oc_blocking; k++) {
int bias_offset = jcp.typesize_out * k * jcp.oc_block;
for (int j = 0; j < ur_w; j++) {
- Zmm zmm = zmm_out(j, k);
- vadd(zmm, EVEX_compress_addr(reg_bias, bias_offset));
+ Vmm vmm = vmm_out(j, k);
+ vadd(vmm, EVEX_compress_addr(reg_bias, bias_offset));
}
mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64));
}
int depthwise_inj_idx = 0;
const auto &p = attr_.post_ops_;
- if (p.len_ == 0 && eltwise_injectors.size() == 1) {
- for (int k = 0; k < jcp.nb_oc_blocking; k++)
- eltwise_injectors[0]->compute_vector_range(
- k*jcp.ur_w, k*jcp.ur_w + ur_w);
- }
-
for (int i = 0; i < p.len_; i++) {
auto& post_op = p.entry_[i];
if (post_op.is_eltwise()) {
- for (int k = 0; k < jcp.nb_oc_blocking; k++)
- eltwise_injectors[eltwise_inj_idx]->compute_vector_range(
- k*jcp.ur_w, k*jcp.ur_w + ur_w);
+ if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) {
+ Vmm vmm_zero = vmm_wei;
+ vpxord(vmm_zero, vmm_zero, vmm_zero);
+
+ for (int k = 0; k < jcp.nb_oc_blocking; k++)
+ for (int j = 0; j < ur_w; j++) {
+ Vmm vmm = vmm_out(j, k);
+ vpcmpd(k1, vmm, vmm_zero, _cmp_lt_os);
+ vpmulld(vmm | k1, vmm, vmm_zero);
+ }
+ } else {
+ if (ur_w == jcp.ur_w) {
+ eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0,
+ jcp.nb_oc_blocking * jcp.ur_w);
+ } else {
+ for (int k = 0; k < jcp.nb_oc_blocking; k++)
+ eltwise_injectors[eltwise_inj_idx]->compute_vector_range(k * jcp.ur_w,
+ k * jcp.ur_w + ur_w);
+ }
+ }
eltwise_inj_idx++;
} else if (post_op.is_depthwise()) {
L(store_label);
for (int k = 0; k < jcp.nb_oc_blocking; k++)
for (int j = 0; j < ur_w; j++) {
- Zmm zmm = zmm_out(j, k);
+ Vmm vmm = vmm_out(j, k);
size_t aux_output_offset = (size_t)typesize *
((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset,
- reg_out_long_offt), zmm);
- if (!is_1D_prefetching(jcp))
+ reg_out_long_offt), vmm);
+ if (!is_owb_prefetching(jcp))
mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf,
aux_output_offset, reg_out_long_offt));
}
}
-void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w,
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma_1st(int ur_w,
+ int pad_l, int pad_r)
+{
+}
+
+template<>
+void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma_1st(int ur_w,
int pad_l, int pad_r)
{
assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0);
int ic_block = jcp.ic_block;
int oc_block = jcp.oc_block;
- Label kh_label, kd_label, skip_kd_loop;
-
- prepare_output(ur_w);
+ Label kh_label, kd_label;
if (one_of(jcp.ndims, 3, 4)) {
mov(aux_reg_inp, reg_inp);
mov(aux_reg_inp_d, reg_inp);
mov(aux_reg_inp_d_prf, reg_inp_prf);
- if ((jcp.kd - 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
- cmp(reg_ki, 0);
- je(skip_kd_loop, T_NEAR);
- }
L(kd_label);
}
mov(reg_kj, reg_kh);
- Label skip_kh_loop;
- if ((jcp.kh - 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
- cmp(reg_kj, 0);
- je(skip_kh_loop, T_NEAR);
- }
if (jcp.ndims == 5) {
mov(aux_reg_inp, aux_reg_inp_d);
mov(aux_reg_ker, aux_reg_ker_d);
* ((ki + i) * oc_block
+ ic * kw * jcp.kh * jcp.kd * oc_block);
if (ki + i < kw)
- vmovups(zmm_ker(i),
+ vmovups(vmm_ker(i),
EVEX_compress_addr(aux_reg_ker, aux_ker_offset));
else
- vpxord(zmm_ker(i), zmm_ker(i), zmm_ker(i));
+ vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i));
}
int j_start = get_ow_start(ki, pad_l);
size_t aux_input_offset = (size_t)jcp.typesize_in
* ((size_t)(ki + j * stride_w
- pad_l) + (size_t)ic * iw * ih * jcp.id);
- v4fmaddps(zmm_out(j, 0), zmm_ker(0),
+ v4fmaddps(vmm_out(j, 0), vmm_ker(0),
EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset,
reg_long_offt));
if (ki + prf_count < kw && prf_count < 4
cmp(reg_kj, 0);
jg(kh_label, T_NEAR);
- L(skip_kh_loop);
-
if (jcp.ndims == 5) {
add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw);
add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block);
dec(reg_ki);
cmp(reg_ki, 0);
jg(kd_label, T_NEAR);
- L(skip_kd_loop);
pop(reg_out);
pop(reg_out_prf);
}
- store_output(ur_w);
if (max_input_offset > INT_MAX) pop(reg_inp_prf);
}
-void jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w,
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma(int ur_w,
+ int pad_l, int pad_r)
+{
+}
+
+template<>
+void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma(int ur_w,
int pad_l, int pad_r)
{
int stride_w = jcp.stride_w;
int ic_block = jcp.ic_block;
int oc_block = jcp.oc_block;
- Label kh_label, last_iter_label, loop_end_label, kd_label, skip_kd_loop;
+ Label kh_label, last_iter_label, loop_end_label, kd_label;
int ker_load_number = 4;
int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block;
int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
auto kernel_loads = [=](int ki, int ic, int kk) {
for (int ii = 0; ii < ker_load_number; ii++) {
int aux_kernel_offset = kernel_offset(kk, ic + ii, ki);
- vmovups(zmm_ker(ii),
+ vmovups(vmm_ker(ii),
EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
}
};
}
};
- prepare_output(ur_w);
-
if (one_of(jcp.ndims, 3, 4)) {
mov(aux_reg_inp, reg_inp);
mov(aux_reg_ker, reg_ker);
mov(aux_reg_inp_d, reg_inp);
mov(aux_reg_inp_d_prf, reg_inp_prf);
mov(aux_reg_ker_d_prf, reg_ker_prf);
-
- if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
- cmp(reg_ki, 0);
- je(skip_kd_loop, T_NEAR);
- }
L(kd_label);
mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
} else {
mov(reg_kj, reg_kh);
}
- Label skip_kh_loop;
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
- cmp(reg_kj, 0);
- je(skip_kh_loop, T_NEAR);
- }
if (jcp.ndims == 5) {
mov(aux_reg_inp, aux_reg_inp_d);
mov(aux_reg_ker, aux_reg_ker_d);
* ((ki * (jcp.dilate_w + 1) + oi * stride_w
- pad_l) * ic_block
+ ic);
- v4fmaddps(zmm_out(oi, kk), zmm_ker(0),
+ v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
EVEX_compress_addr(aux_reg_inp, aux_input_offset));
if (oi % 2) {
* ((ki * (jcp.dilate_w + 1) + oi * stride_w
- pad_l) * ic_block
+ ic);
- v4fmaddps(zmm_out(oi, kk), zmm_ker(0),
+ v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
EVEX_compress_addr(aux_reg_inp,
aux_input_offset));
if (oi % 2) {
int aux_input_offset = typesize
* ((ki * (jcp.dilate_w + 1) + oi * stride_w
- pad_l) * ic_block + ic);
- v4fmaddps(zmm_out(oi, kk), zmm_ker(0),
+ v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
EVEX_compress_addr(aux_reg_inp,
aux_input_offset));
- if (!is_1D_prefetching(jcp)) {
+ if (!is_owb_prefetching(jcp)) {
if ((oi % 2) && (prf_count_t1 < 4)) {
mic_prefetcht1(EVEX_compress_addr(
aux_reg_ker_prf, kernel_offset(kk,
prf_count_t0++;
}
}
- if (!is_1D_prefetching(jcp)) {
+ if (!is_owb_prefetching(jcp)) {
if (pref_current_inp) {
if (ki == 0 && ic == 0 && kk == 0)
mic_prefetcht0(EVEX_compress_addr(
cmp(reg_kj, 0);
jg(kh_label, T_NEAR);
- L(skip_kh_loop);
-
if (jcp.ndims == 5) {
add(aux_reg_inp_d,
typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
dec(reg_ki);
cmp(reg_ki, 0);
jg(kd_label, T_NEAR);
- L(skip_kd_loop);
pop(reg_out);
pop(reg_out_prf);
}
-
- store_output(ur_w);
}
-void jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w,
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma(int ur_w,
int pad_l, int pad_r)
{
bool prf_ker = true;
int ic_block = jcp.ic_block;
int oc_block = jcp.oc_block;
int nb_oc_block = jcp.nb_oc_blocking;
- Label kh_label, kd_label, skip_kd_loop;
+ Label kh_label, kd_label;
int ker_pipeline_depth = 4;
assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
assert(oc_block >= ker_pipeline_depth);
int num_ker_loads = ic_block * nb_oc_block * kw;
- const int simd_w = 16;
int num_ker_prfs = prf_ker ? num_ker_loads : 0;
int num_inp_prfs = prf_inp ?
ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) :
0;
if (jcp.is_1stconv && prf_inp) {
- num_inp_prfs = div_up(num_inp_prfs, simd_w) * ic_block;
+ num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block;
}
int num_prfs = num_ker_prfs + num_inp_prfs;
int num_fmas = num_ker_loads * ur_w;
int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
int inp_mul = !jcp.is_1stconv ? ic_block : 1;
- prepare_output(ur_w);
-
if (one_of(jcp.ndims, 3, 4)) {
mov(aux_reg_inp, reg_inp);
mov(aux_reg_ker, reg_ker);
mov(aux_reg_inp_d_prf, reg_inp_prf);
mov(aux_reg_ker_d_prf, reg_ker_prf);
- if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
- cmp(reg_ki, 0);
- je(skip_kd_loop, T_NEAR);
- }
L(kd_label);
mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
} else {
mov(reg_kj, reg_kh);
}
- Label skip_kh_loop;
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
- cmp(reg_kj, 0);
- je(skip_kh_loop, T_NEAR);
- }
if (jcp.ndims == 5) {
mov(aux_reg_inp, aux_reg_inp_d);
if (step == 0) {
for (int i = 0; i < ker_pipeline_depth; i++) {
aux_kernel_offset = get_kernel_offset(ki, ic, 0, i);
- vmovups(zmm_ker(i), EVEX_compress_addr(
+ vmovups(vmm_ker(i), EVEX_compress_addr(
aux_reg_ker, aux_kernel_offset));
}
} else if (step < num_ker_loads - ker_pipeline_depth + 1) {
= (step + load_offset) % ker_pipeline_depth;
aux_kernel_offset
= get_kernel_offset(ki, ic, 0, load_offset);
- vmovups(zmm_ker(ker_load_reg_idx),
+ vmovups(vmm_ker(ker_load_reg_idx),
EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
}
bool ker_prf_inserted = false;
- Zmm zmm_kernel = zmm_ker(step % ker_pipeline_depth);
+ Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth);
int j_start = get_ow_start(ki, pad_l);
int j_end = get_ow_end(ur_w, ki, pad_r);
for (int j = j_start; j < j_end; j++) {
size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l);
auto addr = EVEX_compress_addr_safe(aux_reg_inp,
aux_input_offset, reg_long_offt, true);
- vfmadd231ps(zmm_out(j, 0), zmm_kernel, addr);
+ vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr);
int fma_idx = step * ur_w + j;
int prf_slot_idx = fma_idx / prf_inst_spacing;
if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
size_t ic_prf_stride =
(size_t)jcp.typesize_in * iw * ih * id;
size_t iw_prf_stride
- = jcp.typesize_in * simd_w;
+ = jcp.typesize_in * jcp.simd_w;
inp_prf_offset = ((inp_prf_idx / ic_block)
* iw_prf_stride
+ (inp_prf_idx % ic_block)
jg(kh_label, T_NEAR);
}
- L(skip_kh_loop);
if (jcp.ndims == 5) {
add(aux_reg_inp_d,
dec(reg_ki);
cmp(reg_ki, 0);
jg(kd_label, T_NEAR);
- L(skip_kd_loop);
pop(reg_out);
pop(reg_out_prf);
}
if (max_input_offset > INT_MAX) pop(reg_inp_prf);
- store_output(ur_w);
}
-void jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w,
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma_core(int ur_w,
int pad_l, int pad_r)
{
int kw = jcp.kw;
int ic_block = jcp.ic_block;
int oc_block = jcp.oc_block;
int nb_oc_block = jcp.nb_oc_blocking;
- Label kh_label, skip_kh_loop, kd_label, skip_kd_loop;
+ Label kh_label, kd_label;
int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block
* jcp.ic_block;
int inp_mul = !jcp.is_1stconv ? ic_block : 1;
* (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id));
};
- prepare_output(ur_w);
-
if (one_of(jcp.ndims, 3, 4)) {
mov(aux_reg_inp, reg_inp);
mov(aux_reg_ker, reg_ker);
mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
mov(aux_reg_inp_d, reg_inp);
- if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
- cmp(reg_ki, 0);
- je(skip_kd_loop, T_NEAR);
- }
L(kd_label);
mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
} else {
mov(reg_kj, reg_kh);
}
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
- cmp(reg_kj, 0);
- je(skip_kh_loop, T_NEAR);
- }
if (jcp.ndims == 5) {
mov(aux_reg_inp, aux_reg_inp_d);
if (jcp.kernel_kind == expl_bcast) {
for (int jj = jj_start; jj < jj_end; jj++) {
size_t aux_input_offset = input_offset(jj, ic, ki);
- vbroadcastss(zmm_inp(jj, nb_oc_block),
+ vbroadcastss(vmm_inp(jj, nb_oc_block),
EVEX_compress_addr_safe(aux_reg_inp,
aux_input_offset, reg_long_offt));
}
* (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block
* oc_block + ki * ic_block * oc_block + ic * oc_block);
if (jj_end - jj_start > 0)
- vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker,
+ vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker,
aux_kernel_offset));
for (int jj = jj_start; jj < jj_end; jj++)
if (jcp.kernel_kind == expl_bcast)
- vfmadd231ps(zmm_out(jj, ii),
- zmm_inp(jj, nb_oc_block), zmm_wei);
+ vfmadd231ps(vmm_out(jj, ii),
+ vmm_inp(jj, nb_oc_block), vmm_wei);
else {
size_t aux_input_offset = input_offset(jj, ic, ki);
- vfmadd231ps(zmm_out(jj, ii), zmm_wei,
+ vfmadd231ps(vmm_out(jj, ii), vmm_wei,
EVEX_compress_addr_safe(aux_reg_inp,
aux_input_offset, reg_long_offt, true));
}
cmp(reg_kj, 0);
jg(kh_label, T_NEAR);
}
- L(skip_kh_loop);
if (jcp.ndims == 5) {
add(aux_reg_inp_d,
dec(reg_ki);
cmp(reg_ki, 0);
jg(kd_label, T_NEAR);
- L(skip_kd_loop);
pop(reg_out);
}
+}
- store_output(ur_w);
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_vnni(
+ int ur_w, int pad_l, int pad_r)
+{
}
-void jit_avx512_common_conv_fwd_kernel::compute_loop_vnni(
+template<>
+void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_vnni(
int ur_w, int pad_l, int pad_r)
{
Label kh_label, kd_label;
assert(reg_inp_prf == reg_long_offt);
if (max_input_offset > INT_MAX) push(reg_inp_prf);
- prepare_output(ur_w);
if (one_of(jcp.ndims, 3, 4)) {
mov(aux_reg_inp, reg_inp);
mov(aux_reg_inp_prf, reg_inp_prf);
}
- Label skip_kh_loop, skip_kd_loop;
-
if (jcp.ndims == 5) {
push(reg_out_prf);
push(reg_out);
mov(aux_reg_inp_d_prf, reg_inp_prf);
mov(aux_reg_ker_d_prf, reg_ker_prf);
- if ((jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
- cmp(reg_ki, 0);
- je(skip_kd_loop, T_NEAR);
- }
L(kd_label);
mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
} else {
mov(reg_kj, reg_kh);
}
- if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
- cmp(reg_kj, 0);
- je(skip_kh_loop, T_NEAR);
- }
if (jcp.ndims == 5) {
mov(aux_reg_inp, aux_reg_inp_d);
mov(aux_reg_ker, aux_reg_ker_d);
if (jcp.kernel_kind == expl_bcast) {
for (int oi = ow_start; oi < ow_end; oi++) {
size_t input_offset = get_input_offset(ki, ic, oi, pad_l);
- vpbroadcastd(zmm_inp(oi, jcp.nb_oc_blocking),
+ vpbroadcastd(vmm_inp(oi, jcp.nb_oc_blocking),
EVEX_compress_addr_safe(aux_reg_inp, input_offset,
reg_long_offt));
}
for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
if (jcp.kernel_kind == expl_bcast) {
int kernel_offset = get_kernel_offset(ki, ic, kk, 0);
- vmovups(zmm_wei,
+ vmovups(vmm_wei,
EVEX_compress_addr(aux_reg_ker, kernel_offset));
} else {
for (int ii = 0; ii < ker_load_number; ii++) {
for (int oi = ow_start, prf_count = 0; oi < ow_end; oi++) {
size_t input_offset = get_input_offset(ki, ic, oi, pad_l);
if (jcp.kernel_kind == expl_bcast) {
- vpdpwssd(zmm_out(oi, kk), zmm_wei,
- zmm_inp(oi, jcp.nb_oc_blocking));
+ vpdpwssd(vmm_out(oi, kk), vmm_wei,
+ vmm_inp(oi, jcp.nb_oc_blocking));
} else {
- vpXdpwssd(zmm_out(oi, kk), Zmm(ker_reg_base_idx),
- EVEX_compress_addr_safe(aux_reg_inp, input_offset,
- reg_long_offt, jcp.ver != ver_4vnni));
+ if (jcp.ver == ver_4vnni)
+ vp4dpwssd(vmm_out(oi, kk), Zmm(ker_reg_base_idx),
+ EVEX_compress_addr_safe(aux_reg_inp,
+ input_offset, reg_long_offt, false));
+ else
+ vpdpwssd(vmm_out(oi, kk), Zmm(ker_reg_base_idx),
+ EVEX_compress_addr_safe(aux_reg_inp,
+ input_offset, reg_long_offt, true));
}
if ((oi % 2) && (prf_count < ker_load_number)) {
int kernel_offset = get_kernel_offset(
jg(kh_label, T_NEAR);
}
- L(skip_kh_loop);
-
if (jcp.ndims == 5) {
add(aux_reg_inp_d, jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block);
add(aux_reg_ker_d, jcp.typesize_in * jcp.kw * jcp.kh * jcp.oc_block
dec(reg_ki);
cmp(reg_ki, 0);
jg(kd_label, T_NEAR);
- L(skip_kd_loop);
pop(reg_out);
pop(reg_out_prf);
}
if (max_input_offset > INT_MAX) pop(reg_inp_prf);
- store_output(ur_w);
}
-void jit_avx512_common_conv_fwd_kernel::compute_loop(int ur_w,
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop(int ur_w,
int pad_l, int pad_r)
{
if (jcp.ndims == 5) push(reg_oi);
+
+ prepare_output(ur_w);
+
+ Label skip_compute_loop;
+ if (jcp.ndims == 5) {
+ if ((jcp.dilate_d >= jcp.id)
+ || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
+ mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]);
+ cmp(reg_kj, 0);
+ je(skip_compute_loop, T_NEAR);
+ }
+ }
+ if ((jcp.dilate_h >= jcp.ih)
+ || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
+ mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
+ cmp(reg_kj, 0);
+ je(skip_compute_loop, T_NEAR);
+ }
+
if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
compute_loop_vnni(ur_w, pad_l, pad_r);
else if (jcp.ver == ver_4fma)
compute_loop_fma_core(ur_w, pad_l, pad_r);
else
assert(!"unknown convolution version");
+
+ L(skip_compute_loop);
+ store_output(ur_w);
if (jcp.ndims == 5) pop(reg_oi);
}
-void jit_avx512_common_conv_fwd_kernel::generate()
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::generate()
{
- if (jcp.with_eltwise) {
- eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
- this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
- ));
- }
-
const auto &p = attr_.post_ops_;
for (int i = 0; i < p.len_; i++) {
auto &post_op = p.entry_[i];
auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
switch (p.len_) {
- case 0: return true; // no post_ops
- case 1:
- return true // sum OR eltwise OR depthwise
- && !jcp.with_eltwise && (is_simple(0) || is_sum(0));
- case 2:
- return true // sum->relu
- && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) ||
- (is_simple(0) && is_simple(1)));
- case 3:
- return true // sum->relu
- && !jcp.with_eltwise && (is_sum(0) && is_simple(1) && is_simple(2));
+ case 0: return true;
+ case 1: return is_simple(0) || is_sum(0);
+ case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
+ case 3: return is_sum(0) && is_simple(1) && is_simple(2);
default: return false;
}
}
status_t jit_avx512_common_conv_fwd_kernel::init_conf(
- jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
- cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
- cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr,
- int nthreads, bool with_relu, float relu_negative_slope)
+ jit_conv_conf_t &jcp, const convolution_desc_t &cd,
+ cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
+ cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
+ const primitive_attr_t &attr, int nthreads)
{
using namespace prop_kind;
if (!mayiuse(avx512_common))
return status::unimplemented;
- const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
-
const memory_desc_wrapper src_d(&src_pd);
const memory_desc_wrapper weights_d(&weights_pd);
const memory_desc_wrapper dst_d(&dst_pd);
const memory_desc_wrapper bias_d(&bias_pd);
- int regs = 28;
+ const int regs = 28;
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
int ndims = src_d.ndims();
jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
jcp.stride_w = cd.strides[ndims-3];
jcp.src_fmt = src_d.format();
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alg = mkldnn_eltwise_relu;
- jcp.eltwise_alpha = relu_negative_slope;
jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
jcp.is_1stconv = is_1stconv(jcp);
- jcp.oc_block = simd_w;
- jcp.ic_block = jcp.is_1stconv ? jcp.ic : simd_w;
- jcp.aligned_threads = 0;
-
bool ok_to_pad_channels = true
&& jcp.ngroups == 1
&& src_d.data_type() == data_type::f32;
+ const int full_simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
+ jcp.simd_w = full_simd_w;
+ bool ok_to_try_xmm = true
+ && mayiuse(avx512_core)
+ && src_d.data_type() == data_type::f32
+ && !jcp.is_1stconv
+ && !ok_to_pad_channels
+ && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0)
+ && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0);
+ if (ok_to_try_xmm)
+ jcp.simd_w = 4;
+
+ jcp.oc_block = jcp.simd_w;
+ jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
+ jcp.aligned_threads = 0;
+
if (ok_to_pad_channels) {
jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
const auto &p = attr.post_ops_;
jcp.with_sum = p.find(primitive_kind::sum) != -1;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise) {
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+ if (dst_d.data_type() == data_type::s32) return status::unimplemented;
+ }
auto src_format = jcp.is_1stconv
? pick(ndims - 3, ncw, nchw, ncdhw)
+ : ((jcp.simd_w == 4)
+ ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
+ : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c));
+ auto dst_format = (jcp.simd_w == 4)
+ ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
: pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
- auto dst_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
auto wei_format = with_groups
- ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
- : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o);
+ ? ((jcp.simd_w == 4)
+ ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o)
+ : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o))
+ : ((jcp.simd_w == 4)
+ ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o)
+ : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o));
if (src_d.format() == any)
CHECK(src_pd.set_format(src_format));
jcp.ver = ver_fma;
if (jcp.ver == ver_4fma) {
const auto w_format = with_groups
- ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)
- : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o);
+ ? ((jcp.simd_w == 4)
+ ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o)
+ : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o))
+ : ((jcp.simd_w == 4)
+ ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o)
+ : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o));
if (weights_d.format() == any)
CHECK(weights_pd.set_format(w_format));
if (weights_d.format() != w_format)
return status::unimplemented;
} else {
const auto w_format = with_groups
- ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
- : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
+ ? ((jcp.simd_w == 4)
+ ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o)
+ : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o))
+ : ((jcp.simd_w == 4)
+ ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o)
+ : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o));
if (weights_d.format() == any)
CHECK(weights_pd.set_format(w_format));
if (weights_d.format() != w_format)
}
}
+ /* Grouped channel offset to support 'non-blocked data' format for
+ * convolution sizes with '(input_channel / ngroups) < simd' */
+ jcp.nonblk_group_off
+ = (jcp.ngroups > 1 && one_of(src_d.format(), ncw, nchw, ncdhw)) ?
+ jcp.ic :
+ 1;
+
jcp.nb_ic = jcp.ic / jcp.ic_block;
jcp.nb_oc = jcp.oc / jcp.oc_block;
jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
+ auto is_ow_threading_applicable = [=]() {
+ return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4)
+ && IMPLICATION(mayiuse(avx512_mic),
+ jcp.ver == ver_4fma
+ && IMPLICATION(jcp.mb != 1,
+ jcp.ih == 1 && jcp.kh == 1)));
+ };
+
if (jcp.ver == ver_4vnni) {
jcp.kernel_kind = embd_bcast;
}
}
if (one_of(jcp.ver, ver_4vnni, ver_4fma) && !jcp.is_1stconv) {
- if (jcp.kw == 3 && jcp.kh == 3 && jcp.ow == 7 && jcp.oh == 7) {
- if (jcp.nb_oc % 2 == 0)
+ if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8
+ && jcp.oh <= 8 && jcp.ow == jcp.oh)
+ || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) {
+ if (jcp.nb_oc % 2 == 0) {
jcp.nb_oc_blocking = 2;
+ jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
+ }
} else {
for (int i = jcp.nb_oc; i > 0; i--)
if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) {
break;
}
}
- if (jcp.ver == ver_4fma
- && is_1D_conv(jcp) && one_of(jcp.ndims, 3, 4)) {
- if (jcp.nb_oc % 2 == 0) {
+ if (jcp.ver == ver_4fma && is_ow_threading_applicable()) {
+ if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow
+ && jcp.ow != 2 * jcp.ur_w) {
jcp.nb_oc_blocking = 2;
jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
}
}
}
+ jcp.ow_block = jcp.ow;
+
+ auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) {
+ int nb_ow = div_up(jcp.ow, ow_block);
+ int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking);
+ int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow;
+ float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block);
+ float thr_eff = disbalance * (float)work_amount
+ / rnd_up(work_amount, nthreads);
+ return thr_eff;
+ };
+
+ auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) {
+ int res_ow_block = jcp.ow;
+ eff = get_thr_eff(nb_oc_blocking, res_ow_block);
+ if (!is_ow_threading_applicable())
+ return res_ow_block;
+
+ int L2_part = (get_cache_size(2) * 7 / 8) / typesize;
+ if (jcp.ver == ver_4fma)
+ L2_part /= 2;
+ int size_src_chunk = jcp.ic_block * ur_w * jcp.kh;
+ int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w;
+ int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block
+ * jcp.kw * jcp.kh;
+ int nurw_cache = (L2_part - 2 * size_wei_chunk)
+ / (2 * size_dst_chunk + 2 * size_src_chunk);
+ // current design of generate() requires ow_block >= 2 * ur_w
+ int ow_block_cache = ur_w * nstl::max(2, nurw_cache);
+
+ int ow_block_thr = ow_block_cache;
+ eff = get_thr_eff(nb_oc_blocking, ow_block_thr);
+
+ int max_nb_ow = div_up(jcp.ow, 2 * ur_w);
+ int start_nb_ow = div_up(jcp.ow, ow_block_thr);
+ for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) {
+ int ow_block
+ = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow);
+ float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f;
+ if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold)
+ break;
+ if (div_up(jcp.ow, ow_block) != nb_ow)
+ continue;
+ float thr_eff = get_thr_eff(nb_oc_blocking, ow_block);
+ float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f;
+ if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) {
+ ow_block_thr = ow_block;
+ eff = thr_eff;
+ }
+ eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f;
+ if (eff > eff_threshold)
+ break;
+ }
+ res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr));
+ eff = get_thr_eff(nb_oc_blocking, res_ow_block);
+ return res_ow_block;
+ };
+
+
if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
int try_nb_oc_blocking = 2;
unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w)
&& !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512);
if (jcp.mb == 1) {
- jcp.kernel_kind = embd_bcast;
unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h)
* div_up(jcp.iw, jcp.stride_w) * jcp.ic;
unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw;
}
}
}
- } else if (jcp.kw > 3
- || (jcp.stride_w == 1 && jcp.stride_h == 1
- && embd_bcast_condition)
- || ((jcp.stride_w != 1 || jcp.stride_h != 1)
- && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10)
- && embd_bcast_condition)))
- ) {
+ }
+
+ if (jcp.kw > 3
+ || (jcp.stride_w == 1 && jcp.stride_h == 1
+ && embd_bcast_condition)
+ || ((jcp.stride_w != 1 || jcp.stride_h != 1)
+ && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10)
+ && embd_bcast_condition)))
+ || (jcp.mb == 1
+ && (jcp.ur_w >= jcp.ow || jcp.is_1stconv
+ || (jcp.ow <= 147 && jcp.oc <= 96)))) {
jcp.kernel_kind = embd_bcast;
jcp.ur_w = nstl::min(jcp.ow, regs);
jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3
- && jcp.kw <= 3) {
- if (jcp.nb_oc % try_nb_oc_blocking == 0 && !jcp.is_1stconv) {
- jcp.nb_oc_blocking = try_nb_oc_blocking;
- jcp.ur_w = 31 / (jcp.nb_oc_blocking + 1);
- if (jcp.ow < jcp.ur_w)
- jcp.ur_w = jcp.ow;
- }
+ && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0
+ && IMPLICATION(jcp.is_1stconv, jcp.mb == 1)
+ && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) {
+ jcp.nb_oc_blocking = try_nb_oc_blocking;
+ jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
}
} else {
jcp.kernel_kind = expl_bcast;
jcp.nb_ic_blocking = 1;
- jcp.nb_oc_blocking = 4;
- if (jcp.nb_oc < jcp.nb_oc_blocking) jcp.nb_oc_blocking = jcp.nb_oc;
- if (jcp.nb_oc % jcp.nb_oc_blocking != 0)
- for (int i = jcp.nb_oc_blocking; i > 0; i--)
+ if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) {
+ float best_thr_eff = 0.f;
+ int best_nb_oc_blocking = 1;
+ for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) {
if (jcp.nb_oc % i == 0) {
- jcp.nb_oc_blocking = i;
- break;
+ float thr_eff;
+ int ur_w = nstl::min(jcp.ow, 31 / (i + 1));
+ get_ow_block(i, ur_w, thr_eff);
+ if (thr_eff > 1.05f * best_thr_eff) {
+ best_nb_oc_blocking = i;
+ best_thr_eff = thr_eff;
+ }
}
- jcp.ur_w = 31 / (jcp.nb_oc_blocking + 1);
- if (jcp.ow < jcp.ur_w)
- jcp.ur_w = jcp.ow;
+ }
+ jcp.nb_oc_blocking = best_nb_oc_blocking;
+ jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
+ }
}
}
jcp.ur_w_tail = jcp.ow % jcp.ur_w;
- jcp.ow_block = jcp.ow;
- if (is_ow_threading_available(jcp)) {
- const int L1_part = get_cache_size(1) * 5 / 8;
- int size_src_chunk = typesize * jcp.ic_block * jcp.ur_w;
- int size_dst_chunk = typesize
- * jcp.oc_block * jcp.nb_oc_blocking * jcp.ur_w;
- int size_wei_chunk = typesize
- * jcp.oc_block * jcp.ic_block * jcp.nb_oc_blocking * jcp.kw;
- int nurw = (L1_part - size_wei_chunk)
- / (size_dst_chunk + size_src_chunk);
- // current design of generate() requires ow_block >= 2 * ur_w
- jcp.ow_block = jcp.ur_w * nstl::max(2, nurw);
- }
- jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
-
args_ok = true
&& jcp.l_pad <= jcp.ur_w
&& jcp.ic <= src_d.blocking_desc().padding_dims[1]
jcp.nb_ic_L2 = jcp.nb_ic;
+ float thr_eff;
+ jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff);
+ jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
+
const int L2_size = get_cache_size(2, true) / sizeof(float);
// Source and output data needs to fit in L2,
// leaving some space for weights and prefetching.
- int h_L2 = int(((0.6f * L2_size) / simd_w
+ int h_L2 = int(((0.6f * L2_size) / jcp.simd_w
- nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw)
/ (jcp.stride_h * jcp.iw + jcp.ow));
jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2));
break;
}
}
- } else {
+ } else if (jcp.ic > 64) {
jcp.nb_ic_L2 = 2; /* according to performance data*/
}
}
return status::success;
}
+void jit_avx512_common_conv_fwd_kernel::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
+}
+
void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w)
{
for (int k = 0; k < jcp.nb_ic_blocking; k++) {
int kw = jcp.kw;
int ic_block = jcp.ic_block;
int oc_block = jcp.oc_block;
- Label kh_label, last_iter_label, loop_end_label, kd_label, skip_kd_loop;
+ Label kh_label, last_iter_label, loop_end_label, kd_label;
int ker_load_number = 4;
int shift_ker_ptr = typesize * kw * oc_block * ic_block;
int shift_dst_ptr = typesize * ow * oc_block;
}
};
- prepare_output(ur_w);
-
if (one_of(jcp.ndims, 3, 4)) {
mov(aux_reg_dst, reg_dst);
mov(aux_reg_ker, reg_ker);
dec(reg_ki);
cmp(reg_ki, 0);
jg(kd_label, T_NEAR);
- L(skip_kd_loop);
pop(reg_src);
pop(reg_src_prf);
}
-
- store_output(ur_w);
}
void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_vnni(
return jcp.typesize_in * (blk_offset + oc_offset);
};
- prepare_output(ur_w);
-
mov(aux_reg_dst, reg_dst);
mov(aux_reg_ker, reg_ker);
mov(aux_reg_dst_prf, reg_dst_prf);
cmp(reg_kj, 0);
jg(kh_label, T_NEAR);
}
-
- store_output(ur_w);
}
void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma(
int ur_w, int l_overflow, int r_overflow)
{
- Label kh_label, kd_label, skip_kd_loop;
- Label store_output_label;
+ Label kh_label, kd_label;
int kw = jcp.kw;
int ow = jcp.ow;
int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs);
int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
- prepare_output(ur_w);
-
if (one_of(jcp.ndims, 3, 4)) {
mov(aux_reg_dst, reg_dst);
mov(aux_reg_ker, reg_ker);
push(reg_src);
mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
- cmp(reg_ki, 0);
- je(store_output_label, T_NEAR);
-
mov(aux_reg_dst_d, reg_dst);
mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
mov(aux_reg_dst_d_prf, reg_dst_prf);
} else {
mov(reg_kj, reg_kh);
}
- cmp(reg_kj, 0);
- je(store_output_label, T_NEAR);
if (jcp.ndims == 5) {
mov(aux_reg_dst, aux_reg_dst_d);
dec(reg_ki);
cmp(reg_ki, 0);
jg(kd_label, T_NEAR);
- L(skip_kd_loop);
}
- L(store_output_label); {
- if (jcp.ndims == 5)
- {
- pop(reg_src);
- pop(reg_src_prf);
- }
- store_output(ur_w);
+ if (jcp.ndims == 5)
+ {
+ pop(reg_src);
+ pop(reg_src_prf);
}
}
int ic_block = jcp.ic_block;
int oc_block = jcp.oc_block;
int nb_ic_block = jcp.nb_ic_blocking;
- Label kh_label, skip_kh_loop, kd_label, skip_kd_loop;
+ Label kh_label, kd_label;
int shift_ker_ptr = typesize * kw * oc_block * ic_block;
int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block;
return typesize * (blk_offset + oc_offset);
};
- prepare_output(ur_w);
-
if (one_of(jcp.ndims, 3, 4)) {
mov(aux_reg_dst, reg_dst);
mov(aux_reg_ker, reg_ker);
} else {
mov(reg_kj, reg_kh);
}
- cmp(reg_kj, 0);
- je(skip_kh_loop, T_NEAR);
if (jcp.ndims == 5) {
mov(aux_reg_dst, aux_reg_dst_d);
cmp(reg_kj, 0);
jg(kh_label, T_NEAR);
}
- L(skip_kh_loop);
if (jcp.ndims == 5) {
sub(aux_reg_dst_d,
dec(reg_ki);
cmp(reg_ki, 0);
jg(kd_label, T_NEAR);
- L(skip_kd_loop);
pop(reg_src);
pop(reg_src_prf);
}
-
- store_output(ur_w);
}
inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop(
int ur_w, int l_overflow, int r_overflow)
{
if (jcp.ndims == 5) push(reg_oi);
+
+ prepare_output(ur_w);
+
+ Label skip_compute_loop;
+ if (jcp.ndims == 5) {
+ mov(reg_kj, ptr[param + GET_OFF(kd_padding)]);
+ cmp(reg_kj, 0);
+ je(skip_compute_loop, T_NEAR);
+ }
+ mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
+ cmp(reg_kj, 0);
+ je(skip_compute_loop, T_NEAR);
+
if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
compute_loop_vnni(ur_w, l_overflow, r_overflow);
else if (jcp.ver == ver_4fma)
compute_loop_fma_core(ur_w, l_overflow, r_overflow);
else
assert("!unknown convolution version");
+
+ L(skip_compute_loop);
+ store_output(ur_w);
if (jcp.ndims == 5) pop(reg_oi);
}
{
if (!mayiuse(avx512_common)) return status::unimplemented;
- const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
+ jcp = zero<decltype(jcp)>();
+
+ jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
int ndims = diff_src_d.ndims();
jcp.is_1stconv = false;
- jcp.oc_block = simd_w;
- jcp.ic_block = jcp.is_1stconv ? jcp.ic : simd_w;
+ jcp.oc_block = jcp.simd_w;
+ jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
bool ok_to_pad_channels = true
&& jcp.ngroups == 1
&& jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
&& jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
&& jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
+ if (!args_ok) return status::unimplemented;
- return args_ok ? status::success : status::unimplemented;
+ return status::success;
+}
+
+void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ UNUSED(scratchpad);
+ UNUSED(jcp);
}
const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28;
status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
jit_conv_conf_t &jcp, const convolution_desc_t &cd,
cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &diff_weights_pd,
- cpu_memory_t::pd_t &diff_bias_pd, cpu_memory_t::pd_t &diff_dst_pd)
-{
+ cpu_memory_t::pd_t &diff_bias_pd, cpu_memory_t::pd_t &diff_dst_pd) {
if (!mayiuse(avx512_common))
return status::unimplemented;
- const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
-
const memory_desc_wrapper src_d(&src_pd);
const memory_desc_wrapper diff_weights_d(&diff_weights_pd);
const memory_desc_wrapper diff_bias_d(&diff_bias_pd);
int ndims = src_d.ndims();
jcp = zero<decltype(jcp)>();
+
+ jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
jcp.ndims = ndims;
jcp.prop_kind = cd.prop_kind;
/* check for the 1st convolution */
jcp.is_1stconv = is_1stconv(jcp);
- jcp.oc_block = simd_w;
+ jcp.oc_block = jcp.simd_w;
bool ok_to_pad_channels = true
&& jcp.ngroups == 1
&& src_d.data_type() == data_type::f32;
if (ok_to_pad_channels)
- jcp.oc = rnd_up(jcp.oc, simd_w);
+ jcp.oc = rnd_up(jcp.oc, jcp.simd_w);
if (jcp.oc % jcp.oc_block)
return status::unimplemented;
&& everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad)
&& jcp.kw <= 28 - jcp.with_bias
&& jcp.stride_w == 4
- && tr_ld / simd_w <= 4 /* [bwd_w:tr_src:r1] */
+ && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */
&& IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */
&& IMPLICATION(diff_weights_d.format() != any,
diff_weights_d.format() == want_4fma_wfmt);
if (!ok)
return status::unimplemented;
- jcp.ic_block = simd_w;
+ jcp.ic_block = jcp.simd_w;
if (ok_to_pad_channels)
jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
jcp.nb_ic = jcp.ic / jcp.ic_block;
&& jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
&& jcp.ic <= diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
&& jcp.oc <= diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
+ if (!args_ok) return status::unimplemented;
+
+ { // balancing
+ int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
+ balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
+ jcp.nthr = nthr;
+ jcp.nthr_mb = nthr_mb;
+ jcp.nthr_g = nthr_g;
+ jcp.nthr_oc_b = nthr_oc_b;
+ jcp.nthr_ic_b = nthr_ic_b;
+ }
+
+ return status::success;
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ if (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
+ if (jcp.is_1stconv) {
+ const size_t tr_src_size =
+ jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld;
+ scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
+ } else {
+ // XXX: See the comment about tr_iw and guarding elements in
+ // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
+ const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic;
+ const size_t min_tr_src_size_per_thr
+ = jcp.ih * jcp.ic_block * jcp.tr_iw;
+ const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr
+ + jcp.tr_src_num_guard_elems;
+ scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
+ }
+
+ /* prepare synchronization contexts */
+ if (jcp.nthr_oc_b > 1) {
+ const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
+ scratchpad.book(key_conv_tr_src_bctx,
+ sizeof(simple_barrier::ctx_t) * tr_src_bctx_size);
+ }
+
+ if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
+ const size_t tr_diff_dst_size = jcp.nthr_mb * jcp.ngroups
+ * jcp.nb_oc * jcp.oc_block * jcp.tr_ow * jcp.oh;
+ scratchpad.book(key_conv_tr_diff_dst,
+ jcp.typesize_in * tr_diff_dst_size);
+
+ /* prepare synchronization contexts */
+ if (jcp.nthr_ic_b > 1) {
+ const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
+ scratchpad.book(key_conv_tr_diff_dst_bctx,
+ sizeof(simple_barrier::ctx_t) * tr_diff_dst_bctx_size);
+ }
+ }
+ }
+
+ if (jcp.nthr_mb > 1) {
+ const int wei_size = jcp.ngroups * jcp.oc * jcp.ic
+ * jcp.kh * jcp.kw * jcp.kd;
+ const int bia_size = jcp.ngroups * jcp.oc;
+ const size_t wei_bia_reduction_size = wei_size + bia_size;
+
+ scratchpad.book(key_conv_wei_bia_reduction,
+ jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1));
+ scratchpad.book(key_conv_wei_bia_reduction_bctx,
+ sizeof(simple_barrier::ctx_t));
+ }
+
+ if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
+}
- return args_ok ? status::success : status::unimplemented;
+void jit_avx512_common_conv_bwd_weights_kernel_f32::balance(
+ const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
+ int &nthr_oc_b_, int &nthr_ic_b_)
+{
+ nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
+
+ const int max_threads = mkldnn_get_max_threads();
+
+ if (max_threads < j.ngroups) {
+ /* simplification... fortunately it doesn't hurt much */
+ return;
+ }
+
+ if (!mkldnn_thr_syncable()
+ && utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
+ // should not happen -- the driver is not ready
+ // for TBB-like non-synchronous threading yet
+ return;
+ }
+
+ if (j.ver == ver_4fma && j.is_1stconv) {
+ nthr_g_ = 1;
+ nthr_oc_b_ = 1;
+ nthr_ic_b_ = nstl::min(j.nb_ic, max_threads);
+ nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb);
+ nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_;
+ return;
+ }
+
+ nthr_g_ = j.ngroups;
+ const int nthr = max_threads / nthr_g_;
+
+ auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
+ /* calculate per thread memory cost (read/write). high level optimizer
+ * tries to minimize memory consumption. few notes:
+ * (n1) unclear why, but that essentially helps first convolution...
+ * (n2) assuming the reduction over minibatch is always there:
+ * - instead of 8 it should be 5 here (write ~= 2 read):
+ * kernel: temporal workspace 1 write
+ * reduction: 1 read from workspace and 1 write to the diff_wei
+ * - but experiments showed 8 works better than 5 or 6... */
+
+ const int src_coef = j.ver == ver_4fma || j.ver == ver_vnni ? 4 : 1;
+ const int dst_coef = 1;
+ const int wei_coef = j.ver == ver_vnni ? 4 : 8;
+
+ return 0
+ + src_coef
+ * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
+ * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id
+ / j.stride_d / j.stride_h / j.stride_w /* (n1) */
+ + dst_coef
+ * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
+ * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od
+ + wei_coef /* (n2) */
+ * div_up(j.ngroups, nthr_g_)
+ * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b)
+ * j.kh * j.kw * j.kd * j.ic_block * j.oc_block;
+ };
+
+ int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
+
+ /* step 1: find the best thread distribution with lowest memory cost */
+ const int nthr_mb_max = nstl::min(nthr, j.mb * j.od);
+ for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
+ const int nthr_par = nthr / nthr_mb;
+ const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
+ for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
+ int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
+
+ int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
+ if (mem_cost <= best_mem_cost) {
+ best_mem_cost = mem_cost;
+ nthr_mb_ = nthr_mb;
+ nthr_oc_b_ = nthr_oc_b;
+ nthr_ic_b_ = nthr_ic_b;
+ }
+ }
+
+ if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
+ }
+
+ if (j.ver != ver_vnni && !mayiuse(avx512_mic)) {
+ auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
+ return 1
+ * div_up(j.mb, nthr_mb)
+ * div_up(j.ngroups, nthr_g_)
+ * div_up(j.nb_oc, nthr_oc_b)
+ * div_up(j.nb_ic, nthr_ic_b);
+ };
+
+ /* step 2: search for a thread distribution with lower compute cost.
+ * the constrains:
+ * - memory cost cannot exceed 110% of the best found in the step 1
+ * - unless compute cost is 133% lower than the current best case
+ * note: both constants were found empirically */
+ int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
+ for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
+ const int nthr_par = nthr / nthr_mb;
+ const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
+ for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
+ int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
+ int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
+ int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
+
+ const bool opt1 = comp_cost <= best_comp_cost
+ && mem_cost < 1.1 * best_mem_cost;
+ const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost;
+
+ if (opt1 || opt2) {
+ best_comp_cost = comp_cost;
+ nthr_mb_ = nthr_mb;
+ nthr_oc_b_ = nthr_oc_b;
+ nthr_ic_b_ = nthr_ic_b;
+ }
+ }
+
+ if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
+ }
+ }
+
+ if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads)
+ nthr_mb_ = nstl::min(j.mb * j.od, max_threads);
+ nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
+
+ assert(nthr_ <= max_threads);
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1));
}
+template struct _jit_avx512_common_conv_fwd_kernel<Zmm>;
+template struct _jit_avx512_common_conv_fwd_kernel<Xmm>;
+
}
}
}