1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
22 #include "jit_generator.hpp"
24 #include "jit_uni_depthwise.hpp"
26 #define GET_OFF(field) offsetof(jit_args, field)
32 using namespace Xbyak;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
44 struct jit_uni_depthwise_kernel_f32 : public c_compatible {
45 const depthwise_desc_t &desc_;
46 void (*ker_)(const jit_args *);
49 void operator()(const jit_args *args) { assert(ker_); ker_(args); }
51 jit_uni_depthwise_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
52 : desc_(desc), ker_(nullptr), with_bias_(with_bias) {}
53 virtual ~jit_uni_depthwise_kernel_f32() {}
56 template <cpu_isa_t isa>
57 int jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(alg_kind_t depthwise_alg) {
58 switch (depthwise_alg) {
59 case alg_kind::depthwise_scale_shift: return 0;
60 case alg_kind::depthwise_prelu: return 2;
61 default: assert(!"unsupported depthwise algorithm");
67 template <cpu_isa_t isa>
68 void jit_uni_depthwise_injector_f32<isa>::injector_preamble(size_t start_idx, size_t end_idx) {
69 preserved_vecs_count = 0;
70 vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(depthwise_alg);
72 for (size_t i = 0; i < vecs_count; i++) {
73 if (preserved_vecs_count >= vecs_to_preserve)
76 if (i < start_idx || i >= end_idx) {
77 preserved_vec_idxs[preserved_vecs_count] = i;
78 preserved_vecs_count++;
82 start_idx_tail = start_idx;
83 size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
84 for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
85 preserved_vec_idxs[preserved_vecs_count] = start_idx + i;
86 preserved_vecs_count++;
87 start_idx_tail = start_idx + i + 1;
90 h->sub(h->rsp, preserved_vecs_count * vlen);
91 for (size_t i = 0; i < preserved_vecs_count; ++i)
92 h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[i]));
97 template <cpu_isa_t isa>
98 void jit_uni_depthwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx, size_t end_idx) {
99 size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
100 int idx_off = (vecs_to_preserve - tail_vecs_to_preserve);
102 if (tail_vecs_to_preserve > 0) {
103 h->add(h->rsp, idx_off * vlen);
104 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
105 h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), h->ptr[h->rsp + i * vlen]);
107 for (size_t i = 0; i < tail_vecs_to_preserve; ++i) {
108 preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
111 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
112 h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[idx_off + i]));
113 h->sub(h->rsp, idx_off * vlen);
119 template <cpu_isa_t isa>
120 void jit_uni_depthwise_injector_f32<isa>::injector_postamble() {
121 for (size_t i = 0; i < preserved_vecs_count; ++i)
122 h->uni_vmovups(Vmm(preserved_vec_idxs[i]), h->ptr[h->rsp + i * vlen]);
123 h->add(h->rsp, preserved_vecs_count * vlen);
126 template <cpu_isa_t isa>
127 void jit_uni_depthwise_injector_f32<isa>::assign_regs() {
128 vmm_mask = Vmm(preserved_vec_idxs[0]);
129 vmm_aux0 = Vmm(preserved_vec_idxs[1]);
132 template <cpu_isa_t isa>
133 void jit_uni_depthwise_injector_f32<isa>::scale_shift_compute_vector(const Vmm &vmm_src,
134 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
135 h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights]);
136 h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias]);
139 template <cpu_isa_t isa>
140 void jit_uni_depthwise_injector_f32<isa>::prelu_compute_vector(const Vmm &vmm_src,
141 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
142 const unsigned char _cmp_gt_os = 6;
143 const unsigned char _cmp_lt_os = 1;
146 h->pxor(vmm_mask, vmm_mask);
147 h->cmpps(vmm_mask, vmm_src, _cmp_gt_os);
148 h->movups(vmm_aux0, vmm_src);
149 h->mulps(vmm_aux0, h->ptr[p_weights]);
150 h->blendvps(vmm_src, vmm_aux0);
151 } else if (isa == avx2) {
152 h->vxorps(vmm_mask, vmm_mask, vmm_mask);
153 h->vcmpgtps(vmm_mask, vmm_src, vmm_mask);
154 h->vmulps(vmm_aux0, vmm_src, h->ptr[p_weights]);
155 h->vblendvps(vmm_src, vmm_aux0, vmm_src, vmm_mask);
156 } else if (isa == avx512_common) {
157 h->vxorpd(vmm_mask, vmm_mask, vmm_mask);
158 h->vmovups(vmm_aux0, vmm_src);
159 h->vcmpps(k_mask, vmm_src, vmm_mask, _cmp_lt_os);
160 h->vmulps(vmm_src | k_mask, vmm_aux0, h->ptr[p_weights]);
164 template <cpu_isa_t isa>
165 void jit_uni_depthwise_injector_f32<isa>::compute_body(size_t start_idx, size_t end_idx,
166 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
167 for (size_t idx = start_idx; idx < end_idx; idx++) {
168 switch (depthwise_alg) {
169 case alg_kind::depthwise_scale_shift:
170 scale_shift_compute_vector(Vmm(idx), p_weights, p_bias); break;
171 case alg_kind::depthwise_prelu:
172 prelu_compute_vector(Vmm(idx), p_weights, p_bias); break;
173 default: assert(!"unsupported depthwise algorithm");
178 template <cpu_isa_t isa>
179 void jit_uni_depthwise_injector_f32<isa>::compute_vector_range(int start_idx, int end_idx,
180 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
181 injector_preamble(start_idx, end_idx);
182 compute_body(start_idx_tail, end_idx, p_weights, p_bias);
183 injector_preamble_tail(start_idx, end_idx);
184 compute_body(start_idx, start_idx_tail, p_weights, p_bias);
185 injector_postamble();
188 template struct jit_uni_depthwise_injector_f32<avx512_common>;
189 template struct jit_uni_depthwise_injector_f32<avx2>;
190 template struct jit_uni_depthwise_injector_f32<sse42>;
195 template <cpu_isa_t isa>
196 struct jit_uni_scale_shift_kernel_f32 : public jit_uni_depthwise_kernel_f32,
199 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_scale_shift_kernel_f32)
200 jit_uni_scale_shift_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
201 : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
202 assert(desc.alg_kind == alg_kind::depthwise_scale_shift);
203 assert(isa == sse42 || isa == avx2 || isa == avx512_common);
205 bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw ;
207 Reg64 param = abi_param1;
209 const int block_size = isa == avx512_common ? 16 : 8;
210 const int main_loop_step = isFlat ? block_size : 1;
214 mov(reg_from, ptr[param + GET_OFF(from)]);
215 mov(reg_to, ptr[param + GET_OFF(to)]);
216 mov(reg_scale, ptr[param + GET_OFF(weights)]);
217 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
219 mov(reg_shift, ptr[param + GET_OFF(bias)]);
221 Label main_loop_label;
222 Label tail_loop_label;
225 int repeats = isa == sse42 ? 2 : 1;
226 for (int i = 0; i < repeats; i++) {
228 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
230 uni_vbroadcastss(get_shift_reg(i), ptr[reg_shift]);
232 uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
234 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
236 uni_vmovups(get_shift_reg(i), ptr[reg_shift + i*4*sizeof(float)]);
238 uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
243 uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
245 uni_vbroadcastss(xmm_shift, ptr[reg_shift]);
247 uni_vpxor(xmm_shift, xmm_shift, xmm_shift);
250 L(main_loop_label); {
251 cmp(reg_work_amount, main_loop_step-1);
252 jle(tail_loop_label, T_NEAR);
254 int repeats = isa == sse42 ? 2 : 1;
255 for (int i = 0; i < repeats; i++) {
256 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
257 uni_vmovups(vmm_dst, get_shift_reg(i));
258 uni_vfmadd231ps(vmm_dst, vmm_src, get_scale_reg(i));
259 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
262 add(reg_from, block_size*sizeof(float));
263 add(reg_to, block_size*sizeof(float));
264 sub(reg_work_amount, main_loop_step);
266 jmp(main_loop_label, T_NEAR);
269 L(tail_loop_label); {
270 cmp(reg_work_amount, 0);
271 jle(exit_label, T_NEAR);
273 movss(xmm_src, ptr[reg_from]);
274 uni_vmovups(xmm_dst, xmm_shift);
275 uni_vfmadd231ps(xmm_dst, xmm_src, xmm_scale);
276 movss(ptr[reg_to], xmm_dst);
278 add(reg_from, 1*sizeof(float));
279 add(reg_to, 1*sizeof(float));
280 dec(reg_work_amount);
282 jmp(tail_loop_label, T_NEAR);
289 ker_ = (decltype(ker_))this->getCode();
293 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
294 isa == avx2, Ymm, Zmm>::type;
296 inline Vmm get_scale_reg(int idx) { return Vmm(idx + 2); }
297 inline Vmm get_shift_reg(int idx) { return Vmm(idx + 4); }
301 Reg64 reg_work_amount = r10;
302 Reg64 reg_scale = r11;
303 Reg64 reg_shift = r12;
305 Vmm vmm_src = Vmm(0);
306 Vmm vmm_dst = Vmm(1);
308 Xmm xmm_src = Xmm(0);
309 Xmm xmm_dst = Xmm(1);
310 Xmm xmm_scale = Xmm(6);
311 Xmm xmm_shift = Xmm(7);
314 template <cpu_isa_t isa>
315 struct jit_uni_prelu_kernel_f32 : public jit_uni_depthwise_kernel_f32,
318 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_prelu_kernel_f32)
319 jit_uni_prelu_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
320 : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
321 assert(desc.alg_kind == alg_kind::depthwise_prelu);
322 assert(isa == sse42 || isa == avx2 || isa == avx512_common);
324 bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw;
326 Reg64 param = abi_param1;
328 const int block_size = isa == avx512_common ? 16 : 8;
329 const int main_loop_step = isFlat ? block_size : 1;
333 mov(reg_from, ptr[param + GET_OFF(from)]);
334 mov(reg_to, ptr[param + GET_OFF(to)]);
335 mov(reg_scale, ptr[param + GET_OFF(weights)]);
336 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
338 uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
340 int repeats = isa == sse42 ? 2 : 1;
341 for (int i = 0; i < repeats; i++) {
343 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
345 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
350 uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
353 Label main_loop_label;
354 Label tail_loop_label;
357 L(main_loop_label); {
358 cmp(reg_work_amount, main_loop_step-1);
359 jle(tail_loop_label, T_NEAR);
361 for (int i = 0; i < repeats; i++) {
362 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
365 pxor(vmm_mask, vmm_mask);
366 cmpps(vmm_mask, vmm_src, _cmp_gt_os);
367 movups(vmm_dst, vmm_src);
368 mulps(vmm_src, get_scale_reg(i));
369 blendvps(vmm_dst, vmm_src);
370 } else if (isa == avx2) {
371 vcmpgtps(vmm_mask, vmm_src, vmm_zero);
372 vmulps(vmm_dst, vmm_src, get_scale_reg(i));
373 vblendvps(vmm_dst, vmm_dst, vmm_src, vmm_mask);
374 } else if (isa == avx512_common) {
375 Opmask kmask = Opmask(7);
376 vmovups(vmm_dst, vmm_src);
377 vcmpps(kmask, vmm_src, vmm_zero, _cmp_lt_os);
378 vmulps(vmm_dst | kmask, vmm_src, get_scale_reg(i));
381 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
384 add(reg_from, block_size*sizeof(float));
385 add(reg_to, block_size*sizeof(float));
386 sub(reg_work_amount, main_loop_step);
388 jmp(main_loop_label, T_NEAR);
391 L(tail_loop_label); {
392 cmp(reg_work_amount, 0);
393 jle(exit_label, T_NEAR);
395 movss(xmm_src, ptr[reg_from]);
397 pxor(xmm_mask, xmm_mask);
398 cmpps(xmm_mask, xmm_src, _cmp_gt_os);
399 movups(xmm_dst, xmm_src);
400 mulps(xmm_src, xmm_scale);
401 blendvps(xmm_dst, xmm_src);
403 movss(ptr[reg_to], xmm_dst);
405 add(reg_from, 1*sizeof(float));
406 add(reg_to, 1*sizeof(float));
407 dec(reg_work_amount);
409 jmp(tail_loop_label, T_NEAR);
416 ker_ = (decltype(ker_))this->getCode();
420 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
421 isa == avx2, Ymm, Zmm>::type;
423 inline Vmm get_scale_reg(int idx) { return Vmm(idx + 4); }
427 Reg64 reg_work_amount = r10;
428 Reg64 reg_scale = r11;
430 Vmm vmm_mask = Vmm(0);
431 Vmm vmm_src = Vmm(1);
432 Vmm vmm_zero = Vmm(2);
433 Vmm vmm_dst = Vmm(3);
435 Xmm xmm_mask = Xmm(0);
436 Xmm xmm_src = Xmm(1);
437 Xmm xmm_dst = Xmm(3);
438 Xmm xmm_scale = Xmm(4);
440 const unsigned char _cmp_gt_os = 6;
441 const unsigned char _cmp_lt_os = 1;
446 template <cpu_isa_t isa>
447 status_t jit_uni_depthwise_fwd_t<isa>::pd_t::init() {
448 using namespace alg_kind;
450 auto desired_blk_fmt = isa == avx512_common ? nChw16c : nChw8c;
452 assert(engine()->kind() == engine_kind::cpu);
453 bool ok = true && mayiuse(isa)
454 && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
455 prop_kind::forward_inference)
456 && utils::everyone_is(data_type::f32, desc()->src_desc.data_type, desc()->dst_desc.data_type)
457 && desc()->src_desc.format == desc()->dst_desc.format
458 && utils::one_of(desc()->src_desc.format, desired_blk_fmt, nchw)
459 && utils::one_of(desc()->dst_desc.format, desired_blk_fmt, nchw)
460 && utils::one_of(desc()->weights_desc.format, x)
461 && utils::implication(this->with_bias(), x == desc()->bias_desc.format)
462 && attr()->has_default_values();
464 return ok ? status::success : status::unimplemented;
467 template <cpu_isa_t isa>
468 jit_uni_depthwise_fwd_t<isa>::jit_uni_depthwise_fwd_t(const pd_t *pd,
469 const input_vector &inputs, const output_vector &outputs)
470 : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr),
471 padded_weights_(nullptr), padded_bias_(nullptr) {
472 const auto &desc = *conf_.desc();
473 switch (desc.alg_kind) {
474 case alg_kind::depthwise_scale_shift:
475 kernel_ = new jit_uni_scale_shift_kernel_f32<isa>(desc, pd->with_bias()); break;
476 case alg_kind::depthwise_prelu:
477 kernel_ = new jit_uni_prelu_kernel_f32<isa>(desc, pd->with_bias()); break;
478 default: assert(!"unknown depthwise alg_kind");
481 const int simd_w = isa == avx512_common ? 16 : 8;
482 const memory_desc_wrapper data_d(conf_.src_pd());
483 const int c_without_padding = data_d.dims()[1];
484 const int c_padded = rnd_up(c_without_padding, simd_w);
486 if (conf_.want_padded_weights()) {
487 padded_weights_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
488 for (int oc = c_without_padding; oc < c_padded; ++oc)
489 padded_weights_[oc] = 0;
491 if (conf_.with_bias()) {
492 padded_bias_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
493 for (int oc = c_without_padding; oc < c_padded; ++oc)
494 padded_bias_[oc] = 0;
499 template <cpu_isa_t isa>
500 jit_uni_depthwise_fwd_t<isa>::~jit_uni_depthwise_fwd_t() {
502 free(padded_weights_);
506 template <cpu_isa_t isa>
507 void jit_uni_depthwise_fwd_t<isa>::execute_forward() {
508 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
509 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
510 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
511 auto dst = reinterpret_cast<data_t *>(this->memory());
513 const memory_desc_wrapper data_d(conf_.src_pd());
514 const memory_desc_wrapper weights_d(conf_.weights_pd(0));
515 const memory_desc_wrapper bias_d(conf_.weights_pd(1));
517 const int N = data_d.dims()[0];
518 const int C = data_d.dims()[1];
519 const int H = data_d.dims()[2];
520 const int W = data_d.dims()[3];
522 const int simd_w = isa == avx512_common ? 16 : 8;
523 const int ch_block_size = data_d.format() == nchw ? 1 : simd_w;
524 const int CB = div_up(C, ch_block_size);
526 if (conf_.want_padded_weights()) {
527 for (int oc = 0; oc < C; ++oc)
528 padded_weights_[oc] = weights[oc];
529 weights = padded_weights_;
531 if (conf_.with_bias()) {
532 for (int oc = 0; oc < C; ++oc)
533 padded_bias_[oc] = bias[oc];
538 parallel_nd(N, CB, H,
539 [&](int n, int cb, int h) {
542 arg.from = &src[data_d.blk_off(n, cb, h)];
543 arg.to = &dst[data_d.blk_off(n, cb, h)];
544 arg.weights = &weights[weights_d.blk_off(cb * ch_block_size)];
546 arg.bias = &bias[bias_d.blk_off(cb * ch_block_size)];
547 arg.work_amount = (size_t)W;
553 template struct jit_uni_depthwise_fwd_t<sse42>;
554 template struct jit_uni_depthwise_fwd_t<avx2>;
555 template struct jit_uni_depthwise_fwd_t<avx512_common>;
558 #define GET_OFF_DW(field) offsetof(jit_conv_call_s, field)
560 template <cpu_isa_t isa>
561 void jit_uni_dw_conv_row_f32<isa>::load_src(int ur_w) {
562 int repeats = isa == sse42 ? 2 : 1;
563 for (int i = 0; i < repeats; i++) {
564 for (int ow = 0; ow < ur_w; ow++) {
565 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
567 if (this->jcp.with_bias)
568 uni_vmovups(vmm_acc, vmmword[reg_bias + i*4*sizeof(float)]);
570 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
572 int o_off = ow*jcp.ch_block + i*4;
573 if (this->jcp.with_sum)
574 uni_vaddps(vmm_acc, vmm_acc,
575 vmmword[reg_output + o_off*sizeof(float)]);
580 template <cpu_isa_t isa>
581 void jit_uni_dw_conv_row_f32<isa>::apply_filter(int ur_w, int kw_size) {
582 int ch_blk = jcp.ch_block;
583 int stride_w = jcp.stride_w;
587 int repeats = isa == sse42 ? 2 : 1;
590 jl(exit_label, T_NEAR);
591 for (int i = 0; i < repeats; i++) {
592 for (int kw = 0; kw < kw_size; kw++) {
593 int ker_off = kw * ch_blk + i*4;
595 Vmm vmm_ker = get_ker_reg(0);
596 uni_vmovups(vmm_ker, ptr[aux_reg_kernel
597 + ker_off * sizeof(float)]);
599 for (int ow = 0; ow < ur_w; ow++) {
600 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
602 Vmm vmm_src = get_src_reg(0);
603 uni_vmovups(vmm_src, ptr[aux_reg_input0
604 + inp_off * sizeof(float)]);
606 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
607 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
611 add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
614 jl(exit_label, T_NEAR);
615 for (int i = 0; i < repeats; i++) {
616 for (int kw = 0; kw < kw_size; kw++) {
617 int ker_off = kw * ch_blk + i*4;
619 Vmm vmm_ker = get_ker_reg(0);
620 uni_vmovups(vmm_ker, ptr[aux_reg_kernel
621 + ker_off * sizeof(float)]);
623 for (int ow = 0; ow < ur_w; ow++) {
624 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
626 Vmm vmm_src = get_src_reg(0);
627 uni_vmovups(vmm_src, ptr[aux_reg_input1
628 + inp_off * sizeof(float)]);
630 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
631 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
635 add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
638 jl(exit_label, T_NEAR);
639 for (int i = 0; i < repeats; i++) {
640 for (int kw = 0; kw < kw_size; kw++) {
641 int ker_off = kw * ch_blk + i*4;
643 Vmm vmm_ker = get_ker_reg(0);
644 uni_vmovups(vmm_ker, ptr[aux_reg_kernel
645 + ker_off * sizeof(float)]);
647 for (int ow = 0; ow < ur_w; ow++) {
648 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*4;
650 Vmm vmm_src = get_src_reg(0);
651 uni_vmovups(vmm_src, ptr[aux_reg_input2
652 + inp_off * sizeof(float)]);
654 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
655 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
663 template <cpu_isa_t isa>
664 void jit_uni_dw_conv_row_f32<isa>::apply_activation(int ur_w) {
665 if (this->jcp.with_eltwise) {
666 int repeats = isa == sse42 ? 2 : 1;
667 eltwise_injector->compute_vector_range(4, repeats * ur_w + 4);
671 template <cpu_isa_t isa>
672 void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w) {
673 int repeats = isa == sse42 ? 2 : 1;
674 for (int i = 0; i < repeats; i++) {
675 for (int ow = 0; ow < ur_w; ow++) {
676 int o_off = ow*jcp.ch_block + i*4;
677 Vmm vmm_dst = get_acc_reg(i*ur_w + ow);
679 uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst);
684 template <cpu_isa_t isa>
685 void jit_uni_dw_conv_row_f32<isa>::loop_body() {
686 Label left_pad_label;
687 Label right_pad_label;
688 Label unrolled_w_label;
694 int kw = jcp.iw == 1 ? jcp.kw - 2 : jcp.kw - 1;
696 mov(aux_reg_input0, reg_input0);
697 mov(aux_reg_input1, reg_input1);
698 mov(aux_reg_input2, reg_input2);
699 mov(aux_reg_kernel, reg_kernel);
700 add(aux_reg_kernel, jcp.ch_block*sizeof(float));
703 apply_filter(ur_w, kw);
704 apply_activation(ur_w);
707 add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
708 add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
709 add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * (jcp.stride_w-1));
711 add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
716 L(unrolled_w_label); {
721 jle(tail_w_label, T_NEAR);
723 mov(aux_reg_input0, reg_input0);
724 mov(aux_reg_input1, reg_input1);
725 mov(aux_reg_input2, reg_input2);
726 mov(aux_reg_kernel, reg_kernel);
729 apply_filter(ur_w, kw);
730 apply_activation(ur_w);
733 add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
734 add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
735 add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
736 add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
739 jmp(unrolled_w_label, T_NEAR);
748 jle(right_pad_label, T_NEAR);
750 jle(exit_label, T_NEAR);
752 mov(aux_reg_input0, reg_input0);
753 mov(aux_reg_input1, reg_input1);
754 mov(aux_reg_input2, reg_input2);
755 mov(aux_reg_kernel, reg_kernel);
758 apply_filter(ur_w, kw);
759 apply_activation(ur_w);
762 add(reg_input0, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
763 add(reg_input1, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
764 add(reg_input2, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
765 add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
768 jmp(tail_w_label, T_NEAR);
772 L(right_pad_label); {
774 int kw = jcp.kw - ((jcp.stride_w == 1) ? 1 : jcp.iw % jcp.stride_w);
776 mov(aux_reg_input0, reg_input0);
777 mov(aux_reg_input1, reg_input1);
778 mov(aux_reg_input2, reg_input2);
779 mov(aux_reg_kernel, reg_kernel);
782 apply_filter(ur_w, kw);
783 apply_activation(ur_w);
793 template <cpu_isa_t isa>
794 void jit_uni_dw_conv_row_f32<isa>::generate()
798 mov(reg_input0, ptr[this->param1 + GET_OFF_DW(src_row0)]);
799 mov(reg_input1, ptr[this->param1 + GET_OFF_DW(src_row1)]);
800 mov(reg_input2, ptr[this->param1 + GET_OFF_DW(src_row2)]);
801 mov(reg_output, ptr[this->param1 + GET_OFF_DW(dst)]);
802 mov(reg_kernel, ptr[this->param1 + GET_OFF_DW(filt)]);
804 mov(reg_bias, ptr[this->param1 + GET_OFF_DW(bias)]);
805 mov(reg_kh, ptr[this->param1 + GET_OFF_DW(kh_padding)]);
806 mov(reg_ur_w, ptr[this->param1 + GET_OFF_DW(ur_w)]);
812 if (jcp.with_eltwise)
813 eltwise_injector->prepare_table();
816 template <cpu_isa_t isa>
817 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_conv_conf_t &jcp,
818 int ic, int ih, int iw, int oh, int ow, int ker_h, int ker_w, int str_h, int str_w, alg_kind_t eltwise_alg,
819 float eltwise_alpha, float eltwise_beta, bool with_sum) {
820 if (!mayiuse(isa)) return status::unimplemented;
821 const int simd_w = isa == avx512_common ? 16 : 8;
825 jcp.ch_block = simd_w;
826 jcp.with_bias = true;
833 jcp.stride_h = str_h;
834 jcp.stride_w = str_w;
836 if (jcp.kh != 3 || jcp.kw != 3)
837 return status::unimplemented;
841 jcp.with_eltwise = eltwise_alg != mkldnn_alg_kind_undef;
842 jcp.eltwise_alg = eltwise_alg;
843 jcp.eltwise_alpha = eltwise_alpha;
844 jcp.eltwise_beta = eltwise_beta;
845 jcp.with_sum = with_sum;
847 return status::success;
850 template struct jit_uni_dw_conv_row_f32<avx512_common>;
851 template struct jit_uni_dw_conv_row_f32<avx2>;
852 template struct jit_uni_dw_conv_row_f32<sse42>;