Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_batch_normalization.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "c_types_map.hpp"
20 #include "math_utils.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "nstl.hpp"
24 #include "type_helpers.hpp"
25 #include "utils.hpp"
26
27 #include "cpu_barrier.hpp"
28 #include "cpu_batch_normalization_utils.hpp"
29 #include "jit_generator.hpp"
30
31 #include "jit_uni_batch_normalization.hpp"
32
33 namespace mkldnn {
34 namespace impl {
35 namespace cpu {
36
37 namespace {
38
39 using namespace memory_tracking::names;
40
41 using namespace Xbyak;
42 namespace barrier = simple_barrier;
43
44 typedef float data_t;
45
46 template <cpu_isa_t isa>
47 struct jit_bnorm_t: public jit_generator {
48     struct call_params_t {
49         // keep all sizes at 8 bytes -- jit code expects this
50         size_t N_ithr, N_nthr;
51         size_t coff_max, soff_max;
52         size_t mb_stride_Bc, spat_size, spat_size_loc;
53         size_t S_s, S_tail;
54         size_t is_cblk_tail;
55         data_t chan_size, eps, one;
56         const data_t *scale_shift;
57         const data_t *mean, *var;
58         const data_t *diff_scale_shift;
59         const data_t *src, *dst;
60         const data_t *diff_src, *diff_dst;
61         const data_t *rbuf1, *rbuf2;
62         const uint8_t *ws;
63         barrier::ctx_t *barrier;
64     };
65
66     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t)
67
68     /* cpu specific part */
69     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
70                                              isa == avx2, Ymm, Zmm>::type;
71     const AddressFrame &vmmword = (isa == sse42) ? xword :
72                                   (isa == avx2) ? yword : zword;
73
74     const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
75
76     const batch_normalization_pd_t *bdesc_;
77     bool is_spatial_thr_;
78
79     void (*ker)(const call_params_t *);
80     void operator()(const call_params_t *p) { (*ker)(p); }
81
82     Reg64 reg_param = abi_param1;
83
84     Reg64 reg_scale_shift = rbx;
85     Reg64 reg_rbuf1 = abi_not_param1;
86     Reg64 reg_rbuf2 = rdx;
87
88     Reg64 reg_mean = rbp;
89     Reg64 reg_var = reg_param;
90     Reg64 reg_diff_scale_shift = rax;
91
92     Reg64 reg_coff = r8;
93     Reg64 reg_coff_max = r9;
94     Reg64 reg_soff = r10;
95     Reg64 reg_soff_max = r11;
96     Reg64 reg_ctr = r12;
97     Reg64 reg_roff = r13;
98
99     Reg64 reg_mb_stride_Bc = r14;
100
101     Reg64 reg_src = r15;
102     Reg64 reg_diff_src = reg_rbuf1;
103     Reg64 reg_dst = rsi;
104     Reg64 reg_diff_dst = reg_dst;
105
106     Reg64 reg_tmp_off = reg_roff;
107
108     // Reuse loop counters
109     Reg64 reg_bar = reg_coff;
110     Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff
111     Reg64 reg_tmp = reg_ctr;
112
113     // Relu section
114     bool with_relu, with_relu_inf_only;
115     Vmm vzero; // is_fwd() ? vdiff_beta : vbeta
116     Reg64 reg_ws = reg_roff;
117     Label l_relu_mask_avx2;
118     Opmask kstore_mask = Opmask(1);
119
120     // channel tail processing
121     Opmask ktail_mask = Opmask(2);
122
123     size_t unroll_blocks;
124     size_t unroll_regs;
125     Vmm vbuf = Vmm(isa == avx512_common ? 20 : 5);
126     Vmm vdiff_beta = Vmm(isa == avx512_common ? 21 : 6);
127     Vmm vdiff_gamma = Vmm(isa == avx512_common ? 22 : 7);
128     Vmm vsqrtvar = Vmm(isa == avx512_common ? 23 : 8);
129     Vmm vone = Vmm(isa == avx512_common ? 24 : 9);
130     Vmm vmean = Vmm(isa == avx512_common ? 25 : 10);
131     Vmm vgamma = Vmm(isa == avx512_common ? 26 : 11);
132     Vmm vbeta = Vmm(isa == avx512_common ? 27 : 12);
133     Vmm veps = Vmm(isa == avx512_common ? 28 : 13);
134     Vmm vchan_size = Vmm(isa == avx512_common ? 29 : 14);
135     Vmm vtail_mask = Vmm(isa == avx512_common ? 30 : 15);
136
137     size_t t0_pf_offt;
138     size_t t1_pf_offt;
139     size_t spat_size;
140     size_t chan_data_offt;
141
142     enum {
143         stack_off_N_nthr = 0,
144         stack_off_N_ithr = 8,
145         stack_off_src = 16,
146         stack_off_dst = 24,
147         stack_off_diff_src = 32,
148         stack_off_diff_dst = 40,
149         stack_off_diff_scale_shift = 48,
150         stack_off_ws = 56,
151         stack_off_barrier = 64,
152         stack_off_spat_size_loc = 72,
153         stack_off_s_s = 80,
154         stack_off_s_tail = 88,
155         stack_off_is_cblk_tail = 96,
156         stack_size_required = 104,
157     };
158
159     bool is_c_padded() const {
160         const memory_desc_wrapper data_d(bdesc_->src_pd());
161         return bdesc_->C() != data_d.blocking_desc().padding_dims[1];
162     }
163
164     void compute_static_strides() {
165         spat_size = bdesc_->D() * bdesc_->W() * bdesc_->H();
166         chan_data_offt = bdesc_->C() * sizeof(data_t);
167
168         if (isa == avx512_mic) {
169             t0_pf_offt = 4096;
170             t1_pf_offt = 0;
171         } else {
172             t0_pf_offt = 0;
173             t1_pf_offt = 0;
174         }
175     }
176
177     void load_common_params() {
178 #       define PARAM_OFF(x) offsetof(call_params_t, x)
179         mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]);
180         if (bdesc_->is_bwd())
181             mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]);
182         mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]);
183         mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]);
184         mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]);
185         shl(reg_coff_max, 2);
186         shl(reg_soff_max, 2);
187         shl(reg_mb_stride_Bc, 2);
188
189         mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
190         mov(reg_scale_shift, ptr[reg_param + PARAM_OFF(scale_shift)]);
191
192         uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]);
193         uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]);
194         uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
195
196         mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]);
197         mov(ptr[rsp + stack_off_N_nthr], reg_tmp);
198         mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]);
199         mov(ptr[rsp + stack_off_N_ithr], reg_tmp);
200         mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]);
201         mov(ptr[rsp + stack_off_src], reg_tmp);
202         mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]);
203         mov(ptr[rsp + stack_off_dst], reg_tmp);
204         mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]);
205         mov(ptr[rsp + stack_off_diff_src], reg_tmp);
206         mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]);
207         mov(ptr[rsp + stack_off_diff_dst], reg_tmp);
208         mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]);
209         mov(ptr[rsp + stack_off_ws], reg_tmp);
210         mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]);
211         mov(ptr[rsp + stack_off_barrier], reg_tmp);
212         if (is_spatial_thr_) {
213             mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]);
214             mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp);
215             mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]);
216             mov(ptr[rsp + stack_off_s_s], reg_tmp);
217             mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]);
218             mov(ptr[rsp + stack_off_s_tail], reg_tmp);
219         }
220         if (is_c_padded()) {
221             mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]);
222             mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp);
223         }
224
225         if (bdesc_->is_fwd()) {
226             mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
227             mov(reg_var, reg_tmp);
228         } else {
229             mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale_shift)]);
230             mov(ptr[rsp + stack_off_diff_scale_shift], reg_tmp);
231             mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
232             mov(reg_var, reg_tmp);
233         }
234 #       undef PARAM_OFF
235     }
236
237     void prepare_tail_mask_avx512_common() {
238         if (!is_c_padded()) return;
239
240         const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
241         const int mask = (1 << tail) - 1;
242
243         Reg32 regw_tmp = reg_tmp.cvt32();
244         mov(regw_tmp, mask);
245         kmovw(ktail_mask, regw_tmp);
246     }
247
248     void prepare_tail_mask_avx2_common() {
249         if (!is_c_padded()) return;
250
251         const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
252         static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
253                 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
254                 0, 0, 0, 0, 0, 0, 0, 0};
255
256         mov(reg_tmp, reinterpret_cast<size_t>(&mask[8 - tail]));
257         vmovups(vtail_mask, ptr[reg_tmp]);
258     }
259
260     void prepare_relu() {
261         with_relu = bdesc_->is_fwd()
262             ? bdesc_->with_relu_post_op() || bdesc_->fuse_bn_relu()
263             : bdesc_->fuse_bn_relu();
264         with_relu_inf_only = with_relu && bdesc_->is_fwd()
265             && !(bdesc_->fuse_bn_relu() && bdesc_->is_training());
266
267         vzero = bdesc_->is_fwd() ? vdiff_beta : vbeta;
268         if (with_relu) {
269             uni_vpxor(vzero, vzero, vzero);
270             if (!bdesc_->is_fwd() && isa == avx2)
271                 prepare_l_relu_mask_avx2();
272         }
273     }
274
275     void prepare_l_relu_mask_avx2() {
276         Label l_mask_after;
277         jmp(l_mask_after);
278         align(32);
279         L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
280         for (int i = 0; i < 8; ++i) dd(1<<i);
281         L(l_mask_after);
282     }
283
284     void fwd_process_relu_avx2(Vmm vdst, int offt, Vmm vstore_mask) {
285         Reg64 reg_store_mask = reg_diff_scale_shift;
286         shr(reg_soff, 5);
287         vcmpps(vstore_mask, vzero, vdst, _cmp_lt_os);
288         vmovmskps(reg_store_mask, vstore_mask);
289         mov(ptr[reg_ws + reg_soff + offt / (1 << 5)], reg_store_mask.cvt8());
290         vblendvps(vdst, vzero, vdst, vstore_mask);
291         shl(reg_soff, 5);
292     }
293
294     void fwd_process_relu_avx512_common(Vmm vdst, int offt) {
295         shr(reg_soff, 5);
296         vcmpps(kstore_mask, vzero, vdst, _cmp_lt_os);
297         kmovw(ptr[reg_ws + reg_soff + offt / (1 << 5)], kstore_mask);
298         vblendmps(vdst | kstore_mask, vzero, vdst);
299         shl(reg_soff, 5);
300     }
301
302     void bwd_process_relu_avx2(Vmm vdiff_dst, int offt, Vmm vstore_mask) {
303         shr(reg_soff, 5);
304         vpbroadcastb(vstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]);
305         vpand(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
306         vpcmpeqd(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
307         vblendvps(vdiff_dst, vzero, vdiff_dst, vstore_mask);
308         shl(reg_soff, 5);
309     }
310
311     void bwd_process_relu_avx512_common(Vmm vdiff_dst, int offt) {
312         shr(reg_soff, 5);
313         kmovw(kstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]);
314         vmovups(vdiff_dst | kstore_mask | T_z, vdiff_dst);
315         shl(reg_soff, 5);
316     }
317
318     void uni_vmovups_tail_avx2_common(const Operand &dst,
319             const Operand &src, Label &l_ret) {
320         if (dst.isMEM()) {
321             vmaskmovps(dst.getAddress(), vtail_mask, Vmm(src.getIdx()));
322         } else {
323             vmaskmovps(Vmm(dst.getIdx()), vtail_mask, src.getAddress());
324         }
325         jmp(l_ret);
326     }
327
328     void uni_vmovups_tail_avx512_common(const Operand &dst,
329             const Operand &src, Label &l_ret) {
330         if (dst.isMEM())
331             uni_vmovups(dst.getAddress() | ktail_mask | T_z, Vmm(src.getIdx()));
332         else
333             uni_vmovups(Vmm(dst.getIdx()) | ktail_mask | T_z, src.getAddress());
334
335         jmp(l_ret);
336     }
337
338     void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) {
339         Label l_no_mask, l_ret;
340
341         if (is_c_padded()) {
342             mov(reg_tmp, ptr[rsp + stack_off_is_cblk_tail]);
343             cmp(reg_tmp, 0);
344             jz(l_no_mask);
345
346             lea(reg_tmp, ptr[reg_coff + vlen]);
347             cmp(reg_tmp, reg_coff_max);
348             jl(l_no_mask);
349             assert(isa == avx512_common || isa == avx2);
350             if (isa == avx512_common)
351                 uni_vmovups_tail_avx512_common(dst, src, l_ret);
352             else if (isa == avx2)
353                 uni_vmovups_tail_avx2_common(dst, src, l_ret);
354         }
355         L(l_no_mask);
356         if (dst.isMEM())
357             uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
358         else
359             uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
360
361         L(l_ret);
362     }
363
364     void barrier() {
365         mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
366         mov(reg_bar, ptr[rsp + stack_off_barrier]);
367         simple_barrier::generate(*this, reg_bar, reg_nnthr);
368     }
369
370     Address mean_ptr(size_t offt = 0) {
371         return vmmword[reg_mean + reg_coff + offt + 0 * chan_data_offt];
372     }
373
374     Address var_ptr(size_t offt = 0) {
375         return vmmword[reg_var + reg_coff + offt + 0 * chan_data_offt];
376     }
377
378     Address diff_gamma_ptr(size_t offt = 0) {
379         return vmmword[reg_diff_scale_shift + reg_coff + offt
380             + 0 * chan_data_offt];
381     }
382
383     Address diff_beta_ptr(size_t offt = 0) {
384         return vmmword[reg_diff_scale_shift + reg_coff + offt
385             + 1 * chan_data_offt];
386      }
387
388     Address gamma_ptr(size_t offt = 0) {
389         return vmmword[reg_scale_shift + reg_coff + offt + 0 * chan_data_offt];
390     }
391
392     Address beta_ptr(size_t offt = 0) {
393         return vmmword[reg_scale_shift + reg_coff + offt + 1 * chan_data_offt];
394     }
395
396     template <typename init_t, typename body_t, typename fini_t>
397     void spat_loop(size_t len, size_t blocks, size_t regs,
398             init_t init, body_t body, fini_t fini) {
399         size_t factor = regs * blocks;
400         size_t loop_unroll = len / factor * factor;
401         size_t loop_tail = len - loop_unroll;
402         size_t num_active_regs = (len < regs) ? len : regs;
403         for (size_t i = 0; i < num_active_regs; i++)
404             init(i);
405         if (loop_unroll) {
406             if (is_spatial_thr_) {
407                 mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]);
408                 add(reg_soff, ptr[rsp + stack_off_s_s]);
409             } else {
410                 mov(reg_ctr, loop_unroll);
411             }
412             Label label;
413             L(label); {
414                 for (size_t i = 0; i < factor; i++) {
415                     size_t base_reg = i % regs;
416                     body(base_reg, i);
417                 }
418                 add(reg_soff, factor * vlen);
419                 sub(reg_ctr, factor);
420                 jnz(label);
421             }
422             if (is_spatial_thr_) {
423                 add(reg_soff, ptr[rsp + stack_off_s_tail]);
424             }
425         }
426
427         for (size_t i = 0; i < loop_tail; i++) {
428             size_t base_reg = i % regs;
429             body(base_reg, i);
430         }
431         if (loop_tail)
432             add(reg_soff, loop_tail * vlen);
433
434         for (size_t i = 0; i < num_active_regs; i++)
435             fini(i);
436     }
437
438     void mean_channels() {
439         Label ch_label;
440         L(ch_label); {
441             uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
442             spat_loop(spat_size, unroll_blocks,
443                 unroll_regs,
444                     [=](size_t base_reg) {
445                         Vmm v = Vmm(base_reg * 2);
446                         if (base_reg)
447                             uni_vpxor(v, v, v);
448                     },
449                     [=](size_t base_reg, size_t i) {
450                         Vmm v0 = Vmm(base_reg * 2 + 0);
451                         Vmm v1 = Vmm(base_reg * 2 + 1);
452                         size_t offt = i * vlen;
453                         uni_vmovups(v1,
454                             vmmword[reg_src + reg_soff + offt]);
455                         uni_vaddps(v0, v0, v1);
456                         mic_prefetcht0(ptr[reg_src + reg_soff + offt
457                                 + t0_pf_offt]);
458                         mic_prefetcht1(ptr[reg_src + reg_soff + offt
459                                 + t1_pf_offt]);
460                     },
461                     [=](size_t base_reg) {
462                         Vmm b = Vmm(0);
463                         Vmm v = Vmm(base_reg * 2);
464                         if (base_reg)
465                             uni_vaddps(b, b, v);
466                     });
467             uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
468
469             add(reg_coff, vlen);
470             cmp(reg_coff, reg_coff_max);
471             jl(ch_label);
472         }
473     }
474
475     void var_channels() {
476         Label ch_label;
477         L(ch_label); {
478             uni_vmovups_maybe_tail(vmean, mean_ptr());
479             uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
480             spat_loop(spat_size, unroll_blocks, unroll_regs,
481                     [=](size_t base_reg) {
482                         Vmm v = Vmm(base_reg * 3);
483                         if (base_reg > 0)
484                             uni_vpxor(v, v, v);
485                     },
486                     [=](size_t base_reg, size_t i) {
487                         Vmm v = Vmm(3 * base_reg);
488                         Vmm vtmp0 = Vmm(3 * base_reg + 1);
489                         Vmm vtmp1 = Vmm(3 * base_reg + 2);
490                         size_t offt = i * vlen;
491                         uni_vmovups(vtmp0,
492                             vmmword[reg_src + reg_soff + offt]);
493                         if (isa == sse42) {
494                             movups(vtmp1, vmean);
495                             subps(vtmp1, vtmp0);
496                         } else {
497                             vsubps(vtmp1, vmean, vtmp0);
498                         }
499                         uni_vfmadd231ps(v, vtmp1, vtmp1);
500
501                         mic_prefetcht0(ptr[reg_src + reg_soff + offt
502                                 + t0_pf_offt]);
503                         mic_prefetcht1(ptr[reg_src + reg_soff + offt
504                                 + t1_pf_offt]);
505                     },
506                     [=](size_t base_reg) {
507                         Vmm b = Vmm(0);
508                         Vmm v = Vmm(base_reg * 3);
509                         if (base_reg)
510                             uni_vaddps(b, b, v);
511                     });
512             uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
513             add(reg_coff, vlen);
514             cmp(reg_coff, reg_coff_max);
515             jl(ch_label);
516         }
517     }
518
519     void compute_mean_variance() {
520         uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
521         xor_(reg_coff, reg_coff);
522         Label zero_rbuf;
523         L(zero_rbuf); {
524             uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
525             add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
526             cmp(reg_coff, reg_coff_max);
527             jne(zero_rbuf);
528         }
529
530         mov(reg_src, ptr[rsp + stack_off_src]);
531
532         xor_(reg_soff, reg_soff);
533         Label mean_spatial;
534         L(mean_spatial); {
535             xor_(reg_coff, reg_coff);
536
537             if (isa == sse42)
538                 mov(reg_tmp_off, reg_soff);
539
540             mean_channels();
541
542             if (isa == sse42) {
543                 mov(reg_soff, reg_tmp_off);
544                 add(reg_src, vlen / 2);
545                 mov(reg_coff, vlen / 2);
546
547                 mean_channels();
548
549                 sub(reg_src, vlen / 2);
550             }
551
552             add(reg_soff, reg_mb_stride_Bc);
553             cmp(reg_soff, reg_soff_max);
554             jne(mean_spatial);
555         }
556
557         Label no_mean_reduction;
558         barrier(); {
559             mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
560             cmp(reg_tmp, 0);
561             jne(no_mean_reduction);
562             mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
563             xor_(reg_coff, reg_coff);
564             Label mean_reduction_channels;
565             L(mean_reduction_channels); {
566                 mov(reg_roff, reg_coff);
567                 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
568                 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
569                 mov(reg_ctr, reg_nnthr);
570                 Label mean_reduction_thrs;
571                 L(mean_reduction_thrs); {
572                     uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
573                     uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0));
574                     add(reg_roff, reg_coff_max);
575                     sub(reg_ctr, 1);
576                     jnz(mean_reduction_thrs);
577                 }
578                 uni_vdivps(Vmm(1), Vmm(1), vchan_size);
579                 uni_vmovups_maybe_tail(mean_ptr(), Vmm(1));
580
581                 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
582
583                 cmp(reg_coff, reg_coff_max);
584                 jne(mean_reduction_channels);
585             }
586         }
587         L(no_mean_reduction);
588         barrier();
589
590         xor_(reg_soff, reg_soff);
591         Label var_spatial;
592         L(var_spatial); {
593             xor_(reg_coff, reg_coff);
594
595             if (isa == sse42)
596                 mov(reg_tmp_off, reg_soff);
597
598             var_channels();
599
600             if (isa == sse42) {
601                 mov(reg_soff, reg_tmp_off);
602                 add(reg_src, vlen / 2);
603                 mov(reg_coff, vlen / 2);
604
605                 var_channels();
606
607                 sub(reg_src, vlen / 2);
608             }
609
610             add(reg_soff, reg_mb_stride_Bc);
611             cmp(reg_soff, reg_soff_max);
612             jne(var_spatial);
613         }
614
615         Label no_var_reduction;
616         barrier(); {
617             mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
618             cmp(reg_tmp, 0);
619             jne(no_var_reduction);
620
621             mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
622             xor_(reg_coff, reg_coff);
623             Label var_reduction_channels;
624             L(var_reduction_channels); {
625                 mov(reg_roff, reg_coff);
626                 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
627                 mov(reg_ctr, reg_nnthr);
628                 Label var_reduction_thrs;
629                 L(var_reduction_thrs); { // TODO: unroll (?)
630                     uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
631                     add(reg_roff, reg_coff_max);
632                     sub(reg_ctr, 1);
633                     jnz(var_reduction_thrs);
634                 }
635                 uni_vdivps(Vmm(1), Vmm(1), vchan_size);
636                 uni_vmovups_maybe_tail(var_ptr(), Vmm(1));
637                 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
638
639                 cmp(reg_coff, reg_coff_max);
640                 jne(var_reduction_channels);
641             }
642         }
643         L(no_var_reduction);
644         barrier();
645     }
646
647     void forward_channels() {
648         Label ch_label;
649         L(ch_label); {
650             uni_vmovups_maybe_tail(vmean, mean_ptr());
651             uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
652             uni_vaddps(vsqrtvar, vsqrtvar, veps);
653             uni_vsqrtps(vsqrtvar, vsqrtvar);
654
655             if (isa == sse42) {
656                 movups(vbuf, vone);
657                 divps(vbuf, vsqrtvar);
658                 movups(vsqrtvar, vbuf);
659             } else {
660                 vdivps(vsqrtvar, vone, vsqrtvar);
661             }
662
663             if (bdesc_->use_scaleshift()) {
664                 uni_vmovups_maybe_tail(vgamma, gamma_ptr());
665                 uni_vmovups_maybe_tail(vbeta, beta_ptr());
666             }
667
668             auto compute = [=](bool output_is_aligned) {
669                 spat_loop(spat_size, unroll_blocks, unroll_regs,
670                         [](size_t base_reg) {UNUSED(base_reg);},
671                         [=](size_t base_reg, size_t i) {
672                              Vmm v = Vmm(base_reg);
673                              size_t offt = i * vlen;
674                              uni_vmovups(v,
675                                  vmmword[reg_src + reg_soff + offt]);
676                              mic_prefetcht0(ptr[reg_src + reg_soff + offt
677                                      + t0_pf_offt]);
678                              mic_prefetcht1(ptr[reg_src + reg_soff + offt
679                                      + t1_pf_offt]);
680                              uni_vsubps(v, v, vmean);
681                              uni_vmulps(v, v, vsqrtvar);
682                              if (bdesc_->use_scaleshift()) {
683                                  uni_vfmadd213ps(v, vgamma, vbeta);
684                              }
685                              if (with_relu_inf_only) {
686                                  uni_vmaxps(v, v, vzero);
687                              } else if (with_relu) {
688                                  if (isa == avx512_common)
689                                      fwd_process_relu_avx512_common(v, offt);
690                                  else
691                                      fwd_process_relu_avx2(v, offt, Vmm(3));
692                              }
693                              if (output_is_aligned) {
694                                  uni_vmovntps(
695                                      vmmword[reg_dst + reg_soff + offt], v);
696                              } else {
697                                  uni_vmovups(
698                                      vmmword[reg_dst + reg_soff + offt], v);
699                              }
700                         },
701                         [](size_t base_reg) {UNUSED(base_reg);});
702             };
703
704             Label unaligned_store, end_store;
705             test(reg_dst, vlen - 1);
706             jnz(unaligned_store, T_NEAR);
707             compute(true);
708             jmp(end_store, T_NEAR);
709             L(unaligned_store); {
710                 compute(false);
711             }
712             L(end_store);
713
714             add(reg_coff, vlen);
715             cmp(reg_coff, reg_coff_max);
716             jl(ch_label);
717         }
718     }
719
720     void forward() {
721         mov(reg_src, ptr[rsp + stack_off_src]);
722         mov(reg_dst, ptr[rsp + stack_off_dst]);
723         mov(reg_ws, ptr[rsp + stack_off_ws]);
724
725         xor_(reg_soff, reg_soff);
726         Label dst_spatial;
727         L(dst_spatial); {
728             xor_(reg_coff, reg_coff);
729             if (isa == sse42)
730                 mov(reg_tmp_off, reg_soff);
731
732             forward_channels();
733
734             if (isa == sse42) {
735                 mov(reg_soff, reg_tmp_off);
736                 add(reg_src, vlen / 2);
737                 add(reg_dst, vlen / 2);
738                 mov(reg_coff, vlen / 2);
739
740                 forward_channels();
741
742                 sub(reg_src, vlen / 2);
743                 sub(reg_dst, vlen / 2);
744             }
745
746             add(reg_soff, reg_mb_stride_Bc);
747             cmp(reg_soff, reg_soff_max);
748             jnz(dst_spatial);
749         }
750     }
751
752     void backward_sh_channels() {
753         Label sh_channels;
754         L(sh_channels); {
755             uni_vmovups_maybe_tail(vmean, mean_ptr());
756             uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
757             uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]);
758             spat_loop(spat_size, 1, 1,
759                     [=](size_t base_reg) {
760                         if (base_reg > 0) {
761                             for (int i = 0; i < 2; i++) {
762                                 Vmm v(base_reg * 5 + i);
763                                 uni_vpxor(v, v, v);
764                             }
765                         }
766                     },
767                     [=](size_t base_reg, size_t i) {
768                         Vmm o0 = Vmm(base_reg * 5 + 0);
769                         Vmm o1 = Vmm(base_reg * 5 + 1);
770                         Vmm t1 = Vmm(base_reg * 5 + 2);
771                         Vmm t2 = Vmm(base_reg * 5 + 3);
772                         Vmm t3 = Vmm(base_reg * 5 + 4);
773                         size_t offt = i * vlen;
774                         uni_vmovups(t1, vmmword[reg_src + reg_soff + offt]);
775                         uni_vmovups(t2, vmmword[reg_diff_dst + reg_soff
776                                 + offt]);
777                         if (with_relu) {
778                             if (isa == avx512_common)
779                                 bwd_process_relu_avx512_common(t2, offt);
780                             else if (isa == avx2)
781                                 bwd_process_relu_avx2(t2, offt, t3);
782                             else
783                                 assert(false);
784                         }
785                         uni_vsubps(t3, vmean, t1, t3);
786                         if (isa == sse42) {
787                             mulps(t3, t2);
788                             subps(o0, t3);
789                         } else {
790                             vfnmadd231ps(o0, t3, t2);
791                         }
792                         uni_vaddps(o1, o1, t2);
793                         mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
794                                 + t0_pf_offt]);
795                         mic_prefetcht0(ptr[reg_src + reg_soff + offt
796                                 + t0_pf_offt]);
797                         mic_prefetcht1(ptr[reg_diff_dst + reg_soff + offt
798                                 + t1_pf_offt]);
799                         mic_prefetcht1(ptr[reg_src + reg_soff + offt
800                                 + t1_pf_offt]);
801                     },
802                     [=](size_t base_reg) {
803                         Vmm b0 = Vmm(0);
804                         Vmm b1 = Vmm(1);
805                         if (base_reg) {
806                             uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0));
807                             uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1));
808                         }
809                     });
810             uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
811             uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1));
812             add(reg_coff, vlen);
813             cmp(reg_coff, reg_coff_max);
814             jl(sh_channels);
815         }
816     }
817
818     void backward_diff_channels() {
819         Label diff_channels;
820         L(diff_channels); {
821             uni_vmovups_maybe_tail(vmean, mean_ptr());
822             uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
823             uni_vaddps(vsqrtvar, vsqrtvar, veps);
824             uni_vsqrtps(vsqrtvar, vsqrtvar);
825             uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
826             if (bdesc_->use_scaleshift())
827                 uni_vmovups_maybe_tail(vgamma, gamma_ptr());
828             uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr());
829             uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr());
830             uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar);
831             uni_vdivps(vdiff_beta, vdiff_beta, vchan_size);
832             uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size);
833
834             auto compute = [=](bool output_is_aligned) {
835                 spat_loop(spat_size, unroll_blocks, unroll_regs,
836                         [=](size_t base_reg) {UNUSED(base_reg);},
837                         [=](size_t base_reg, size_t i) {
838                             Vmm v(base_reg * 2 + 0);
839                             Vmm t(base_reg * 2 + 1);
840                             Vmm t1(base_reg * 2 + 2);
841                             size_t offt = i * vlen;
842                             uni_vmovups(v, vmmword[reg_diff_dst + reg_soff
843                                     + offt]);
844                             if (with_relu) {
845                                 if (isa == avx512_common)
846                                     bwd_process_relu_avx512_common(v, offt);
847                                 else if (isa == avx2)
848                                     bwd_process_relu_avx2(v, offt, t);
849                                 else
850                                     assert(false);
851                             }
852                             if (!bdesc_->use_global_stats()) {
853                                 uni_vsubps(v, v, vdiff_beta);
854                                 uni_vmovups(t, vmmword[reg_src + reg_soff
855                                         + offt]);
856                                 uni_vsubps(t, vmean, t, t1);
857                                 uni_vmulps(t, t, vdiff_gamma);
858                                 uni_vaddps(v, v, t);
859                             }
860                             uni_vmulps(v, v, vsqrtvar);
861                             if (bdesc_->use_scaleshift()) {
862                                 uni_vmulps(v, v, vgamma);
863                             }
864                             if (output_is_aligned) {
865                                 uni_vmovntps(
866                                     vmmword[reg_diff_src + reg_soff + offt],
867                                     v);
868                             } else {
869                                 uni_vmovups(
870                                     vmmword[reg_diff_src + reg_soff + offt],
871                                     v);
872                             }
873                             mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
874                                     + t0_pf_offt]);
875                             mic_prefetcht0(ptr[reg_src + reg_soff + offt
876                                     + t0_pf_offt]);
877                             mic_prefetcht1(ptr[reg_diff_dst + reg_soff
878                                     + offt + t1_pf_offt]);
879                             mic_prefetcht1(ptr[reg_src + reg_soff + offt
880                                     + t1_pf_offt]);
881                         },
882                         [=](size_t base_reg) {UNUSED(base_reg);});
883             };
884
885             Label unaligned_store, end_store;
886             test(reg_diff_src, vlen - 1);
887             jnz(unaligned_store, T_NEAR);
888             compute(true);
889             jmp(end_store, T_NEAR);
890             L(unaligned_store); {
891                 compute(false);
892             }
893             L(end_store);
894
895             add(reg_coff, vlen);
896             cmp(reg_coff, reg_coff_max);
897             jl(diff_channels);
898         }
899     }
900
901     void backward() {
902         uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
903         xor_(reg_coff, reg_coff);
904         Label zero_rbuf, sh_spatial;
905
906         L(zero_rbuf); {
907             uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
908             uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0));
909             add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
910             cmp(reg_coff, reg_coff_max);
911             jne(zero_rbuf);
912         }
913
914         mov(reg_src, ptr[rsp + stack_off_src]);
915         mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]);
916         if (with_relu) {
917             assert(isa == avx2 || isa == avx512_common);
918             mov(reg_ws, ptr[rsp + stack_off_ws]);
919         }
920
921         xor_(reg_soff, reg_soff);
922         L(sh_spatial); {
923             xor_(reg_coff, reg_coff);
924             if (isa == sse42) {
925                 mov(reg_tmp_off, reg_soff);
926             }
927             backward_sh_channels();
928             if (isa == sse42) {
929                 mov(reg_soff, reg_tmp_off);
930                 add(reg_diff_dst, vlen / 2);
931                 add(reg_src, vlen / 2);
932                 mov(reg_coff, vlen / 2);
933                 backward_sh_channels();
934                 sub(reg_diff_dst, vlen / 2);
935                 sub(reg_src, vlen / 2);
936             }
937             add(reg_soff, reg_mb_stride_Bc);
938             cmp(reg_soff, reg_soff_max);
939             jne(sh_spatial);
940         }
941
942         mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]);
943
944         Label no_sh_reduction;
945         barrier(); {
946             mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
947             cmp(reg_tmp, 0);
948             Label sh_reduction_channels;
949             jne(no_sh_reduction, T_NEAR);
950
951             mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
952             xor_(reg_coff, reg_coff);
953             L(sh_reduction_channels); {
954                 mov(reg_roff, reg_coff);
955                 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
956                 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
957                 uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
958                 uni_vaddps(vsqrtvar, vsqrtvar, veps);
959                 uni_vsqrtps(vsqrtvar, vsqrtvar);
960                 uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
961                 mov(reg_ctr, reg_nnthr);
962                 Label sh_reduction_thrs;
963                 L(sh_reduction_thrs); { // TODO: unroll (?)
964                     uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]);
965                     uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]);
966                     add(reg_roff, reg_coff_max);
967                     sub(reg_ctr, 1);
968                     jnz(sh_reduction_thrs);
969                 }
970                 uni_vmulps(Vmm(0), Vmm(0), vsqrtvar);
971                 uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0));
972                 uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1));
973                 add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
974                 cmp(reg_coff, reg_coff_max);
975                 jne(sh_reduction_channels);
976             }
977         }
978         L(no_sh_reduction);
979         barrier();
980
981         mov(reg_diff_src, ptr[rsp + stack_off_diff_src]);
982         if (with_relu) {
983             assert(isa == avx2 || isa == avx512_common);
984             mov(reg_ws, ptr[rsp + stack_off_ws]);
985         }
986
987         xor_(reg_soff, reg_soff);
988         Label diff_spatial;
989         L(diff_spatial); {
990             xor_(reg_coff, reg_coff);
991             if (isa == sse42) {
992                 mov(reg_tmp_off, reg_soff);
993             }
994             backward_diff_channels();
995             if (isa == sse42) {
996                 mov(reg_soff, reg_tmp_off);
997                 add(reg_diff_dst, vlen / 2);
998                 add(reg_diff_src, vlen / 2);
999                 add(reg_src, vlen / 2);
1000                 mov(reg_coff, vlen / 2);
1001                 backward_diff_channels();
1002                 sub(reg_diff_dst, vlen / 2);
1003                 sub(reg_diff_src, vlen / 2);
1004                 sub(reg_src, vlen / 2);
1005             }
1006             add(reg_soff, reg_mb_stride_Bc);
1007             cmp(reg_soff, reg_soff_max);
1008             jne(diff_spatial);
1009         }
1010     }
1011
1012     jit_bnorm_t(const batch_normalization_pd_t *bdesc): bdesc_(bdesc) {
1013         static_assert(isa == sse42 || isa == avx2 || isa == avx512_common
1014                 || isa == avx512_mic, "unsupported isa");
1015
1016         const int simd_w = isa == sse42 ? 8 :
1017             cpu_isa_traits<isa>::vlen / sizeof(data_t);
1018         is_spatial_thr_ =
1019             bnorm_utils::is_spatial_thr(bdesc_, simd_w, sizeof(data_t));
1020
1021         unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
1022         unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
1023
1024         preamble();
1025
1026         if (isa == avx512_common)
1027             prepare_tail_mask_avx512_common();
1028         else if (isa == avx2)
1029             prepare_tail_mask_avx2_common();
1030
1031         compute_static_strides();
1032         sub(rsp, stack_size_required);
1033         load_common_params();
1034         prepare_relu();
1035
1036         if (bdesc_->is_fwd()) {
1037             if (!bdesc_->stats_is_src()) {
1038                 compute_mean_variance();
1039             }
1040             forward();
1041         } else {
1042             backward();
1043         }
1044         add(rsp, stack_size_required);
1045         postamble();
1046
1047         ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
1048                     this->getCode()));
1049     }
1050 };
1051
1052 template <cpu_isa_t isa>
1053 struct uni_bnorm_driver_t: public c_compatible {
1054     uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc)
1055         : bdesc_(bdesc), ker_(bdesc_)
1056     {
1057         const int nthrs = mkldnn_get_max_threads();
1058         const int C_PADDED = get_c_padded(bdesc_);
1059
1060         size_t data_size = sizeof(data_t) * bdesc_->MB() * C_PADDED
1061             * bdesc_->D() * bdesc_->H() * bdesc_->W();
1062         l3_size_ = get_cache_size(3, true) * nthrs / 2;
1063         do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0);
1064     }
1065
1066     ~uni_bnorm_driver_t() {}
1067
1068     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
1069             const batch_normalization_pd_t *bdesc) {
1070         int nthrs = mkldnn_get_max_threads();
1071         int C_PADDED = get_c_padded(bdesc);
1072
1073         int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED;
1074         int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED;
1075         int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs;
1076
1077         scratchpad.book(key_bnorm_tmp_stats, sizeof(data_t) * sbuf_sz);
1078         scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * pbuf_sz);
1079         scratchpad.book(key_bnorm_reduction, sizeof(data_t) * rbuf_sz);
1080
1081         if (mkldnn_thr_syncable()) {
1082             int n_barriers = C_PADDED / simd_w;
1083             scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers);
1084         }
1085     }
1086
1087     void exec(int ithr, int nthr, const data_t *src, data_t *diff_src,
1088             data_t *dst, const data_t *diff_dst, const data_t *scale_shift,
1089             data_t *diff_scale_shift, const data_t *mean, const data_t *var,
1090             const uint8_t *ws, const memory_tracking::grantor_t &scratchpad) {
1091         auto sbuf = scratchpad.get<data_t>(key_bnorm_tmp_stats);
1092         auto pbuf = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
1093         auto rbuf = scratchpad.get<data_t>(key_bnorm_reduction);
1094         auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
1095
1096         size_t N = bdesc_->MB();
1097         size_t C = bdesc_->C();
1098         size_t C_PADDED = get_c_padded(bdesc_);
1099         size_t D = bdesc_->D();
1100         size_t H = bdesc_->H();
1101         size_t W = bdesc_->W();
1102         int SP = D * H * W;
1103         size_t img_size = C_PADDED * D * H * W;
1104         const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
1105
1106         typename jit_bnorm_t<isa>::call_params_t p;
1107
1108         p.eps = bdesc_->desc()->batch_norm_epsilon;
1109         p.one = 1.0f;
1110         p.spat_size = D * H * W;
1111         p.chan_size = 1.0f * N * p.spat_size;
1112
1113         int C_blks = C_PADDED / simd_w;
1114
1115         int C_ithr{0}, C_nthr{0}, N_ithr{0}, N_nthr{0}, S_ithr{0}, S_nthr{0};
1116         int C_blk_s{0}, C_blk_e{0}, N_s{0}, N_e{0}, S_s{0}, S_e{0};
1117
1118         int C_blks_per_iter{ 1 }, iters{ 1 };
1119         if (do_blocking_) {
1120             int num_tensors = bdesc_->is_fwd() ? 1 : 2;
1121             size_t working_set_size
1122                 = (N * D * H * W * simd_w * sizeof(data_t)) * num_tensors;
1123             bnorm_utils::cache_balance(working_set_size, C_blks,
1124                 C_blks_per_iter, iters);
1125         }
1126
1127         bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
1128                 true, ithr, nthr, N, do_blocking_ ? C_blks_per_iter : C_blks,
1129                 SP, C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e,
1130                 S_ithr, S_nthr, S_s, S_e);
1131
1132         int SP_N_ithr = N_ithr * S_nthr + S_ithr;
1133         int SP_N_nthr = N_nthr * S_nthr;
1134         assert(IMPLICATION(!mkldnn_thr_syncable(), SP_N_nthr == 1));
1135
1136         p.N_ithr = SP_N_ithr;
1137         p.N_nthr = SP_N_nthr;
1138
1139         int last_iter_blks = C_blks - (iters - 1) * C_blks_per_iter;
1140         int global_C_blk_s;
1141         int global_barriers_per_iter = C_nthr;
1142
1143         for (int it = 0; it < iters; it++) {
1144             if (it == iters - 1 && iters > 1) {
1145                 C_blk_s = C_blk_e = N_s = N_e = 0;
1146                 spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
1147                         spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
1148                         C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
1149                         N_e, S_ithr, S_nthr, S_s, S_e);
1150
1151                 // Update call parameters for JIT, last iteration
1152                 p.N_ithr = N_ithr * S_nthr + S_ithr;
1153                 p.N_nthr = N_nthr * S_nthr;
1154             }
1155
1156             global_C_blk_s = do_blocking_ ?
1157                     (C_blk_s == -1) ? -1 : it * C_blks_per_iter + C_blk_s :
1158                     C_blk_s;
1159
1160             int C_blks_thr = C_blk_e - C_blk_s;
1161             int N_thr = N_e - N_s;
1162
1163             size_t coff_base = global_C_blk_s * simd_w;
1164             size_t soff_base
1165                     = global_C_blk_s * p.spat_size * simd_w + N_s * img_size;
1166
1167             p.spat_size_loc = S_e - S_s;
1168             p.S_s = S_s * vlen;
1169             p.S_tail = (p.spat_size - S_e) * vlen;
1170             p.coff_max = C_blks_thr * simd_w;
1171             p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base;
1172             p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base;
1173             p.scale_shift = scale_shift + coff_base;
1174             p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_)
1175                     ? pbuf : diff_scale_shift) + coff_base;
1176
1177             p.soff_max = N_thr * img_size;
1178             p.src = src + soff_base;
1179             p.dst = dst + soff_base;
1180             p.diff_src = diff_src + soff_base;
1181             p.diff_dst = diff_dst + soff_base;
1182             p.ws = ws + soff_base / 8;
1183
1184             p.mb_stride_Bc = img_size - p.coff_max * p.spat_size;
1185
1186             // use SP_N_nthr which is the same as p.N_nthr except maybe for
1187             // the last iteration.
1188             p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr
1189                     + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w;
1190             // rbuf1 and rbuf2 have to be disjoint
1191             p.rbuf2 = p.rbuf1 + C_PADDED * nthr;
1192             p.is_cblk_tail =
1193                 (size_t)((it * C_blks_per_iter + C_blk_e) * simd_w) > C;
1194
1195             size_t iter_bariers
1196                     = do_blocking_ ? it * global_barriers_per_iter : 0;
1197             p.barrier = barriers + C_ithr + iter_bariers;
1198             if (p.soff_max != 0 && p.coff_max != 0)
1199                 ker_(&p);
1200         }
1201     }
1202
1203     void init_barriers(const memory_tracking::grantor_t &scratchpad) {
1204         auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
1205         if (barriers) {
1206             const int n_barriers = get_c_padded(bdesc_) / simd_w;
1207             for (int i = 0; i < n_barriers; ++i)
1208                 barrier::ctx_init(&barriers[i]);
1209         }
1210     }
1211
1212 private:
1213     enum {
1214         simd_w = isa == sse42 ? 8 : cpu_isa_traits<isa>::vlen / sizeof(data_t)
1215     };
1216
1217     static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) {
1218         return true
1219             && !bdesc->stats_is_src()
1220             && bdesc->desc()->prop_kind == prop_kind::forward_inference;
1221     }
1222
1223     static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc)
1224     {
1225         return false
1226             || (bdesc->is_bwd() && !bdesc->use_scaleshift())
1227             || bdesc->desc()->prop_kind == prop_kind::backward_data;
1228     }
1229
1230     static int get_c_padded(const batch_normalization_pd_t *bdesc)
1231     { return bdesc->src_pd()->desc()->layout_desc.blocking.padding_dims[1]; }
1232
1233     const batch_normalization_pd_t *bdesc_;
1234     bool do_blocking_;
1235     size_t l3_size_;
1236
1237     jit_bnorm_t<isa> ker_;
1238 };
1239
1240 }
1241
1242 using namespace data_type;
1243 using namespace memory_format;
1244 using namespace utils;
1245
1246 /* fwd */
1247
1248 template <cpu_isa_t isa>
1249 status_t jit_uni_batch_normalization_fwd_t<isa>::pd_t::init() {
1250     assert(engine()->kind() == engine_kind::cpu);
1251     auto desired_fmt = (ndims() == 4)
1252         ? isa == avx512_common ? nChw16c : nChw8c
1253         : isa == avx512_common ? nCdhw16c : nCdhw8c;
1254
1255     bool ok = true
1256         && mayiuse(isa)
1257         && is_fwd()
1258         && !has_zero_dim_memory()
1259         && one_of(ndims(), 4, 5)
1260         && desc()->data_desc.data_type == f32
1261         && IMPLICATION(use_scaleshift(),
1262                 desc()->data_scaleshift_desc.data_type == f32)
1263         && desc()->data_desc.format == desired_fmt
1264         && (attr()->has_default_values() || this->with_relu_post_op());
1265     if (!ok) return status::unimplemented;
1266
1267     if (is_training() && fuse_bn_relu()) {
1268         if (isa < avx2) return status::unimplemented;
1269         bn_init_default_ws(this, this->workspace_pd_, 1);
1270     }
1271
1272     if (memory_desc_wrapper(&data_pd_).blocking_desc().padding_dims[1]
1273             != this->C() && isa < avx2)
1274         return status::unimplemented;
1275
1276     if (stats_is_src() || is_training()) {
1277         memory_desc_t stats_d;
1278         dims_t stats_dims = { C() };
1279         mkldnn_memory_desc_init(&stats_d, 1, stats_dims, f32, x);
1280         mean_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
1281         variance_pd_ = cpu_memory_t::pd_t(engine_, &stats_d);
1282     }
1283
1284     auto scratchpad = scratchpad_registry().registrar();
1285     uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
1286
1287     return status::success;
1288 }
1289
1290 template <cpu_isa_t isa>
1291 jit_uni_batch_normalization_fwd_t<isa>::jit_uni_batch_normalization_fwd_t(
1292         const pd_t *apd, const input_vector &inputs,
1293         const output_vector &outputs)
1294     : cpu_primitive_t(apd, inputs, outputs)
1295 { bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
1296
1297 template <cpu_isa_t isa>
1298 void jit_uni_batch_normalization_fwd_t<isa>::execute(event_t *e) const {
1299     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1300     auto dst = reinterpret_cast<data_t*>(this->memory(0));
1301     auto mean = reinterpret_cast<data_t*>(pd()->stats_is_src()
1302             ? const_cast<char*>(this->input_memory(1))
1303             : this->memory(1));
1304     auto var = reinterpret_cast<data_t*>(pd()->stats_is_src()
1305             ? const_cast<char*>(this->input_memory(2))
1306             : this->memory(2));
1307
1308     auto idx_scale_shift = 1 + 2*pd()->stats_is_src();
1309     auto scale_shift =
1310         reinterpret_cast<const data_t *>(this->input_memory(idx_scale_shift));
1311     auto ws = reinterpret_cast<uint8_t *>(this->memory(pd()->ws_idx()));
1312
1313     auto scratchpad = this->scratchpad();
1314
1315     bnorm_driver_->init_barriers(scratchpad);
1316
1317     parallel(0, [&](const int ithr, const int nthr) {
1318         bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr,
1319                 scale_shift, nullptr, mean, var, ws, scratchpad);
1320     });
1321     e->set_state(event_t::ready);
1322 }
1323
1324 template <cpu_isa_t isa>
1325 jit_uni_batch_normalization_fwd_t<isa>::~jit_uni_batch_normalization_fwd_t()
1326 { delete bnorm_driver_; }
1327
1328 /* bwd */
1329
1330 template <cpu_isa_t isa>
1331 status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init() {
1332     assert(engine()->kind() == engine_kind::cpu);
1333     auto desired_fmt = (ndims() == 4)
1334         ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c
1335         : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c;
1336
1337     bool ok = true
1338         && mayiuse(isa)
1339         && is_bwd()
1340         && !has_zero_dim_memory()
1341         && one_of(ndims(), 4, 5)
1342         && everyone_is(f32, desc()->data_desc.data_type,
1343                 desc()->diff_data_desc.data_type)
1344         && IMPLICATION(use_scaleshift(),
1345                 desc()->data_scaleshift_desc.data_type == f32)
1346         && everyone_is(desired_fmt, desc()->diff_data_desc.format,
1347                 desc()->data_desc.format)
1348         && attr()->has_default_values();
1349     if (!ok) return status::unimplemented;
1350
1351     if (memory_desc_wrapper(&data_pd_).blocking_desc()
1352             .padding_dims[1] != this->C() && isa < avx2)
1353         return status::unimplemented;
1354
1355     if (fuse_bn_relu()) {
1356         if (isa < avx2) return status::unimplemented;
1357         bn_init_default_ws(this, this->workspace_pd_, 1);
1358         size_t this_ws_sz = memory_desc_wrapper(this->workspace_pd()).size();
1359
1360         bool ws_ok = true
1361             && hint_fwd_pd_->workspace_pd()
1362             && memory_desc_wrapper(hint_fwd_pd_->workspace_pd()).size()
1363             == this_ws_sz;
1364         if (!ws_ok) return status::unimplemented;
1365     }
1366
1367     /* TODO: extra checks required */
1368
1369     auto scratchpad = scratchpad_registry().registrar();
1370     uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
1371
1372     return status::success;
1373 }
1374
1375 template <cpu_isa_t isa>
1376 jit_uni_batch_normalization_bwd_t<isa>::jit_uni_batch_normalization_bwd_t(
1377         const pd_t *apd, const input_vector &inputs,
1378         const output_vector &outputs)
1379     : cpu_primitive_t(apd, inputs, outputs)
1380 { bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
1381
1382 template <cpu_isa_t isa>
1383 void jit_uni_batch_normalization_bwd_t<isa>::execute(event_t *e) const {
1384     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
1385     auto mean = reinterpret_cast<const data_t *>(this->input_memory(1));
1386     auto var = reinterpret_cast<const data_t *>(this->input_memory(2));
1387     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(3));
1388     auto scale_shift = reinterpret_cast<const data_t *>(this->input_memory(4));
1389     auto diff_src = reinterpret_cast<data_t*>(this->memory(0));
1390     auto diff_scale_shift = reinterpret_cast<data_t *>(this->memory(1));
1391     auto ws = reinterpret_cast<const uint8_t *>(
1392             this->input_memory(pd()->ws_idx()));
1393
1394     auto scratchpad = this->scratchpad();
1395
1396     bnorm_driver_->init_barriers(scratchpad);
1397
1398     parallel(0, [&](const int ithr, const int nthr) {
1399         bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst,
1400                 scale_shift, diff_scale_shift, mean, var, ws, scratchpad);
1401     });
1402     e->set_state(event_t::ready);
1403 }
1404
1405 template <cpu_isa_t isa>
1406 jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t()
1407 { delete bnorm_driver_; }
1408
1409 /* struct instantiation */
1410 template struct jit_uni_batch_normalization_fwd_t<sse42>;
1411 template struct jit_uni_batch_normalization_bwd_t<sse42>;
1412 template struct jit_uni_batch_normalization_fwd_t<avx2>;
1413 template struct jit_uni_batch_normalization_bwd_t<avx2>;
1414 template struct jit_uni_batch_normalization_fwd_t<avx512_common>;
1415 template struct jit_uni_batch_normalization_bwd_t<avx512_common>;
1416
1417 }
1418 }
1419 }