updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_pool_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 * Copyright 2018 YANDEX LLC
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #include "c_types_map.hpp"
19 #include "nstl.hpp"
20 #include "utils.hpp"
21 #include "cpu_pooling_pd.hpp"
22
23 #include "jit_uni_pool_kernel.hpp"
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 using namespace Xbyak;
30 using namespace alg_kind;
31 using namespace mkldnn::impl::memory_format;
32
33 #define GET_OFF(field) offsetof(jit_pool_call_s, field)
34
35 template <cpu_isa_t isa>
36 status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
37             const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
38             const memory_desc_wrapper &dst_d) {
39
40     bool args_ok = true
41         && utils::one_of(pd.alg_kind, pooling_max,
42                 pooling_avg_include_padding,
43                 pooling_avg_exclude_padding);
44     if (!args_ok) return status::unimplemented;
45
46     const int simd_w = isa == avx512_common ? 16 : 8;
47     const int ndims = src_d.ndims();
48
49     jpp.is_cpx = mayiuse(avx512_core_bf16);
50
51     jpp.ndims = ndims;
52     jpp.mb = src_d.dims()[0];
53
54     jpp.c = utils::rnd_up(src_d.dims()[1], simd_w);
55     if (jpp.c > src_d.blocking_desc().padding_dims[1])
56         return status::unimplemented;
57
58     jpp.id = (ndims == 5) ? src_d.dims()[2] : 1;
59     jpp.ih = src_d.dims()[ndims-2];
60     jpp.iw = src_d.dims()[ndims-1];
61     jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
62     jpp.oh = dst_d.dims()[ndims-2];
63     jpp.ow = dst_d.dims()[ndims-1];
64
65     jpp.stride_d = (ndims == 5 ) ? pd.strides[0] : 1;
66     jpp.stride_h = pd.strides[ndims-4];
67     jpp.stride_w = pd.strides[ndims-3];
68     jpp.kd = (ndims == 5) ? pd.kernel[0] : 1;
69     jpp.kh = pd.kernel[ndims-4];
70     jpp.kw = pd.kernel[ndims-3];
71
72     jpp.f_pad = (ndims == 5 ) ? pd.padding[0][0] : 0;
73     jpp.t_pad = pd.padding[0][ndims-4];
74     jpp.l_pad = pd.padding[0][ndims-3];
75     jpp.b_pad = pd.padding[1][ndims-4];
76     jpp.r_pad = pd.padding[1][ndims-3];
77     jpp.back_pad = pd.padding[1][ndims-2];
78
79 // This condition was relaxed in order to support old behavior
80 //    if (jpp.f_pad >= jpp.kd || jpp.t_pad >= jpp.kh || jpp.l_pad >= jpp.kw
81 //         || jpp.back_pad >= jpp.kd || jpp.b_pad >= jpp.kh || jpp.r_pad >= jpp.kw)
82 //        return status::unimplemented;
83     if (jpp.f_pad >= jpp.kd || jpp.back_pad >= jpp.kd)
84         return status::unimplemented;
85
86     jpp.alg = pd.alg_kind;
87
88     jpp.is_training = pd.prop_kind == prop_kind::forward_training;
89     jpp.is_backward = pd.prop_kind == prop_kind::backward_data;
90     jpp.ind_dt = pooling_index_data_type(&pd);
91     jpp.is_bf16 = (src_d.data_type() == data_type::bf16
92                     && dst_d.data_type() == data_type::bf16);
93
94     if (!IMPLICATION(jpp.is_bf16, mayiuse(avx512_core)))
95         return status::unimplemented;
96
97     jpp.dt_size = (jpp.is_bf16) ? sizeof(mkldnn_bfloat16_t) : sizeof(float);
98
99     jpp.simple_alg = jpp.is_training
100         || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d);
101
102     jpp.c_block = simd_w;
103
104     jpp.nb_c = jpp.c / jpp.c_block;
105     if (jpp.alg == pooling_max) {
106         jpp.ur_w = isa == avx512_common ? 16 : 4;
107         if (jpp.is_training)
108             jpp.ur_w = isa == avx512_common ? 9 : 3;
109         else if (jpp.is_backward)
110             jpp.ur_w = isa == avx512_common ? 6 : 3;
111     } else {
112         if (jpp.is_backward)
113             jpp.ur_w = isa == avx512_common ? 12 : 6;
114         else
115             jpp.ur_w = isa == avx512_common ? 24 : 12;
116     }
117     if (jpp.is_bf16) {
118         jpp.ur_w = (!jpp.is_cpx)
119                    ? jpp.ur_w - 4  // Free registers for AVX512 emulation
120                    : jpp.ur_w - 1; // Free register for cvt from bf16 to f32
121     }
122     if (jpp.ow < jpp.ur_w) jpp.ur_w = jpp.ow;
123     if (jpp.l_pad > jpp.ur_w) return status::unimplemented;
124     jpp.ur_w_tail = jpp.ow % jpp.ur_w;
125     return status::success;
126 }
127
128 template <cpu_isa_t isa>
129 inline void jit_uni_pool_kernel<isa>::maybe_recalculate_divisor(int jj,
130         int ur_w, int pad_l, int pad_r, int pad_r_logic) {
131     int kw = jpp.kw;
132     int stride_w = jpp.stride_w;
133
134     int non_zero_kw = kw;
135     if (jpp.alg == pooling_avg_exclude_padding) {
136         non_zero_kw -= nstl::max(0, pad_l - jj * stride_w);
137         non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj) * stride_w);
138     } else { //  jpp.alg == pooling_avg_include_padding
139         non_zero_kw -= nstl::max(0, pad_r_logic - (ur_w - 1 - jj) * stride_w);
140     }
141
142     if (non_zero_kw != prev_kw) {
143         mov(tmp_gpr, float2int((float)non_zero_kw));
144         movq(xmm_tmp, tmp_gpr);
145         uni_vbroadcastss(vmm_tmp, xmm_tmp);
146         uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h);
147         prev_kw = non_zero_kw;
148     }
149
150 }
151
152 template <cpu_isa_t isa>
153 inline void jit_uni_pool_kernel<isa>::avg_step(int ur_w, int pad_l,
154         int pad_r, int pad_r_logic) {
155
156     int iw = jpp.iw;
157     int kw = jpp.kw;
158     int stride_w = jpp.stride_w;
159     int c_block = jpp.c_block;
160     Label kd_label, kh_label;
161
162     for (int jj = 0; jj < ur_w; jj++) {
163         if (jpp.is_backward) {
164             load(jj, reg_output, jpp.dt_size * jj * c_block);
165             maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r, pad_r_logic);
166             uni_vdivps(vreg(jj), vreg(jj), vmm_tmp);
167         } else {
168             uni_vpxor(vreg(jj), vreg(jj), vreg(jj));
169         }
170     }
171
172     if (jpp.simple_alg && jpp.ndims == 5) {
173         push(reg_input);
174         push(reg_output);
175         mov(aux_reg_input_d, reg_input);
176         mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
177         L(kd_label);
178         mov(aux_reg_input, aux_reg_input_d);
179     } else {
180         mov(aux_reg_input, reg_input);
181     }
182
183     xor_(kj, kj);
184     L(kh_label);
185     {
186         for (int ki = 0; ki < kw; ki++) {
187             int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w));
188             int jj_end = ur_w
189                 - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
190
191             for (int jj = jj_start; jj  < jj_end; jj++) {
192                 int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
193                 if (aux_input_offset > iw * c_block)
194                     continue;
195                 int input_offset = jpp.dt_size * aux_input_offset;
196                 if (jpp.is_backward) {
197                     load(ur_w + jj, aux_reg_input, input_offset);
198                     uni_vaddps(vreg(ur_w+jj), vreg(ur_w+jj), vreg(jj));
199                     if (jpp.is_bf16) {
200                         if (!jpp.is_cpx)
201                             bf16_emu_->r_vcvtneps2bf16(
202                                     yreg(ur_w + jj), zreg(ur_w + jj));
203                         else
204                             vcvtneps2bf16(yreg(ur_w + jj), vreg(ur_w + jj));
205                         vmovdqu16(ptr[aux_reg_input + input_offset],
206                                 yreg(ur_w + jj));
207                     } else {
208                         uni_vmovups(vmmword[aux_reg_input + input_offset],
209                                 vreg(ur_w + jj));
210                     }
211                 } else {
212                     if (jpp.is_bf16) {
213                         vmovups(ymm_tmp_1, ptr[aux_reg_input + input_offset]);
214                         vpermw(vmm_tmp_1 | k_mask_cvt | T_z, vmm_idx(), vmm_tmp_1);
215
216                         uni_vaddps(vreg(jj), vreg(jj), vmm_tmp_1);
217                     } else {
218                         uni_vaddps(vreg(jj), vreg(jj),
219                                    ptr[aux_reg_input + input_offset]);
220                     }
221                 }
222             }
223         }
224         add(aux_reg_input,  jpp.dt_size * iw * c_block);
225         inc(kj);
226         cmp(kj, reg_kh);
227         jl(kh_label, T_NEAR);
228     }
229
230     if (jpp.simple_alg && jpp.ndims == 5)
231     {
232         add(aux_reg_input_d,  jpp.dt_size * jpp.ih * iw * c_block);
233         dec(ki);
234         cmp(ki, 0);
235         jg(kd_label, T_NEAR);
236         pop(reg_output);
237         pop(reg_input);
238     }
239
240     if (!jpp.is_backward) {
241         for (int jj = 0; jj < ur_w; jj++) {
242             maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r, pad_r_logic);
243             uni_vdivps(vreg(jj), vreg(jj), vmm_tmp);
244             if (jpp.is_bf16) {
245                 if (!jpp.is_cpx)
246                     bf16_emu_->r_vcvtneps2bf16(yreg(jj), zreg(jj));
247                 else
248                     vcvtneps2bf16(yreg(jj), vreg(jj));
249                 vmovdqu16(
250                         ptr[reg_output + jpp.dt_size * jj * c_block], yreg(jj));
251             } else {
252                 uni_vmovups(vmmword[reg_output + jpp.dt_size * jj * c_block],
253                         vreg(jj));
254             }
255         }
256     }
257 }
258
259 template <cpu_isa_t isa>
260 inline void jit_uni_pool_kernel<isa>::max_step_fwd(int ur_w, int pad_l,
261         int pad_r) {
262     int iw = jpp.iw;
263     int kw = jpp.kw;
264     int stride_w = jpp.stride_w;
265     int c_block = jpp.c_block;
266     Label kd_label, kh_label;
267
268     float lowest = nstl::numeric_limits<float>::lowest();
269     mov(tmp_gpr, float2int(lowest));
270     movq(xmm_tmp, tmp_gpr);
271     uni_vbroadcastss(vmm_tmp, xmm_tmp);
272
273     for (int jj = 0; jj < ur_w; jj++) {
274         uni_vmovups(vreg(jj), vmm_tmp);
275         if (jpp.is_training)
276             uni_vpxor(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(2*ur_w+jj));
277     }
278     if (jpp.is_training) {
279         movq(xmm_tmp, reg_k_shift);
280         uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
281     }
282
283     if (jpp.ndims == 5) {
284         push(reg_input);
285         push(reg_output);
286         mov(aux_reg_input_d, reg_input);
287         mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
288         L(kd_label);
289         mov(aux_reg_input, aux_reg_input_d);
290     } else {
291         mov(aux_reg_input, reg_input);
292     }
293     xor_(kj, kj);
294     L(kh_label);
295     {
296         for (int ki = 0; ki < kw; ki++) {
297             int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w));
298             int jj_end = ur_w
299                 - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
300             for (int jj = jj_start; jj  < jj_end; jj++) {
301                 int aux_input_offset = (ki + jj * stride_w - pad_l) * c_block;
302                 if (aux_input_offset > iw * c_block)
303                     continue;
304                 int input_offset = jpp.dt_size*aux_input_offset;
305                 load(ur_w + jj, aux_reg_input, input_offset);
306                 if (isa == sse42) {
307                     movups(vmm_mask, vreg(jj));
308                     cmpps(vmm_mask, vreg(ur_w+jj), _cmp_lt_os);
309                     blendvps(vreg(jj), vreg(ur_w+jj));
310                     if (jpp.is_training)
311                         blendvps(vreg(2*ur_w+jj), vmm_k_offset);
312                 } else if (isa == avx) {
313                     vcmpps(vreg(3*ur_w+jj), vreg(jj), vreg(ur_w+jj),
314                            _cmp_lt_os);
315                     vblendvps(vreg(jj), vreg(jj), vreg(ur_w+jj),
316                               vreg(3*ur_w+jj));
317                     if (jpp.is_training)
318                         vblendvps(vreg(2*ur_w+jj), vreg(2*ur_w+jj),
319                                   vmm_k_offset, vreg(3*ur_w+jj));
320                 } else {
321                     vcmpps(k_store_mask, vreg(jj), vreg(ur_w+jj), _cmp_lt_os);
322                     vblendmps(vreg(jj) | k_store_mask, vreg(jj), vreg(ur_w+jj));
323                     if (jpp.is_training)
324                         vblendmps(vreg(2*ur_w+jj) | k_store_mask,
325                                   vreg(2*ur_w+jj), vmm_k_offset);
326                 }
327             }
328             if (jpp.is_training) {
329                 if (isa == avx && !mayiuse(avx2)) {
330                     avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
331                 } else {
332                     uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
333                 }
334             }
335         }
336         add(aux_reg_input,  jpp.dt_size * iw * c_block);
337         inc(kj);
338         cmp(kj, reg_kh);
339         jl(kh_label, T_NEAR);
340     }
341
342     if (jpp.ndims == 5)
343     {
344         add(aux_reg_input_d,  jpp.dt_size * jpp.ih * iw * c_block);
345         if (jpp.is_training) {
346             mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]);
347             movq(xmm_tmp, tmp_gpr);
348             uni_vpbroadcastd(vmm_tmp, xmm_tmp);
349             if (isa == avx && !mayiuse(avx2)) {
350                 Xmm t(vmm_mask.getIdx());
351                 avx_vpadd1(vmm_k_offset, xmm_tmp, t);
352             } else {
353                 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
354             }
355         }
356
357         dec(ki);
358         cmp(ki, 0);
359         jg(kd_label, T_NEAR);
360         pop(reg_output);
361         pop(reg_input);
362     }
363
364     for (int jj = 0; jj < ur_w; jj++) {
365         if (jpp.is_bf16) {
366             if (!jpp.is_cpx)
367                 bf16_emu_->r_vcvtneps2bf16(yreg(jj), zreg(jj));
368             else
369                 vcvtneps2bf16(yreg(jj), vreg(jj));
370             vmovups(ptr[reg_output + jpp.dt_size*jj*c_block], yreg(jj));
371         } else {
372             uni_vmovups(vmmword[reg_output + jpp.dt_size*jj*c_block], vreg(jj));
373         }
374         if (jpp.is_training) {
375             const size_t step_index
376                 = jj * c_block * types::data_type_size(jpp.ind_dt);
377
378             auto x = xreg(2 * ur_w + jj);
379             if (jpp.ind_dt == data_type::u8) {
380                 if (isa == sse42) {
381                     for (int i = 0; i < 4; ++i)
382                         pextrb(ptr[reg_index + step_index + i], x, 4*i);
383                 } else if (isa == avx) {
384                     auto y = yreg(2 * ur_w + jj);
385                     if (jj == 0) {
386                         movd(xmm_tmp, reg_shuf_mask);
387                         uni_vpbroadcastd(vmm_tmp, xmm_tmp);
388                     }
389                     if (mayiuse(avx2)) {
390                         vpshufb(y, y, vmm_tmp);
391                         movd(ptr[reg_index + step_index], x);
392                         vperm2i128(y, y, y, 0x1u);
393                         movd(ptr[reg_index + step_index + 4], x);
394                     } else {
395                         Xmm t(vmm_mask.getIdx());
396                         vextractf128(t, y, 0);
397                         vpshufb(t, t, xmm_tmp);
398                         movd(ptr[reg_index + step_index], t);
399                         vextractf128(t, y, 1);
400                         vpshufb(t, t, xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0]
401                         movd(ptr[reg_index + step_index + 4], t);
402                     }
403                 } else {
404                     auto v = vreg(2 * ur_w + jj);
405                     vpmovusdb(x, v);
406                     vmovups(ptr[reg_index + step_index], v | k_index_mask);
407                 }
408             } else {
409                 uni_vmovups(ptr[reg_index + step_index], vreg(2*ur_w+jj));
410             }
411         }
412     }
413 }
414
415 template <cpu_isa_t isa>
416 inline void jit_uni_pool_kernel<isa>::max_step_bwd(int ur_w, int pad_l,
417         int pad_r) {
418
419     int iw = jpp.iw;
420     int kw = jpp.kw;
421     int stride_w = jpp.stride_w;
422     int c_block = jpp.c_block;
423     Label kd_label, kh_label;
424
425     for (int jj = 0; jj < ur_w; jj++) {
426         load(jj, reg_output, jpp.dt_size * jj * c_block);
427         const size_t step_index
428             = jj * c_block * types::data_type_size(jpp.ind_dt);
429         if (jpp.ind_dt == data_type::u8) {
430             if (isa == sse42) {
431                 movd(xreg(ur_w+jj), ptr[reg_index + step_index]);
432                 pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
433             } else if (isa == avx) {
434                 movq(xreg(ur_w+jj), ptr[reg_index + step_index]);
435                 if (!mayiuse(avx2)) {
436                     avx_pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj), xmm_tmp);
437                 } else {
438                     vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
439                 }
440             } else {
441                 vmovups(vreg(ur_w+jj) | k_index_mask,
442                         ptr[reg_index + step_index]);
443                 vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
444             }
445         } else {
446             uni_vmovups(vreg(ur_w+jj), ptr[reg_index + step_index]);
447         }
448     }
449     movq(xmm_tmp, reg_k_shift);
450     uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
451
452     if (jpp.simple_alg && jpp.ndims == 5) {
453         push(reg_input);
454         push(reg_output);
455         if (isa == sse42) {
456             // Save rdi since it is used in maskmovdqu
457             assert(dst_ptr == rdi);
458             push(dst_ptr);
459         }
460         mov(aux_reg_input_d, reg_input);
461         mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
462         mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]);
463         L(kd_label);
464         mov(aux_reg_input, aux_reg_input_d);
465     } else {
466         mov(aux_reg_input, reg_input);
467     }
468
469     xor_(kj, kj);
470     L(kh_label);
471     {
472         for (int ki = 0; ki < kw; ki++) {
473             int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w));
474             int jj_end = ur_w
475                 - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
476             for (int jj = jj_start; jj  < jj_end; jj++) {
477                 int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
478                 if (aux_input_offset > iw * c_block)
479                     continue;
480                 int input_offset = jpp.dt_size*aux_input_offset;
481                 load(2 * ur_w + jj, aux_reg_input, input_offset);
482                 if (isa == sse42) {
483                     mov(dst_ptr, aux_reg_input);
484                     add(dst_ptr, input_offset);
485
486                     movups(vreg(3*ur_w+jj), vreg(ur_w+jj));
487                     pcmpeqd(vreg(3*ur_w+jj), vmm_k_offset);
488                     addps(vreg(2*ur_w+jj), vreg(jj));
489                     maskmovdqu(vreg(2*ur_w+jj), vreg(3*ur_w+jj));
490                 } else if (isa == avx) {
491                     if (mayiuse(avx2)) {
492                         vpcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset);
493                     } else {
494                         avx_pcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset, xmm_tmp);
495                     }
496                     vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(jj));
497                     vmaskmovps(vmmword[aux_reg_input + input_offset],
498                             vreg(3*ur_w+jj), vreg(2*ur_w+jj));
499                 } else {
500                     vpcmpeqd(k_store_mask, vreg(ur_w+jj), vmm_k_offset);
501                     vblendmps(vmm_tmp | k_store_mask | T_z, vreg(jj), vreg(jj));
502                     vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vmm_tmp);
503                     if (jpp.is_bf16) {
504                         if (!jpp.is_cpx)
505                             bf16_emu_->r_vcvtneps2bf16(yreg(2*ur_w+jj), zreg(2*ur_w+jj));
506                         else
507                             vcvtneps2bf16(yreg(2*ur_w+jj), vreg(2*ur_w+jj));
508                         vmovdqu16(ptr[aux_reg_input +
509                             jpp.dt_size*aux_input_offset], yreg(2*ur_w+jj));
510                     } else {
511                         vmovups(vmmword[aux_reg_input +
512                             jpp.dt_size*aux_input_offset], vreg(2*ur_w+jj));
513                     }
514                 }
515             }
516             if (isa == avx && !mayiuse(avx2)) {
517                 avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
518             } else {
519                 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
520             }
521         }
522         add(aux_reg_input,  jpp.dt_size * iw * c_block);
523         inc(kj);
524         cmp(kj, reg_kh);
525         jl(kh_label, T_NEAR);
526     }
527     if (jpp.simple_alg && jpp.ndims == 5)
528     {
529         add(aux_reg_input_d,  jpp.dt_size * jpp.ih * iw * c_block);
530
531         mov(tmp_gpr, reg_kd_pad_shift);
532         movq(xmm_tmp, tmp_gpr);
533         uni_vpbroadcastd(vmm_tmp, xmm_tmp);
534         if (isa == avx && !mayiuse(avx2)) {
535             Xmm t(vmm_mask.getIdx());
536             avx_vpadd1(vmm_k_offset, vmm_tmp, t);
537         } else {
538             uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
539         }
540
541         dec(ki);
542         cmp(ki, 0);
543         jg(kd_label, T_NEAR);
544         if (isa == sse42) {
545             // Save rdi since it is used in maskmovdqu
546             assert(dst_ptr == rdi);
547             pop(dst_ptr);
548         }
549         pop(reg_output);
550         pop(reg_input);
551     }
552 }
553
554 template <cpu_isa_t isa>
555 void jit_uni_pool_kernel<isa>::maybe_zero_diff_src() {
556     assert(jpp.c_block * sizeof(float) % cpu_isa_traits<isa>::vlen == 0);
557     Label l_skip, l_zero;
558
559     auto reg_oh = tmp_gpr;
560     mov(reg_oh, ptr[reg_param + GET_OFF(oh)]);
561     cmp(reg_oh, 0);
562     jz(l_skip, T_NEAR);
563
564     if (jpp.ndims == 5) {
565         mov(zero_size, ptr[reg_param + GET_OFF(oh)]);
566         mov(tmp_gpr, jpp.ih * jpp.iw * jpp.c_block * jpp.dt_size);
567         imul(zero_size, tmp_gpr);
568     }
569
570     auto vzero = vmm_tmp;
571     auto yzero = ymm_tmp;
572     uni_vpxor(vzero, vzero, vzero);
573
574     auto reg_off = tmp_gpr;
575     xor_(reg_off, reg_off);
576
577     L(l_zero);
578     {
579         const int dim = jpp.iw * jpp.c_block * jpp.dt_size;
580         int step = (jpp.is_bf16)
581             ? cpu_isa_traits<isa>::vlen / 2
582             : cpu_isa_traits<isa>::vlen;
583         for (int i = 0; i < dim; i += step)
584             if (jpp.is_bf16) {
585                 vmovdqu16(ptr[reg_input + reg_off + i], yzero);
586             } else {
587                 uni_vmovups(ptr[reg_input + reg_off + i], vzero);
588             }
589         add(reg_off, dim);
590         if (jpp.ndims == 5) cmp(reg_off, zero_size);
591         else cmp(reg_off, jpp.ih * dim);
592         jl(l_zero, T_NEAR);
593     }
594
595     L(l_skip);
596 }
597
598 template <cpu_isa_t isa>
599 void jit_uni_pool_kernel<isa>::generate() {
600
601     this->preamble();
602
603     Label idx_table;
604
605     int ow = jpp.ow;
606     int iw = jpp.iw;
607     int kw = jpp.kw;
608     int ur_w = jpp.ur_w;
609     int c_block = jpp.c_block;
610     int stride_w = jpp.stride_w;
611     int l_pad = jpp.l_pad;
612     int ur_w_tail = jpp.ur_w_tail;
613
614     int n_oi = ow / ur_w;
615
616     prev_kw = 0;
617
618     int vlen = cpu_isa_traits<isa>::vlen;
619
620 #if defined(_WIN32)
621     // Always mimic the Unix ABI (see the note about maskmovdqu in the header
622     // file).
623     xor_(rdi, rcx);
624     xor_(rcx, rdi);
625     xor_(rdi, rcx);
626 #endif
627     if (!jpp.is_cpx && jpp.is_bf16)
628         bf16_emu_->init_vcvtneps2bf16();
629
630     mov(reg_input, ptr[reg_param + GET_OFF(src)]);
631     mov(reg_output, ptr[reg_param + GET_OFF(dst)]);
632     if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
633         mov(reg_index, ptr[reg_param + GET_OFF(indices)]);
634     mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]);
635     mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]);
636     mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]);
637
638     if (jpp.is_bf16) {
639         mov(tmp_gpr.cvt32(), 0xAAAAAAAA);
640         kmovd(k_mask_cvt, tmp_gpr.cvt32());
641
642         mov(tmp_gpr, idx_table);
643         vmovups(vmm_idx(), ptr[tmp_gpr]);
644     }
645
646     if (jpp.is_backward)
647         maybe_zero_diff_src();
648
649     if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) {
650         mov(tmp_gpr, 1);
651         movq(xmm_one, tmp_gpr);
652         uni_vpbroadcastd(vmm_one, xmm_one);
653
654         if (isa == avx) {
655             mov(reg_shuf_mask, 0x0c080400);
656         } else if (isa >= avx512_common) {
657             mov(tmp_gpr.cvt32(), 0x000f);
658             kmovw(k_index_mask, tmp_gpr.cvt32());
659         }
660     }
661
662     int r_pad  = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad - 1));
663     int r_pad_log = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad + jpp.r_pad - 1));
664     int r_pad1 = (ur_w*n_oi - 1)*stride_w + kw - 1 - (iw + l_pad - 1);
665     int r_pad1_log = nstl::max(0, r_pad1 - jpp.r_pad);
666     if (r_pad1 > 0) n_oi--;
667
668     movq(xmm_ker_area_h, reg_ker_area_h);
669     uni_vpbroadcastd(vmm_ker_area_h, xmm_ker_area_h);
670
671     if (l_pad > 0) {
672         n_oi--;
673         if (n_oi < 0 && r_pad1 > 0) {
674             step(ur_w, l_pad, r_pad1, r_pad1_log);
675         } else  {
676             step(ur_w, l_pad, 0, 0);
677         }
678
679         if (isa == sse42) {
680             if (n_oi < 0 && r_pad1 > 0) {
681                 step_high_half(ur_w, l_pad, r_pad1, r_pad1_log);
682             } else  {
683                 step_high_half(ur_w, l_pad, 0, 0);
684             }
685         }
686
687         if (isa == sse42) {
688             add(reg_input, jpp.dt_size*(ur_w*stride_w-l_pad)*c_block - vlen);
689             add(reg_output, jpp.dt_size*ur_w*c_block - vlen);
690             if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
691                 add(reg_index, (2 * ur_w - 1) * c_block / 2
692                         * types::data_type_size(jpp.ind_dt));
693         } else {
694             add(reg_input, jpp.dt_size*(ur_w*stride_w - l_pad)*c_block);
695             add(reg_output, jpp.dt_size*ur_w*c_block);
696             if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
697                 add(reg_index, ur_w * c_block
698                         * types::data_type_size(jpp.ind_dt));
699         }
700     }
701
702     xor_(oi_iter, oi_iter);
703     if (n_oi > 0) {
704         Label ow_loop;
705         L(ow_loop); {
706             step(ur_w, 0, 0, 0);
707
708             if (isa == sse42) {
709                 step_high_half(ur_w, 0, 0, 0);
710             }
711
712             if (isa == sse42) {
713                 add(reg_input, jpp.dt_size*ur_w*stride_w*c_block - vlen);
714                 add(reg_output, jpp.dt_size*ur_w*c_block - vlen);
715                 if (jpp.alg == pooling_max &&
716                     (jpp.is_training || jpp.is_backward))
717                     add(reg_index, (2 * ur_w - 1) * c_block / 2
718                             * types::data_type_size(jpp.ind_dt));
719             } else {
720                 add(reg_input, jpp.dt_size*ur_w*stride_w*c_block);
721                 add(reg_output, jpp.dt_size*ur_w*c_block);
722                 if (jpp.alg == pooling_max &&
723                     (jpp.is_training || jpp.is_backward))
724                     add(reg_index, ur_w * c_block
725                             * types::data_type_size(jpp.ind_dt));
726             }
727
728             inc(oi_iter);
729             cmp(oi_iter, n_oi);
730             jl(ow_loop, T_NEAR);
731         }
732     }
733
734     if (r_pad1 > 0 && n_oi >= 0) {
735         step(ur_w, 0, r_pad1, r_pad1_log);
736
737         if (isa == sse42) {
738             step_high_half(ur_w, 0, r_pad1, r_pad1_log);
739         }
740
741         if (isa == sse42) {
742             add(reg_input, jpp.dt_size*ur_w*stride_w*c_block - vlen);
743             add(reg_output, jpp.dt_size*ur_w*c_block - vlen);
744             if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
745                 add(reg_index, (2 * ur_w - 1) * c_block / 2
746                         * types::data_type_size(jpp.ind_dt));
747         } else {
748             add(reg_input, jpp.dt_size*ur_w*stride_w*c_block);
749             add(reg_output, jpp.dt_size*ur_w*c_block);
750             if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
751                 add(reg_index, ur_w * c_block
752                         * types::data_type_size(jpp.ind_dt));
753         }
754     }
755
756     if (ur_w_tail != 0) {
757         step(ur_w_tail, 0, r_pad, r_pad_log);
758
759         if (isa == sse42) {
760             step_high_half(ur_w_tail, 0, r_pad, r_pad_log);
761         }
762     }
763
764     this->postamble();
765
766     if (jpp.is_bf16) {
767         align(64);
768         L(idx_table);
769         const uint16_t _idx[] = { 0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,
770                                   9,9,10,10,11,11,12,12,13,13,14,14,15,15 };
771         for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
772             dw(_idx[i]);
773     }
774 }
775
776 template struct jit_uni_pool_kernel<sse42>;
777 template struct jit_uni_pool_kernel<avx>; // implements both <avx> and <avx2>
778 template struct jit_uni_pool_kernel<avx512_common>;
779
780 }
781 }
782 }
783
784 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s