1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "mkldnn_types.h"
21 #include "mkldnn_thread.hpp"
24 #include "jit_generator.hpp"
26 #include "jit_avx512_core_i8i8_pooling.hpp"
32 using namespace Xbyak;
34 using namespace mkldnn::impl::utils;
35 using namespace mkldnn::impl::memory_format;
36 using namespace mkldnn::impl::utils;
37 using namespace mkldnn::impl::types;
38 using namespace alg_kind;
40 struct jit_avx512_core_i8i8_pool_fwd_ker_t: public jit_generator {
41 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_i8i8_pool_fwd_ker_t)
43 struct call_params_t {
51 Reg64 reg_ptr_src_i8 = r8;
52 Reg64 reg_ptr_dst_i8 = r9;
60 Reg64 aux_reg_src_h = rax;
61 Reg64 aux_reg_src_w = rbx;
67 Opmask k_cmp_mask = Opmask(7);
69 Opmask mask(int idx) {
70 return Opmask(6 - idx);
74 Zmm vreg_tmp = Zmm(30);
75 Zmm vreg_zeros = Zmm(31);
77 size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
78 size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
81 Zmm vreg_src(int idx) {
85 Zmm vreg_dst(int idx) {
86 return Zmm(jpp.ur_c + idx);
90 Zmm vreg_src_s32(int jj, int ll) {
91 return Zmm(12*jj + ll);
94 Zmm vreg_dst_s32(int jj, int ll) {
95 return Zmm(12*jj + ll + 4);
98 Zmm vreg_dst_f32(int jj, int ll) {
99 return Zmm(12*jj + ll + 8);
102 void (*ker_)(const call_params_t *);
108 void load_src(int jj, int ll, int c_tail);
109 void store_dst(int jj, int ll, int c_tail);
111 void compute_avg_step(int ur_c, int c_tail);
112 void compute_max_step(int ur_c, int c_tail);
113 void compute_step(int ur_c, int c_tail);
115 void compute_c_block();
118 static status_t init_conf(jit_pool_conf_t &jpp,
119 const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
120 const memory_desc_wrapper &dst_d);
122 jit_avx512_core_i8i8_pool_fwd_ker_t(const jit_pool_conf_t &jpp_)
125 ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
130 void jit_avx512_core_i8i8_pool_fwd_ker_t::load_src(int jj, int ll, int c_tail) {
131 using namespace data_type;
133 int c_block = jpp.c_block;
138 auto offset = jj*c_block*sizeof_src_dt();
139 if (jj == ur_c - 1 && c_tail) {
140 if (jpp.src_dt == data_type::s32) {
141 vmovups(vreg_src(jj) | mask(0),
142 ptr[aux_reg_src_w + offset]);
144 vmovdqu8(vreg_src(jj) | mask(0),
145 ptr[aux_reg_src_w + offset]);
148 vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
152 case pooling_avg_include_padding:
153 case pooling_avg_exclude_padding: {
154 auto offset = (ll*(c_block/4) + jj*c_block)*sizeof_src_dt();
155 if (jj == jpp.ur_c - 1 && c_tail) {
157 switch (jpp.src_dt) {
159 vmovups(vreg_src_s32(jj, ll) | mask(ll),
160 ptr[aux_reg_src_w + offset]);
163 vpmovsxbd(vreg_src_s32(jj, ll) | mask(ll),
164 ptr[aux_reg_src_w + offset]);
167 vpmovzxbd(vreg_src_s32(jj, ll) | mask(ll),
168 ptr[aux_reg_src_w + offset]);
170 default: assert(!"unsupported src data type");
174 switch (jpp.src_dt) {
176 vmovups(vreg_src_s32(jj, ll),
177 ptr[aux_reg_src_w + offset]);
180 vpmovsxbd(vreg_src_s32(jj, ll),
181 ptr[aux_reg_src_w + offset]);
184 vpmovzxbd(vreg_src_s32(jj, ll),
185 ptr[aux_reg_src_w + offset]);
187 default: assert(!"unsupported src data type");
192 default: assert(!"unsupported algorithm");
196 void jit_avx512_core_i8i8_pool_fwd_ker_t::store_dst(int jj, int ll,
198 using namespace data_type;
200 int c_block = jpp.c_block;
205 auto offset = jj*c_block*sizeof_dst_dt();
206 if (jj == ur_c - 1 && c_tail) {
207 if (jpp.src_dt == data_type::s32) {
208 vmovups(ptr[reg_ptr_dst_i8 + offset],
209 vreg_dst(jj) | mask(0));
211 vmovdqu8(ptr[reg_ptr_dst_i8 + offset],
212 vreg_dst(jj) | mask(0));
215 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
219 case pooling_avg_include_padding:
220 case pooling_avg_exclude_padding: {
221 auto offset = (ll*(c_block/4) + jj*c_block)*sizeof_dst_dt();
222 if (jj == ur_c - 1 && c_tail) {
224 switch (jpp.dst_dt) {
226 vmovups(ptr[reg_ptr_dst_i8 + offset],
227 vreg_dst_s32(jj, ll) | mask(ll));
230 vpmovdb(ptr[reg_ptr_dst_i8 + offset],
231 vreg_dst_s32(jj, ll) | mask(ll));
234 vpmovusdb(ptr[reg_ptr_dst_i8 + offset],
235 vreg_dst_s32(jj, ll) | mask(ll));
237 default: assert(!"unsupported dst data_type");
241 switch (jpp.dst_dt) {
243 vmovups(ptr[reg_ptr_dst_i8 + offset],
244 vreg_dst_s32(jj, ll));
247 vpmovdb(ptr[reg_ptr_dst_i8 + offset],
248 vreg_dst_s32(jj, ll));
251 vpmovusdb(ptr[reg_ptr_dst_i8 + offset],
252 vreg_dst_s32(jj, ll));
254 default: assert(!"unsuppotred dst data_type");
259 default: assert(!"unsupported pooling algorithm");
263 void jit_avx512_core_i8i8_pool_fwd_ker_t::compute_max_step(int ur_c, int c_tail)
270 for (int jj = 0; jj < ur_c; jj++)
271 vmovups(vreg_dst(jj), vreg_tmp);
273 mov(aux_reg_src_h, reg_ptr_src_i8);
278 mov(aux_reg_src_w, aux_reg_src_h);
282 for (int jj = 0; jj < ur_c; jj++) {
283 load_src(jj, 0, c_tail);
284 if (jpp.src_dt == data_type::s32) {
285 vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
286 vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj),
289 if (jpp.src_dt == data_type::s8)
290 vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj),
293 vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj),
295 vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj),
299 add(aux_reg_src_w, c * sizeof_src_dt());
304 add(aux_reg_src_h, iw * c * sizeof_src_dt());
310 for (int jj = 0; jj < ur_c; jj++)
311 store_dst(jj, 0, c_tail);
314 void jit_avx512_core_i8i8_pool_fwd_ker_t::compute_avg_step(int ur_c, int c_tail)
316 using namespace data_type;
323 int num_ll = jpp.src_dt == data_type::s32 ? 1 : 4;
325 for (int jj = 0; jj < ur_c; jj++) {
326 for (int ll = 0; ll < 4; ll++) {
327 uni_vpxor(vreg_src_s32(jj, ll),
328 vreg_src_s32(jj, ll), vreg_src_s32(jj, ll));
329 uni_vpxor(vreg_dst_s32(jj, ll),
330 vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll));
334 mov(aux_reg_src_h, reg_ptr_src_i8);
339 mov(aux_reg_src_w, aux_reg_src_h);
343 for (int jj = 0; jj < ur_c; jj++) {
344 for (int ll = 0; ll < num_ll; ll++) {
345 load_src(jj, ll, c_tail);
346 vpaddd(vreg_dst_s32(jj, ll),
347 vreg_dst_s32(jj, ll), vreg_src_s32(jj, ll));
350 add(aux_reg_src_w, c * sizeof_src_dt());
355 add(aux_reg_src_h, iw * c * sizeof_src_dt());
361 for (int jj = 0; jj < ur_c; jj++) {
362 for (int ll = 0; ll < num_ll; ll++) {
363 vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll));
364 vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp);
365 vcvtps2dq(vreg_dst_s32(jj, ll) | T_rn_sae, vreg_dst_f32(jj, ll));
367 store_dst(jj, ll, c_tail);
372 void jit_avx512_core_i8i8_pool_fwd_ker_t::compute_step(int ur_c, int c_tail) {
375 compute_max_step(ur_c, c_tail); break;
376 case pooling_avg_include_padding:
377 case pooling_avg_exclude_padding:
378 compute_avg_step(ur_c, c_tail); break;
379 default: assert(!"unsupported pooling algorithm");
383 void jit_avx512_core_i8i8_pool_fwd_ker_t::compute_c_block(){
387 int c_block = jpp.c_block;
389 int ur_c_tail = jpp.ur_c_tail;
390 int c_steps = nb_c / ur_c;
391 int c_tail = jpp.c_tail;
393 xor_(c_iter, c_iter);
396 compute_step(ur_c, 0);
397 add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt());
398 add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt());
400 cmp(c_iter, c_steps);
401 jl(l_main_loop, T_NEAR);
405 if (ur_c_tail != 0) {
406 compute_step(ur_c_tail, c_tail);
410 void jit_avx512_core_i8i8_pool_fwd_ker_t::init_mask() {
411 for (int i = 0; i < 4; i++) {
412 mov(reg_mask, jpp.tail[i]);
413 kmovq(mask(i), reg_mask);
417 void jit_avx512_core_i8i8_pool_fwd_ker_t::init_tmp_reg() {
418 using namespace data_type;
421 case pooling_avg_include_padding:
422 case pooling_avg_exclude_padding:
423 mov(reg_tmp, ptr[abi_param1 + offsetof(call_params_t, idivider)]);
424 movq(xmm_tmp, reg_tmp);
425 vpbroadcastd(vreg_tmp, xmm_tmp);
428 switch (jpp.src_dt) {
430 mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest());
433 mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest());
436 mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest());
438 default: assert(!"unsupported src data_type");
441 movq(xmm_tmp, reg_tmp);
442 if (jpp.src_dt == s32)
443 vpbroadcastd(vreg_tmp, xmm_tmp);
445 vpbroadcastb(vreg_tmp, xmm_tmp);
447 default: assert(!"unsupported pooling algorithm");
452 void jit_avx512_core_i8i8_pool_fwd_ker_t::generate() {
455 # define READ_PARAM(reg, field) \
456 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
457 READ_PARAM(reg_ptr_src_i8, src_i8);
458 READ_PARAM(reg_ptr_dst_i8, dst_i8);
459 READ_PARAM(reg_kw, kw_range);
460 READ_PARAM(reg_kh, kh_range);
467 uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
474 status_t jit_avx512_core_i8i8_pool_fwd_ker_t::init_conf(jit_pool_conf_t &jpp,
475 const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
476 const memory_desc_wrapper &dst_d) {
477 if (!mayiuse(avx512_core)) {
478 return status::unimplemented;
481 jpp.mb = src_d.dims()[0];
482 jpp.c = src_d.dims()[1];
483 jpp.ih = src_d.dims()[2];
484 jpp.iw = src_d.dims()[3];
485 jpp.oh = dst_d.dims()[2];
486 jpp.ow = dst_d.dims()[3];
488 jpp.stride_h = pd.strides[0];
489 jpp.stride_w = pd.strides[1];
490 jpp.kh = pd.kernel[0];
491 jpp.kw = pd.kernel[1];
493 jpp.t_pad = pd.padding[0][0];
494 jpp.l_pad = pd.padding[0][1];
496 jpp.alg = pd.alg_kind;
498 jpp.src_dt = pd.src_desc.data_type;
499 jpp.dst_dt = pd.dst_desc.data_type;
501 jpp.c_block = 64 / (jpp.src_dt == data_type::s32 ? 4 : 1);
502 jpp.c_tail = jpp.c % jpp.c_block;
503 jpp.nb_c = jpp.c / jpp.c_block;
505 jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c +
508 size_t tail_mask = (1ULL << jpp.c_tail) - 1;
512 jpp.tail[0] = tail_mask;
517 case pooling_avg_include_padding:
518 case pooling_avg_exclude_padding:
519 jpp.tail[0] = tail_mask & 0xffff;
520 for (size_t i = 1, m = tail_mask; i < 4; i++) {
522 jpp.tail[i] = m & 0xffff;
525 default: return status::unimplemented;
528 return status::success;
531 status_t jit_avx512_core_i8i8_pooling_fwd_t::pd_t::jit_conf() {
532 return jit_avx512_core_i8i8_pool_fwd_ker_t::init_conf(jpp_,
533 desc_, src_pd_.desc(), dst_pd_.desc());
536 jit_avx512_core_i8i8_pooling_fwd_t::
537 jit_avx512_core_i8i8_pooling_fwd_t(const pd_t *pd,
538 const input_vector &inputs, const output_vector &outputs)
539 : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), ker_(nullptr)
540 { ker_ = new jit_avx512_core_i8i8_pool_fwd_ker_t(conf_.jpp_); }
542 jit_avx512_core_i8i8_pooling_fwd_t::
543 ~jit_avx512_core_i8i8_pooling_fwd_t() { delete ker_; }
545 void jit_avx512_core_i8i8_pooling_fwd_t::execute_forward() {
546 auto src_i8 = reinterpret_cast<const char *>(input_memory(0));
547 auto dst_i8 = reinterpret_cast<char *>(memory());
549 const memory_desc_wrapper src_d(conf_.src_pd());
550 const memory_desc_wrapper dst_d(conf_.dst_pd());
552 const auto &jpp = conf_.jpp_;
554 parallel_nd(jpp.mb, jpp.oh, jpp.ow,
555 [&](int n, int oh, int ow) {
556 const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0);
557 const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0);
559 const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
560 const int kh_end = nstl::min(jpp.kh,
561 jpp.ih + jpp.t_pad - oh * jpp.stride_h);
562 const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
563 const int kw_end = nstl::min(jpp.kw,
564 jpp.iw + jpp.l_pad - ow * jpp.stride_w);
566 auto p = jit_avx512_core_i8i8_pool_fwd_ker_t::call_params_t();
568 src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
570 dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
571 p.kw_range = (size_t)(kw_end - kw_start);
572 p.kh_range = (size_t)(kh_end - kh_start);
573 p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ?
574 p.kh_range*p.kw_range : jpp.kw*jpp.kh);