1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "jit_uni_i8i8_pooling.hpp"
21 #include "mkldnn_types.h"
23 #include "mkldnn_thread.hpp"
26 #include "jit_generator.hpp"
33 using namespace Xbyak;
35 using namespace mkldnn::impl::utils;
36 using namespace mkldnn::impl::memory_format;
37 using namespace mkldnn::impl::utils;
38 using namespace mkldnn::impl::types;
39 using namespace alg_kind;
41 template <cpu_isa_t isa>
42 struct jit_uni_i8i8_pooling_fwd_ker_t: public jit_generator {
43 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t)
45 struct call_params_t {
53 using Vmm = typename cpu_isa_traits<isa>::Vmm;
54 Xmm xreg(int idx) const { return Xmm(idx); }
55 Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); }
56 Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); }
58 // Rounding modes for axv2
59 enum:uint8_t { rnd_op_nearest = 0x0 };
61 // In case of avx2 with data type i8 we need to use
62 // maskmovdqu instruction which has its destination hardcoded in rdi.
63 // Windows ABI: abi_param1 is rcx - nothing to do else
64 // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1
65 Reg64 reg_param = rcx; // Our "unified abi_param1"
66 Reg64 reg_ptr_src_i8 = r8;
67 Reg64 reg_ptr_dst_i8 = r9;
68 Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi
76 Reg64 aux_reg_src_h = rax;
77 Reg64 aux_reg_src_w = rbx;
83 Opmask k_cmp_mask = Opmask(7);
85 Opmask mask(int idx) {
86 return Opmask(6 - idx);
89 // ref to any of XYZ-regs via xreg/yreg/vreg functions
90 Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp
91 Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type
92 Vmm vreg_zeros = vreg(1);
94 // only in case of <isa> == avx2
95 Vmm vreg_mask = vreg(2); // full byte-mask
96 Xmm xreg_mask_lo = xreg(2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask)
97 Xmm xreg_mask_hi = xreg(3); // "max" - high 128-bits part of byte-mask (stored separately)
98 Xmm xreg_mask_q = xreg(3); // "avg" - 1/4 part of the mask for s8/u8 operations
99 Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails
101 enum:int {vidx_base = isa == avx2 ? 4 : 2};
102 Vmm base_vr(int idx) const { return vreg(vidx_base + idx); }
104 size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
105 size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
108 Vmm vreg_src(int idx) const { return base_vr(idx); } // [0 .. ur_c-1]
109 Vmm vreg_dst(int idx) const { return base_vr(jpp.ur_c + idx); } // [ur_c .. 2*ur_c-1]
112 // s32 used for processing of s8/u8 data
113 // thus we need to take into account ratio of sizes s32/i8 = 4
114 static constexpr data_type_t avg_proc_dt = data_type::s32;
116 s32_to_i8_ratio = sizeof(typename prec_traits<avg_proc_dt>::type)
117 / sizeof(typename prec_traits<data_type::u8>::type),
118 max_num_ll = s32_to_i8_ratio
120 Vmm vreg_src_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 0*max_num_ll); } // ll: 0..4 [0..3]
121 Vmm vreg_dst_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 1*max_num_ll); } // ll: 0..4 [4..7]
122 Vmm vreg_dst_f32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 2*max_num_ll); } // ll: 0..4 [8..11]
124 void (*ker_)(const call_params_t *);
130 void load_vreg_mask_q(int ll) {};
132 void load_src_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
133 void load_src_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
134 void load_src(int jj, int ll, int c_tail);
136 void store_dst_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
137 void store_dst_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
138 void store_dst(int jj, int ll, int c_tail);
140 void compute_avg_step(int ur_c, int c_tail);
141 void compute_max_op(const int jj);
142 void compute_max_step(int ur_c, int c_tail);
143 void compute_step(int ur_c, int c_tail);
145 void compute_c_block();
148 static status_t init_conf(jit_pool_conf_t &jpp,
149 const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
150 const memory_desc_wrapper &dst_d);
152 jit_uni_i8i8_pooling_fwd_ker_t(const jit_pool_conf_t &jpp_)
155 ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
161 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_vreg_mask_q(int ll) {
163 // extract ll-th part of mask (ll-th QWORD)
164 vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, 0x3 << ll); // 0x3 - mask for 2 x DWORD
166 // Move mask from ll-th pos to 0-th pos
168 vpermq(vreg_mask_q, vreg_mask_q, ll);
172 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_max_op(int jj, int ll,
173 size_t offset, bool masked, uint64_t msk) {
174 using namespace data_type;
177 if (jpp.src_dt == s32) {
178 vpblendd(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], static_cast<uint8_t>(msk));
180 vpblendvb(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], vreg_mask);
183 vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
187 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_max_op(int jj, int ll,
188 size_t offset, bool masked, uint64_t msk) {
189 using namespace data_type;
192 if (jpp.src_dt == s32)
193 vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
195 vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
197 vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
201 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_avg_op(int jj, int ll,
202 size_t offset, bool masked, uint64_t msk) {
203 using namespace data_type;
205 // Don't generate useless code
209 auto load_i8 = [&](bool is_signed, const Vmm& vr_src) {
211 // Need to use mask of tail?
214 // load ll-th part of mask into vreg_mask_q
215 load_vreg_mask_q(ll);
217 // Load by mask from mem into register vr_src
218 vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w + offset], vreg_mask_q);
220 // Conversion s8/u8 -> s32
222 vpmovsxbd(vr_src, vr_src);
224 vpmovzxbd(vr_src, vr_src);
227 // Load from mem into vr_src with conversion
229 vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
231 vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
235 switch (jpp.src_dt) {
238 vpblendd(vreg_src_s32(jj, ll), vreg_zeros, ptr[aux_reg_src_w + offset],
239 static_cast<uint8_t>(msk));
241 vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]);
244 load_i8(true, vreg_src_s32(jj, ll));
247 load_i8(false, vreg_src_s32(jj, ll));
249 default: assert(!"unsupported src data type");
254 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_avg_op(int jj, int ll,
255 size_t offset, bool masked, uint64_t msk) {
256 using namespace data_type;
258 // Don't generate useless code
262 const Vmm& vr_src = masked ?
263 vreg_src_s32(jj, ll) | mask(ll) :
264 vreg_src_s32(jj, ll);
266 switch (jpp.src_dt) {
268 vmovups(vr_src, ptr[aux_reg_src_w + offset]);
271 vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
274 vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
276 default: assert(!"unsupported src data type");
280 template <cpu_isa_t isa>
281 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::load_src(int jj, int ll, int c_tail) {
282 using namespace data_type;
284 int c_block = jpp.c_block;
289 auto offset = jj*c_block*sizeof_src_dt();
290 bool masked = jj == ur_c - 1 && c_tail;
291 load_src_max_op(jj, ll, offset, masked, jpp.tail[0]);
294 case pooling_avg_include_padding:
295 case pooling_avg_exclude_padding: {
296 auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_src_dt();
297 bool masked = jj == ur_c - 1 && c_tail;
298 load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
301 default: assert(!"unsupported algorithm");
306 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_max_op(int jj, int ll,
307 size_t offset, bool masked, uint64_t msk) {
308 using namespace data_type;
310 int c_block = jpp.c_block;
313 switch (jpp.src_dt) {
315 vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj));
319 // Store low half by mask (bytes 0...15)
320 lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
321 maskmovdqu(vreg_dst(jj), xreg_mask_lo);
323 // Do we need to store high half (bytes 16...31) ?
324 const uint64_t low_mask = (1ULL << (c_block/2))-1;
325 if (msk & ~low_mask) {
326 vextracti128(Xmm(vreg_dst(jj).getIdx()), vreg_dst(jj), 1);
327 add(reg_ptr_maskmovdqu_dst, c_block / 2);
328 maskmovdqu(vreg_dst(jj), xreg_mask_hi);
331 default: assert(!"unsupported src data type");
334 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
338 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_max_op(int jj, int ll,
339 size_t offset, bool masked, uint64_t msk) {
340 using namespace data_type;
343 switch (jpp.src_dt) {
345 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
349 vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
351 default: assert(!"unsupported src data type");
354 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
358 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_avg_op(int jj, int ll,
359 size_t offset, bool masked, uint64_t msk){
360 using namespace data_type;
362 // Don't generate useless code
366 auto s32_to_i8 = [&](bool is_signed, const Vmm& vr_dst) {
368 // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16}
369 // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0}
371 vpackssdw(vr_dst, vr_dst, vreg_zeros);
373 vpackusdw(vr_dst, vr_dst, vreg_zeros);
375 // Permute qwords to restore original order
376 // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0}
377 vpermq(vr_dst, vr_dst, 0x58);
379 // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8}
380 // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx}
382 vpacksswb(vr_dst, vr_dst, vreg_zeros);
384 vpackuswb(vr_dst, vr_dst, vreg_zeros);
388 auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm& vr_dst) {
390 // Conversion s32 -> s8/u8
391 s32_to_i8(is_signed, vr_dst);
393 // Need to use mask of tail?
395 // load ll-th part of mask into vreg_mask_q
396 load_vreg_mask_q(ll);
400 lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
401 maskmovdqu(vr_dst, xreg_mask_q);
404 switch (jpp.dst_dt) {
407 vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_s32(jj, ll));
409 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll));
412 store_i8(true, masked, vreg_dst_s32(jj, ll));
415 store_i8(false, masked, vreg_dst_s32(jj, ll));
417 default: assert(!"unsuppotred dst data_type");
422 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_avg_op(int jj, int ll,
423 size_t offset, bool masked, uint64_t msk) {
424 using namespace data_type;
426 // Don't generate useless code
430 const Vmm& vr_dst = masked ?
431 vreg_dst_s32(jj, ll) | mask(ll) :
432 vreg_dst_s32(jj, ll);
434 switch (jpp.dst_dt) {
436 vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst);
439 vpmovdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
442 vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
444 default: assert(!"unsupported dst data_type");
449 template <cpu_isa_t isa>
450 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::store_dst(int jj, int ll,
452 using namespace data_type;
454 int c_block = jpp.c_block;
459 auto offset = jj*c_block*sizeof_dst_dt();
460 bool masked = jj == ur_c - 1 && c_tail;
461 store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]);
464 case pooling_avg_include_padding:
465 case pooling_avg_exclude_padding: {
466 auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_dst_dt();
467 bool masked = jj == ur_c - 1 && c_tail;
468 store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
471 default: assert(!"unsupported pooling algorithm");
476 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::compute_max_op(const int jj)
478 using namespace data_type;
479 switch (jpp.src_dt) {
481 vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
484 vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
487 vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
489 default: assert(!"unsupported src data type");
494 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::compute_max_op(const int jj)
496 using namespace data_type;
499 switch (jpp.src_dt) {
501 vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
504 vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
507 vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
509 default: assert(!"unsupported src data type");
512 // move max values into vreg_dst
513 if (jpp.src_dt == s32)
514 vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
516 vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
520 template <cpu_isa_t isa>
521 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_max_step(int ur_c, int c_tail)
528 for (int jj = 0; jj < ur_c; jj++)
529 vmovups(vreg_dst(jj), vreg_tmp);
531 mov(aux_reg_src_h, reg_ptr_src_i8);
536 mov(aux_reg_src_w, aux_reg_src_h);
540 for (int jj = 0; jj < ur_c; jj++) {
541 load_src(jj, 0, c_tail);
544 add(aux_reg_src_w, c * sizeof_src_dt());
549 add(aux_reg_src_h, iw * c * sizeof_src_dt());
555 for (int jj = 0; jj < ur_c; jj++)
556 store_dst(jj, 0, c_tail);
559 template <cpu_isa_t isa>
560 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_tail)
562 using namespace data_type;
569 const int num_ll = data_type_size(avg_proc_dt)/data_type_size(jpp.src_dt);
571 for (int jj = 0; jj < ur_c; jj++) {
572 for (int ll = 0; ll < num_ll; ll++) {
573 bool masked = jj == ur_c - 1 && c_tail;
574 size_t msk = jpp.tail[ll];
575 if (!(masked && !msk)) {
576 uni_vpxor(vreg_src_s32(jj, ll), vreg_src_s32(jj, ll), vreg_src_s32(jj, ll));
577 uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll));
582 mov(aux_reg_src_h, reg_ptr_src_i8);
587 mov(aux_reg_src_w, aux_reg_src_h);
591 for (int jj = 0; jj < ur_c; jj++) {
592 for (int ll = 0; ll < num_ll; ll++) {
593 bool masked = jj == ur_c - 1 && c_tail;
594 size_t msk = jpp.tail[ll];
595 if (!(masked && !msk)) {
596 load_src(jj, ll, c_tail);
597 vpaddd(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll),
598 vreg_src_s32(jj, ll));
602 add(aux_reg_src_w, c * sizeof_src_dt());
607 add(aux_reg_src_h, iw * c * sizeof_src_dt());
613 for (int jj = 0; jj < ur_c; jj++) {
614 for (int ll = 0; ll < num_ll; ll++) {
615 bool masked = jj == ur_c - 1 && c_tail;
616 size_t msk = jpp.tail[ll];
617 if (!(masked && !msk)) {
619 vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll));
620 vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp);
623 uni_vroundps(vreg_dst_f32(jj, ll), vreg_dst_f32(jj, ll), rnd_op_nearest);
624 vcvtps2dq(vreg_dst_s32(jj, ll), vreg_dst_f32(jj, ll));
625 } else if (isa >= avx512_common) {
626 // AVX512: use of EVEX-embedded static rounding override
627 vcvtps2dq(vreg_dst_s32(jj, ll) | T_rn_sae, vreg_dst_f32(jj, ll));
630 store_dst(jj, ll, c_tail);
636 template <cpu_isa_t isa>
637 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_step(int ur_c, int c_tail) {
640 compute_max_step(ur_c, c_tail); break;
641 case pooling_avg_include_padding:
642 case pooling_avg_exclude_padding:
643 compute_avg_step(ur_c, c_tail); break;
644 default: assert(!"unsupported pooling algorithm");
648 template <cpu_isa_t isa>
649 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_c_block(){
653 int c_block = jpp.c_block;
655 int ur_c_tail = jpp.ur_c_tail;
656 int c_steps = nb_c / ur_c;
657 int c_tail = jpp.c_tail;
659 xor_(c_iter, c_iter);
662 compute_step(ur_c, 0);
663 add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt());
664 add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt());
666 cmp(c_iter, c_steps);
667 jl(l_main_loop, T_NEAR);
671 if (ur_c_tail != 0) {
672 compute_step(ur_c_tail, c_tail);
677 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::init_mask() {
678 using namespace data_type;
679 using cpu_isa = cpu_isa_traits<avx2>;
681 // AVX2 mask initialization: mask stored in Ymm-regs
682 auto init = [&](uint64_t bit_mask, bool init_mask_q) {
683 const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t);
685 uint64_t vmask[QW_PER_VREG];
686 for (size_t i = 0; i < QW_PER_VREG; i++){
688 uint64_t qw_vmask=0ULL;
689 const size_t DBITS = 8*sizeof_src_dt();
690 const uint64_t VMSK = 1ULL << (DBITS-1);
691 const size_t D_PER_QW = (8*sizeof(qw_vmask))/DBITS;
692 for (size_t j = 0; j < D_PER_QW; j++) {
694 qw_vmask |= VMSK << DBITS * j;
700 // Put QWORDS with target mask into xmm regs
701 const int xdst_i[QW_PER_VREG] = {
702 xreg_mask_lo.getIdx(),
703 xreg_mask_lo.getIdx(),
704 xreg_mask_hi.getIdx(),
705 xreg_mask_hi.getIdx()
707 const int xsrc_i[QW_PER_VREG] = {
708 vreg_zeros.getIdx(), // 0-th qword insert in zeros -> {qw0, 0}
709 xreg_mask_lo.getIdx(), // 1-st and 0-th merge -> {qw0,qw1}
711 xreg_mask_hi.getIdx()
713 const uint8 qw_dst_idx[QW_PER_VREG] = {0, 1, 0, 1}; // qword index in 128-bit xreg
715 for (size_t i = 0; i < QW_PER_VREG; i++) {
716 mov(reg_mask, vmask[i]);
717 vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, qw_dst_idx[i]);
720 // Merge Low (xreg_mask_lo alias for vreg_mask.xreg)
721 // and High (xreg_mask_hi) into full vreg_mask
722 // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg}
723 vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1);
725 // Keep only low qword of mask in xreg_mask_q
727 mov(reg_mask, vmask[0]);
728 vpinsrq(xreg_mask_q, Xmm(vreg_zeros.getIdx()), reg_mask, 0);
732 uint64_t tail_mask = (1ULL << jpp.c_tail) - 1;
735 // For "max" we need mask only in case of non-zero tail
737 init(tail_mask, false);
739 case pooling_avg_include_padding:
740 case pooling_avg_exclude_padding:
741 // For "avg" we need mask:
742 // - s32 - in case of the non-zero tail
743 // - s8/u8 - irrespective of the tail
744 switch (jpp.src_dt) {
747 init(tail_mask, false);
751 init(tail_mask ? tail_mask : ~0ULL, tail_mask == 0);
753 default: assert(!"unsupported src data type");
756 default: assert(!"unsupported pooling algorithm");
761 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::init_mask() {
763 for (int ll = 0; ll < max_num_ll; ll++) {
764 mov(reg_mask, jpp.tail[ll]);
765 kmovq(mask(ll), reg_mask);
769 template <cpu_isa_t isa>
770 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_tmp_reg() {
771 using namespace data_type;
774 case pooling_avg_include_padding:
775 case pooling_avg_exclude_padding:
776 mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]);
777 movq(xmm_tmp, reg_tmp);
778 vpbroadcastd(vreg_tmp, xmm_tmp);
781 switch (jpp.src_dt) {
783 mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest());
786 mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest());
789 mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest());
791 default: assert(!"unsupported src data_type");
794 movq(xmm_tmp, reg_tmp);
795 if (jpp.src_dt == s32)
796 vpbroadcastd(vreg_tmp, xmm_tmp);
798 vpbroadcastb(vreg_tmp, xmm_tmp);
800 default: assert(!"unsupported pooling algorithm");
805 template <cpu_isa_t isa>
806 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::generate() {
810 // Always use rcx as abi_param1 -
811 // see the note about maskmovdqu near reg_param.
815 # define READ_PARAM(reg, field) \
816 mov(reg, ptr[reg_param + offsetof(call_params_t, field)])
817 READ_PARAM(reg_ptr_src_i8, src_i8);
818 READ_PARAM(reg_ptr_dst_i8, dst_i8);
819 READ_PARAM(reg_kw, kw_range);
820 READ_PARAM(reg_kh, kh_range);
824 uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
835 template <cpu_isa_t isa>
836 status_t jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jit_pool_conf_t &jpp,
837 const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
838 const memory_desc_wrapper &dst_d) {
840 return status::unimplemented;
842 jpp.mb = src_d.dims()[0];
843 jpp.c = src_d.dims()[1];
844 jpp.ih = src_d.dims()[2];
845 jpp.iw = src_d.dims()[3];
846 jpp.oh = dst_d.dims()[2];
847 jpp.ow = dst_d.dims()[3];
849 jpp.stride_h = pd.strides[0];
850 jpp.stride_w = pd.strides[1];
851 jpp.kh = pd.kernel[0];
852 jpp.kw = pd.kernel[1];
854 jpp.t_pad = pd.padding[0][0];
855 jpp.l_pad = pd.padding[0][1];
857 int right_pad = (jpp.ow - 1) * jpp.stride_w
858 + jpp.kw - 1 - (jpp.iw + jpp.l_pad - 1);
859 int bottom_pad = (jpp.oh - 1) * jpp.stride_h
860 + jpp.kh - 1 - (jpp.ih + jpp.t_pad - 1);
862 if (jpp.t_pad >= jpp.kh || jpp.l_pad >= jpp.kw
863 || bottom_pad >= jpp.kh || right_pad >= jpp.kw)
864 return status::unimplemented;
866 jpp.alg = pd.alg_kind;
868 jpp.src_dt = pd.src_desc.data_type;
869 jpp.dst_dt = pd.dst_desc.data_type;
871 // data_type items per one vreg on the <isa>
872 // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32
873 // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32
874 int simd_w = cpu_isa_traits<isa>::vlen / data_type_size(jpp.src_dt);
876 jpp.c_block = simd_w;
877 jpp.c_tail = jpp.c % jpp.c_block;
878 jpp.nb_c = jpp.c / jpp.c_block;
880 jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c +
883 size_t tail_mask = (1ULL << jpp.c_tail) - 1;
887 jpp.tail[0] = tail_mask;
892 case pooling_avg_include_padding:
893 case pooling_avg_exclude_padding: {
894 // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32)
895 // avx2 : 8, avx512 : 16
896 const size_t msk_gran = cpu_isa_traits<isa>::vlen / data_type_size(avg_proc_dt);
897 const size_t msk_msk = (1ULL << msk_gran) - 1;
898 size_t m = tail_mask;
899 for (size_t ll = 0; ll < max_num_ll; ll++) {
900 jpp.tail[ll] = m & msk_msk;
905 default: return status::unimplemented;
908 return status::success;
911 template <cpu_isa_t isa>
912 status_t jit_uni_i8i8_pooling_fwd_t<isa>::pd_t::jit_conf() {
913 return jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jpp_,
914 desc_, src_pd_.desc(), dst_pd_.desc());
917 template <cpu_isa_t isa>
918 jit_uni_i8i8_pooling_fwd_t<isa>::
919 jit_uni_i8i8_pooling_fwd_t(const pd_t *apd,
920 const input_vector &inputs, const output_vector &outputs)
921 : cpu_primitive_t(apd, inputs, outputs), ker_(nullptr)
922 { ker_ = new jit_uni_i8i8_pooling_fwd_ker_t<isa>(pd()->jpp_); }
924 template <cpu_isa_t isa>
925 jit_uni_i8i8_pooling_fwd_t<isa>::
926 ~jit_uni_i8i8_pooling_fwd_t() { delete ker_; }
928 template <cpu_isa_t isa>
929 void jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward() const {
930 auto src_i8 = reinterpret_cast<const char *>(input_memory(0));
931 auto dst_i8 = reinterpret_cast<char *>(memory());
933 const memory_desc_wrapper src_d(pd()->src_pd());
934 const memory_desc_wrapper dst_d(pd()->dst_pd());
936 const auto &jpp = pd()->jpp_;
938 parallel_nd(jpp.mb, jpp.oh, jpp.ow,
939 [&](int n, int oh, int ow) {
940 const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0);
941 const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0);
943 const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
944 const int kh_end = nstl::min(jpp.kh,
945 jpp.ih + jpp.t_pad - oh * jpp.stride_h);
946 const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
947 const int kw_end = nstl::min(jpp.kw,
948 jpp.iw + jpp.l_pad - ow * jpp.stride_w);
950 auto p = typename jit_uni_i8i8_pooling_fwd_ker_t<isa>::call_params_t();
952 src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
954 dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
955 p.kw_range = (size_t)(kw_end - kw_start);
956 p.kh_range = (size_t)(kh_end - kh_start);
957 p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ?
958 p.kh_range*p.kw_range : jpp.kw*jpp.kh);
964 // Explicit instantiation only for supported <isa> values.
966 template struct jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>;
967 template struct jit_uni_i8i8_pooling_fwd_t<avx512_core>;
969 template struct jit_uni_i8i8_pooling_fwd_ker_t<avx2>;
970 template struct jit_uni_i8i8_pooling_fwd_t<avx2>;