1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "c_types_map.hpp"
20 #include "math_utils.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
27 #include "cpu_barrier.hpp"
28 #include "cpu_batch_normalization_utils.hpp"
29 #include "jit_generator.hpp"
31 #include "jit_avx512_core_bf16cvt.hpp"
32 #include "jit_uni_batch_normalization.hpp"
40 using namespace memory_tracking::names;
42 using namespace Xbyak;
43 namespace barrier = simple_barrier;
45 typedef float acc_data_t;
47 template <cpu_isa_t isa>
48 struct jit_bnorm_t: public jit_generator {
49 struct call_params_t {
50 // keep all sizes at 8 bytes -- jit code expects this
51 size_t N_ithr, N_nthr;
52 size_t coff_max, soff_max;
53 size_t mb_stride_Bc, spat_size, spat_size_loc;
56 acc_data_t chan_size, eps, one;
57 const acc_data_t *scale_shift;
58 const acc_data_t *mean, *var;
59 const acc_data_t *diff_scale_shift;
60 const void *src, *dst;
61 const void *diff_src, *diff_dst;
62 const acc_data_t *rbuf1, *rbuf2;
64 barrier::ctx_t *barrier;
67 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t)
69 /* cpu specific part */
70 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
71 isa == avx2, Ymm, Zmm>::type;
72 const AddressFrame &vmmword = (isa == sse42) ? xword :
73 (isa == avx2) ? yword : zword;
75 const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
76 int vlen_spat_data_; // set by ctor depending on data type (BF16 or FP32);
78 const batch_normalization_pd_t *bdesc_;
82 void (*ker)(const call_params_t *);
83 void operator()(const call_params_t *p) { (*ker)(p); }
85 Reg64 reg_param = abi_param1;
87 Reg64 reg_scale_shift = rbx;
88 Reg64 reg_rbuf1 = abi_not_param1;
89 Reg64 reg_rbuf2 = rdx;
92 Reg64 reg_var = reg_param;
93 Reg64 reg_diff_scale_shift = rax;
96 Reg64 reg_coff_max = r9;
98 Reg64 reg_soff_max = r11;
100 Reg64 reg_roff = r13;
102 Reg64 reg_mb_stride_Bc = r14;
105 Reg64 reg_diff_src = reg_rbuf1;
107 Reg64 reg_diff_dst = reg_dst;
109 Reg64 reg_tmp_off = reg_roff;
111 // Reuse loop counters
112 Reg64 reg_bar = reg_coff;
113 Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff
114 Reg64 reg_tmp = reg_ctr;
117 bool with_relu, with_relu_inf_only;
118 Vmm vzero; // is_fwd() ? vdiff_beta : vbeta
119 Reg64 reg_ws = reg_roff;
120 Label l_relu_mask_avx2;
121 Opmask kstore_mask = Opmask(1);
123 // channel tail processing
124 Opmask ktail_mask = Opmask(2);
126 // FP32->BF16 emulation
127 bf16_emulation_t *bf16_emu_;
128 Reg64 reg_bf16_tmp = reg_tmp;
129 Zmm vcvt_bf16_one = Zmm(16);
130 Zmm vcvt_bf16_eve = Zmm(17);
131 Zmm vcvt_bf16_sel = Zmm(18);
132 Zmm vcvt_bf16_tmp = Zmm(19);
134 size_t unroll_blocks;
136 Vmm vbuf = Vmm(isa == avx512_common ? 20 : 5);
137 Vmm vdiff_beta = Vmm(isa == avx512_common ? 21 : 6);
138 Vmm vdiff_gamma = Vmm(isa == avx512_common ? 22 : 7);
139 Vmm vsqrtvar = Vmm(isa == avx512_common ? 23 : 8);
140 Vmm vone = Vmm(isa == avx512_common ? 24 : 9);
141 Vmm vmean = Vmm(isa == avx512_common ? 25 : 10);
142 Vmm vgamma = Vmm(isa == avx512_common ? 26 : 11);
143 Vmm vbeta = Vmm(isa == avx512_common ? 27 : 12);
144 Vmm veps = Vmm(isa == avx512_common ? 28 : 13);
145 Vmm vchan_size = Vmm(isa == avx512_common ? 29 : 14);
146 Vmm vtail_mask = Vmm(isa == avx512_common ? 30 : 15);
151 size_t chan_data_offt;
154 stack_off_N_nthr = 0,
155 stack_off_N_ithr = 8,
158 stack_off_diff_src = 32,
159 stack_off_diff_dst = 40,
160 stack_off_diff_scale_shift = 48,
162 stack_off_barrier = 64,
163 stack_off_spat_size_loc = 72,
165 stack_off_s_tail = 88,
166 stack_off_is_cblk_tail = 96,
167 stack_size_required = 104,
170 bool is_c_padded() const {
171 const memory_desc_wrapper data_d(bdesc_->src_pd());
172 return bdesc_->C() != data_d.blocking_desc().padding_dims[1];
175 void compute_static_strides() {
176 spat_size = bdesc_->D() * bdesc_->W() * bdesc_->H();
177 chan_data_offt = bdesc_->C() * sizeof(acc_data_t);
179 if (isa == avx512_mic) {
188 void load_common_params() {
189 # define PARAM_OFF(x) offsetof(call_params_t, x)
190 mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]);
191 if (bdesc_->is_bwd())
192 mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]);
193 mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]);
194 mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]);
195 mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]);
196 shl(reg_coff_max, 2);
198 mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
199 mov(reg_scale_shift, ptr[reg_param + PARAM_OFF(scale_shift)]);
201 uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]);
202 uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]);
203 uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
205 mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]);
206 mov(ptr[rsp + stack_off_N_nthr], reg_tmp);
207 mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]);
208 mov(ptr[rsp + stack_off_N_ithr], reg_tmp);
209 mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]);
210 mov(ptr[rsp + stack_off_src], reg_tmp);
211 mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]);
212 mov(ptr[rsp + stack_off_dst], reg_tmp);
213 mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]);
214 mov(ptr[rsp + stack_off_diff_src], reg_tmp);
215 mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]);
216 mov(ptr[rsp + stack_off_diff_dst], reg_tmp);
217 mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]);
218 mov(ptr[rsp + stack_off_ws], reg_tmp);
219 mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]);
220 mov(ptr[rsp + stack_off_barrier], reg_tmp);
221 if (is_spatial_thr_) {
222 mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]);
223 mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp);
224 mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]);
225 mov(ptr[rsp + stack_off_s_s], reg_tmp);
226 mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]);
227 mov(ptr[rsp + stack_off_s_tail], reg_tmp);
230 mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]);
231 mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp);
234 if (bdesc_->is_fwd()) {
235 mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
236 mov(reg_var, reg_tmp);
238 mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale_shift)]);
239 mov(ptr[rsp + stack_off_diff_scale_shift], reg_tmp);
240 mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
241 mov(reg_var, reg_tmp);
246 void prepare_tail_mask_avx512_common() {
247 if (!is_c_padded()) return;
249 const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
250 const int mask = (1 << tail) - 1;
252 Reg32 regw_tmp = reg_tmp.cvt32();
254 kmovw(ktail_mask, regw_tmp);
257 void prepare_tail_mask_avx2_common() {
258 if (!is_c_padded()) return;
260 const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
261 static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
262 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
263 0, 0, 0, 0, 0, 0, 0, 0};
265 mov(reg_tmp, reinterpret_cast<size_t>(&mask[8 - tail]));
266 vmovups(vtail_mask, ptr[reg_tmp]);
269 void prepare_relu() {
270 with_relu = bdesc_->is_fwd()
271 ? bdesc_->with_relu_post_op() || bdesc_->fuse_bn_relu()
272 : bdesc_->fuse_bn_relu();
273 with_relu_inf_only = with_relu && bdesc_->is_fwd()
274 && !(bdesc_->fuse_bn_relu() && bdesc_->is_training());
276 vzero = bdesc_->is_fwd() ? vdiff_beta : vbeta;
278 uni_vpxor(vzero, vzero, vzero);
279 if (!bdesc_->is_fwd() && isa == avx2)
280 prepare_l_relu_mask_avx2();
284 void prepare_l_relu_mask_avx2() {
288 L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
289 for (int i = 0; i < 8; ++i) dd(1<<i);
293 void fwd_process_relu_avx2(Vmm vdst, int offt, Vmm vstore_mask) {
294 Reg64 reg_store_mask = reg_diff_scale_shift;
296 vcmpps(vstore_mask, vzero, vdst, _cmp_lt_os);
297 vmovmskps(reg_store_mask, vstore_mask);
298 mov(ptr[reg_ws + reg_soff + offt / (1 << 5)], reg_store_mask.cvt8());
299 vblendvps(vdst, vzero, vdst, vstore_mask);
303 void fwd_process_relu_avx512_common(Vmm vdst, int offt) {
304 int bs = 5 - is_bf16_; // bit shift depends on data type
306 vcmpps(kstore_mask, vzero, vdst, _cmp_lt_os);
307 kmovw(ptr[reg_ws + reg_soff + offt / (1 << bs)], kstore_mask);
308 vblendmps(vdst | kstore_mask, vzero, vdst);
312 void bwd_process_relu_avx2(Vmm vdiff_dst, int offt, Vmm vstore_mask) {
314 vpbroadcastb(vstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]);
315 vpand(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
316 vpcmpeqd(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
317 vblendvps(vdiff_dst, vzero, vdiff_dst, vstore_mask);
321 void bwd_process_relu_avx512_common(Vmm vdiff_dst, int offt) {
322 int bs = 5 - is_bf16_; // bit shift depends on data type
324 kmovw(kstore_mask, ptr[reg_ws + reg_soff + offt / (1 << bs)]);
325 vmovups(vdiff_dst | kstore_mask | T_z, vdiff_dst);
329 void uni_vmovups_spat_data(const Operand &dst, const Operand &src) {
332 if (mayiuse(avx512_core_bf16))
333 vcvtneps2bf16(Ymm(src.getIdx()), Zmm(src.getIdx()));
335 bf16_emu_->r_vcvtneps2bf16(
336 Ymm(src.getIdx()), Zmm(src.getIdx()));
337 vmovdqu16(dst.getAddress(), Ymm(src.getIdx()));
339 uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
343 vpmovzxwd(Zmm(dst.getIdx()), src.getAddress());
344 vpslld(Zmm(dst.getIdx()), Zmm(dst.getIdx()), 0x10);
346 uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
351 void uni_vmovups_tail_avx2_common(const Operand &dst,
352 const Operand &src, Label &l_ret) {
354 vmaskmovps(dst.getAddress(), vtail_mask, Vmm(src.getIdx()));
356 vmaskmovps(Vmm(dst.getIdx()), vtail_mask, src.getAddress());
361 void uni_vmovups_tail_avx512_common(const Operand &dst,
362 const Operand &src, Label &l_ret) {
364 uni_vmovups(dst.getAddress() | ktail_mask | T_z, Vmm(src.getIdx()));
366 uni_vmovups(Vmm(dst.getIdx()) | ktail_mask | T_z, src.getAddress());
371 void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) {
372 Label l_no_mask, l_ret;
375 mov(reg_tmp, ptr[rsp + stack_off_is_cblk_tail]);
379 lea(reg_tmp, ptr[reg_coff + vlen]);
380 cmp(reg_tmp, reg_coff_max);
382 assert(isa == avx512_common || isa == avx2);
383 if (isa == avx512_common)
384 uni_vmovups_tail_avx512_common(dst, src, l_ret);
385 else if (isa == avx2)
386 uni_vmovups_tail_avx2_common(dst, src, l_ret);
390 uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
392 uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
398 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
399 mov(reg_bar, ptr[rsp + stack_off_barrier]);
400 simple_barrier::generate(*this, reg_bar, reg_nnthr);
403 Address mean_ptr(size_t offt = 0) {
404 return vmmword[reg_mean + reg_coff + offt + 0 * chan_data_offt];
407 Address var_ptr(size_t offt = 0) {
408 return vmmword[reg_var + reg_coff + offt + 0 * chan_data_offt];
411 Address diff_gamma_ptr(size_t offt = 0) {
412 return vmmword[reg_diff_scale_shift + reg_coff + offt
413 + 0 * chan_data_offt];
416 Address diff_beta_ptr(size_t offt = 0) {
417 return vmmword[reg_diff_scale_shift + reg_coff + offt
418 + 1 * chan_data_offt];
421 Address gamma_ptr(size_t offt = 0) {
422 return vmmword[reg_scale_shift + reg_coff + offt + 0 * chan_data_offt];
425 Address beta_ptr(size_t offt = 0) {
426 return vmmword[reg_scale_shift + reg_coff + offt + 1 * chan_data_offt];
429 template <typename init_t, typename body_t, typename fini_t>
430 void spat_loop(size_t len, size_t blocks, size_t regs,
431 init_t init, body_t body, fini_t fini) {
432 size_t factor = regs * blocks;
433 size_t loop_unroll = len / factor * factor;
434 size_t loop_tail = len - loop_unroll;
435 size_t num_active_regs = (len < regs) ? len : regs;
436 for (size_t i = 0; i < num_active_regs; i++)
439 if (is_spatial_thr_) {
440 mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]);
441 add(reg_soff, ptr[rsp + stack_off_s_s]);
443 mov(reg_ctr, loop_unroll);
447 for (size_t i = 0; i < factor; i++) {
448 size_t base_reg = i % regs;
451 add(reg_soff, factor * vlen_spat_data_);
452 sub(reg_ctr, factor);
455 if (is_spatial_thr_) {
456 add(reg_soff, ptr[rsp + stack_off_s_tail]);
460 for (size_t i = 0; i < loop_tail; i++) {
461 size_t base_reg = i % regs;
465 add(reg_soff, loop_tail * vlen_spat_data_);
467 for (size_t i = 0; i < num_active_regs; i++)
471 void mean_channels() {
474 uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
475 spat_loop(spat_size, unroll_blocks,
477 [=](size_t base_reg) {
478 Vmm v = Vmm(base_reg * 2);
482 [=](size_t base_reg, size_t i) {
483 Vmm v0 = Vmm(base_reg * 2 + 0);
484 Vmm v1 = Vmm(base_reg * 2 + 1);
485 size_t offt = i * vlen_spat_data_;
486 uni_vmovups_spat_data(
487 v1, vmmword[reg_src + reg_soff + offt]);
488 uni_vaddps(v0, v0, v1);
489 mic_prefetcht0(ptr[reg_src + reg_soff + offt
491 mic_prefetcht1(ptr[reg_src + reg_soff + offt
494 [=](size_t base_reg) {
496 Vmm v = Vmm(base_reg * 2);
500 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
503 cmp(reg_coff, reg_coff_max);
508 void var_channels() {
511 uni_vmovups_maybe_tail(vmean, mean_ptr());
512 uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
513 spat_loop(spat_size, unroll_blocks, unroll_regs,
514 [=](size_t base_reg) {
515 Vmm v = Vmm(base_reg * 3);
519 [=](size_t base_reg, size_t i) {
520 Vmm v = Vmm(3 * base_reg);
521 Vmm vtmp0 = Vmm(3 * base_reg + 1);
522 Vmm vtmp1 = Vmm(3 * base_reg + 2);
523 size_t offt = i * vlen_spat_data_;
524 uni_vmovups_spat_data(
525 vtmp0, vmmword[reg_src + reg_soff + offt]);
527 movups(vtmp1, vmean);
530 vsubps(vtmp1, vmean, vtmp0);
532 uni_vfmadd231ps(v, vtmp1, vtmp1);
534 mic_prefetcht0(ptr[reg_src + reg_soff + offt
536 mic_prefetcht1(ptr[reg_src + reg_soff + offt
539 [=](size_t base_reg) {
541 Vmm v = Vmm(base_reg * 3);
545 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
547 cmp(reg_coff, reg_coff_max);
552 void compute_mean_variance() {
553 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
554 xor_(reg_coff, reg_coff);
557 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
558 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
559 cmp(reg_coff, reg_coff_max);
563 mov(reg_src, ptr[rsp + stack_off_src]);
565 xor_(reg_soff, reg_soff);
568 xor_(reg_coff, reg_coff);
571 mov(reg_tmp_off, reg_soff);
576 mov(reg_soff, reg_tmp_off);
577 add(reg_src, vlen / 2);
578 mov(reg_coff, vlen / 2);
582 sub(reg_src, vlen / 2);
585 add(reg_soff, reg_mb_stride_Bc);
586 cmp(reg_soff, reg_soff_max);
590 Label no_mean_reduction;
592 mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
594 jne(no_mean_reduction);
595 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
596 xor_(reg_coff, reg_coff);
597 Label mean_reduction_channels;
598 L(mean_reduction_channels); {
599 mov(reg_roff, reg_coff);
600 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
601 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
602 mov(reg_ctr, reg_nnthr);
603 Label mean_reduction_thrs;
604 L(mean_reduction_thrs); {
605 uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
606 uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0));
607 add(reg_roff, reg_coff_max);
609 jnz(mean_reduction_thrs);
611 uni_vdivps(Vmm(1), Vmm(1), vchan_size);
612 uni_vmovups_maybe_tail(mean_ptr(), Vmm(1));
614 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
616 cmp(reg_coff, reg_coff_max);
617 jne(mean_reduction_channels);
620 L(no_mean_reduction);
623 xor_(reg_soff, reg_soff);
626 xor_(reg_coff, reg_coff);
629 mov(reg_tmp_off, reg_soff);
634 mov(reg_soff, reg_tmp_off);
635 add(reg_src, vlen / 2);
636 mov(reg_coff, vlen / 2);
640 sub(reg_src, vlen / 2);
643 add(reg_soff, reg_mb_stride_Bc);
644 cmp(reg_soff, reg_soff_max);
648 Label no_var_reduction;
650 mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
652 jne(no_var_reduction);
654 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
655 xor_(reg_coff, reg_coff);
656 Label var_reduction_channels;
657 L(var_reduction_channels); {
658 mov(reg_roff, reg_coff);
659 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
660 mov(reg_ctr, reg_nnthr);
661 Label var_reduction_thrs;
662 L(var_reduction_thrs); { // TODO: unroll (?)
663 uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
664 add(reg_roff, reg_coff_max);
666 jnz(var_reduction_thrs);
668 uni_vdivps(Vmm(1), Vmm(1), vchan_size);
669 uni_vmovups_maybe_tail(var_ptr(), Vmm(1));
670 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
672 cmp(reg_coff, reg_coff_max);
673 jne(var_reduction_channels);
680 void forward_channels() {
683 uni_vmovups_maybe_tail(vmean, mean_ptr());
684 uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
685 uni_vaddps(vsqrtvar, vsqrtvar, veps);
686 uni_vsqrtps(vsqrtvar, vsqrtvar);
688 if (bdesc_->use_scaleshift()) {
689 uni_vmovups_maybe_tail(vgamma, gamma_ptr());
690 uni_vmovups_maybe_tail(vbeta, beta_ptr());
693 Vmm vscale = bdesc_->use_scaleshift() ? vgamma : vone;
694 Vmm vdiv = bdesc_->use_scaleshift() ? vgamma : vsqrtvar;
697 movups(vbuf, vscale);
698 divps(vbuf, vsqrtvar);
701 vdivps(vdiv, vscale, vsqrtvar);
704 auto compute = [=](bool output_is_aligned) {
705 spat_loop(spat_size, unroll_blocks, unroll_regs,
706 [](size_t base_reg) {UNUSED(base_reg);},
707 [=](size_t base_reg, size_t i) {
708 Vmm v = Vmm(base_reg);
709 size_t offt = i * vlen_spat_data_;
710 uni_vmovups_spat_data(
711 v, vmmword[reg_src + reg_soff + offt]);
712 mic_prefetcht0(ptr[reg_src + reg_soff + offt
714 mic_prefetcht1(ptr[reg_src + reg_soff + offt
716 uni_vsubps(v, v, vmean);
717 if (bdesc_->use_scaleshift()) {
718 uni_vfmadd213ps(v, vgamma, vbeta);
720 uni_vmulps(v, v, vsqrtvar);
722 if (with_relu_inf_only) {
723 uni_vmaxps(v, v, vzero);
724 } else if (with_relu) {
725 if (isa == avx512_common)
726 fwd_process_relu_avx512_common(v, offt);
728 fwd_process_relu_avx2(v, offt, Vmm(3));
730 if (output_is_aligned) {
732 vmmword[reg_dst + reg_soff + offt], v);
734 uni_vmovups_spat_data(
735 vmmword[reg_dst + reg_soff + offt], v);
738 [](size_t base_reg) {UNUSED(base_reg);});
742 compute(false); // no mask-able NT store for BF16
744 Label unaligned_store, end_store;
745 test(reg_dst, vlen - 1);
746 jnz(unaligned_store, T_NEAR);
748 jmp(end_store, T_NEAR);
749 L(unaligned_store); {
756 cmp(reg_coff, reg_coff_max);
762 mov(reg_src, ptr[rsp + stack_off_src]);
763 mov(reg_dst, ptr[rsp + stack_off_dst]);
764 mov(reg_ws, ptr[rsp + stack_off_ws]);
766 xor_(reg_soff, reg_soff);
769 xor_(reg_coff, reg_coff);
771 mov(reg_tmp_off, reg_soff);
776 mov(reg_soff, reg_tmp_off);
777 add(reg_src, vlen / 2);
778 add(reg_dst, vlen / 2);
779 mov(reg_coff, vlen / 2);
783 sub(reg_src, vlen / 2);
784 sub(reg_dst, vlen / 2);
787 add(reg_soff, reg_mb_stride_Bc);
788 cmp(reg_soff, reg_soff_max);
793 void backward_sh_channels() {
796 uni_vmovups_maybe_tail(vmean, mean_ptr());
797 uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
798 uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]);
799 spat_loop(spat_size, 1, 1,
800 [=](size_t base_reg) {
802 for (int i = 0; i < 2; i++) {
803 Vmm v(base_reg * 5 + i);
808 [=](size_t base_reg, size_t i) {
809 Vmm o0 = Vmm(base_reg * 5 + 0);
810 Vmm o1 = Vmm(base_reg * 5 + 1);
811 Vmm t1 = Vmm(base_reg * 5 + 2);
812 Vmm t2 = Vmm(base_reg * 5 + 3);
813 Vmm t3 = Vmm(base_reg * 5 + 4);
814 size_t offt = i * vlen_spat_data_;
815 uni_vmovups_spat_data(
816 t1, vmmword[reg_src + reg_soff + offt]);
817 uni_vmovups_spat_data(
818 t2, vmmword[reg_diff_dst + reg_soff + offt]);
820 if (isa == avx512_common)
821 bwd_process_relu_avx512_common(t2, offt);
822 else if (isa == avx2)
823 bwd_process_relu_avx2(t2, offt, t3);
827 uni_vsubps(t3, vmean, t1, t3);
832 vfnmadd231ps(o0, t3, t2);
834 uni_vaddps(o1, o1, t2);
835 mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
837 mic_prefetcht0(ptr[reg_src + reg_soff + offt
839 mic_prefetcht1(ptr[reg_diff_dst + reg_soff + offt
841 mic_prefetcht1(ptr[reg_src + reg_soff + offt
844 [=](size_t base_reg) {
848 uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0));
849 uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1));
852 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
853 uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1));
855 cmp(reg_coff, reg_coff_max);
860 void backward_diff_channels() {
863 uni_vmovups_maybe_tail(vmean, mean_ptr());
864 uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
865 uni_vaddps(vsqrtvar, vsqrtvar, veps);
866 uni_vsqrtps(vsqrtvar, vsqrtvar);
867 uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
868 if (bdesc_->use_scaleshift())
869 uni_vmovups_maybe_tail(vgamma, gamma_ptr());
870 uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr());
871 uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr());
872 uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar);
873 uni_vdivps(vdiff_beta, vdiff_beta, vchan_size);
874 uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size);
876 auto compute = [=](bool output_is_aligned) {
877 spat_loop(spat_size, unroll_blocks, unroll_regs,
878 [=](size_t base_reg) {UNUSED(base_reg);},
879 [=](size_t base_reg, size_t i) {
880 Vmm v(base_reg * 2 + 0);
881 Vmm t(base_reg * 2 + 1);
882 Vmm t1(base_reg * 2 + 2);
883 size_t offt = i * vlen_spat_data_;
884 uni_vmovups_spat_data(
885 v, vmmword[reg_diff_dst + reg_soff + offt]);
887 if (isa == avx512_common)
888 bwd_process_relu_avx512_common(v, offt);
889 else if (isa == avx2)
890 bwd_process_relu_avx2(v, offt, t);
894 if (!bdesc_->use_global_stats()) {
895 uni_vsubps(v, v, vdiff_beta);
896 uni_vmovups_spat_data(
897 t, vmmword[reg_src + reg_soff + offt]);
898 uni_vsubps(t, vmean, t, t1);
899 uni_vmulps(t, t, vdiff_gamma);
902 uni_vmulps(v, v, vsqrtvar);
903 if (bdesc_->use_scaleshift()) {
904 uni_vmulps(v, v, vgamma);
906 if (output_is_aligned) {
908 vmmword[reg_diff_src + reg_soff + offt],
911 uni_vmovups_spat_data(
912 vmmword[reg_diff_src + reg_soff + offt],
915 mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
917 mic_prefetcht0(ptr[reg_src + reg_soff + offt
919 mic_prefetcht1(ptr[reg_diff_dst + reg_soff
920 + offt + t1_pf_offt]);
921 mic_prefetcht1(ptr[reg_src + reg_soff + offt
924 [=](size_t base_reg) {UNUSED(base_reg);});
928 compute(false); // no mask-able NT store for BF16
930 Label unaligned_store, end_store;
931 test(reg_diff_src, vlen - 1);
932 jnz(unaligned_store, T_NEAR);
934 jmp(end_store, T_NEAR);
935 L(unaligned_store); {
942 cmp(reg_coff, reg_coff_max);
948 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
949 xor_(reg_coff, reg_coff);
950 Label zero_rbuf, sh_spatial;
953 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
954 uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0));
955 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
956 cmp(reg_coff, reg_coff_max);
960 mov(reg_src, ptr[rsp + stack_off_src]);
961 mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]);
963 assert(isa == avx2 || isa == avx512_common);
964 mov(reg_ws, ptr[rsp + stack_off_ws]);
967 xor_(reg_soff, reg_soff);
969 xor_(reg_coff, reg_coff);
971 mov(reg_tmp_off, reg_soff);
973 backward_sh_channels();
975 mov(reg_soff, reg_tmp_off);
976 add(reg_diff_dst, vlen / 2);
977 add(reg_src, vlen / 2);
978 mov(reg_coff, vlen / 2);
979 backward_sh_channels();
980 sub(reg_diff_dst, vlen / 2);
981 sub(reg_src, vlen / 2);
983 add(reg_soff, reg_mb_stride_Bc);
984 cmp(reg_soff, reg_soff_max);
988 mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]);
990 Label no_sh_reduction;
992 mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
994 Label sh_reduction_channels;
995 jne(no_sh_reduction, T_NEAR);
997 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
998 xor_(reg_coff, reg_coff);
999 L(sh_reduction_channels); {
1000 mov(reg_roff, reg_coff);
1001 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
1002 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
1003 uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
1004 uni_vaddps(vsqrtvar, vsqrtvar, veps);
1005 uni_vsqrtps(vsqrtvar, vsqrtvar);
1006 uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
1007 mov(reg_ctr, reg_nnthr);
1008 Label sh_reduction_thrs;
1009 L(sh_reduction_thrs); { // TODO: unroll (?)
1010 uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]);
1011 uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]);
1012 add(reg_roff, reg_coff_max);
1014 jnz(sh_reduction_thrs);
1016 uni_vmulps(Vmm(0), Vmm(0), vsqrtvar);
1017 uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0));
1018 uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1));
1019 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
1020 cmp(reg_coff, reg_coff_max);
1021 jne(sh_reduction_channels);
1027 mov(reg_diff_src, ptr[rsp + stack_off_diff_src]);
1029 assert(isa == avx2 || isa == avx512_common);
1030 mov(reg_ws, ptr[rsp + stack_off_ws]);
1033 xor_(reg_soff, reg_soff);
1036 xor_(reg_coff, reg_coff);
1038 mov(reg_tmp_off, reg_soff);
1040 backward_diff_channels();
1042 mov(reg_soff, reg_tmp_off);
1043 add(reg_diff_dst, vlen / 2);
1044 add(reg_diff_src, vlen / 2);
1045 add(reg_src, vlen / 2);
1046 mov(reg_coff, vlen / 2);
1047 backward_diff_channels();
1048 sub(reg_diff_dst, vlen / 2);
1049 sub(reg_diff_src, vlen / 2);
1050 sub(reg_src, vlen / 2);
1052 add(reg_soff, reg_mb_stride_Bc);
1053 cmp(reg_soff, reg_soff_max);
1058 jit_bnorm_t(const batch_normalization_pd_t *bdesc)
1059 : bdesc_(bdesc), bf16_emu_() {
1060 static_assert(isa == sse42 || isa == avx2 || isa == avx512_common
1061 || isa == avx512_mic, "unsupported isa");
1063 is_bf16_ = bdesc_->desc()->data_desc.data_type == data_type::bf16;
1064 size_t dt_size = is_bf16_ ? types::data_type_size(data_type::bf16)
1065 : sizeof(acc_data_t);
1066 const int simd_w = isa == sse42 ? 8 :
1067 cpu_isa_traits<isa>::vlen / sizeof(float);
1069 bnorm_utils::is_spatial_thr(bdesc_, simd_w, dt_size);
1070 vlen_spat_data_ = vlen / (1 + is_bf16_); // 32B of BF16 -> 64B of FP32
1072 unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
1073 unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
1078 // init emulation of bfloat16 operations
1079 if (!mayiuse(avx512_core_bf16)) {
1080 bf16_emu_ = new bf16_emulation_t(this, vcvt_bf16_one,
1081 vcvt_bf16_eve, vcvt_bf16_sel, reg_bf16_tmp,
1082 vcvt_bf16_tmp, vcvt_bf16_tmp);
1083 bf16_emu_->init_vcvtneps2bf16();
1087 if (isa == avx512_common)
1088 prepare_tail_mask_avx512_common();
1089 else if (isa == avx2)
1090 prepare_tail_mask_avx2_common();
1092 compute_static_strides();
1093 sub(rsp, stack_size_required);
1094 load_common_params();
1097 if (bdesc_->is_fwd()) {
1098 if (!bdesc_->stats_is_src()) {
1099 compute_mean_variance();
1105 add(rsp, stack_size_required);
1108 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
1112 ~jit_bnorm_t() { delete bf16_emu_; }
1115 template <cpu_isa_t isa>
1116 struct uni_bnorm_driver_t: public c_compatible {
1117 uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc)
1118 : bdesc_(bdesc), ker_(bdesc_) {
1119 const int nthrs = mkldnn_get_max_threads();
1120 const dim_t C_PADDED = get_c_padded(bdesc_);
1122 bool is_bf16 = bdesc_->desc()->data_desc.data_type == data_type::bf16;
1123 dt_size_ = is_bf16 ? types::data_type_size(data_type::bf16)
1124 : sizeof(acc_data_t);
1125 size_t data_size = dt_size_ * bdesc_->MB() * C_PADDED
1126 * bdesc_->D() * bdesc_->H() * bdesc_->W();
1127 l3_size_ = get_cache_size(3, true) * nthrs / 2;
1128 do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0);
1131 ~uni_bnorm_driver_t() {}
1133 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
1134 const batch_normalization_pd_t *bdesc) {
1135 int nthrs = mkldnn_get_max_threads();
1136 dim_t C_PADDED = get_c_padded(bdesc);
1138 int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED;
1139 int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED;
1140 int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs;
1142 scratchpad.book(key_bnorm_tmp_stats, sizeof(acc_data_t) * sbuf_sz);
1143 scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(acc_data_t) * pbuf_sz);
1144 scratchpad.book(key_bnorm_reduction, sizeof(acc_data_t) * rbuf_sz);
1146 if (mkldnn_thr_syncable()) {
1147 int n_barriers = C_PADDED / simd_w;
1148 scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers);
1152 void exec(int ithr, int nthr, const void *src, void *diff_src, void *dst,
1153 const void *diff_dst, const acc_data_t *scale_shift,
1154 acc_data_t *diff_scale_shift, const acc_data_t *mean,
1155 const acc_data_t *var, const uint8_t *ws,
1156 const memory_tracking::grantor_t &scratchpad) {
1157 auto sbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_stats);
1158 auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss);
1159 auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction);
1160 auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
1162 size_t N = bdesc_->MB();
1163 size_t C = bdesc_->C();
1164 size_t C_PADDED = get_c_padded(bdesc_);
1165 size_t D = bdesc_->D();
1166 size_t H = bdesc_->H();
1167 size_t W = bdesc_->W();
1169 size_t img_size = C_PADDED * D * H * W;
1170 const int vlen_spat_data = ker_.vlen_spat_data_;
1172 typename jit_bnorm_t<isa>::call_params_t p;
1174 p.eps = bdesc_->desc()->batch_norm_epsilon;
1176 p.spat_size = D * H * W;
1177 p.chan_size = 1.0f * N * p.spat_size;
1179 int C_blks = C_PADDED / simd_w;
1181 int C_ithr{0}, C_nthr{0}, N_ithr{0}, N_nthr{0}, S_ithr{0}, S_nthr{0};
1182 int C_blk_s{0}, C_blk_e{0}, N_s{0}, N_e{0}, S_s{0}, S_e{0};
1184 int C_blks_per_iter{ 1 }, iters{ 1 };
1186 int num_tensors = bdesc_->is_fwd() ? 1 : 2;
1187 size_t working_set_size
1188 = dt_size_ * (N * D * H * W * simd_w) * num_tensors;
1189 bnorm_utils::cache_balance(working_set_size, C_blks,
1190 C_blks_per_iter, iters);
1193 bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
1194 true, ithr, nthr, N, do_blocking_ ? C_blks_per_iter : C_blks,
1195 SP, C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e,
1196 S_ithr, S_nthr, S_s, S_e);
1198 int SP_N_ithr = N_ithr * S_nthr + S_ithr;
1199 int SP_N_nthr = N_nthr * S_nthr;
1200 assert(IMPLICATION(!mkldnn_thr_syncable(), SP_N_nthr == 1));
1202 p.N_ithr = SP_N_ithr;
1203 p.N_nthr = SP_N_nthr;
1205 int last_iter_blks = C_blks - (iters - 1) * C_blks_per_iter;
1207 int global_barriers_per_iter = C_nthr;
1209 for (int it = 0; it < iters; it++) {
1210 if (it == iters - 1 && iters > 1) {
1211 C_blk_s = C_blk_e = N_s = N_e = 0;
1212 spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
1213 spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
1214 C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
1215 N_e, S_ithr, S_nthr, S_s, S_e);
1217 // Update call parameters for JIT, last iteration
1218 p.N_ithr = N_ithr * S_nthr + S_ithr;
1219 p.N_nthr = N_nthr * S_nthr;
1222 global_C_blk_s = do_blocking_ ?
1223 (C_blk_s == -1) ? -1 : it * C_blks_per_iter + C_blk_s :
1226 int C_blks_thr = C_blk_e - C_blk_s;
1227 int N_thr = N_e - N_s;
1229 size_t coff_base = global_C_blk_s * simd_w;
1231 = global_C_blk_s * p.spat_size * simd_w + N_s * img_size;
1233 p.spat_size_loc = S_e - S_s;
1234 p.S_s = S_s * vlen_spat_data;
1235 p.S_tail = (p.spat_size - S_e) * vlen_spat_data;
1236 p.coff_max = C_blks_thr * simd_w;
1237 p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base;
1238 p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base;
1239 p.scale_shift = scale_shift + coff_base;
1240 p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_)
1241 ? pbuf : diff_scale_shift) + coff_base;
1243 p.soff_max = dt_size_ * N_thr * img_size;
1244 p.src = (void *)((char *)src + soff_base * dt_size_);
1245 p.dst = (void *)((char *)dst + soff_base * dt_size_);
1246 p.diff_src = (void *)((char *)diff_src + soff_base * dt_size_);
1247 p.diff_dst = (void *)((char *)diff_dst + soff_base * dt_size_);
1248 p.ws = ws + soff_base / 8;
1250 p.mb_stride_Bc = dt_size_ * (img_size - p.coff_max * p.spat_size);
1252 // use SP_N_nthr which is the same as p.N_nthr except maybe for
1253 // the last iteration.
1254 p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr
1255 + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w;
1256 // rbuf1 and rbuf2 have to be disjoint
1257 p.rbuf2 = p.rbuf1 + C_PADDED * nthr;
1259 (size_t)((it * C_blks_per_iter + C_blk_e) * simd_w) > C;
1262 = do_blocking_ ? it * global_barriers_per_iter : 0;
1263 p.barrier = barriers + C_ithr + iter_bariers;
1264 if (p.soff_max != 0 && p.coff_max != 0)
1269 void init_barriers(const memory_tracking::grantor_t &scratchpad) {
1270 auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
1272 const int n_barriers = get_c_padded(bdesc_) / simd_w;
1273 for (int i = 0; i < n_barriers; ++i)
1274 barrier::ctx_init(&barriers[i]);
1280 simd_w = isa == sse42 ? 8 : cpu_isa_traits<isa>::vlen
1281 / sizeof(acc_data_t) // BF16 will expand to FP32
1284 static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) {
1286 && !bdesc->stats_is_src()
1287 && bdesc->desc()->prop_kind == prop_kind::forward_inference;
1290 static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc)
1293 || (bdesc->is_bwd() && !bdesc->use_scaleshift())
1294 || bdesc->desc()->prop_kind == prop_kind::backward_data;
1297 static dim_t get_c_padded(const batch_normalization_pd_t *bdesc)
1298 { return bdesc->src_pd()->desc()->layout_desc.blocking.padding_dims[1]; }
1300 const batch_normalization_pd_t *bdesc_;
1301 jit_bnorm_t<isa> ker_;
1305 acc_data_t *buf_, *sbuf_, *rbuf_, *pbuf_;
1312 using namespace data_type;
1313 using namespace memory_format;
1314 using namespace utils;
1317 template <cpu_isa_t isa, data_type_t d_type>
1318 status_t jit_uni_batch_normalization_fwd_t<isa, d_type>::pd_t::init() {
1319 assert(engine()->kind() == engine_kind::cpu);
1320 auto desired_fmt = (ndims() == 4)
1321 ? isa == avx512_common ? nChw16c : nChw8c
1322 : isa == avx512_common ? nCdhw16c : nCdhw8c;
1327 && !has_zero_dim_memory()
1328 && one_of(ndims(), 4, 5)
1329 && desc()->data_desc.data_type == d_type
1330 && IMPLICATION(d_type == bf16, mayiuse(avx512_core))
1331 && IMPLICATION(use_scaleshift(),
1332 desc()->data_scaleshift_desc.data_type == f32)
1333 && desc()->data_desc.format == desired_fmt
1334 && (attr()->has_default_values() || this->with_relu_post_op());
1335 if (!ok) return status::unimplemented;
1337 if (is_training() && fuse_bn_relu()) {
1338 if (isa < avx2) return status::unimplemented;
1339 bn_init_default_ws(this, this->workspace_pd_, 1);
1342 if (memory_desc_wrapper(&data_pd_).blocking_desc().padding_dims[1]
1343 != this->C() && isa < avx2)
1344 return status::unimplemented;
1346 if (stats_is_src() || is_training()) {
1347 memory_desc_t stats_d;
1348 dims_t stats_dims = { C() };
1349 mkldnn_memory_desc_init(&stats_d, 1, stats_dims, f32, x);
1350 mean_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
1351 variance_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
1354 auto scratchpad = scratchpad_registry().registrar();
1355 uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
1357 return status::success;
1360 template <cpu_isa_t isa, data_type_t d_type>
1361 jit_uni_batch_normalization_fwd_t<isa,
1362 d_type>::jit_uni_batch_normalization_fwd_t(const pd_t *apd,
1363 const input_vector &inputs, const output_vector &outputs)
1364 : cpu_primitive_t(apd, inputs, outputs) {
1365 bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd());
1368 template <cpu_isa_t isa, data_type_t d_type>
1369 void jit_uni_batch_normalization_fwd_t<isa, d_type>::execute(event_t *e) const {
1370 auto src = reinterpret_cast<const void *>(this->input_memory(0));
1371 auto dst = reinterpret_cast<void *>(this->memory(0));
1372 auto mean = reinterpret_cast<acc_data_t *>(pd()->stats_is_src()
1373 ? const_cast<char *>(this->input_memory(1))
1375 auto var = reinterpret_cast<acc_data_t *>(pd()->stats_is_src()
1376 ? const_cast<char *>(this->input_memory(2))
1379 auto idx_scale_shift = 1 + 2*pd()->stats_is_src();
1380 auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
1382 auto scratchpad = this->scratchpad();
1384 bnorm_driver_->init_barriers(scratchpad);
1385 auto scale_shift = reinterpret_cast<const acc_data_t *>(
1386 this->input_memory(idx_scale_shift));
1388 parallel(0, (size_t)mkldnn_get_max_threads(), [&](const int ithr, const int nthr) {
1389 bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr,
1390 scale_shift, nullptr, mean, var, ws, scratchpad);
1392 e->set_state(event_t::ready);
1395 template <cpu_isa_t isa, data_type_t d_type>
1396 jit_uni_batch_normalization_fwd_t<isa,
1397 d_type>::~jit_uni_batch_normalization_fwd_t() {
1398 delete bnorm_driver_;
1401 template <cpu_isa_t isa, data_type_t d_type>
1402 status_t jit_uni_batch_normalization_bwd_t<isa, d_type>::pd_t::init() {
1403 assert(engine()->kind() == engine_kind::cpu);
1404 auto desired_fmt = (ndims() == 4)
1405 ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c
1406 : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c;
1411 && !has_zero_dim_memory()
1412 && one_of(ndims(), 4, 5)
1413 && everyone_is(d_type, desc()->data_desc.data_type,
1414 desc()->diff_data_desc.data_type)
1415 && IMPLICATION(d_type == bf16, mayiuse(avx512_core))
1416 && IMPLICATION(use_scaleshift(), utils::everyone_is(f32,
1417 desc()->data_scaleshift_desc.data_type,
1418 desc()->diff_data_scaleshift_desc.data_type))
1419 && everyone_is(desired_fmt, desc()->diff_data_desc.format,
1420 desc()->data_desc.format)
1421 && attr()->has_default_values();
1422 if (!ok) return status::unimplemented;
1424 if (memory_desc_wrapper(&data_pd_).blocking_desc()
1425 .padding_dims[1] != this->C() && isa < avx2)
1426 return status::unimplemented;
1428 if (fuse_bn_relu()) {
1429 if (isa < avx2) return status::unimplemented;
1430 bn_init_default_ws(this, this->workspace_pd_, 1);
1431 size_t this_ws_sz = memory_desc_wrapper(this->workspace_pd()).size();
1434 && hint_fwd_pd_->workspace_pd()
1435 && memory_desc_wrapper(hint_fwd_pd_->workspace_pd()).size()
1437 if (!ws_ok) return status::unimplemented;
1440 /* TODO: extra checks required */
1442 auto scratchpad = scratchpad_registry().registrar();
1443 uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
1445 return status::success;
1447 template <cpu_isa_t isa, data_type_t d_type>
1448 jit_uni_batch_normalization_bwd_t<isa,
1449 d_type>::jit_uni_batch_normalization_bwd_t(const pd_t *apd,
1450 const input_vector &inputs, const output_vector &outputs)
1451 : cpu_primitive_t(apd, inputs, outputs) {
1452 bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd());
1455 template <cpu_isa_t isa, data_type_t d_type>
1456 void jit_uni_batch_normalization_bwd_t<isa, d_type>::execute(event_t *e) const {
1457 auto src = reinterpret_cast<const void *>(this->input_memory(0));
1458 auto mean = reinterpret_cast<const acc_data_t *>(this->input_memory(1));
1459 auto var = reinterpret_cast<const acc_data_t *>(this->input_memory(2));
1460 auto diff_dst = reinterpret_cast<const void *>(this->input_memory(3));
1462 = reinterpret_cast<const acc_data_t *>(this->input_memory(4));
1463 auto diff_src = reinterpret_cast<void *>(this->memory(0));
1464 auto diff_scale_shift = reinterpret_cast<acc_data_t *>(this->memory(1));
1465 auto ws = reinterpret_cast<const uint8_t *>(
1466 this->input_memory(pd()->ws_idx()));
1468 auto scratchpad = this->scratchpad();
1470 bnorm_driver_->init_barriers(scratchpad);
1472 parallel(0, (size_t)mkldnn_get_max_threads(), [&](const int ithr, const int nthr) {
1473 bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst,
1474 scale_shift, diff_scale_shift, mean, var, ws, scratchpad);
1476 e->set_state(event_t::ready);
1479 template <cpu_isa_t isa, data_type_t d_type>
1480 jit_uni_batch_normalization_bwd_t<isa,
1481 d_type>::~jit_uni_batch_normalization_bwd_t() {
1482 delete bnorm_driver_;
1485 /* struct instantiation */
1486 template struct jit_uni_batch_normalization_fwd_t<sse42, data_type::f32>;
1487 template struct jit_uni_batch_normalization_bwd_t<sse42, data_type::f32>;
1488 template struct jit_uni_batch_normalization_fwd_t<avx2, data_type::f32>;
1489 template struct jit_uni_batch_normalization_bwd_t<avx2, data_type::f32>;
1490 template struct jit_uni_batch_normalization_fwd_t<avx512_common, data_type::f32>;
1491 template struct jit_uni_batch_normalization_bwd_t<avx512_common, data_type::f32>;
1492 template struct jit_uni_batch_normalization_fwd_t<avx512_common, data_type::bf16>;
1493 template struct jit_uni_batch_normalization_bwd_t<avx512_common, data_type::bf16>;