1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "mkldnn_types.h"
21 #include "mkldnn_thread.hpp"
24 #include "jit_generator.hpp"
26 #include "jit_sse42_i8i8_pooling.hpp"
32 using namespace Xbyak;
34 using namespace mkldnn::impl::utils;
35 using namespace mkldnn::impl::memory_format;
36 using namespace mkldnn::impl::utils;
37 using namespace mkldnn::impl::types;
38 using namespace alg_kind;
40 struct call_params_t {
48 struct jit_sse42_i8i8_pool_fwd_ker_t : public jit_generator {
49 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_i8i8_pool_fwd_ker_t)
51 Reg64 reg_ptr_src_i8 = r8;
52 Reg64 reg_ptr_dst_i8 = r9;
60 Reg64 aux_reg_src_h = rax;
61 Reg64 aux_reg_src_w = rbx;
64 Reg64 reg_src_64 = r15;
65 Reg32 reg_src_32 = r15d;
66 Reg8 reg_src_8 = r15b;
68 size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
69 size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
72 Xmm vreg_tmp = Xmm(14);
73 Xmm vreg_zeros = Xmm(15);
76 Xmm vmm_src(int jj, int ii) {
77 return Xmm(2*jj + ii);
84 Xmm vmm_dst(int jj, int ii) {
85 return Xmm(2*jj + ii + 2 * jpp.ur_c);
89 return Xmm(2*jj + 2 * jpp.ur_c);
93 Xmm vmm_src_s32(int jj, int ii) {
94 return Xmm(2*jj + ii);
97 Xmm xmm_src_s32(int jj, int ii) {
98 return Xmm(2*jj + ii);
101 Xmm vmm_dst_s32(int jj, int ii) {
102 return Xmm(2*jj + ii + 2 * jpp.ur_c);
105 Ymm ymm_dst_s32(int jj, int ii) {
106 return Ymm(2*jj + ii + 2 * jpp.ur_c);
109 Xmm xmm_dst_s32(int jj, int ii) {
110 return Xmm(2*jj + ii + 2 * jpp.ur_c);
113 Xmm vmm_dst_f32(int jj, int ii) {
114 return Xmm(2*jj + ii + 4 * jpp.ur_c);
117 void (*ker_)(const call_params_t *);
122 void load_src(int jj, int c_step);
123 void store_dst(int jj, int c_step);
125 void compute_avg_step(int ur_c, int c_step);
126 void compute_max_step(int ur_c, int c_step);
127 void compute_step(int ur_c, int c_step);
129 void compute_c_block();
132 static status_t init_conf(jit_pool_conf_t &jpp,
133 const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
134 const memory_desc_wrapper &dst_d);
136 jit_sse42_i8i8_pool_fwd_ker_t(const jit_pool_conf_t &jpp_)
139 ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
144 void jit_sse42_i8i8_pool_fwd_ker_t::load_src(int jj, int c_step) {
145 using namespace data_type;
147 int repeats = c_step != 1 ? 2 : 1;
150 auto offset = jj*c_step*sizeof_src_dt();
151 if (c_step == jpp.c_block) {
152 for (int ii = 0; ii < repeats; ii++)
153 uni_vmovups(vmm_src(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
154 } else if (c_step == 1) {
155 if (jpp.src_dt == s32) {
156 movsd(xmm_src(jj), ptr[aux_reg_src_w + offset]);
158 mov(reg_src_8, ptr[aux_reg_src_w + offset]);
159 movq(xmm_src(jj), reg_src_64);
164 case pooling_avg_include_padding:
165 case pooling_avg_exclude_padding: {
166 auto offset = jj*c_step*sizeof_src_dt();
167 switch (jpp.src_dt) {
169 if (c_step == jpp.c_block) {
170 for (int ii = 0; ii < repeats; ii++)
171 uni_vmovups(vmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
172 } else if (c_step == 1) {
173 movsd(xmm_src_s32(jj, 0), ptr[aux_reg_src_w + offset]);
177 if (c_step == jpp.c_block) {
178 for (int ii = 0; ii < repeats; ii++) {
179 movd(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
181 uni_vpmovsxbd(vmm_src_s32(jj, ii), xmm_src_s32(jj, ii));
183 } else if (c_step == 1) {
184 movsx(reg_src_32, ptr[aux_reg_src_w + offset]);
185 movq(xmm_src_s32(jj, 0), reg_src_64);
189 if (c_step == jpp.c_block) {
190 for (int ii = 0; ii < repeats; ii++) {
191 movd(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
193 uni_vpmovzxbd(vmm_src_s32(jj, ii), xmm_src_s32(jj, ii));
195 } else if (c_step == 1) {
196 movzx(reg_src_32, ptr[aux_reg_src_w + offset]);
197 movq(xmm_src_s32(jj, 0), reg_src_64);
200 default: assert(!"unsupported src data type");
204 default: assert(!"unsupported algorithm");
208 void jit_sse42_i8i8_pool_fwd_ker_t::store_dst(int jj, int c_step) {
209 using namespace data_type;
211 int repeats = c_step != 1 ? 2 : 1;
214 auto offset = jj*c_step*sizeof_dst_dt();
215 if (c_step == jpp.c_block) {
216 for (int ii = 0; ii < repeats; ii++)
217 uni_vmovups(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], vmm_dst(jj, ii));
218 } else if (c_step == 1) {
219 if (jpp.src_dt == s32) {
220 movq(reg_src_64, xmm_dst(jj));
221 mov(ptr[reg_ptr_dst_i8 + offset], reg_src_32);
223 movq(reg_src_64, xmm_dst(jj));
224 mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
229 case pooling_avg_include_padding:
230 case pooling_avg_exclude_padding: {
231 auto offset = jj*c_step*sizeof_dst_dt();
232 switch (jpp.dst_dt) {
234 if (c_step == jpp.c_block) {
235 for (int ii = 0; ii < repeats; ii++)
236 uni_vmovups(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], vmm_dst_s32(jj, ii));
237 } else if (c_step == 1) {
238 movq(reg_src_64, xmm_dst_s32(jj, 0));
239 mov(ptr[reg_ptr_dst_i8 + offset], reg_src_32);
243 if (c_step == jpp.c_block) {
244 for (int ii = 0; ii < repeats; ii++) {
245 uni_vpackssdw(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
246 uni_vpacksswb(xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii));
248 movd(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
250 } else if (c_step == 1) {
251 vpackssdw(vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0));
252 vpacksswb(xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0));
253 movq(reg_src_64, xmm_dst_s32(jj, 0));
254 mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
258 if (c_step == jpp.c_block) {
259 for (int ii = 0; ii < repeats; ii++) {
260 uni_vpackusdw(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
261 uni_vpackuswb(xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii));
263 movd(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
265 } else if (c_step == 1) {
266 vpackusdw(vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0));
267 vpackuswb(xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0));
268 movq(reg_src_64, xmm_dst_s32(jj, 0));
269 mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
272 default: assert(!"unsuppotred dst data_type");
276 default: assert(!"unsupported pooling algorithm");
280 void jit_sse42_i8i8_pool_fwd_ker_t::compute_max_step(int ur_c, int c_step)
287 int repeats = c_step != 1 ? 2 : 1;
289 for (int jj = 0; jj < ur_c; jj++) {
290 for (int ii = 0; ii < repeats; ii++) {
291 uni_vmovups(vmm_dst(jj, ii), vreg_tmp);
295 mov(aux_reg_src_h, reg_ptr_src_i8);
300 mov(aux_reg_src_w, aux_reg_src_h);
304 for (int jj = 0; jj < ur_c; jj++) {
305 load_src(jj, c_step);
307 for (int ii = 0; ii < repeats; ii++) {
308 if (jpp.src_dt == data_type::s32) {
309 uni_vpmaxsd(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
311 if (jpp.src_dt == data_type::s8)
312 uni_vpmaxsb(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
314 uni_vpmaxub(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
318 add(aux_reg_src_w, c * sizeof_src_dt());
323 add(aux_reg_src_h, iw * c * sizeof_src_dt());
329 for (int jj = 0; jj < ur_c; jj++)
330 store_dst(jj, c_step);
333 void jit_sse42_i8i8_pool_fwd_ker_t::compute_avg_step(int ur_c, int c_step)
335 using namespace data_type;
342 int repeats = c_step != 1 ? 2 : 1;
344 for (int jj = 0; jj < ur_c; jj++) {
345 for (int ii = 0; ii < repeats; ii++) {
346 uni_vpxor(vmm_src_s32(jj, ii), vmm_src_s32(jj, ii), vmm_src_s32(jj, ii));
347 uni_vpxor(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
351 mov(aux_reg_src_h, reg_ptr_src_i8);
356 mov(aux_reg_src_w, aux_reg_src_h);
360 for (int jj = 0; jj < ur_c; jj++) {
361 load_src(jj, c_step);
363 for (int ii = 0; ii < repeats; ii++) {
364 uni_vpaddd(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_src_s32(jj, ii));
367 add(aux_reg_src_w, c * sizeof_src_dt());
372 add(aux_reg_src_h, iw * c * sizeof_src_dt());
378 for (int jj = 0; jj < ur_c; jj++) {
379 for (int ii = 0; ii < repeats; ii++) {
380 uni_vcvtdq2ps(vmm_dst_f32(jj, ii), vmm_dst_s32(jj, ii));
382 mulps(vmm_dst_f32(jj, ii), vreg_tmp);
384 uni_vcvtps2dq(vmm_dst_s32(jj, ii), vmm_dst_f32(jj, ii));
387 store_dst(jj, c_step);
391 void jit_sse42_i8i8_pool_fwd_ker_t::compute_step(int ur_c, int c_step) {
394 compute_max_step(ur_c, c_step); break;
395 case pooling_avg_include_padding:
396 case pooling_avg_exclude_padding:
397 compute_avg_step(ur_c, c_step); break;
398 default: assert(!"unsupported pooling algorithm");
402 void jit_sse42_i8i8_pool_fwd_ker_t::compute_c_block() {
409 xor_(c_iter, c_iter);
413 cmp(c_iter, jpp.c - ur_c * jpp.c_block);
414 jg(l_tail_loop, T_NEAR);
416 compute_step(ur_c, jpp.c_block);
418 add(reg_ptr_src_i8, ur_c * jpp.c_block * sizeof_src_dt());
419 add(reg_ptr_dst_i8, ur_c * jpp.c_block * sizeof_dst_dt());
420 add(c_iter, ur_c * jpp.c_block);
426 cmp(c_iter, jpp.c - ur_c);
429 compute_step(ur_c, 1);
431 add(reg_ptr_src_i8, ur_c * sizeof_src_dt());
432 add(reg_ptr_dst_i8, ur_c * sizeof_dst_dt());
440 void jit_sse42_i8i8_pool_fwd_ker_t::init_tmp_reg() {
441 using namespace data_type;
444 case pooling_avg_include_padding:
445 case pooling_avg_exclude_padding:
446 mov(reg_tmp, ptr[abi_param1 + offsetof(call_params_t, idivider)]);
447 movq(xmm_tmp, reg_tmp);
448 uni_vpbroadcastd(vreg_tmp, xmm_tmp);
451 switch (jpp.src_dt) {
453 mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest());
456 mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest());
459 mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest());
461 default: assert(!"unsupported src data_type");
464 movq(xmm_tmp, reg_tmp);
465 if (jpp.src_dt == s32) {
466 uni_vpbroadcastd(vreg_tmp, xmm_tmp);
468 movups(vreg_tmp, xmm_tmp);
469 uni_vpxor(xmm_tmp, xmm_tmp, xmm_tmp);
470 pshufb(vreg_tmp, xmm_tmp);
473 default: assert(!"unsupported pooling algorithm");
478 void jit_sse42_i8i8_pool_fwd_ker_t::generate() {
481 # define READ_PARAM(reg, field) \
482 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
483 READ_PARAM(reg_ptr_src_i8, src_i8);
484 READ_PARAM(reg_ptr_dst_i8, dst_i8);
485 READ_PARAM(reg_kw, kw_range);
486 READ_PARAM(reg_kh, kh_range);
492 uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
499 status_t jit_sse42_i8i8_pool_fwd_ker_t::init_conf(jit_pool_conf_t &jpp,
500 const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
501 const memory_desc_wrapper &dst_d) {
502 if (!mayiuse(sse42)) {
503 return status::unimplemented;
506 jpp.mb = src_d.dims()[0];
507 jpp.c = src_d.dims()[1];
508 jpp.ih = src_d.dims()[2];
509 jpp.iw = src_d.dims()[3];
510 jpp.oh = dst_d.dims()[2];
511 jpp.ow = dst_d.dims()[3];
513 jpp.stride_h = pd.strides[0];
514 jpp.stride_w = pd.strides[1];
515 jpp.kh = pd.kernel[0];
516 jpp.kw = pd.kernel[1];
518 jpp.t_pad = pd.padding[0][0];
519 jpp.l_pad = pd.padding[0][1];
521 jpp.alg = pd.alg_kind;
523 jpp.src_dt = pd.src_desc.data_type;
524 jpp.dst_dt = pd.dst_desc.data_type;
526 jpp.c_block = jpp.alg == pooling_max ? 32 / (jpp.src_dt == data_type::s32 ? 4 : 1) : 8;
527 jpp.c_tail = jpp.c % jpp.c_block;
528 jpp.nb_c = jpp.c / jpp.c_block;
530 jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c + (jpp.c_tail != 0);
532 return status::success;
535 status_t jit_sse42_i8i8_pooling_fwd_t::pd_t::jit_conf() {
536 return jit_sse42_i8i8_pool_fwd_ker_t::init_conf(jpp_,
537 desc_, src_pd_.desc(), dst_pd_.desc());
540 jit_sse42_i8i8_pooling_fwd_t::jit_sse42_i8i8_pooling_fwd_t(const pd_t *apd,
541 const input_vector &inputs, const output_vector &outputs)
542 : cpu_primitive_t(apd, inputs, outputs), ker_(nullptr)
543 { ker_ = new jit_sse42_i8i8_pool_fwd_ker_t(pd()->jpp_); }
545 jit_sse42_i8i8_pooling_fwd_t::~jit_sse42_i8i8_pooling_fwd_t() {
549 void jit_sse42_i8i8_pooling_fwd_t::execute_forward() const {
550 auto src_i8 = reinterpret_cast<const char *>(input_memory(0));
551 auto dst_i8 = reinterpret_cast<char *>(memory());
553 const memory_desc_wrapper src_d(pd()->src_pd());
554 const memory_desc_wrapper dst_d(pd()->dst_pd());
556 const auto &jpp = pd()->jpp_;
558 parallel_nd(jpp.mb, jpp.oh, jpp.ow,
559 [&](int n, int oh, int ow) {
560 const int ih = nstl::max(oh * jpp.stride_h - jpp.t_pad, 0);
561 const int iw = nstl::max(ow * jpp.stride_w - jpp.l_pad, 0);
563 const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
564 const int kh_end = nstl::min(jpp.kh,
565 jpp.ih + jpp.t_pad - oh * jpp.stride_h);
566 const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
567 const int kw_end = nstl::min(jpp.kw,
568 jpp.iw + jpp.l_pad - ow * jpp.stride_w);
570 auto p = call_params_t();
572 src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
574 dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
575 p.kw_range = (size_t) (kw_end - kw_start);
576 p.kh_range = (size_t) (kh_end - kh_start);
577 p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ?
578 p.kh_range * p.kw_range : jpp.kw * jpp.kh);