1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "c_types_map.hpp"
20 #include "math_utils.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
27 #include "cpu_barrier.hpp"
28 #include "cpu_batch_normalization_utils.hpp"
29 #include "jit_generator.hpp"
31 #include "jit_uni_batch_normalization.hpp"
39 using namespace memory_tracking::names;
41 using namespace Xbyak;
42 namespace barrier = simple_barrier;
46 template <cpu_isa_t isa>
47 struct jit_bnorm_t: public jit_generator {
48 struct call_params_t {
49 // keep all sizes at 8 bytes -- jit code expects this
50 size_t N_ithr, N_nthr;
51 size_t coff_max, soff_max;
52 size_t mb_stride_Bc, spat_size, spat_size_loc;
55 data_t chan_size, eps, one;
56 const data_t *scale_shift;
57 const data_t *mean, *var;
58 const data_t *diff_scale_shift;
59 const data_t *src, *dst;
60 const data_t *diff_src, *diff_dst;
61 const data_t *rbuf1, *rbuf2;
63 barrier::ctx_t *barrier;
66 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t)
68 /* cpu specific part */
69 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
70 isa == avx2, Ymm, Zmm>::type;
71 const AddressFrame &vmmword = (isa == sse42) ? xword :
72 (isa == avx2) ? yword : zword;
74 const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
76 const batch_normalization_pd_t *bdesc_;
79 void (*ker)(const call_params_t *);
80 void operator()(const call_params_t *p) { (*ker)(p); }
82 Reg64 reg_param = abi_param1;
84 Reg64 reg_scale_shift = rbx;
85 Reg64 reg_rbuf1 = abi_not_param1;
86 Reg64 reg_rbuf2 = rdx;
89 Reg64 reg_var = reg_param;
90 Reg64 reg_diff_scale_shift = rax;
93 Reg64 reg_coff_max = r9;
95 Reg64 reg_soff_max = r11;
99 Reg64 reg_mb_stride_Bc = r14;
102 Reg64 reg_diff_src = reg_rbuf1;
104 Reg64 reg_diff_dst = reg_dst;
106 Reg64 reg_tmp_off = reg_roff;
108 // Reuse loop counters
109 Reg64 reg_bar = reg_coff;
110 Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff
111 Reg64 reg_tmp = reg_ctr;
114 bool with_relu, with_relu_inf_only;
115 Vmm vzero; // is_fwd() ? vdiff_beta : vbeta
116 Reg64 reg_ws = reg_roff;
117 Label l_relu_mask_avx2;
118 Opmask kstore_mask = Opmask(1);
120 // channel tail processing
121 Opmask ktail_mask = Opmask(2);
123 size_t unroll_blocks;
125 Vmm vbuf = Vmm(isa == avx512_common ? 20 : 5);
126 Vmm vdiff_beta = Vmm(isa == avx512_common ? 21 : 6);
127 Vmm vdiff_gamma = Vmm(isa == avx512_common ? 22 : 7);
128 Vmm vsqrtvar = Vmm(isa == avx512_common ? 23 : 8);
129 Vmm vone = Vmm(isa == avx512_common ? 24 : 9);
130 Vmm vmean = Vmm(isa == avx512_common ? 25 : 10);
131 Vmm vgamma = Vmm(isa == avx512_common ? 26 : 11);
132 Vmm vbeta = Vmm(isa == avx512_common ? 27 : 12);
133 Vmm veps = Vmm(isa == avx512_common ? 28 : 13);
134 Vmm vchan_size = Vmm(isa == avx512_common ? 29 : 14);
135 Vmm vtail_mask = Vmm(isa == avx512_common ? 30 : 15);
140 size_t chan_data_offt;
143 stack_off_N_nthr = 0,
144 stack_off_N_ithr = 8,
147 stack_off_diff_src = 32,
148 stack_off_diff_dst = 40,
149 stack_off_diff_scale_shift = 48,
151 stack_off_barrier = 64,
152 stack_off_spat_size_loc = 72,
154 stack_off_s_tail = 88,
155 stack_off_is_cblk_tail = 96,
156 stack_size_required = 104,
159 bool is_c_padded() const {
160 const memory_desc_wrapper data_d(bdesc_->src_pd());
161 return bdesc_->C() != data_d.blocking_desc().padding_dims[1];
164 void compute_static_strides() {
165 spat_size = bdesc_->D() * bdesc_->W() * bdesc_->H();
166 chan_data_offt = bdesc_->C() * sizeof(data_t);
168 if (isa == avx512_mic) {
177 void load_common_params() {
178 # define PARAM_OFF(x) offsetof(call_params_t, x)
179 mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]);
180 if (bdesc_->is_bwd())
181 mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]);
182 mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]);
183 mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]);
184 mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]);
185 shl(reg_coff_max, 2);
186 shl(reg_soff_max, 2);
187 shl(reg_mb_stride_Bc, 2);
189 mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
190 mov(reg_scale_shift, ptr[reg_param + PARAM_OFF(scale_shift)]);
192 uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]);
193 uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]);
194 uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
196 mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]);
197 mov(ptr[rsp + stack_off_N_nthr], reg_tmp);
198 mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]);
199 mov(ptr[rsp + stack_off_N_ithr], reg_tmp);
200 mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]);
201 mov(ptr[rsp + stack_off_src], reg_tmp);
202 mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]);
203 mov(ptr[rsp + stack_off_dst], reg_tmp);
204 mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]);
205 mov(ptr[rsp + stack_off_diff_src], reg_tmp);
206 mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]);
207 mov(ptr[rsp + stack_off_diff_dst], reg_tmp);
208 mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]);
209 mov(ptr[rsp + stack_off_ws], reg_tmp);
210 mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]);
211 mov(ptr[rsp + stack_off_barrier], reg_tmp);
212 if (is_spatial_thr_) {
213 mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]);
214 mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp);
215 mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]);
216 mov(ptr[rsp + stack_off_s_s], reg_tmp);
217 mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]);
218 mov(ptr[rsp + stack_off_s_tail], reg_tmp);
221 mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]);
222 mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp);
225 if (bdesc_->is_fwd()) {
226 mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
227 mov(reg_var, reg_tmp);
229 mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale_shift)]);
230 mov(ptr[rsp + stack_off_diff_scale_shift], reg_tmp);
231 mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
232 mov(reg_var, reg_tmp);
237 void prepare_tail_mask_avx512_common() {
238 if (!is_c_padded()) return;
240 const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
241 const int mask = (1 << tail) - 1;
243 Reg32 regw_tmp = reg_tmp.cvt32();
245 kmovw(ktail_mask, regw_tmp);
248 void prepare_tail_mask_avx2_common() {
249 if (!is_c_padded()) return;
251 const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
252 static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
253 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
254 0, 0, 0, 0, 0, 0, 0, 0};
256 mov(reg_tmp, reinterpret_cast<size_t>(&mask[8 - tail]));
257 vmovups(vtail_mask, ptr[reg_tmp]);
260 void prepare_relu() {
261 with_relu = bdesc_->is_fwd()
262 ? bdesc_->with_relu_post_op() || bdesc_->fuse_bn_relu()
263 : bdesc_->fuse_bn_relu();
264 with_relu_inf_only = with_relu && bdesc_->is_fwd()
265 && !(bdesc_->fuse_bn_relu() && bdesc_->is_training());
267 vzero = bdesc_->is_fwd() ? vdiff_beta : vbeta;
269 uni_vpxor(vzero, vzero, vzero);
270 if (!bdesc_->is_fwd() && isa == avx2)
271 prepare_l_relu_mask_avx2();
275 void prepare_l_relu_mask_avx2() {
279 L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
280 for (int i = 0; i < 8; ++i) dd(1<<i);
284 void fwd_process_relu_avx2(Vmm vdst, int offt, Vmm vstore_mask) {
285 Reg64 reg_store_mask = reg_diff_scale_shift;
287 vcmpps(vstore_mask, vzero, vdst, _cmp_lt_os);
288 vmovmskps(reg_store_mask, vstore_mask);
289 mov(ptr[reg_ws + reg_soff + offt / (1 << 5)], reg_store_mask.cvt8());
290 vblendvps(vdst, vzero, vdst, vstore_mask);
294 void fwd_process_relu_avx512_common(Vmm vdst, int offt) {
296 vcmpps(kstore_mask, vzero, vdst, _cmp_lt_os);
297 kmovw(ptr[reg_ws + reg_soff + offt / (1 << 5)], kstore_mask);
298 vblendmps(vdst | kstore_mask, vzero, vdst);
302 void bwd_process_relu_avx2(Vmm vdiff_dst, int offt, Vmm vstore_mask) {
304 vpbroadcastb(vstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]);
305 vpand(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
306 vpcmpeqd(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
307 vblendvps(vdiff_dst, vzero, vdiff_dst, vstore_mask);
311 void bwd_process_relu_avx512_common(Vmm vdiff_dst, int offt) {
313 kmovw(kstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]);
314 vmovups(vdiff_dst | kstore_mask | T_z, vdiff_dst);
318 void uni_vmovups_tail_avx2_common(const Operand &dst,
319 const Operand &src, Label &l_ret) {
321 vmaskmovps(dst.getAddress(), vtail_mask, Vmm(src.getIdx()));
323 vmaskmovps(Vmm(dst.getIdx()), vtail_mask, src.getAddress());
328 void uni_vmovups_tail_avx512_common(const Operand &dst,
329 const Operand &src, Label &l_ret) {
331 uni_vmovups(dst.getAddress() | ktail_mask | T_z, Vmm(src.getIdx()));
333 uni_vmovups(Vmm(dst.getIdx()) | ktail_mask | T_z, src.getAddress());
338 void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) {
339 Label l_no_mask, l_ret;
342 mov(reg_tmp, ptr[rsp + stack_off_is_cblk_tail]);
346 lea(reg_tmp, ptr[reg_coff + vlen]);
347 cmp(reg_tmp, reg_coff_max);
349 assert(isa == avx512_common || isa == avx2);
350 if (isa == avx512_common)
351 uni_vmovups_tail_avx512_common(dst, src, l_ret);
352 else if (isa == avx2)
353 uni_vmovups_tail_avx2_common(dst, src, l_ret);
357 uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
359 uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
365 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
366 mov(reg_bar, ptr[rsp + stack_off_barrier]);
367 simple_barrier::generate(*this, reg_bar, reg_nnthr);
370 Address mean_ptr(size_t offt = 0) {
371 return vmmword[reg_mean + reg_coff + offt + 0 * chan_data_offt];
374 Address var_ptr(size_t offt = 0) {
375 return vmmword[reg_var + reg_coff + offt + 0 * chan_data_offt];
378 Address diff_gamma_ptr(size_t offt = 0) {
379 return vmmword[reg_diff_scale_shift + reg_coff + offt
380 + 0 * chan_data_offt];
383 Address diff_beta_ptr(size_t offt = 0) {
384 return vmmword[reg_diff_scale_shift + reg_coff + offt
385 + 1 * chan_data_offt];
388 Address gamma_ptr(size_t offt = 0) {
389 return vmmword[reg_scale_shift + reg_coff + offt + 0 * chan_data_offt];
392 Address beta_ptr(size_t offt = 0) {
393 return vmmword[reg_scale_shift + reg_coff + offt + 1 * chan_data_offt];
396 template <typename init_t, typename body_t, typename fini_t>
397 void spat_loop(size_t len, size_t blocks, size_t regs,
398 init_t init, body_t body, fini_t fini) {
399 size_t factor = regs * blocks;
400 size_t loop_unroll = len / factor * factor;
401 size_t loop_tail = len - loop_unroll;
402 size_t num_active_regs = (len < regs) ? len : regs;
403 for (size_t i = 0; i < num_active_regs; i++)
406 if (is_spatial_thr_) {
407 mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]);
408 add(reg_soff, ptr[rsp + stack_off_s_s]);
410 mov(reg_ctr, loop_unroll);
414 for (size_t i = 0; i < factor; i++) {
415 size_t base_reg = i % regs;
418 add(reg_soff, factor * vlen);
419 sub(reg_ctr, factor);
422 if (is_spatial_thr_) {
423 add(reg_soff, ptr[rsp + stack_off_s_tail]);
427 for (size_t i = 0; i < loop_tail; i++) {
428 size_t base_reg = i % regs;
432 add(reg_soff, loop_tail * vlen);
434 for (size_t i = 0; i < num_active_regs; i++)
438 void mean_channels() {
441 uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
442 spat_loop(spat_size, unroll_blocks,
444 [=](size_t base_reg) {
445 Vmm v = Vmm(base_reg * 2);
449 [=](size_t base_reg, size_t i) {
450 Vmm v0 = Vmm(base_reg * 2 + 0);
451 Vmm v1 = Vmm(base_reg * 2 + 1);
452 size_t offt = i * vlen;
454 vmmword[reg_src + reg_soff + offt]);
455 uni_vaddps(v0, v0, v1);
456 mic_prefetcht0(ptr[reg_src + reg_soff + offt
458 mic_prefetcht1(ptr[reg_src + reg_soff + offt
461 [=](size_t base_reg) {
463 Vmm v = Vmm(base_reg * 2);
467 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
470 cmp(reg_coff, reg_coff_max);
475 void var_channels() {
478 uni_vmovups_maybe_tail(vmean, mean_ptr());
479 uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
480 spat_loop(spat_size, unroll_blocks, unroll_regs,
481 [=](size_t base_reg) {
482 Vmm v = Vmm(base_reg * 3);
486 [=](size_t base_reg, size_t i) {
487 Vmm v = Vmm(3 * base_reg);
488 Vmm vtmp0 = Vmm(3 * base_reg + 1);
489 Vmm vtmp1 = Vmm(3 * base_reg + 2);
490 size_t offt = i * vlen;
492 vmmword[reg_src + reg_soff + offt]);
494 movups(vtmp1, vmean);
497 vsubps(vtmp1, vmean, vtmp0);
499 uni_vfmadd231ps(v, vtmp1, vtmp1);
501 mic_prefetcht0(ptr[reg_src + reg_soff + offt
503 mic_prefetcht1(ptr[reg_src + reg_soff + offt
506 [=](size_t base_reg) {
508 Vmm v = Vmm(base_reg * 3);
512 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
514 cmp(reg_coff, reg_coff_max);
519 void compute_mean_variance() {
520 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
521 xor_(reg_coff, reg_coff);
524 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
525 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
526 cmp(reg_coff, reg_coff_max);
530 mov(reg_src, ptr[rsp + stack_off_src]);
532 xor_(reg_soff, reg_soff);
535 xor_(reg_coff, reg_coff);
538 mov(reg_tmp_off, reg_soff);
543 mov(reg_soff, reg_tmp_off);
544 add(reg_src, vlen / 2);
545 mov(reg_coff, vlen / 2);
549 sub(reg_src, vlen / 2);
552 add(reg_soff, reg_mb_stride_Bc);
553 cmp(reg_soff, reg_soff_max);
557 Label no_mean_reduction;
559 mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
561 jne(no_mean_reduction);
562 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
563 xor_(reg_coff, reg_coff);
564 Label mean_reduction_channels;
565 L(mean_reduction_channels); {
566 mov(reg_roff, reg_coff);
567 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
568 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
569 mov(reg_ctr, reg_nnthr);
570 Label mean_reduction_thrs;
571 L(mean_reduction_thrs); {
572 uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
573 uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0));
574 add(reg_roff, reg_coff_max);
576 jnz(mean_reduction_thrs);
578 uni_vdivps(Vmm(1), Vmm(1), vchan_size);
579 uni_vmovups_maybe_tail(mean_ptr(), Vmm(1));
581 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
583 cmp(reg_coff, reg_coff_max);
584 jne(mean_reduction_channels);
587 L(no_mean_reduction);
590 xor_(reg_soff, reg_soff);
593 xor_(reg_coff, reg_coff);
596 mov(reg_tmp_off, reg_soff);
601 mov(reg_soff, reg_tmp_off);
602 add(reg_src, vlen / 2);
603 mov(reg_coff, vlen / 2);
607 sub(reg_src, vlen / 2);
610 add(reg_soff, reg_mb_stride_Bc);
611 cmp(reg_soff, reg_soff_max);
615 Label no_var_reduction;
617 mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
619 jne(no_var_reduction);
621 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
622 xor_(reg_coff, reg_coff);
623 Label var_reduction_channels;
624 L(var_reduction_channels); {
625 mov(reg_roff, reg_coff);
626 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
627 mov(reg_ctr, reg_nnthr);
628 Label var_reduction_thrs;
629 L(var_reduction_thrs); { // TODO: unroll (?)
630 uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
631 add(reg_roff, reg_coff_max);
633 jnz(var_reduction_thrs);
635 uni_vdivps(Vmm(1), Vmm(1), vchan_size);
636 uni_vmovups_maybe_tail(var_ptr(), Vmm(1));
637 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
639 cmp(reg_coff, reg_coff_max);
640 jne(var_reduction_channels);
647 void forward_channels() {
650 uni_vmovups_maybe_tail(vmean, mean_ptr());
651 uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
652 uni_vaddps(vsqrtvar, vsqrtvar, veps);
653 uni_vsqrtps(vsqrtvar, vsqrtvar);
657 divps(vbuf, vsqrtvar);
658 movups(vsqrtvar, vbuf);
660 vdivps(vsqrtvar, vone, vsqrtvar);
663 if (bdesc_->use_scaleshift()) {
664 uni_vmovups_maybe_tail(vgamma, gamma_ptr());
665 uni_vmovups_maybe_tail(vbeta, beta_ptr());
668 auto compute = [=](bool output_is_aligned) {
669 spat_loop(spat_size, unroll_blocks, unroll_regs,
670 [](size_t base_reg) {UNUSED(base_reg);},
671 [=](size_t base_reg, size_t i) {
672 Vmm v = Vmm(base_reg);
673 size_t offt = i * vlen;
675 vmmword[reg_src + reg_soff + offt]);
676 mic_prefetcht0(ptr[reg_src + reg_soff + offt
678 mic_prefetcht1(ptr[reg_src + reg_soff + offt
680 uni_vsubps(v, v, vmean);
681 uni_vmulps(v, v, vsqrtvar);
682 if (bdesc_->use_scaleshift()) {
683 uni_vfmadd213ps(v, vgamma, vbeta);
685 if (with_relu_inf_only) {
686 uni_vmaxps(v, v, vzero);
687 } else if (with_relu) {
688 if (isa == avx512_common)
689 fwd_process_relu_avx512_common(v, offt);
691 fwd_process_relu_avx2(v, offt, Vmm(3));
693 if (output_is_aligned) {
695 vmmword[reg_dst + reg_soff + offt], v);
698 vmmword[reg_dst + reg_soff + offt], v);
701 [](size_t base_reg) {UNUSED(base_reg);});
704 Label unaligned_store, end_store;
705 test(reg_dst, vlen - 1);
706 jnz(unaligned_store, T_NEAR);
708 jmp(end_store, T_NEAR);
709 L(unaligned_store); {
715 cmp(reg_coff, reg_coff_max);
721 mov(reg_src, ptr[rsp + stack_off_src]);
722 mov(reg_dst, ptr[rsp + stack_off_dst]);
723 mov(reg_ws, ptr[rsp + stack_off_ws]);
725 xor_(reg_soff, reg_soff);
728 xor_(reg_coff, reg_coff);
730 mov(reg_tmp_off, reg_soff);
735 mov(reg_soff, reg_tmp_off);
736 add(reg_src, vlen / 2);
737 add(reg_dst, vlen / 2);
738 mov(reg_coff, vlen / 2);
742 sub(reg_src, vlen / 2);
743 sub(reg_dst, vlen / 2);
746 add(reg_soff, reg_mb_stride_Bc);
747 cmp(reg_soff, reg_soff_max);
752 void backward_sh_channels() {
755 uni_vmovups_maybe_tail(vmean, mean_ptr());
756 uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
757 uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]);
758 spat_loop(spat_size, 1, 1,
759 [=](size_t base_reg) {
761 for (int i = 0; i < 2; i++) {
762 Vmm v(base_reg * 5 + i);
767 [=](size_t base_reg, size_t i) {
768 Vmm o0 = Vmm(base_reg * 5 + 0);
769 Vmm o1 = Vmm(base_reg * 5 + 1);
770 Vmm t1 = Vmm(base_reg * 5 + 2);
771 Vmm t2 = Vmm(base_reg * 5 + 3);
772 Vmm t3 = Vmm(base_reg * 5 + 4);
773 size_t offt = i * vlen;
774 uni_vmovups(t1, vmmword[reg_src + reg_soff + offt]);
775 uni_vmovups(t2, vmmword[reg_diff_dst + reg_soff
778 if (isa == avx512_common)
779 bwd_process_relu_avx512_common(t2, offt);
780 else if (isa == avx2)
781 bwd_process_relu_avx2(t2, offt, t3);
785 uni_vsubps(t3, vmean, t1, t3);
790 vfnmadd231ps(o0, t3, t2);
792 uni_vaddps(o1, o1, t2);
793 mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
795 mic_prefetcht0(ptr[reg_src + reg_soff + offt
797 mic_prefetcht1(ptr[reg_diff_dst + reg_soff + offt
799 mic_prefetcht1(ptr[reg_src + reg_soff + offt
802 [=](size_t base_reg) {
806 uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0));
807 uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1));
810 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
811 uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1));
813 cmp(reg_coff, reg_coff_max);
818 void backward_diff_channels() {
821 uni_vmovups_maybe_tail(vmean, mean_ptr());
822 uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
823 uni_vaddps(vsqrtvar, vsqrtvar, veps);
824 uni_vsqrtps(vsqrtvar, vsqrtvar);
825 uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
826 if (bdesc_->use_scaleshift())
827 uni_vmovups_maybe_tail(vgamma, gamma_ptr());
828 uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr());
829 uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr());
830 uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar);
831 uni_vdivps(vdiff_beta, vdiff_beta, vchan_size);
832 uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size);
834 auto compute = [=](bool output_is_aligned) {
835 spat_loop(spat_size, unroll_blocks, unroll_regs,
836 [=](size_t base_reg) {UNUSED(base_reg);},
837 [=](size_t base_reg, size_t i) {
838 Vmm v(base_reg * 2 + 0);
839 Vmm t(base_reg * 2 + 1);
840 Vmm t1(base_reg * 2 + 2);
841 size_t offt = i * vlen;
842 uni_vmovups(v, vmmword[reg_diff_dst + reg_soff
845 if (isa == avx512_common)
846 bwd_process_relu_avx512_common(v, offt);
847 else if (isa == avx2)
848 bwd_process_relu_avx2(v, offt, t);
852 if (!bdesc_->use_global_stats()) {
853 uni_vsubps(v, v, vdiff_beta);
854 uni_vmovups(t, vmmword[reg_src + reg_soff
856 uni_vsubps(t, vmean, t, t1);
857 uni_vmulps(t, t, vdiff_gamma);
860 uni_vmulps(v, v, vsqrtvar);
861 if (bdesc_->use_scaleshift()) {
862 uni_vmulps(v, v, vgamma);
864 if (output_is_aligned) {
866 vmmword[reg_diff_src + reg_soff + offt],
870 vmmword[reg_diff_src + reg_soff + offt],
873 mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
875 mic_prefetcht0(ptr[reg_src + reg_soff + offt
877 mic_prefetcht1(ptr[reg_diff_dst + reg_soff
878 + offt + t1_pf_offt]);
879 mic_prefetcht1(ptr[reg_src + reg_soff + offt
882 [=](size_t base_reg) {UNUSED(base_reg);});
885 Label unaligned_store, end_store;
886 test(reg_diff_src, vlen - 1);
887 jnz(unaligned_store, T_NEAR);
889 jmp(end_store, T_NEAR);
890 L(unaligned_store); {
896 cmp(reg_coff, reg_coff_max);
902 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
903 xor_(reg_coff, reg_coff);
904 Label zero_rbuf, sh_spatial;
907 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
908 uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0));
909 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
910 cmp(reg_coff, reg_coff_max);
914 mov(reg_src, ptr[rsp + stack_off_src]);
915 mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]);
917 assert(isa == avx2 || isa == avx512_common);
918 mov(reg_ws, ptr[rsp + stack_off_ws]);
921 xor_(reg_soff, reg_soff);
923 xor_(reg_coff, reg_coff);
925 mov(reg_tmp_off, reg_soff);
927 backward_sh_channels();
929 mov(reg_soff, reg_tmp_off);
930 add(reg_diff_dst, vlen / 2);
931 add(reg_src, vlen / 2);
932 mov(reg_coff, vlen / 2);
933 backward_sh_channels();
934 sub(reg_diff_dst, vlen / 2);
935 sub(reg_src, vlen / 2);
937 add(reg_soff, reg_mb_stride_Bc);
938 cmp(reg_soff, reg_soff_max);
942 mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]);
944 Label no_sh_reduction;
946 mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
948 Label sh_reduction_channels;
949 jne(no_sh_reduction, T_NEAR);
951 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
952 xor_(reg_coff, reg_coff);
953 L(sh_reduction_channels); {
954 mov(reg_roff, reg_coff);
955 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
956 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
957 uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
958 uni_vaddps(vsqrtvar, vsqrtvar, veps);
959 uni_vsqrtps(vsqrtvar, vsqrtvar);
960 uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
961 mov(reg_ctr, reg_nnthr);
962 Label sh_reduction_thrs;
963 L(sh_reduction_thrs); { // TODO: unroll (?)
964 uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]);
965 uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]);
966 add(reg_roff, reg_coff_max);
968 jnz(sh_reduction_thrs);
970 uni_vmulps(Vmm(0), Vmm(0), vsqrtvar);
971 uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0));
972 uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1));
973 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
974 cmp(reg_coff, reg_coff_max);
975 jne(sh_reduction_channels);
981 mov(reg_diff_src, ptr[rsp + stack_off_diff_src]);
983 assert(isa == avx2 || isa == avx512_common);
984 mov(reg_ws, ptr[rsp + stack_off_ws]);
987 xor_(reg_soff, reg_soff);
990 xor_(reg_coff, reg_coff);
992 mov(reg_tmp_off, reg_soff);
994 backward_diff_channels();
996 mov(reg_soff, reg_tmp_off);
997 add(reg_diff_dst, vlen / 2);
998 add(reg_diff_src, vlen / 2);
999 add(reg_src, vlen / 2);
1000 mov(reg_coff, vlen / 2);
1001 backward_diff_channels();
1002 sub(reg_diff_dst, vlen / 2);
1003 sub(reg_diff_src, vlen / 2);
1004 sub(reg_src, vlen / 2);
1006 add(reg_soff, reg_mb_stride_Bc);
1007 cmp(reg_soff, reg_soff_max);
1012 jit_bnorm_t(const batch_normalization_pd_t *bdesc): bdesc_(bdesc) {
1013 static_assert(isa == sse42 || isa == avx2 || isa == avx512_common
1014 || isa == avx512_mic, "unsupported isa");
1016 const int simd_w = isa == sse42 ? 8 :
1017 cpu_isa_traits<isa>::vlen / sizeof(data_t);
1019 bnorm_utils::is_spatial_thr(bdesc_, simd_w, sizeof(data_t));
1021 unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
1022 unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
1026 if (isa == avx512_common)
1027 prepare_tail_mask_avx512_common();
1028 else if (isa == avx2)
1029 prepare_tail_mask_avx2_common();
1031 compute_static_strides();
1032 sub(rsp, stack_size_required);
1033 load_common_params();
1036 if (bdesc_->is_fwd()) {
1037 if (!bdesc_->stats_is_src()) {
1038 compute_mean_variance();
1044 add(rsp, stack_size_required);
1047 ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
1052 template <cpu_isa_t isa>
1053 struct uni_bnorm_driver_t: public c_compatible {
1054 uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc)
1055 : bdesc_(bdesc), ker_(bdesc_)
1057 const int nthrs = mkldnn_get_max_threads();
1058 const int C_PADDED = get_c_padded(bdesc_);
1060 size_t data_size = sizeof(data_t) * bdesc_->MB() * C_PADDED
1061 * bdesc_->D() * bdesc_->H() * bdesc_->W();
1062 l3_size_ = get_cache_size(3, true) * nthrs / 2;
1063 do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0);
1066 ~uni_bnorm_driver_t() {}
1068 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
1069 const batch_normalization_pd_t *bdesc) {
1070 int nthrs = mkldnn_get_max_threads();
1071 int C_PADDED = get_c_padded(bdesc);
1073 int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED;
1074 int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED;
1075 int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs;
1077 scratchpad.book(key_bnorm_tmp_stats, sizeof(data_t) * sbuf_sz);
1078 scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * pbuf_sz);
1079 scratchpad.book(key_bnorm_reduction, sizeof(data_t) * rbuf_sz);
1081 if (mkldnn_thr_syncable()) {
1082 int n_barriers = C_PADDED / simd_w;
1083 scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers);
1087 void exec(int ithr, int nthr, const data_t *src, data_t *diff_src,
1088 data_t *dst, const data_t *diff_dst, const data_t *scale_shift,
1089 data_t *diff_scale_shift, const data_t *mean, const data_t *var,
1090 const uint8_t *ws, const memory_tracking::grantor_t &scratchpad) {
1091 auto sbuf = scratchpad.get<data_t>(key_bnorm_tmp_stats);
1092 auto pbuf = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
1093 auto rbuf = scratchpad.get<data_t>(key_bnorm_reduction);
1094 auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
1096 size_t N = bdesc_->MB();
1097 size_t C = bdesc_->C();
1098 size_t C_PADDED = get_c_padded(bdesc_);
1099 size_t D = bdesc_->D();
1100 size_t H = bdesc_->H();
1101 size_t W = bdesc_->W();
1103 size_t img_size = C_PADDED * D * H * W;
1104 const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
1106 typename jit_bnorm_t<isa>::call_params_t p;
1108 p.eps = bdesc_->desc()->batch_norm_epsilon;
1110 p.spat_size = D * H * W;
1111 p.chan_size = 1.0f * N * p.spat_size;
1113 int C_blks = C_PADDED / simd_w;
1115 int C_ithr{0}, C_nthr{0}, N_ithr{0}, N_nthr{0}, S_ithr{0}, S_nthr{0};
1116 int C_blk_s{0}, C_blk_e{0}, N_s{0}, N_e{0}, S_s{0}, S_e{0};
1118 int C_blks_per_iter{ 1 }, iters{ 1 };
1120 int num_tensors = bdesc_->is_fwd() ? 1 : 2;
1121 size_t working_set_size
1122 = (N * D * H * W * simd_w * sizeof(data_t)) * num_tensors;
1123 bnorm_utils::cache_balance(working_set_size, C_blks,
1124 C_blks_per_iter, iters);
1127 bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
1128 true, ithr, nthr, N, do_blocking_ ? C_blks_per_iter : C_blks,
1129 SP, C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e,
1130 S_ithr, S_nthr, S_s, S_e);
1132 int SP_N_ithr = N_ithr * S_nthr + S_ithr;
1133 int SP_N_nthr = N_nthr * S_nthr;
1134 assert(IMPLICATION(!mkldnn_thr_syncable(), SP_N_nthr == 1));
1136 p.N_ithr = SP_N_ithr;
1137 p.N_nthr = SP_N_nthr;
1139 int last_iter_blks = C_blks - (iters - 1) * C_blks_per_iter;
1141 int global_barriers_per_iter = C_nthr;
1143 for (int it = 0; it < iters; it++) {
1144 if (it == iters - 1 && iters > 1) {
1145 C_blk_s = C_blk_e = N_s = N_e = 0;
1146 spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
1147 spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
1148 C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
1149 N_e, S_ithr, S_nthr, S_s, S_e);
1151 // Update call parameters for JIT, last iteration
1152 p.N_ithr = N_ithr * S_nthr + S_ithr;
1153 p.N_nthr = N_nthr * S_nthr;
1156 global_C_blk_s = do_blocking_ ?
1157 (C_blk_s == -1) ? -1 : it * C_blks_per_iter + C_blk_s :
1160 int C_blks_thr = C_blk_e - C_blk_s;
1161 int N_thr = N_e - N_s;
1163 size_t coff_base = global_C_blk_s * simd_w;
1165 = global_C_blk_s * p.spat_size * simd_w + N_s * img_size;
1167 p.spat_size_loc = S_e - S_s;
1169 p.S_tail = (p.spat_size - S_e) * vlen;
1170 p.coff_max = C_blks_thr * simd_w;
1171 p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base;
1172 p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base;
1173 p.scale_shift = scale_shift + coff_base;
1174 p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_)
1175 ? pbuf : diff_scale_shift) + coff_base;
1177 p.soff_max = N_thr * img_size;
1178 p.src = src + soff_base;
1179 p.dst = dst + soff_base;
1180 p.diff_src = diff_src + soff_base;
1181 p.diff_dst = diff_dst + soff_base;
1182 p.ws = ws + soff_base / 8;
1184 p.mb_stride_Bc = img_size - p.coff_max * p.spat_size;
1186 // use SP_N_nthr which is the same as p.N_nthr except maybe for
1187 // the last iteration.
1188 p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr
1189 + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w;
1190 // rbuf1 and rbuf2 have to be disjoint
1191 p.rbuf2 = p.rbuf1 + C_PADDED * nthr;
1193 (size_t)((it * C_blks_per_iter + C_blk_e) * simd_w) > C;
1196 = do_blocking_ ? it * global_barriers_per_iter : 0;
1197 p.barrier = barriers + C_ithr + iter_bariers;
1198 if (p.soff_max != 0 && p.coff_max != 0)
1203 void init_barriers(const memory_tracking::grantor_t &scratchpad) {
1204 auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
1206 const int n_barriers = get_c_padded(bdesc_) / simd_w;
1207 for (int i = 0; i < n_barriers; ++i)
1208 barrier::ctx_init(&barriers[i]);
1214 simd_w = isa == sse42 ? 8 : cpu_isa_traits<isa>::vlen / sizeof(data_t)
1217 static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) {
1219 && !bdesc->stats_is_src()
1220 && bdesc->desc()->prop_kind == prop_kind::forward_inference;
1223 static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc)
1226 || (bdesc->is_bwd() && !bdesc->use_scaleshift())
1227 || bdesc->desc()->prop_kind == prop_kind::backward_data;
1230 static int get_c_padded(const batch_normalization_pd_t *bdesc)
1231 { return bdesc->src_pd()->desc()->layout_desc.blocking.padding_dims[1]; }
1233 const batch_normalization_pd_t *bdesc_;
1237 jit_bnorm_t<isa> ker_;
1242 using namespace data_type;
1243 using namespace memory_format;
1244 using namespace utils;
1248 template <cpu_isa_t isa>
1249 status_t jit_uni_batch_normalization_fwd_t<isa>::pd_t::init() {
1250 assert(engine()->kind() == engine_kind::cpu);
1251 auto desired_fmt = (ndims() == 4)
1252 ? isa == avx512_common ? nChw16c : nChw8c
1253 : isa == avx512_common ? nCdhw16c : nCdhw8c;
1258 && !has_zero_dim_memory()
1259 && one_of(ndims(), 4, 5)
1260 && desc()->data_desc.data_type == f32
1261 && IMPLICATION(use_scaleshift(),
1262 desc()->data_scaleshift_desc.data_type == f32)
1263 && desc()->data_desc.format == desired_fmt
1264 && (attr()->has_default_values() || this->with_relu_post_op());
1265 if (!ok) return status::unimplemented;
1267 if (is_training() && fuse_bn_relu()) {
1268 if (isa < avx2) return status::unimplemented;
1269 bn_init_default_ws(this, this->workspace_pd_, 1);
1272 if (memory_desc_wrapper(&data_pd_).blocking_desc().padding_dims[1]
1273 != this->C() && isa < avx2)
1274 return status::unimplemented;
1276 if (stats_is_src() || is_training()) {
1277 memory_desc_t stats_d;
1278 dims_t stats_dims = { C() };
1279 mkldnn_memory_desc_init(&stats_d, 1, stats_dims, f32, x);
1280 mean_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
1281 variance_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
1284 auto scratchpad = scratchpad_registry().registrar();
1285 uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
1287 return status::success;
1290 template <cpu_isa_t isa>
1291 jit_uni_batch_normalization_fwd_t<isa>::jit_uni_batch_normalization_fwd_t(
1292 const pd_t *apd, const input_vector &inputs,
1293 const output_vector &outputs)
1294 : cpu_primitive_t(apd, inputs, outputs)
1295 { bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
1297 template <cpu_isa_t isa>
1298 void jit_uni_batch_normalization_fwd_t<isa>::execute(event_t *e) const {
1299 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1300 auto dst = reinterpret_cast<data_t*>(this->memory(0));
1301 auto mean = reinterpret_cast<data_t*>(pd()->stats_is_src()
1302 ? const_cast<char*>(this->input_memory(1))
1304 auto var = reinterpret_cast<data_t*>(pd()->stats_is_src()
1305 ? const_cast<char*>(this->input_memory(2))
1308 auto idx_scale_shift = 1 + 2*pd()->stats_is_src();
1310 reinterpret_cast<const data_t *>(this->input_memory(idx_scale_shift));
1311 auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
1313 auto scratchpad = this->scratchpad();
1315 bnorm_driver_->init_barriers(scratchpad);
1317 parallel(0, [&](const int ithr, const int nthr) {
1318 bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr,
1319 scale_shift, nullptr, mean, var, ws, scratchpad);
1321 e->set_state(event_t::ready);
1324 template <cpu_isa_t isa>
1325 jit_uni_batch_normalization_fwd_t<isa>::~jit_uni_batch_normalization_fwd_t()
1326 { delete bnorm_driver_; }
1330 template <cpu_isa_t isa>
1331 status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init() {
1332 assert(engine()->kind() == engine_kind::cpu);
1333 auto desired_fmt = (ndims() == 4)
1334 ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c
1335 : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c;
1340 && !has_zero_dim_memory()
1341 && one_of(ndims(), 4, 5)
1342 && everyone_is(f32, desc()->data_desc.data_type,
1343 desc()->diff_data_desc.data_type)
1344 && IMPLICATION(use_scaleshift(),
1345 desc()->data_scaleshift_desc.data_type == f32)
1346 && everyone_is(desired_fmt, desc()->diff_data_desc.format,
1347 desc()->data_desc.format)
1348 && attr()->has_default_values();
1349 if (!ok) return status::unimplemented;
1351 if (memory_desc_wrapper(&data_pd_).blocking_desc()
1352 .padding_dims[1] != this->C() && isa < avx2)
1353 return status::unimplemented;
1355 if (fuse_bn_relu()) {
1356 if (isa < avx2) return status::unimplemented;
1357 bn_init_default_ws(this, this->workspace_pd_, 1);
1358 size_t this_ws_sz = memory_desc_wrapper(this->workspace_pd()).size();
1361 && hint_fwd_pd_->workspace_pd()
1362 && memory_desc_wrapper(hint_fwd_pd_->workspace_pd()).size()
1364 if (!ws_ok) return status::unimplemented;
1367 /* TODO: extra checks required */
1369 auto scratchpad = scratchpad_registry().registrar();
1370 uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
1372 return status::success;
1375 template <cpu_isa_t isa>
1376 jit_uni_batch_normalization_bwd_t<isa>::jit_uni_batch_normalization_bwd_t(
1377 const pd_t *apd, const input_vector &inputs,
1378 const output_vector &outputs)
1379 : cpu_primitive_t(apd, inputs, outputs)
1380 { bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
1382 template <cpu_isa_t isa>
1383 void jit_uni_batch_normalization_bwd_t<isa>::execute(event_t *e) const {
1384 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1385 auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
1386 auto var = reinterpret_cast<const data_t *>(this->input_memory(2));
1387 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(3));
1388 auto scale_shift = reinterpret_cast<const data_t *>(this->input_memory(4));
1389 auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
1390 auto diff_scale_shift = reinterpret_cast<data_t *>(this->memory(1));
1391 auto ws = reinterpret_cast<const uint8_t *>(
1392 this->input_memory(pd()->ws_idx()));
1394 auto scratchpad = this->scratchpad();
1396 bnorm_driver_->init_barriers(scratchpad);
1398 parallel(0, [&](const int ithr, const int nthr) {
1399 bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst,
1400 scale_shift, diff_scale_shift, mean, var, ws, scratchpad);
1402 e->set_state(event_t::ready);
1405 template <cpu_isa_t isa>
1406 jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t()
1407 { delete bnorm_driver_; }
1409 /* struct instantiation */
1410 template struct jit_uni_batch_normalization_fwd_t<sse42>;
1411 template struct jit_uni_batch_normalization_bwd_t<sse42>;
1412 template struct jit_uni_batch_normalization_fwd_t<avx2>;
1413 template struct jit_uni_batch_normalization_bwd_t<avx2>;
1414 template struct jit_uni_batch_normalization_fwd_t<avx512_common>;
1415 template struct jit_uni_batch_normalization_bwd_t<avx512_common>;