updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_batch_normalization.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "c_types_map.hpp"
20 #include "math_utils.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "nstl.hpp"
24 #include "type_helpers.hpp"
25 #include "utils.hpp"
26
27 #include "cpu_barrier.hpp"
28 #include "cpu_batch_normalization_utils.hpp"
29 #include "jit_generator.hpp"
30
31 #include "jit_avx512_core_bf16cvt.hpp"
32 #include "jit_uni_batch_normalization.hpp"
33
34 namespace mkldnn {
35 namespace impl {
36 namespace cpu {
37
38 namespace {
39
40 using namespace memory_tracking::names;
41
42 using namespace Xbyak;
43 namespace barrier = simple_barrier;
44
45 typedef float acc_data_t;
46
47 template <cpu_isa_t isa>
48 struct jit_bnorm_t: public jit_generator {
49     struct call_params_t {
50         // keep all sizes at 8 bytes -- jit code expects this
51         size_t N_ithr, N_nthr;
52         size_t coff_max, soff_max;
53         size_t mb_stride_Bc, spat_size, spat_size_loc;
54         size_t S_s, S_tail;
55         size_t is_cblk_tail;
56         acc_data_t chan_size, eps, one;
57         const acc_data_t *scale_shift;
58         const acc_data_t *mean, *var;
59         const acc_data_t *diff_scale_shift;
60         const void *src, *dst;
61         const void *diff_src, *diff_dst;
62         const acc_data_t *rbuf1, *rbuf2;
63         const uint8_t *ws;
64         barrier::ctx_t *barrier;
65     };
66
67     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t)
68
69     /* cpu specific part */
70     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
71                                              isa == avx2, Ymm, Zmm>::type;
72     const AddressFrame &vmmword = (isa == sse42) ? xword :
73                                   (isa == avx2) ? yword : zword;
74
75     const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
76     int vlen_spat_data_; // set by ctor depending on data type (BF16 or FP32);
77
78     const batch_normalization_pd_t *bdesc_;
79     bool is_spatial_thr_;
80     bool is_bf16_;
81
82     void (*ker)(const call_params_t *);
83     void operator()(const call_params_t *p) { (*ker)(p); }
84
85     Reg64 reg_param = abi_param1;
86
87     Reg64 reg_scale_shift = rbx;
88     Reg64 reg_rbuf1 = abi_not_param1;
89     Reg64 reg_rbuf2 = rdx;
90
91     Reg64 reg_mean = rbp;
92     Reg64 reg_var = reg_param;
93     Reg64 reg_diff_scale_shift = rax;
94
95     Reg64 reg_coff = r8;
96     Reg64 reg_coff_max = r9;
97     Reg64 reg_soff = r10;
98     Reg64 reg_soff_max = r11;
99     Reg64 reg_ctr = r12;
100     Reg64 reg_roff = r13;
101
102     Reg64 reg_mb_stride_Bc = r14;
103
104     Reg64 reg_src = r15;
105     Reg64 reg_diff_src = reg_rbuf1;
106     Reg64 reg_dst = rsi;
107     Reg64 reg_diff_dst = reg_dst;
108
109     Reg64 reg_tmp_off = reg_roff;
110
111     // Reuse loop counters
112     Reg64 reg_bar = reg_coff;
113     Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff
114     Reg64 reg_tmp = reg_ctr;
115
116     // Relu section
117     bool with_relu, with_relu_inf_only;
118     Vmm vzero; // is_fwd() ? vdiff_beta : vbeta
119     Reg64 reg_ws = reg_roff;
120     Label l_relu_mask_avx2;
121     Opmask kstore_mask = Opmask(1);
122
123     // channel tail processing
124     Opmask ktail_mask = Opmask(2);
125
126     // FP32->BF16 emulation
127     bf16_emulation_t *bf16_emu_;
128     Reg64 reg_bf16_tmp = reg_tmp;
129     Zmm vcvt_bf16_one = Zmm(16);
130     Zmm vcvt_bf16_eve = Zmm(17);
131     Zmm vcvt_bf16_sel = Zmm(18);
132     Zmm vcvt_bf16_tmp = Zmm(19);
133
134     size_t unroll_blocks;
135     size_t unroll_regs;
136     Vmm vbuf = Vmm(isa == avx512_common ? 20 : 5);
137     Vmm vdiff_beta = Vmm(isa == avx512_common ? 21 : 6);
138     Vmm vdiff_gamma = Vmm(isa == avx512_common ? 22 : 7);
139     Vmm vsqrtvar = Vmm(isa == avx512_common ? 23 : 8);
140     Vmm vone = Vmm(isa == avx512_common ? 24 : 9);
141     Vmm vmean = Vmm(isa == avx512_common ? 25 : 10);
142     Vmm vgamma = Vmm(isa == avx512_common ? 26 : 11);
143     Vmm vbeta = Vmm(isa == avx512_common ? 27 : 12);
144     Vmm veps = Vmm(isa == avx512_common ? 28 : 13);
145     Vmm vchan_size = Vmm(isa == avx512_common ? 29 : 14);
146     Vmm vtail_mask = Vmm(isa == avx512_common ? 30 : 15);
147
148     size_t t0_pf_offt;
149     size_t t1_pf_offt;
150     size_t spat_size;
151     size_t chan_data_offt;
152
153     enum {
154         stack_off_N_nthr = 0,
155         stack_off_N_ithr = 8,
156         stack_off_src = 16,
157         stack_off_dst = 24,
158         stack_off_diff_src = 32,
159         stack_off_diff_dst = 40,
160         stack_off_diff_scale_shift = 48,
161         stack_off_ws = 56,
162         stack_off_barrier = 64,
163         stack_off_spat_size_loc = 72,
164         stack_off_s_s = 80,
165         stack_off_s_tail = 88,
166         stack_off_is_cblk_tail = 96,
167         stack_size_required = 104,
168     };
169
170     bool is_c_padded() const {
171         const memory_desc_wrapper data_d(bdesc_->src_pd());
172         return bdesc_->C() != data_d.blocking_desc().padding_dims[1];
173     }
174
175     void compute_static_strides() {
176         spat_size = bdesc_->D() * bdesc_->W() * bdesc_->H();
177         chan_data_offt = bdesc_->C() * sizeof(acc_data_t);
178
179         if (isa == avx512_mic) {
180             t0_pf_offt = 4096;
181             t1_pf_offt = 0;
182         } else {
183             t0_pf_offt = 0;
184             t1_pf_offt = 0;
185         }
186     }
187
188     void load_common_params() {
189 #       define PARAM_OFF(x) offsetof(call_params_t, x)
190         mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]);
191         if (bdesc_->is_bwd())
192             mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]);
193         mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]);
194         mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]);
195         mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]);
196         shl(reg_coff_max, 2);
197
198         mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
199         mov(reg_scale_shift, ptr[reg_param + PARAM_OFF(scale_shift)]);
200
201         uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]);
202         uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]);
203         uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
204
205         mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]);
206         mov(ptr[rsp + stack_off_N_nthr], reg_tmp);
207         mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]);
208         mov(ptr[rsp + stack_off_N_ithr], reg_tmp);
209         mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]);
210         mov(ptr[rsp + stack_off_src], reg_tmp);
211         mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]);
212         mov(ptr[rsp + stack_off_dst], reg_tmp);
213         mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]);
214         mov(ptr[rsp + stack_off_diff_src], reg_tmp);
215         mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]);
216         mov(ptr[rsp + stack_off_diff_dst], reg_tmp);
217         mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]);
218         mov(ptr[rsp + stack_off_ws], reg_tmp);
219         mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]);
220         mov(ptr[rsp + stack_off_barrier], reg_tmp);
221         if (is_spatial_thr_) {
222             mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]);
223             mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp);
224             mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]);
225             mov(ptr[rsp + stack_off_s_s], reg_tmp);
226             mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]);
227             mov(ptr[rsp + stack_off_s_tail], reg_tmp);
228         }
229         if (is_c_padded()) {
230             mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]);
231             mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp);
232         }
233
234         if (bdesc_->is_fwd()) {
235             mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
236             mov(reg_var, reg_tmp);
237         } else {
238             mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale_shift)]);
239             mov(ptr[rsp + stack_off_diff_scale_shift], reg_tmp);
240             mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
241             mov(reg_var, reg_tmp);
242         }
243 #       undef PARAM_OFF
244     }
245
246     void prepare_tail_mask_avx512_common() {
247         if (!is_c_padded()) return;
248
249         const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
250         const int mask = (1 << tail) - 1;
251
252         Reg32 regw_tmp = reg_tmp.cvt32();
253         mov(regw_tmp, mask);
254         kmovw(ktail_mask, regw_tmp);
255     }
256
257     void prepare_tail_mask_avx2_common() {
258         if (!is_c_padded()) return;
259
260         const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
261         static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
262                 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
263                 0, 0, 0, 0, 0, 0, 0, 0};
264
265         mov(reg_tmp, reinterpret_cast<size_t>(&mask[8 - tail]));
266         vmovups(vtail_mask, ptr[reg_tmp]);
267     }
268
269     void prepare_relu() {
270         with_relu = bdesc_->is_fwd()
271             ? bdesc_->with_relu_post_op() || bdesc_->fuse_bn_relu()
272             : bdesc_->fuse_bn_relu();
273         with_relu_inf_only = with_relu && bdesc_->is_fwd()
274             && !(bdesc_->fuse_bn_relu() && bdesc_->is_training());
275
276         vzero = bdesc_->is_fwd() ? vdiff_beta : vbeta;
277         if (with_relu) {
278             uni_vpxor(vzero, vzero, vzero);
279             if (!bdesc_->is_fwd() && isa == avx2)
280                 prepare_l_relu_mask_avx2();
281         }
282     }
283
284     void prepare_l_relu_mask_avx2() {
285         Label l_mask_after;
286         jmp(l_mask_after);
287         align(32);
288         L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
289         for (int i = 0; i < 8; ++i) dd(1<<i);
290         L(l_mask_after);
291     }
292
293     void fwd_process_relu_avx2(Vmm vdst, int offt, Vmm vstore_mask) {
294         Reg64 reg_store_mask = reg_diff_scale_shift;
295         shr(reg_soff, 5);
296         vcmpps(vstore_mask, vzero, vdst, _cmp_lt_os);
297         vmovmskps(reg_store_mask, vstore_mask);
298         mov(ptr[reg_ws + reg_soff + offt / (1 << 5)], reg_store_mask.cvt8());
299         vblendvps(vdst, vzero, vdst, vstore_mask);
300         shl(reg_soff, 5);
301     }
302
303     void fwd_process_relu_avx512_common(Vmm vdst, int offt) {
304         int bs = 5 - is_bf16_; // bit shift depends on data type
305         shr(reg_soff, bs);
306         vcmpps(kstore_mask, vzero, vdst, _cmp_lt_os);
307         kmovw(ptr[reg_ws + reg_soff + offt / (1 << bs)], kstore_mask);
308         vblendmps(vdst | kstore_mask, vzero, vdst);
309         shl(reg_soff, bs);
310     }
311
312     void bwd_process_relu_avx2(Vmm vdiff_dst, int offt, Vmm vstore_mask) {
313         shr(reg_soff, 5);
314         vpbroadcastb(vstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]);
315         vpand(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
316         vpcmpeqd(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
317         vblendvps(vdiff_dst, vzero, vdiff_dst, vstore_mask);
318         shl(reg_soff, 5);
319     }
320
321     void bwd_process_relu_avx512_common(Vmm vdiff_dst, int offt) {
322         int bs = 5 - is_bf16_; // bit shift depends on data type
323         shr(reg_soff, bs);
324         kmovw(kstore_mask, ptr[reg_ws + reg_soff + offt / (1 << bs)]);
325         vmovups(vdiff_dst | kstore_mask | T_z, vdiff_dst);
326         shl(reg_soff, bs);
327     }
328
329     void uni_vmovups_spat_data(const Operand &dst, const Operand &src) {
330         if (dst.isMEM()) {
331             if (is_bf16_) {
332                 if (mayiuse(avx512_core_bf16))
333                     vcvtneps2bf16(Ymm(src.getIdx()), Zmm(src.getIdx()));
334                 else
335                     bf16_emu_->r_vcvtneps2bf16(
336                             Ymm(src.getIdx()), Zmm(src.getIdx()));
337                 vmovdqu16(dst.getAddress(), Ymm(src.getIdx()));
338             } else {
339                 uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
340             }
341         } else {
342             if (is_bf16_) {
343                 vpmovzxwd(Zmm(dst.getIdx()), src.getAddress());
344                 vpslld(Zmm(dst.getIdx()), Zmm(dst.getIdx()), 0x10);
345             } else {
346                 uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
347             }
348         }
349     }
350
351     void uni_vmovups_tail_avx2_common(const Operand &dst,
352             const Operand &src, Label &l_ret) {
353         if (dst.isMEM()) {
354             vmaskmovps(dst.getAddress(), vtail_mask, Vmm(src.getIdx()));
355         } else {
356             vmaskmovps(Vmm(dst.getIdx()), vtail_mask, src.getAddress());
357         }
358         jmp(l_ret);
359     }
360
361     void uni_vmovups_tail_avx512_common(const Operand &dst,
362             const Operand &src, Label &l_ret) {
363         if (dst.isMEM())
364             uni_vmovups(dst.getAddress() | ktail_mask | T_z, Vmm(src.getIdx()));
365         else
366             uni_vmovups(Vmm(dst.getIdx()) | ktail_mask | T_z, src.getAddress());
367
368         jmp(l_ret);
369     }
370
371     void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) {
372         Label l_no_mask, l_ret;
373
374         if (is_c_padded()) {
375             mov(reg_tmp, ptr[rsp + stack_off_is_cblk_tail]);
376             cmp(reg_tmp, 0);
377             jz(l_no_mask);
378
379             lea(reg_tmp, ptr[reg_coff + vlen]);
380             cmp(reg_tmp, reg_coff_max);
381             jl(l_no_mask);
382             assert(isa == avx512_common || isa == avx2);
383             if (isa == avx512_common)
384                 uni_vmovups_tail_avx512_common(dst, src, l_ret);
385             else if (isa == avx2)
386                 uni_vmovups_tail_avx2_common(dst, src, l_ret);
387         }
388         L(l_no_mask);
389         if (dst.isMEM())
390             uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
391         else
392             uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
393
394         L(l_ret);
395     }
396
397     void barrier() {
398         mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
399         mov(reg_bar, ptr[rsp + stack_off_barrier]);
400         simple_barrier::generate(*this, reg_bar, reg_nnthr);
401     }
402
403     Address mean_ptr(size_t offt = 0) {
404         return vmmword[reg_mean + reg_coff + offt + 0 * chan_data_offt];
405     }
406
407     Address var_ptr(size_t offt = 0) {
408         return vmmword[reg_var + reg_coff + offt + 0 * chan_data_offt];
409     }
410
411     Address diff_gamma_ptr(size_t offt = 0) {
412         return vmmword[reg_diff_scale_shift + reg_coff + offt
413             + 0 * chan_data_offt];
414     }
415
416     Address diff_beta_ptr(size_t offt = 0) {
417         return vmmword[reg_diff_scale_shift + reg_coff + offt
418             + 1 * chan_data_offt];
419      }
420
421     Address gamma_ptr(size_t offt = 0) {
422         return vmmword[reg_scale_shift + reg_coff + offt + 0 * chan_data_offt];
423     }
424
425     Address beta_ptr(size_t offt = 0) {
426         return vmmword[reg_scale_shift + reg_coff + offt + 1 * chan_data_offt];
427     }
428
429     template <typename init_t, typename body_t, typename fini_t>
430     void spat_loop(size_t len, size_t blocks, size_t regs,
431             init_t init, body_t body, fini_t fini) {
432         size_t factor = regs * blocks;
433         size_t loop_unroll = len / factor * factor;
434         size_t loop_tail = len - loop_unroll;
435         size_t num_active_regs = (len < regs) ? len : regs;
436         for (size_t i = 0; i < num_active_regs; i++)
437             init(i);
438         if (loop_unroll) {
439             if (is_spatial_thr_) {
440                 mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]);
441                 add(reg_soff, ptr[rsp + stack_off_s_s]);
442             } else {
443                 mov(reg_ctr, loop_unroll);
444             }
445             Label label;
446             L(label); {
447                 for (size_t i = 0; i < factor; i++) {
448                     size_t base_reg = i % regs;
449                     body(base_reg, i);
450                 }
451                 add(reg_soff, factor * vlen_spat_data_);
452                 sub(reg_ctr, factor);
453                 jnz(label);
454             }
455             if (is_spatial_thr_) {
456                 add(reg_soff, ptr[rsp + stack_off_s_tail]);
457             }
458         }
459
460         for (size_t i = 0; i < loop_tail; i++) {
461             size_t base_reg = i % regs;
462             body(base_reg, i);
463         }
464         if (loop_tail)
465             add(reg_soff, loop_tail * vlen_spat_data_);
466
467         for (size_t i = 0; i < num_active_regs; i++)
468             fini(i);
469     }
470
471     void mean_channels() {
472         Label ch_label;
473         L(ch_label); {
474             uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
475             spat_loop(spat_size, unroll_blocks,
476                 unroll_regs,
477                     [=](size_t base_reg) {
478                         Vmm v = Vmm(base_reg * 2);
479                         if (base_reg)
480                             uni_vpxor(v, v, v);
481                     },
482                     [=](size_t base_reg, size_t i) {
483                         Vmm v0 = Vmm(base_reg * 2 + 0);
484                         Vmm v1 = Vmm(base_reg * 2 + 1);
485                         size_t offt = i * vlen_spat_data_;
486                         uni_vmovups_spat_data(
487                                 v1, vmmword[reg_src + reg_soff + offt]);
488                         uni_vaddps(v0, v0, v1);
489                         mic_prefetcht0(ptr[reg_src + reg_soff + offt
490                                 + t0_pf_offt]);
491                         mic_prefetcht1(ptr[reg_src + reg_soff + offt
492                                 + t1_pf_offt]);
493                     },
494                     [=](size_t base_reg) {
495                         Vmm b = Vmm(0);
496                         Vmm v = Vmm(base_reg * 2);
497                         if (base_reg)
498                             uni_vaddps(b, b, v);
499                     });
500             uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
501
502             add(reg_coff, vlen);
503             cmp(reg_coff, reg_coff_max);
504             jl(ch_label);
505         }
506     }
507
508     void var_channels() {
509         Label ch_label;
510         L(ch_label); {
511             uni_vmovups_maybe_tail(vmean, mean_ptr());
512             uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
513             spat_loop(spat_size, unroll_blocks, unroll_regs,
514                     [=](size_t base_reg) {
515                         Vmm v = Vmm(base_reg * 3);
516                         if (base_reg > 0)
517                             uni_vpxor(v, v, v);
518                     },
519                     [=](size_t base_reg, size_t i) {
520                         Vmm v = Vmm(3 * base_reg);
521                         Vmm vtmp0 = Vmm(3 * base_reg + 1);
522                         Vmm vtmp1 = Vmm(3 * base_reg + 2);
523                         size_t offt = i * vlen_spat_data_;
524                         uni_vmovups_spat_data(
525                                 vtmp0, vmmword[reg_src + reg_soff + offt]);
526                         if (isa == sse42) {
527                             movups(vtmp1, vmean);
528                             subps(vtmp1, vtmp0);
529                         } else {
530                             vsubps(vtmp1, vmean, vtmp0);
531                         }
532                         uni_vfmadd231ps(v, vtmp1, vtmp1);
533
534                         mic_prefetcht0(ptr[reg_src + reg_soff + offt
535                                 + t0_pf_offt]);
536                         mic_prefetcht1(ptr[reg_src + reg_soff + offt
537                                 + t1_pf_offt]);
538                     },
539                     [=](size_t base_reg) {
540                         Vmm b = Vmm(0);
541                         Vmm v = Vmm(base_reg * 3);
542                         if (base_reg)
543                             uni_vaddps(b, b, v);
544                     });
545             uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
546             add(reg_coff, vlen);
547             cmp(reg_coff, reg_coff_max);
548             jl(ch_label);
549         }
550     }
551
552     void compute_mean_variance() {
553         uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
554         xor_(reg_coff, reg_coff);
555         Label zero_rbuf;
556         L(zero_rbuf); {
557             uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
558             add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
559             cmp(reg_coff, reg_coff_max);
560             jne(zero_rbuf);
561         }
562
563         mov(reg_src, ptr[rsp + stack_off_src]);
564
565         xor_(reg_soff, reg_soff);
566         Label mean_spatial;
567         L(mean_spatial); {
568             xor_(reg_coff, reg_coff);
569
570             if (isa == sse42)
571                 mov(reg_tmp_off, reg_soff);
572
573             mean_channels();
574
575             if (isa == sse42) {
576                 mov(reg_soff, reg_tmp_off);
577                 add(reg_src, vlen / 2);
578                 mov(reg_coff, vlen / 2);
579
580                 mean_channels();
581
582                 sub(reg_src, vlen / 2);
583             }
584
585             add(reg_soff, reg_mb_stride_Bc);
586             cmp(reg_soff, reg_soff_max);
587             jne(mean_spatial);
588         }
589
590         Label no_mean_reduction;
591         barrier(); {
592             mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
593             cmp(reg_tmp, 0);
594             jne(no_mean_reduction);
595             mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
596             xor_(reg_coff, reg_coff);
597             Label mean_reduction_channels;
598             L(mean_reduction_channels); {
599                 mov(reg_roff, reg_coff);
600                 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
601                 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
602                 mov(reg_ctr, reg_nnthr);
603                 Label mean_reduction_thrs;
604                 L(mean_reduction_thrs); {
605                     uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
606                     uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0));
607                     add(reg_roff, reg_coff_max);
608                     sub(reg_ctr, 1);
609                     jnz(mean_reduction_thrs);
610                 }
611                 uni_vdivps(Vmm(1), Vmm(1), vchan_size);
612                 uni_vmovups_maybe_tail(mean_ptr(), Vmm(1));
613
614                 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
615
616                 cmp(reg_coff, reg_coff_max);
617                 jne(mean_reduction_channels);
618             }
619         }
620         L(no_mean_reduction);
621         barrier();
622
623         xor_(reg_soff, reg_soff);
624         Label var_spatial;
625         L(var_spatial); {
626             xor_(reg_coff, reg_coff);
627
628             if (isa == sse42)
629                 mov(reg_tmp_off, reg_soff);
630
631             var_channels();
632
633             if (isa == sse42) {
634                 mov(reg_soff, reg_tmp_off);
635                 add(reg_src, vlen / 2);
636                 mov(reg_coff, vlen / 2);
637
638                 var_channels();
639
640                 sub(reg_src, vlen / 2);
641             }
642
643             add(reg_soff, reg_mb_stride_Bc);
644             cmp(reg_soff, reg_soff_max);
645             jne(var_spatial);
646         }
647
648         Label no_var_reduction;
649         barrier(); {
650             mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
651             cmp(reg_tmp, 0);
652             jne(no_var_reduction);
653
654             mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
655             xor_(reg_coff, reg_coff);
656             Label var_reduction_channels;
657             L(var_reduction_channels); {
658                 mov(reg_roff, reg_coff);
659                 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
660                 mov(reg_ctr, reg_nnthr);
661                 Label var_reduction_thrs;
662                 L(var_reduction_thrs); { // TODO: unroll (?)
663                     uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
664                     add(reg_roff, reg_coff_max);
665                     sub(reg_ctr, 1);
666                     jnz(var_reduction_thrs);
667                 }
668                 uni_vdivps(Vmm(1), Vmm(1), vchan_size);
669                 uni_vmovups_maybe_tail(var_ptr(), Vmm(1));
670                 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
671
672                 cmp(reg_coff, reg_coff_max);
673                 jne(var_reduction_channels);
674             }
675         }
676         L(no_var_reduction);
677         barrier();
678     }
679
680     void forward_channels() {
681         Label ch_label;
682         L(ch_label); {
683             uni_vmovups_maybe_tail(vmean, mean_ptr());
684             uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
685             uni_vaddps(vsqrtvar, vsqrtvar, veps);
686             uni_vsqrtps(vsqrtvar, vsqrtvar);
687
688             if (bdesc_->use_scaleshift()) {
689                 uni_vmovups_maybe_tail(vgamma, gamma_ptr());
690                 uni_vmovups_maybe_tail(vbeta, beta_ptr());
691             }
692
693             Vmm vscale = bdesc_->use_scaleshift() ? vgamma : vone;
694             Vmm vdiv = bdesc_->use_scaleshift() ? vgamma : vsqrtvar;
695
696             if (isa == sse42) {
697                 movups(vbuf, vscale);
698                 divps(vbuf, vsqrtvar);
699                 movups(vdiv, vbuf);
700             } else {
701                 vdivps(vdiv, vscale, vsqrtvar);
702             }
703
704             auto compute = [=](bool output_is_aligned) {
705                 spat_loop(spat_size, unroll_blocks, unroll_regs,
706                         [](size_t base_reg) {UNUSED(base_reg);},
707                         [=](size_t base_reg, size_t i) {
708                              Vmm v = Vmm(base_reg);
709                              size_t offt = i * vlen_spat_data_;
710                              uni_vmovups_spat_data(
711                                      v, vmmword[reg_src + reg_soff + offt]);
712                              mic_prefetcht0(ptr[reg_src + reg_soff + offt
713                                      + t0_pf_offt]);
714                              mic_prefetcht1(ptr[reg_src + reg_soff + offt
715                                      + t1_pf_offt]);
716                              uni_vsubps(v, v, vmean);
717                              if (bdesc_->use_scaleshift()) {
718                                  uni_vfmadd213ps(v, vgamma, vbeta);
719                              } else {
720                                 uni_vmulps(v, v, vsqrtvar);
721                              }
722                              if (with_relu_inf_only) {
723                                  uni_vmaxps(v, v, vzero);
724                              } else if (with_relu) {
725                                  if (isa == avx512_common)
726                                      fwd_process_relu_avx512_common(v, offt);
727                                  else
728                                      fwd_process_relu_avx2(v, offt, Vmm(3));
729                              }
730                              if (output_is_aligned) {
731                                  uni_vmovntps(
732                                      vmmword[reg_dst + reg_soff + offt], v);
733                              } else {
734                                  uni_vmovups_spat_data(
735                                          vmmword[reg_dst + reg_soff + offt], v);
736                              }
737                         },
738                         [](size_t base_reg) {UNUSED(base_reg);});
739             };
740
741             if (is_bf16_) {
742                 compute(false); // no mask-able NT store for BF16
743             } else {
744                 Label unaligned_store, end_store;
745                 test(reg_dst, vlen - 1);
746                 jnz(unaligned_store, T_NEAR);
747                 compute(true);
748                 jmp(end_store, T_NEAR);
749                 L(unaligned_store); {
750                     compute(false);
751                 }
752                 L(end_store);
753             }
754
755             add(reg_coff, vlen);
756             cmp(reg_coff, reg_coff_max);
757             jl(ch_label);
758         }
759     }
760
761     void forward() {
762         mov(reg_src, ptr[rsp + stack_off_src]);
763         mov(reg_dst, ptr[rsp + stack_off_dst]);
764         mov(reg_ws, ptr[rsp + stack_off_ws]);
765
766         xor_(reg_soff, reg_soff);
767         Label dst_spatial;
768         L(dst_spatial); {
769             xor_(reg_coff, reg_coff);
770             if (isa == sse42)
771                 mov(reg_tmp_off, reg_soff);
772
773             forward_channels();
774
775             if (isa == sse42) {
776                 mov(reg_soff, reg_tmp_off);
777                 add(reg_src, vlen / 2);
778                 add(reg_dst, vlen / 2);
779                 mov(reg_coff, vlen / 2);
780
781                 forward_channels();
782
783                 sub(reg_src, vlen / 2);
784                 sub(reg_dst, vlen / 2);
785             }
786
787             add(reg_soff, reg_mb_stride_Bc);
788             cmp(reg_soff, reg_soff_max);
789             jnz(dst_spatial);
790         }
791     }
792
793     void backward_sh_channels() {
794         Label sh_channels;
795         L(sh_channels); {
796             uni_vmovups_maybe_tail(vmean, mean_ptr());
797             uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
798             uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]);
799             spat_loop(spat_size, 1, 1,
800                     [=](size_t base_reg) {
801                         if (base_reg > 0) {
802                             for (int i = 0; i < 2; i++) {
803                                 Vmm v(base_reg * 5 + i);
804                                 uni_vpxor(v, v, v);
805                             }
806                         }
807                     },
808                     [=](size_t base_reg, size_t i) {
809                         Vmm o0 = Vmm(base_reg * 5 + 0);
810                         Vmm o1 = Vmm(base_reg * 5 + 1);
811                         Vmm t1 = Vmm(base_reg * 5 + 2);
812                         Vmm t2 = Vmm(base_reg * 5 + 3);
813                         Vmm t3 = Vmm(base_reg * 5 + 4);
814                         size_t offt = i * vlen_spat_data_;
815                         uni_vmovups_spat_data(
816                                 t1, vmmword[reg_src + reg_soff + offt]);
817                         uni_vmovups_spat_data(
818                                 t2, vmmword[reg_diff_dst + reg_soff + offt]);
819                         if (with_relu) {
820                             if (isa == avx512_common)
821                                 bwd_process_relu_avx512_common(t2, offt);
822                             else if (isa == avx2)
823                                 bwd_process_relu_avx2(t2, offt, t3);
824                             else
825                                 assert(false);
826                         }
827                         uni_vsubps(t3, vmean, t1, t3);
828                         if (isa == sse42) {
829                             mulps(t3, t2);
830                             subps(o0, t3);
831                         } else {
832                             vfnmadd231ps(o0, t3, t2);
833                         }
834                         uni_vaddps(o1, o1, t2);
835                         mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
836                                 + t0_pf_offt]);
837                         mic_prefetcht0(ptr[reg_src + reg_soff + offt
838                                 + t0_pf_offt]);
839                         mic_prefetcht1(ptr[reg_diff_dst + reg_soff + offt
840                                 + t1_pf_offt]);
841                         mic_prefetcht1(ptr[reg_src + reg_soff + offt
842                                 + t1_pf_offt]);
843                     },
844                     [=](size_t base_reg) {
845                         Vmm b0 = Vmm(0);
846                         Vmm b1 = Vmm(1);
847                         if (base_reg) {
848                             uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0));
849                             uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1));
850                         }
851                     });
852             uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
853             uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1));
854             add(reg_coff, vlen);
855             cmp(reg_coff, reg_coff_max);
856             jl(sh_channels);
857         }
858     }
859
860     void backward_diff_channels() {
861         Label diff_channels;
862         L(diff_channels); {
863             uni_vmovups_maybe_tail(vmean, mean_ptr());
864             uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
865             uni_vaddps(vsqrtvar, vsqrtvar, veps);
866             uni_vsqrtps(vsqrtvar, vsqrtvar);
867             uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
868             if (bdesc_->use_scaleshift())
869                 uni_vmovups_maybe_tail(vgamma, gamma_ptr());
870             uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr());
871             uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr());
872             uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar);
873             uni_vdivps(vdiff_beta, vdiff_beta, vchan_size);
874             uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size);
875
876             auto compute = [=](bool output_is_aligned) {
877                 spat_loop(spat_size, unroll_blocks, unroll_regs,
878                         [=](size_t base_reg) {UNUSED(base_reg);},
879                         [=](size_t base_reg, size_t i) {
880                             Vmm v(base_reg * 2 + 0);
881                             Vmm t(base_reg * 2 + 1);
882                             Vmm t1(base_reg * 2 + 2);
883                             size_t offt = i * vlen_spat_data_;
884                             uni_vmovups_spat_data(
885                                     v, vmmword[reg_diff_dst + reg_soff + offt]);
886                             if (with_relu) {
887                                 if (isa == avx512_common)
888                                     bwd_process_relu_avx512_common(v, offt);
889                                 else if (isa == avx2)
890                                     bwd_process_relu_avx2(v, offt, t);
891                                 else
892                                     assert(false);
893                             }
894                             if (!bdesc_->use_global_stats()) {
895                                 uni_vsubps(v, v, vdiff_beta);
896                                 uni_vmovups_spat_data(
897                                         t, vmmword[reg_src + reg_soff + offt]);
898                                 uni_vsubps(t, vmean, t, t1);
899                                 uni_vmulps(t, t, vdiff_gamma);
900                                 uni_vaddps(v, v, t);
901                             }
902                             uni_vmulps(v, v, vsqrtvar);
903                             if (bdesc_->use_scaleshift()) {
904                                 uni_vmulps(v, v, vgamma);
905                             }
906                             if (output_is_aligned) {
907                                 uni_vmovntps(
908                                     vmmword[reg_diff_src + reg_soff + offt],
909                                     v);
910                             } else {
911                                 uni_vmovups_spat_data(
912                                         vmmword[reg_diff_src + reg_soff + offt],
913                                         v);
914                             }
915                             mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
916                                     + t0_pf_offt]);
917                             mic_prefetcht0(ptr[reg_src + reg_soff + offt
918                                     + t0_pf_offt]);
919                             mic_prefetcht1(ptr[reg_diff_dst + reg_soff
920                                     + offt + t1_pf_offt]);
921                             mic_prefetcht1(ptr[reg_src + reg_soff + offt
922                                     + t1_pf_offt]);
923                         },
924                         [=](size_t base_reg) {UNUSED(base_reg);});
925             };
926
927             if (is_bf16_) {
928                 compute(false); // no mask-able NT store for BF16
929             } else {
930                 Label unaligned_store, end_store;
931                 test(reg_diff_src, vlen - 1);
932                 jnz(unaligned_store, T_NEAR);
933                 compute(true);
934                 jmp(end_store, T_NEAR);
935                 L(unaligned_store); {
936                     compute(false);
937                 }
938                 L(end_store);
939             }
940
941             add(reg_coff, vlen);
942             cmp(reg_coff, reg_coff_max);
943             jl(diff_channels);
944         }
945     }
946
947     void backward() {
948         uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
949         xor_(reg_coff, reg_coff);
950         Label zero_rbuf, sh_spatial;
951
952         L(zero_rbuf); {
953             uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
954             uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0));
955             add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
956             cmp(reg_coff, reg_coff_max);
957             jne(zero_rbuf);
958         }
959
960         mov(reg_src, ptr[rsp + stack_off_src]);
961         mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]);
962         if (with_relu) {
963             assert(isa == avx2 || isa == avx512_common);
964             mov(reg_ws, ptr[rsp + stack_off_ws]);
965         }
966
967         xor_(reg_soff, reg_soff);
968         L(sh_spatial); {
969             xor_(reg_coff, reg_coff);
970             if (isa == sse42) {
971                 mov(reg_tmp_off, reg_soff);
972             }
973             backward_sh_channels();
974             if (isa == sse42) {
975                 mov(reg_soff, reg_tmp_off);
976                 add(reg_diff_dst, vlen / 2);
977                 add(reg_src, vlen / 2);
978                 mov(reg_coff, vlen / 2);
979                 backward_sh_channels();
980                 sub(reg_diff_dst, vlen / 2);
981                 sub(reg_src, vlen / 2);
982             }
983             add(reg_soff, reg_mb_stride_Bc);
984             cmp(reg_soff, reg_soff_max);
985             jne(sh_spatial);
986         }
987
988         mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]);
989
990         Label no_sh_reduction;
991         barrier(); {
992             mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
993             cmp(reg_tmp, 0);
994             Label sh_reduction_channels;
995             jne(no_sh_reduction, T_NEAR);
996
997             mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
998             xor_(reg_coff, reg_coff);
999             L(sh_reduction_channels); {
1000                 mov(reg_roff, reg_coff);
1001                 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
1002                 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
1003                 uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
1004                 uni_vaddps(vsqrtvar, vsqrtvar, veps);
1005                 uni_vsqrtps(vsqrtvar, vsqrtvar);
1006                 uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
1007                 mov(reg_ctr, reg_nnthr);
1008                 Label sh_reduction_thrs;
1009                 L(sh_reduction_thrs); { // TODO: unroll (?)
1010                     uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]);
1011                     uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]);
1012                     add(reg_roff, reg_coff_max);
1013                     sub(reg_ctr, 1);
1014                     jnz(sh_reduction_thrs);
1015                 }
1016                 uni_vmulps(Vmm(0), Vmm(0), vsqrtvar);
1017                 uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0));
1018                 uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1));
1019                 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
1020                 cmp(reg_coff, reg_coff_max);
1021                 jne(sh_reduction_channels);
1022             }
1023         }
1024         L(no_sh_reduction);
1025         barrier();
1026
1027         mov(reg_diff_src, ptr[rsp + stack_off_diff_src]);
1028         if (with_relu) {
1029             assert(isa == avx2 || isa == avx512_common);
1030             mov(reg_ws, ptr[rsp + stack_off_ws]);
1031         }
1032
1033         xor_(reg_soff, reg_soff);
1034         Label diff_spatial;
1035         L(diff_spatial); {
1036             xor_(reg_coff, reg_coff);
1037             if (isa == sse42) {
1038                 mov(reg_tmp_off, reg_soff);
1039             }
1040             backward_diff_channels();
1041             if (isa == sse42) {
1042                 mov(reg_soff, reg_tmp_off);
1043                 add(reg_diff_dst, vlen / 2);
1044                 add(reg_diff_src, vlen / 2);
1045                 add(reg_src, vlen / 2);
1046                 mov(reg_coff, vlen / 2);
1047                 backward_diff_channels();
1048                 sub(reg_diff_dst, vlen / 2);
1049                 sub(reg_diff_src, vlen / 2);
1050                 sub(reg_src, vlen / 2);
1051             }
1052             add(reg_soff, reg_mb_stride_Bc);
1053             cmp(reg_soff, reg_soff_max);
1054             jne(diff_spatial);
1055         }
1056     }
1057
1058     jit_bnorm_t(const batch_normalization_pd_t *bdesc)
1059         : bdesc_(bdesc), bf16_emu_() {
1060         static_assert(isa == sse42 || isa == avx2 || isa == avx512_common
1061                 || isa == avx512_mic, "unsupported isa");
1062
1063         is_bf16_ = bdesc_->desc()->data_desc.data_type == data_type::bf16;
1064         size_t dt_size = is_bf16_ ? types::data_type_size(data_type::bf16)
1065                                   : sizeof(acc_data_t);
1066         const int simd_w = isa == sse42 ? 8 :
1067             cpu_isa_traits<isa>::vlen / sizeof(float);
1068         is_spatial_thr_ =
1069             bnorm_utils::is_spatial_thr(bdesc_, simd_w, dt_size);
1070         vlen_spat_data_ = vlen / (1 + is_bf16_); // 32B of BF16 -> 64B of FP32
1071
1072         unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
1073         unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
1074
1075         preamble();
1076
1077         if (is_bf16_) {
1078             // init emulation of bfloat16 operations
1079             if (!mayiuse(avx512_core_bf16)) {
1080                 bf16_emu_ = new bf16_emulation_t(this, vcvt_bf16_one,
1081                         vcvt_bf16_eve, vcvt_bf16_sel, reg_bf16_tmp,
1082                         vcvt_bf16_tmp, vcvt_bf16_tmp);
1083                 bf16_emu_->init_vcvtneps2bf16();
1084             }
1085         }
1086
1087         if (isa == avx512_common)
1088             prepare_tail_mask_avx512_common();
1089         else if (isa == avx2)
1090             prepare_tail_mask_avx2_common();
1091
1092         compute_static_strides();
1093         sub(rsp, stack_size_required);
1094         load_common_params();
1095         prepare_relu();
1096
1097         if (bdesc_->is_fwd()) {
1098             if (!bdesc_->stats_is_src()) {
1099                 compute_mean_variance();
1100             }
1101             forward();
1102         } else {
1103             backward();
1104         }
1105         add(rsp, stack_size_required);
1106         postamble();
1107
1108         ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
1109                     this->getCode()));
1110     }
1111
1112     ~jit_bnorm_t() { delete bf16_emu_; }
1113 };
1114
1115 template <cpu_isa_t isa>
1116 struct uni_bnorm_driver_t: public c_compatible {
1117     uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc)
1118         : bdesc_(bdesc), ker_(bdesc_) {
1119         const int nthrs = mkldnn_get_max_threads();
1120         const dim_t C_PADDED = get_c_padded(bdesc_);
1121
1122         bool is_bf16 = bdesc_->desc()->data_desc.data_type == data_type::bf16;
1123         dt_size_ = is_bf16 ? types::data_type_size(data_type::bf16)
1124                            : sizeof(acc_data_t);
1125         size_t data_size = dt_size_ * bdesc_->MB() * C_PADDED
1126             * bdesc_->D() * bdesc_->H() * bdesc_->W();
1127         l3_size_ = get_cache_size(3, true) * nthrs / 2;
1128         do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0);
1129     }
1130
1131     ~uni_bnorm_driver_t() {}
1132
1133     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
1134             const batch_normalization_pd_t *bdesc) {
1135         int nthrs = mkldnn_get_max_threads();
1136         dim_t C_PADDED = get_c_padded(bdesc);
1137
1138         int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED;
1139         int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED;
1140         int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs;
1141
1142         scratchpad.book(key_bnorm_tmp_stats, sizeof(acc_data_t) * sbuf_sz);
1143         scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(acc_data_t) * pbuf_sz);
1144         scratchpad.book(key_bnorm_reduction, sizeof(acc_data_t) * rbuf_sz);
1145
1146         if (mkldnn_thr_syncable()) {
1147             int n_barriers = C_PADDED / simd_w;
1148             scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers);
1149         }
1150     }
1151
1152     void exec(int ithr, int nthr, const void *src, void *diff_src, void *dst,
1153             const void *diff_dst, const acc_data_t *scale_shift,
1154             acc_data_t *diff_scale_shift, const acc_data_t *mean,
1155             const acc_data_t *var, const uint8_t *ws,
1156             const memory_tracking::grantor_t &scratchpad) {
1157         auto sbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_stats);
1158         auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss);
1159         auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction);
1160         auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
1161
1162         size_t N = bdesc_->MB();
1163         size_t C = bdesc_->C();
1164         size_t C_PADDED = get_c_padded(bdesc_);
1165         size_t D = bdesc_->D();
1166         size_t H = bdesc_->H();
1167         size_t W = bdesc_->W();
1168         int SP = D * H * W;
1169         size_t img_size = C_PADDED * D * H * W;
1170         const int vlen_spat_data = ker_.vlen_spat_data_;
1171
1172         typename jit_bnorm_t<isa>::call_params_t p;
1173
1174         p.eps = bdesc_->desc()->batch_norm_epsilon;
1175         p.one = 1.0f;
1176         p.spat_size = D * H * W;
1177         p.chan_size = 1.0f * N * p.spat_size;
1178
1179         int C_blks = C_PADDED / simd_w;
1180
1181         int C_ithr{0}, C_nthr{0}, N_ithr{0}, N_nthr{0}, S_ithr{0}, S_nthr{0};
1182         int C_blk_s{0}, C_blk_e{0}, N_s{0}, N_e{0}, S_s{0}, S_e{0};
1183
1184         int C_blks_per_iter{ 1 }, iters{ 1 };
1185         if (do_blocking_) {
1186             int num_tensors = bdesc_->is_fwd() ? 1 : 2;
1187             size_t working_set_size
1188                     = dt_size_ * (N * D * H * W * simd_w) * num_tensors;
1189             bnorm_utils::cache_balance(working_set_size, C_blks,
1190                 C_blks_per_iter, iters);
1191         }
1192
1193         bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
1194                 true, ithr, nthr, N, do_blocking_ ? C_blks_per_iter : C_blks,
1195                 SP, C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e,
1196                 S_ithr, S_nthr, S_s, S_e);
1197
1198         int SP_N_ithr = N_ithr * S_nthr + S_ithr;
1199         int SP_N_nthr = N_nthr * S_nthr;
1200         assert(IMPLICATION(!mkldnn_thr_syncable(), SP_N_nthr == 1));
1201
1202         p.N_ithr = SP_N_ithr;
1203         p.N_nthr = SP_N_nthr;
1204
1205         int last_iter_blks = C_blks - (iters - 1) * C_blks_per_iter;
1206         int global_C_blk_s;
1207         int global_barriers_per_iter = C_nthr;
1208
1209         for (int it = 0; it < iters; it++) {
1210             if (it == iters - 1 && iters > 1) {
1211                 C_blk_s = C_blk_e = N_s = N_e = 0;
1212                 spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
1213                         spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
1214                         C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
1215                         N_e, S_ithr, S_nthr, S_s, S_e);
1216
1217                 // Update call parameters for JIT, last iteration
1218                 p.N_ithr = N_ithr * S_nthr + S_ithr;
1219                 p.N_nthr = N_nthr * S_nthr;
1220             }
1221
1222             global_C_blk_s = do_blocking_ ?
1223                     (C_blk_s == -1) ? -1 : it * C_blks_per_iter + C_blk_s :
1224                     C_blk_s;
1225
1226             int C_blks_thr = C_blk_e - C_blk_s;
1227             int N_thr = N_e - N_s;
1228
1229             size_t coff_base = global_C_blk_s * simd_w;
1230             size_t soff_base
1231                     = global_C_blk_s * p.spat_size * simd_w + N_s * img_size;
1232
1233             p.spat_size_loc = S_e - S_s;
1234             p.S_s = S_s * vlen_spat_data;
1235             p.S_tail = (p.spat_size - S_e) * vlen_spat_data;
1236             p.coff_max = C_blks_thr * simd_w;
1237             p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base;
1238             p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base;
1239             p.scale_shift = scale_shift + coff_base;
1240             p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_)
1241                     ? pbuf : diff_scale_shift) + coff_base;
1242
1243             p.soff_max = dt_size_ * N_thr * img_size;
1244             p.src = (void *)((char *)src + soff_base * dt_size_);
1245             p.dst = (void *)((char *)dst + soff_base * dt_size_);
1246             p.diff_src = (void *)((char *)diff_src + soff_base * dt_size_);
1247             p.diff_dst = (void *)((char *)diff_dst + soff_base * dt_size_);
1248             p.ws = ws + soff_base / 8;
1249
1250             p.mb_stride_Bc = dt_size_ * (img_size - p.coff_max * p.spat_size);
1251
1252             // use SP_N_nthr which is the same as p.N_nthr except maybe for
1253             // the last iteration.
1254             p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr
1255                     + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w;
1256             // rbuf1 and rbuf2 have to be disjoint
1257             p.rbuf2 = p.rbuf1 + C_PADDED * nthr;
1258             p.is_cblk_tail =
1259                 (size_t)((it * C_blks_per_iter + C_blk_e) * simd_w) > C;
1260
1261             size_t iter_bariers
1262                     = do_blocking_ ? it * global_barriers_per_iter : 0;
1263             p.barrier = barriers + C_ithr + iter_bariers;
1264             if (p.soff_max != 0 && p.coff_max != 0)
1265                 ker_(&p);
1266         }
1267     }
1268
1269     void init_barriers(const memory_tracking::grantor_t &scratchpad) {
1270         auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
1271         if (barriers) {
1272             const int n_barriers = get_c_padded(bdesc_) / simd_w;
1273             for (int i = 0; i < n_barriers; ++i)
1274                 barrier::ctx_init(&barriers[i]);
1275         }
1276     }
1277
1278 private:
1279     enum {
1280         simd_w = isa == sse42 ? 8 : cpu_isa_traits<isa>::vlen
1281                         / sizeof(acc_data_t) // BF16 will expand to FP32
1282     };
1283
1284     static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) {
1285         return true
1286             && !bdesc->stats_is_src()
1287             && bdesc->desc()->prop_kind == prop_kind::forward_inference;
1288     }
1289
1290     static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc)
1291     {
1292         return false
1293             || (bdesc->is_bwd() && !bdesc->use_scaleshift())
1294             || bdesc->desc()->prop_kind == prop_kind::backward_data;
1295     }
1296
1297     static dim_t get_c_padded(const batch_normalization_pd_t *bdesc)
1298     { return bdesc->src_pd()->desc()->layout_desc.blocking.padding_dims[1]; }
1299
1300     const batch_normalization_pd_t *bdesc_;
1301     jit_bnorm_t<isa> ker_;
1302     bool do_blocking_;
1303     size_t l3_size_;
1304
1305     acc_data_t *buf_, *sbuf_, *rbuf_, *pbuf_;
1306
1307     size_t dt_size_;
1308 };
1309
1310 }
1311
1312 using namespace data_type;
1313 using namespace memory_format;
1314 using namespace utils;
1315
1316 /* fwd */
1317 template <cpu_isa_t isa, data_type_t d_type>
1318 status_t jit_uni_batch_normalization_fwd_t<isa, d_type>::pd_t::init() {
1319     assert(engine()->kind() == engine_kind::cpu);
1320     auto desired_fmt = (ndims() == 4)
1321         ? isa == avx512_common ? nChw16c : nChw8c
1322         : isa == avx512_common ? nCdhw16c : nCdhw8c;
1323
1324     bool ok = true
1325         && mayiuse(isa)
1326         && is_fwd()
1327         && !has_zero_dim_memory()
1328         && one_of(ndims(), 4, 5)
1329         && desc()->data_desc.data_type == d_type
1330         && IMPLICATION(d_type == bf16, mayiuse(avx512_core))
1331         && IMPLICATION(use_scaleshift(),
1332                 desc()->data_scaleshift_desc.data_type == f32)
1333         && desc()->data_desc.format == desired_fmt
1334         && (attr()->has_default_values() || this->with_relu_post_op());
1335     if (!ok) return status::unimplemented;
1336
1337     if (is_training() && fuse_bn_relu()) {
1338         if (isa < avx2) return status::unimplemented;
1339         bn_init_default_ws(this, this->workspace_pd_, 1);
1340     }
1341
1342     if (memory_desc_wrapper(&data_pd_).blocking_desc().padding_dims[1]
1343             != this->C() && isa < avx2)
1344         return status::unimplemented;
1345
1346     if (stats_is_src() || is_training()) {
1347         memory_desc_t stats_d;
1348         dims_t stats_dims = { C() };
1349         mkldnn_memory_desc_init(&stats_d, 1, stats_dims, f32, x);
1350         mean_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
1351         variance_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
1352     }
1353
1354     auto scratchpad = scratchpad_registry().registrar();
1355     uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
1356
1357     return status::success;
1358 }
1359
1360 template <cpu_isa_t isa, data_type_t d_type>
1361 jit_uni_batch_normalization_fwd_t<isa,
1362         d_type>::jit_uni_batch_normalization_fwd_t(const pd_t *apd,
1363         const input_vector &inputs, const output_vector &outputs)
1364     : cpu_primitive_t(apd, inputs, outputs) {
1365     bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd());
1366 }
1367
1368 template <cpu_isa_t isa, data_type_t d_type>
1369 void jit_uni_batch_normalization_fwd_t<isa, d_type>::execute(event_t *e) const {
1370     auto src = reinterpret_cast<const void *>(this->input_memory(0));
1371     auto dst = reinterpret_cast<void *>(this->memory(0));
1372     auto mean = reinterpret_cast<acc_data_t *>(pd()->stats_is_src()
1373                     ? const_cast<char *>(this->input_memory(1))
1374                     : this->memory(1));
1375     auto var = reinterpret_cast<acc_data_t *>(pd()->stats_is_src()
1376                     ? const_cast<char *>(this->input_memory(2))
1377                     : this->memory(2));
1378
1379     auto idx_scale_shift = 1 + 2*pd()->stats_is_src();
1380     auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
1381
1382     auto scratchpad = this->scratchpad();
1383
1384     bnorm_driver_->init_barriers(scratchpad);
1385     auto scale_shift = reinterpret_cast<const acc_data_t *>(
1386             this->input_memory(idx_scale_shift));
1387
1388     parallel(0, (size_t)mkldnn_get_max_threads(), [&](const int ithr, const int nthr) {
1389         bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr,
1390                 scale_shift, nullptr, mean, var, ws, scratchpad);
1391     });
1392     e->set_state(event_t::ready);
1393 }
1394
1395 template <cpu_isa_t isa, data_type_t d_type>
1396 jit_uni_batch_normalization_fwd_t<isa,
1397         d_type>::~jit_uni_batch_normalization_fwd_t() {
1398     delete bnorm_driver_;
1399 }
1400
1401 template <cpu_isa_t isa, data_type_t d_type>
1402 status_t jit_uni_batch_normalization_bwd_t<isa, d_type>::pd_t::init() {
1403     assert(engine()->kind() == engine_kind::cpu);
1404     auto desired_fmt = (ndims() == 4)
1405         ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c
1406         : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c;
1407
1408     bool ok = true
1409         && mayiuse(isa)
1410         && is_bwd()
1411         && !has_zero_dim_memory()
1412         && one_of(ndims(), 4, 5)
1413         && everyone_is(d_type, desc()->data_desc.data_type,
1414                 desc()->diff_data_desc.data_type)
1415         && IMPLICATION(d_type == bf16, mayiuse(avx512_core))
1416         && IMPLICATION(use_scaleshift(), utils::everyone_is(f32,
1417                 desc()->data_scaleshift_desc.data_type,
1418                 desc()->diff_data_scaleshift_desc.data_type))
1419         && everyone_is(desired_fmt, desc()->diff_data_desc.format,
1420                 desc()->data_desc.format)
1421         && attr()->has_default_values();
1422     if (!ok) return status::unimplemented;
1423
1424     if (memory_desc_wrapper(&data_pd_).blocking_desc()
1425             .padding_dims[1] != this->C() && isa < avx2)
1426         return status::unimplemented;
1427
1428     if (fuse_bn_relu()) {
1429         if (isa < avx2) return status::unimplemented;
1430         bn_init_default_ws(this, this->workspace_pd_, 1);
1431         size_t this_ws_sz = memory_desc_wrapper(this->workspace_pd()).size();
1432
1433         bool ws_ok = true
1434             && hint_fwd_pd_->workspace_pd()
1435             && memory_desc_wrapper(hint_fwd_pd_->workspace_pd()).size()
1436             == this_ws_sz;
1437         if (!ws_ok) return status::unimplemented;
1438     }
1439
1440     /* TODO: extra checks required */
1441
1442     auto scratchpad = scratchpad_registry().registrar();
1443     uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
1444
1445     return status::success;
1446 }
1447 template <cpu_isa_t isa, data_type_t d_type>
1448 jit_uni_batch_normalization_bwd_t<isa,
1449         d_type>::jit_uni_batch_normalization_bwd_t(const pd_t *apd,
1450         const input_vector &inputs, const output_vector &outputs)
1451     : cpu_primitive_t(apd, inputs, outputs) {
1452     bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd());
1453 }
1454
1455 template <cpu_isa_t isa, data_type_t d_type>
1456 void jit_uni_batch_normalization_bwd_t<isa, d_type>::execute(event_t *e) const {
1457     auto src = reinterpret_cast<const void *>(this->input_memory(0));
1458     auto mean = reinterpret_cast<const acc_data_t *>(this->input_memory(1));
1459     auto var = reinterpret_cast<const acc_data_t *>(this->input_memory(2));
1460     auto diff_dst = reinterpret_cast<const void *>(this->input_memory(3));
1461     auto scale_shift
1462             = reinterpret_cast<const acc_data_t *>(this->input_memory(4));
1463     auto diff_src = reinterpret_cast<void *>(this->memory(0));
1464     auto diff_scale_shift = reinterpret_cast<acc_data_t *>(this->memory(1));
1465     auto ws = reinterpret_cast<const uint8_t *>(
1466             this->input_memory(pd()->ws_idx()));
1467
1468     auto scratchpad = this->scratchpad();
1469
1470     bnorm_driver_->init_barriers(scratchpad);
1471
1472     parallel(0, (size_t)mkldnn_get_max_threads(), [&](const int ithr, const int nthr) {
1473         bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst,
1474                 scale_shift, diff_scale_shift, mean, var, ws, scratchpad);
1475     });
1476     e->set_state(event_t::ready);
1477 }
1478
1479 template <cpu_isa_t isa, data_type_t d_type>
1480 jit_uni_batch_normalization_bwd_t<isa,
1481         d_type>::~jit_uni_batch_normalization_bwd_t() {
1482     delete bnorm_driver_;
1483 }
1484
1485 /* struct instantiation */
1486 template struct jit_uni_batch_normalization_fwd_t<sse42, data_type::f32>;
1487 template struct jit_uni_batch_normalization_bwd_t<sse42, data_type::f32>;
1488 template struct jit_uni_batch_normalization_fwd_t<avx2, data_type::f32>;
1489 template struct jit_uni_batch_normalization_bwd_t<avx2, data_type::f32>;
1490 template struct jit_uni_batch_normalization_fwd_t<avx512_common, data_type::f32>;
1491 template struct jit_uni_batch_normalization_bwd_t<avx512_common, data_type::f32>;
1492 template struct jit_uni_batch_normalization_fwd_t<avx512_common, data_type::bf16>;
1493 template struct jit_uni_batch_normalization_bwd_t<avx512_common, data_type::bf16>;
1494
1495 }
1496 }
1497 }