1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 * Copyright 2018 YANDEX LLC
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
18 #include "c_types_map.hpp"
21 #include "cpu_pooling_pd.hpp"
23 #include "jit_uni_pool_kernel.hpp"
29 using namespace Xbyak;
30 using namespace alg_kind;
31 using namespace mkldnn::impl::memory_format;
33 #define GET_OFF(field) offsetof(jit_pool_call_s, field)
35 template <cpu_isa_t isa>
36 status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
37 const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
38 const memory_desc_wrapper &dst_d) {
41 && utils::one_of(pd.alg_kind, pooling_max,
42 pooling_avg_include_padding,
43 pooling_avg_exclude_padding);
44 if (!args_ok) return status::unimplemented;
46 const int simd_w = isa == avx512_common ? 16 : 8;
47 const int ndims = src_d.ndims();
49 jpp.is_cpx = mayiuse(avx512_core_bf16);
52 jpp.mb = src_d.dims()[0];
54 jpp.c = utils::rnd_up(src_d.dims()[1], simd_w);
55 if (jpp.c > src_d.blocking_desc().padding_dims[1])
56 return status::unimplemented;
58 jpp.id = (ndims == 5) ? src_d.dims()[2] : 1;
59 jpp.ih = src_d.dims()[ndims-2];
60 jpp.iw = src_d.dims()[ndims-1];
61 jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
62 jpp.oh = dst_d.dims()[ndims-2];
63 jpp.ow = dst_d.dims()[ndims-1];
65 jpp.stride_d = (ndims == 5 ) ? pd.strides[0] : 1;
66 jpp.stride_h = pd.strides[ndims-4];
67 jpp.stride_w = pd.strides[ndims-3];
68 jpp.kd = (ndims == 5) ? pd.kernel[0] : 1;
69 jpp.kh = pd.kernel[ndims-4];
70 jpp.kw = pd.kernel[ndims-3];
72 jpp.f_pad = (ndims == 5 ) ? pd.padding[0][0] : 0;
73 jpp.t_pad = pd.padding[0][ndims-4];
74 jpp.l_pad = pd.padding[0][ndims-3];
75 jpp.b_pad = pd.padding[1][ndims-4];
76 jpp.r_pad = pd.padding[1][ndims-3];
77 jpp.back_pad = pd.padding[1][ndims-2];
79 // This condition was relaxed in order to support old behavior
80 // if (jpp.f_pad >= jpp.kd || jpp.t_pad >= jpp.kh || jpp.l_pad >= jpp.kw
81 // || jpp.back_pad >= jpp.kd || jpp.b_pad >= jpp.kh || jpp.r_pad >= jpp.kw)
82 // return status::unimplemented;
83 if (jpp.f_pad >= jpp.kd || jpp.back_pad >= jpp.kd)
84 return status::unimplemented;
86 jpp.alg = pd.alg_kind;
88 jpp.is_training = pd.prop_kind == prop_kind::forward_training;
89 jpp.is_backward = pd.prop_kind == prop_kind::backward_data;
90 jpp.ind_dt = pooling_index_data_type(&pd);
91 jpp.is_bf16 = (src_d.data_type() == data_type::bf16
92 && dst_d.data_type() == data_type::bf16);
94 if (!IMPLICATION(jpp.is_bf16, mayiuse(avx512_core)))
95 return status::unimplemented;
97 jpp.dt_size = (jpp.is_bf16) ? sizeof(mkldnn_bfloat16_t) : sizeof(float);
99 jpp.simple_alg = jpp.is_training
100 || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d);
102 jpp.c_block = simd_w;
104 jpp.nb_c = jpp.c / jpp.c_block;
105 if (jpp.alg == pooling_max) {
106 jpp.ur_w = isa == avx512_common ? 16 : 4;
108 jpp.ur_w = isa == avx512_common ? 9 : 3;
109 else if (jpp.is_backward)
110 jpp.ur_w = isa == avx512_common ? 6 : 3;
113 jpp.ur_w = isa == avx512_common ? 12 : 6;
115 jpp.ur_w = isa == avx512_common ? 24 : 12;
118 jpp.ur_w = (!jpp.is_cpx)
119 ? jpp.ur_w - 4 // Free registers for AVX512 emulation
120 : jpp.ur_w - 1; // Free register for cvt from bf16 to f32
122 if (jpp.ow < jpp.ur_w) jpp.ur_w = jpp.ow;
123 if (jpp.l_pad > jpp.ur_w) return status::unimplemented;
124 jpp.ur_w_tail = jpp.ow % jpp.ur_w;
125 return status::success;
128 template <cpu_isa_t isa>
129 inline void jit_uni_pool_kernel<isa>::maybe_recalculate_divisor(int jj,
130 int ur_w, int pad_l, int pad_r, int pad_r_logic) {
132 int stride_w = jpp.stride_w;
134 int non_zero_kw = kw;
135 if (jpp.alg == pooling_avg_exclude_padding) {
136 non_zero_kw -= nstl::max(0, pad_l - jj * stride_w);
137 non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj) * stride_w);
138 } else { // jpp.alg == pooling_avg_include_padding
139 non_zero_kw -= nstl::max(0, pad_r_logic - (ur_w - 1 - jj) * stride_w);
142 if (non_zero_kw != prev_kw) {
143 mov(tmp_gpr, float2int((float)non_zero_kw));
144 movq(xmm_tmp, tmp_gpr);
145 uni_vbroadcastss(vmm_tmp, xmm_tmp);
146 uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h);
147 prev_kw = non_zero_kw;
152 template <cpu_isa_t isa>
153 inline void jit_uni_pool_kernel<isa>::avg_step(int ur_w, int pad_l,
154 int pad_r, int pad_r_logic) {
158 int stride_w = jpp.stride_w;
159 int c_block = jpp.c_block;
160 Label kd_label, kh_label;
162 for (int jj = 0; jj < ur_w; jj++) {
163 if (jpp.is_backward) {
164 load(jj, reg_output, jpp.dt_size * jj * c_block);
165 maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r, pad_r_logic);
166 uni_vdivps(vreg(jj), vreg(jj), vmm_tmp);
168 uni_vpxor(vreg(jj), vreg(jj), vreg(jj));
172 if (jpp.simple_alg && jpp.ndims == 5) {
175 mov(aux_reg_input_d, reg_input);
176 mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
178 mov(aux_reg_input, aux_reg_input_d);
180 mov(aux_reg_input, reg_input);
186 for (int ki = 0; ki < kw; ki++) {
187 int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w));
189 - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
191 for (int jj = jj_start; jj < jj_end; jj++) {
192 int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
193 if (aux_input_offset > iw * c_block)
195 int input_offset = jpp.dt_size * aux_input_offset;
196 if (jpp.is_backward) {
197 load(ur_w + jj, aux_reg_input, input_offset);
198 uni_vaddps(vreg(ur_w+jj), vreg(ur_w+jj), vreg(jj));
201 bf16_emu_->r_vcvtneps2bf16(
202 yreg(ur_w + jj), zreg(ur_w + jj));
204 vcvtneps2bf16(yreg(ur_w + jj), vreg(ur_w + jj));
205 vmovdqu16(ptr[aux_reg_input + input_offset],
208 uni_vmovups(vmmword[aux_reg_input + input_offset],
213 vmovups(ymm_tmp_1, ptr[aux_reg_input + input_offset]);
214 vpermw(vmm_tmp_1 | k_mask_cvt | T_z, vmm_idx(), vmm_tmp_1);
216 uni_vaddps(vreg(jj), vreg(jj), vmm_tmp_1);
218 uni_vaddps(vreg(jj), vreg(jj),
219 ptr[aux_reg_input + input_offset]);
224 add(aux_reg_input, jpp.dt_size * iw * c_block);
227 jl(kh_label, T_NEAR);
230 if (jpp.simple_alg && jpp.ndims == 5)
232 add(aux_reg_input_d, jpp.dt_size * jpp.ih * iw * c_block);
235 jg(kd_label, T_NEAR);
240 if (!jpp.is_backward) {
241 for (int jj = 0; jj < ur_w; jj++) {
242 maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r, pad_r_logic);
243 uni_vdivps(vreg(jj), vreg(jj), vmm_tmp);
246 bf16_emu_->r_vcvtneps2bf16(yreg(jj), zreg(jj));
248 vcvtneps2bf16(yreg(jj), vreg(jj));
250 ptr[reg_output + jpp.dt_size * jj * c_block], yreg(jj));
252 uni_vmovups(vmmword[reg_output + jpp.dt_size * jj * c_block],
259 template <cpu_isa_t isa>
260 inline void jit_uni_pool_kernel<isa>::max_step_fwd(int ur_w, int pad_l,
264 int stride_w = jpp.stride_w;
265 int c_block = jpp.c_block;
266 Label kd_label, kh_label;
268 float lowest = nstl::numeric_limits<float>::lowest();
269 mov(tmp_gpr, float2int(lowest));
270 movq(xmm_tmp, tmp_gpr);
271 uni_vbroadcastss(vmm_tmp, xmm_tmp);
273 for (int jj = 0; jj < ur_w; jj++) {
274 uni_vmovups(vreg(jj), vmm_tmp);
276 uni_vpxor(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(2*ur_w+jj));
278 if (jpp.is_training) {
279 movq(xmm_tmp, reg_k_shift);
280 uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
283 if (jpp.ndims == 5) {
286 mov(aux_reg_input_d, reg_input);
287 mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
289 mov(aux_reg_input, aux_reg_input_d);
291 mov(aux_reg_input, reg_input);
296 for (int ki = 0; ki < kw; ki++) {
297 int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w));
299 - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
300 for (int jj = jj_start; jj < jj_end; jj++) {
301 int aux_input_offset = (ki + jj * stride_w - pad_l) * c_block;
302 if (aux_input_offset > iw * c_block)
304 int input_offset = jpp.dt_size*aux_input_offset;
305 load(ur_w + jj, aux_reg_input, input_offset);
307 movups(vmm_mask, vreg(jj));
308 cmpps(vmm_mask, vreg(ur_w+jj), _cmp_lt_os);
309 blendvps(vreg(jj), vreg(ur_w+jj));
311 blendvps(vreg(2*ur_w+jj), vmm_k_offset);
312 } else if (isa == avx) {
313 vcmpps(vreg(3*ur_w+jj), vreg(jj), vreg(ur_w+jj),
315 vblendvps(vreg(jj), vreg(jj), vreg(ur_w+jj),
318 vblendvps(vreg(2*ur_w+jj), vreg(2*ur_w+jj),
319 vmm_k_offset, vreg(3*ur_w+jj));
321 vcmpps(k_store_mask, vreg(jj), vreg(ur_w+jj), _cmp_lt_os);
322 vblendmps(vreg(jj) | k_store_mask, vreg(jj), vreg(ur_w+jj));
324 vblendmps(vreg(2*ur_w+jj) | k_store_mask,
325 vreg(2*ur_w+jj), vmm_k_offset);
328 if (jpp.is_training) {
329 if (isa == avx && !mayiuse(avx2)) {
330 avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
332 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
336 add(aux_reg_input, jpp.dt_size * iw * c_block);
339 jl(kh_label, T_NEAR);
344 add(aux_reg_input_d, jpp.dt_size * jpp.ih * iw * c_block);
345 if (jpp.is_training) {
346 mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]);
347 movq(xmm_tmp, tmp_gpr);
348 uni_vpbroadcastd(vmm_tmp, xmm_tmp);
349 if (isa == avx && !mayiuse(avx2)) {
350 Xmm t(vmm_mask.getIdx());
351 avx_vpadd1(vmm_k_offset, xmm_tmp, t);
353 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
359 jg(kd_label, T_NEAR);
364 for (int jj = 0; jj < ur_w; jj++) {
367 bf16_emu_->r_vcvtneps2bf16(yreg(jj), zreg(jj));
369 vcvtneps2bf16(yreg(jj), vreg(jj));
370 vmovups(ptr[reg_output + jpp.dt_size*jj*c_block], yreg(jj));
372 uni_vmovups(vmmword[reg_output + jpp.dt_size*jj*c_block], vreg(jj));
374 if (jpp.is_training) {
375 const size_t step_index
376 = jj * c_block * types::data_type_size(jpp.ind_dt);
378 auto x = xreg(2 * ur_w + jj);
379 if (jpp.ind_dt == data_type::u8) {
381 for (int i = 0; i < 4; ++i)
382 pextrb(ptr[reg_index + step_index + i], x, 4*i);
383 } else if (isa == avx) {
384 auto y = yreg(2 * ur_w + jj);
386 movd(xmm_tmp, reg_shuf_mask);
387 uni_vpbroadcastd(vmm_tmp, xmm_tmp);
390 vpshufb(y, y, vmm_tmp);
391 movd(ptr[reg_index + step_index], x);
392 vperm2i128(y, y, y, 0x1u);
393 movd(ptr[reg_index + step_index + 4], x);
395 Xmm t(vmm_mask.getIdx());
396 vextractf128(t, y, 0);
397 vpshufb(t, t, xmm_tmp);
398 movd(ptr[reg_index + step_index], t);
399 vextractf128(t, y, 1);
400 vpshufb(t, t, xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0]
401 movd(ptr[reg_index + step_index + 4], t);
404 auto v = vreg(2 * ur_w + jj);
406 vmovups(ptr[reg_index + step_index], v | k_index_mask);
409 uni_vmovups(ptr[reg_index + step_index], vreg(2*ur_w+jj));
415 template <cpu_isa_t isa>
416 inline void jit_uni_pool_kernel<isa>::max_step_bwd(int ur_w, int pad_l,
421 int stride_w = jpp.stride_w;
422 int c_block = jpp.c_block;
423 Label kd_label, kh_label;
425 for (int jj = 0; jj < ur_w; jj++) {
426 load(jj, reg_output, jpp.dt_size * jj * c_block);
427 const size_t step_index
428 = jj * c_block * types::data_type_size(jpp.ind_dt);
429 if (jpp.ind_dt == data_type::u8) {
431 movd(xreg(ur_w+jj), ptr[reg_index + step_index]);
432 pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
433 } else if (isa == avx) {
434 movq(xreg(ur_w+jj), ptr[reg_index + step_index]);
435 if (!mayiuse(avx2)) {
436 avx_pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj), xmm_tmp);
438 vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
441 vmovups(vreg(ur_w+jj) | k_index_mask,
442 ptr[reg_index + step_index]);
443 vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
446 uni_vmovups(vreg(ur_w+jj), ptr[reg_index + step_index]);
449 movq(xmm_tmp, reg_k_shift);
450 uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
452 if (jpp.simple_alg && jpp.ndims == 5) {
456 // Save rdi since it is used in maskmovdqu
457 assert(dst_ptr == rdi);
460 mov(aux_reg_input_d, reg_input);
461 mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
462 mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]);
464 mov(aux_reg_input, aux_reg_input_d);
466 mov(aux_reg_input, reg_input);
472 for (int ki = 0; ki < kw; ki++) {
473 int jj_start = nstl::max(0, utils::div_up(pad_l - ki, stride_w));
475 - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
476 for (int jj = jj_start; jj < jj_end; jj++) {
477 int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
478 if (aux_input_offset > iw * c_block)
480 int input_offset = jpp.dt_size*aux_input_offset;
481 load(2 * ur_w + jj, aux_reg_input, input_offset);
483 mov(dst_ptr, aux_reg_input);
484 add(dst_ptr, input_offset);
486 movups(vreg(3*ur_w+jj), vreg(ur_w+jj));
487 pcmpeqd(vreg(3*ur_w+jj), vmm_k_offset);
488 addps(vreg(2*ur_w+jj), vreg(jj));
489 maskmovdqu(vreg(2*ur_w+jj), vreg(3*ur_w+jj));
490 } else if (isa == avx) {
492 vpcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset);
494 avx_pcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset, xmm_tmp);
496 vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(jj));
497 vmaskmovps(vmmword[aux_reg_input + input_offset],
498 vreg(3*ur_w+jj), vreg(2*ur_w+jj));
500 vpcmpeqd(k_store_mask, vreg(ur_w+jj), vmm_k_offset);
501 vblendmps(vmm_tmp | k_store_mask | T_z, vreg(jj), vreg(jj));
502 vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vmm_tmp);
505 bf16_emu_->r_vcvtneps2bf16(yreg(2*ur_w+jj), zreg(2*ur_w+jj));
507 vcvtneps2bf16(yreg(2*ur_w+jj), vreg(2*ur_w+jj));
508 vmovdqu16(ptr[aux_reg_input +
509 jpp.dt_size*aux_input_offset], yreg(2*ur_w+jj));
511 vmovups(vmmword[aux_reg_input +
512 jpp.dt_size*aux_input_offset], vreg(2*ur_w+jj));
516 if (isa == avx && !mayiuse(avx2)) {
517 avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
519 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
522 add(aux_reg_input, jpp.dt_size * iw * c_block);
525 jl(kh_label, T_NEAR);
527 if (jpp.simple_alg && jpp.ndims == 5)
529 add(aux_reg_input_d, jpp.dt_size * jpp.ih * iw * c_block);
531 mov(tmp_gpr, reg_kd_pad_shift);
532 movq(xmm_tmp, tmp_gpr);
533 uni_vpbroadcastd(vmm_tmp, xmm_tmp);
534 if (isa == avx && !mayiuse(avx2)) {
535 Xmm t(vmm_mask.getIdx());
536 avx_vpadd1(vmm_k_offset, vmm_tmp, t);
538 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
543 jg(kd_label, T_NEAR);
545 // Save rdi since it is used in maskmovdqu
546 assert(dst_ptr == rdi);
554 template <cpu_isa_t isa>
555 void jit_uni_pool_kernel<isa>::maybe_zero_diff_src() {
556 assert(jpp.c_block * sizeof(float) % cpu_isa_traits<isa>::vlen == 0);
557 Label l_skip, l_zero;
559 auto reg_oh = tmp_gpr;
560 mov(reg_oh, ptr[reg_param + GET_OFF(oh)]);
564 if (jpp.ndims == 5) {
565 mov(zero_size, ptr[reg_param + GET_OFF(oh)]);
566 mov(tmp_gpr, jpp.ih * jpp.iw * jpp.c_block * jpp.dt_size);
567 imul(zero_size, tmp_gpr);
570 auto vzero = vmm_tmp;
571 auto yzero = ymm_tmp;
572 uni_vpxor(vzero, vzero, vzero);
574 auto reg_off = tmp_gpr;
575 xor_(reg_off, reg_off);
579 const int dim = jpp.iw * jpp.c_block * jpp.dt_size;
580 int step = (jpp.is_bf16)
581 ? cpu_isa_traits<isa>::vlen / 2
582 : cpu_isa_traits<isa>::vlen;
583 for (int i = 0; i < dim; i += step)
585 vmovdqu16(ptr[reg_input + reg_off + i], yzero);
587 uni_vmovups(ptr[reg_input + reg_off + i], vzero);
590 if (jpp.ndims == 5) cmp(reg_off, zero_size);
591 else cmp(reg_off, jpp.ih * dim);
598 template <cpu_isa_t isa>
599 void jit_uni_pool_kernel<isa>::generate() {
609 int c_block = jpp.c_block;
610 int stride_w = jpp.stride_w;
611 int l_pad = jpp.l_pad;
612 int ur_w_tail = jpp.ur_w_tail;
614 int n_oi = ow / ur_w;
618 int vlen = cpu_isa_traits<isa>::vlen;
621 // Always mimic the Unix ABI (see the note about maskmovdqu in the header
627 if (!jpp.is_cpx && jpp.is_bf16)
628 bf16_emu_->init_vcvtneps2bf16();
630 mov(reg_input, ptr[reg_param + GET_OFF(src)]);
631 mov(reg_output, ptr[reg_param + GET_OFF(dst)]);
632 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
633 mov(reg_index, ptr[reg_param + GET_OFF(indices)]);
634 mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]);
635 mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]);
636 mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]);
639 mov(tmp_gpr.cvt32(), 0xAAAAAAAA);
640 kmovd(k_mask_cvt, tmp_gpr.cvt32());
642 mov(tmp_gpr, idx_table);
643 vmovups(vmm_idx(), ptr[tmp_gpr]);
647 maybe_zero_diff_src();
649 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) {
651 movq(xmm_one, tmp_gpr);
652 uni_vpbroadcastd(vmm_one, xmm_one);
655 mov(reg_shuf_mask, 0x0c080400);
656 } else if (isa >= avx512_common) {
657 mov(tmp_gpr.cvt32(), 0x000f);
658 kmovw(k_index_mask, tmp_gpr.cvt32());
662 int r_pad = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad - 1));
663 int r_pad_log = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad + jpp.r_pad - 1));
664 int r_pad1 = (ur_w*n_oi - 1)*stride_w + kw - 1 - (iw + l_pad - 1);
665 int r_pad1_log = nstl::max(0, r_pad1 - jpp.r_pad);
666 if (r_pad1 > 0) n_oi--;
668 movq(xmm_ker_area_h, reg_ker_area_h);
669 uni_vpbroadcastd(vmm_ker_area_h, xmm_ker_area_h);
673 if (n_oi < 0 && r_pad1 > 0) {
674 step(ur_w, l_pad, r_pad1, r_pad1_log);
676 step(ur_w, l_pad, 0, 0);
680 if (n_oi < 0 && r_pad1 > 0) {
681 step_high_half(ur_w, l_pad, r_pad1, r_pad1_log);
683 step_high_half(ur_w, l_pad, 0, 0);
688 add(reg_input, jpp.dt_size*(ur_w*stride_w-l_pad)*c_block - vlen);
689 add(reg_output, jpp.dt_size*ur_w*c_block - vlen);
690 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
691 add(reg_index, (2 * ur_w - 1) * c_block / 2
692 * types::data_type_size(jpp.ind_dt));
694 add(reg_input, jpp.dt_size*(ur_w*stride_w - l_pad)*c_block);
695 add(reg_output, jpp.dt_size*ur_w*c_block);
696 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
697 add(reg_index, ur_w * c_block
698 * types::data_type_size(jpp.ind_dt));
702 xor_(oi_iter, oi_iter);
709 step_high_half(ur_w, 0, 0, 0);
713 add(reg_input, jpp.dt_size*ur_w*stride_w*c_block - vlen);
714 add(reg_output, jpp.dt_size*ur_w*c_block - vlen);
715 if (jpp.alg == pooling_max &&
716 (jpp.is_training || jpp.is_backward))
717 add(reg_index, (2 * ur_w - 1) * c_block / 2
718 * types::data_type_size(jpp.ind_dt));
720 add(reg_input, jpp.dt_size*ur_w*stride_w*c_block);
721 add(reg_output, jpp.dt_size*ur_w*c_block);
722 if (jpp.alg == pooling_max &&
723 (jpp.is_training || jpp.is_backward))
724 add(reg_index, ur_w * c_block
725 * types::data_type_size(jpp.ind_dt));
734 if (r_pad1 > 0 && n_oi >= 0) {
735 step(ur_w, 0, r_pad1, r_pad1_log);
738 step_high_half(ur_w, 0, r_pad1, r_pad1_log);
742 add(reg_input, jpp.dt_size*ur_w*stride_w*c_block - vlen);
743 add(reg_output, jpp.dt_size*ur_w*c_block - vlen);
744 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
745 add(reg_index, (2 * ur_w - 1) * c_block / 2
746 * types::data_type_size(jpp.ind_dt));
748 add(reg_input, jpp.dt_size*ur_w*stride_w*c_block);
749 add(reg_output, jpp.dt_size*ur_w*c_block);
750 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
751 add(reg_index, ur_w * c_block
752 * types::data_type_size(jpp.ind_dt));
756 if (ur_w_tail != 0) {
757 step(ur_w_tail, 0, r_pad, r_pad_log);
760 step_high_half(ur_w_tail, 0, r_pad, r_pad_log);
769 const uint16_t _idx[] = { 0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,
770 9,9,10,10,11,11,12,12,13,13,14,14,15,15 };
771 for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
776 template struct jit_uni_pool_kernel<sse42>;
777 template struct jit_uni_pool_kernel<avx>; // implements both <avx> and <avx2>
778 template struct jit_uni_pool_kernel<avx512_common>;
784 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s