1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
22 #include "jit_generator.hpp"
24 #include "jit_uni_depthwise.hpp"
26 #define GET_OFF(field) offsetof(jit_args, field)
32 using namespace Xbyak;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
44 struct jit_uni_depthwise_kernel_f32 : public c_compatible {
45 const depthwise_desc_t &desc_;
46 void (*ker_)(const jit_args *);
49 void operator()(const jit_args *args) { assert(ker_); ker_(args); }
51 jit_uni_depthwise_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
52 : desc_(desc), ker_(nullptr), with_bias_(with_bias) {}
53 virtual ~jit_uni_depthwise_kernel_f32() {}
56 template <cpu_isa_t isa>
57 int jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(alg_kind_t depthwise_alg) {
58 switch (depthwise_alg) {
59 case alg_kind::depthwise_scale_shift: return isa == sse42 ? 1 : 0;
60 case alg_kind::depthwise_prelu: return 2;
61 default: assert(!"unsupported depthwise algorithm");
67 template <cpu_isa_t isa>
68 void jit_uni_depthwise_injector_f32<isa>::injector_preamble(size_t start_idx, size_t end_idx) {
69 preserved_vecs_count = 0;
70 vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(depthwise_alg);
72 for (size_t i = 0; i < vecs_count; i++) {
73 if (preserved_vecs_count >= vecs_to_preserve)
76 if (i < start_idx || i >= end_idx) {
77 preserved_vec_idxs[preserved_vecs_count] = i;
78 preserved_vecs_count++;
82 start_idx_tail = start_idx;
83 size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
84 for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
85 preserved_vec_idxs[preserved_vecs_count] = start_idx + i;
86 preserved_vecs_count++;
87 start_idx_tail = start_idx + i + 1;
90 h->sub(h->rsp, preserved_vecs_count * vlen);
91 for (size_t i = 0; i < preserved_vecs_count; ++i)
92 h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[i]));
97 template <cpu_isa_t isa>
98 void jit_uni_depthwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx, size_t end_idx) {
99 size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
100 int idx_off = (vecs_to_preserve - tail_vecs_to_preserve);
102 if (tail_vecs_to_preserve > 0) {
103 h->add(h->rsp, idx_off * vlen);
104 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
105 h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), h->ptr[h->rsp + i * vlen]);
107 for (size_t i = 0; i < tail_vecs_to_preserve; ++i) {
108 preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
111 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
112 h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[idx_off + i]));
113 h->sub(h->rsp, idx_off * vlen);
119 template <cpu_isa_t isa>
120 void jit_uni_depthwise_injector_f32<isa>::injector_postamble() {
121 for (size_t i = 0; i < preserved_vecs_count; ++i)
122 h->uni_vmovups(Vmm(preserved_vec_idxs[i]), h->ptr[h->rsp + i * vlen]);
123 h->add(h->rsp, preserved_vecs_count * vlen);
126 template <cpu_isa_t isa>
127 void jit_uni_depthwise_injector_f32<isa>::assign_regs() {
128 vmm_mask = Vmm(preserved_vec_idxs[0]);
129 vmm_aux0 = Vmm(preserved_vec_idxs[1]);
132 template <cpu_isa_t isa>
133 void jit_uni_depthwise_injector_f32<isa>::scale_shift_compute_vector(const Vmm &vmm_src,
134 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
136 h->movups(vmm_mask, h->ptr[p_weights]);
137 h->mulps(vmm_src, vmm_mask);
138 h->movups(vmm_mask, h->ptr[p_bias]);
139 h->addps(vmm_src, vmm_mask);
141 h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights]);
142 h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias]);
146 template <cpu_isa_t isa>
147 void jit_uni_depthwise_injector_f32<isa>::prelu_compute_vector(const Vmm &vmm_src,
148 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
149 const unsigned char _cmp_gt_os = 6;
150 const unsigned char _cmp_lt_os = 1;
153 h->pxor(vmm_mask, vmm_mask);
154 h->cmpps(vmm_mask, vmm_src, _cmp_gt_os);
155 h->movups(vmm_aux0, h->ptr[p_weights]);
156 h->mulps(vmm_aux0, vmm_src);
157 h->blendvps(vmm_src, vmm_aux0);
158 } else if (isa == avx2) {
159 h->vxorps(vmm_mask, vmm_mask, vmm_mask);
160 h->vcmpgtps(vmm_mask, vmm_src, vmm_mask);
161 h->vmulps(vmm_aux0, vmm_src, h->ptr[p_weights]);
162 h->vblendvps(vmm_src, vmm_aux0, vmm_src, vmm_mask);
163 } else if (isa == avx512_common) {
164 h->vxorpd(vmm_mask, vmm_mask, vmm_mask);
165 h->vmovups(vmm_aux0, vmm_src);
166 h->vcmpps(k_mask, vmm_src, vmm_mask, _cmp_lt_os);
167 h->vmulps(vmm_src | k_mask, vmm_aux0, h->ptr[p_weights]);
171 template <cpu_isa_t isa>
172 void jit_uni_depthwise_injector_f32<isa>::compute_body(size_t start_idx, size_t end_idx,
173 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
174 for (size_t idx = start_idx; idx < end_idx; idx++) {
175 switch (depthwise_alg) {
176 case alg_kind::depthwise_scale_shift:
177 scale_shift_compute_vector(Vmm(idx), p_weights, p_bias); break;
178 case alg_kind::depthwise_prelu:
179 prelu_compute_vector(Vmm(idx), p_weights, p_bias); break;
180 default: assert(!"unsupported depthwise algorithm");
185 template <cpu_isa_t isa>
186 void jit_uni_depthwise_injector_f32<isa>::compute_vector_range(int start_idx, int end_idx,
187 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
188 injector_preamble(start_idx, end_idx);
189 compute_body(start_idx_tail, end_idx, p_weights, p_bias);
190 injector_preamble_tail(start_idx, end_idx);
191 compute_body(start_idx, start_idx_tail, p_weights, p_bias);
192 injector_postamble();
195 template struct jit_uni_depthwise_injector_f32<avx512_common>;
196 template struct jit_uni_depthwise_injector_f32<avx2>;
197 template struct jit_uni_depthwise_injector_f32<sse42>;
202 template <cpu_isa_t isa>
203 struct jit_uni_scale_shift_kernel_f32 : public jit_uni_depthwise_kernel_f32,
206 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_scale_shift_kernel_f32)
207 jit_uni_scale_shift_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
208 : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
209 assert(desc.alg_kind == alg_kind::depthwise_scale_shift);
210 assert(isa == sse42 || isa == avx2 || isa == avx512_common);
212 bool isFlat = (desc.src_desc.format == nchw && desc.dst_desc.format == nchw) ||
213 (desc.src_desc.format == ncdhw && desc.dst_desc.format == ncdhw);
215 Reg64 param = abi_param1;
217 const int block_size = isa == avx512_common ? 16 : 8;
218 const int main_loop_step = (isFlat || desc.src_desc.format == nc) ? block_size : 1;
222 mov(reg_from, ptr[param + GET_OFF(from)]);
223 mov(reg_to, ptr[param + GET_OFF(to)]);
224 mov(reg_scale, ptr[param + GET_OFF(weights)]);
225 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
227 mov(reg_shift, ptr[param + GET_OFF(bias)]);
229 Label main_loop_label;
230 Label tail_loop_label;
231 Label tail_loop_flat_label;
234 int repeats = isa == sse42 ? 2 : 1;
235 for (int i = 0; i < repeats; i++) {
237 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
239 uni_vbroadcastss(get_shift_reg(i), ptr[reg_shift]);
241 uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
243 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
245 uni_vmovups(get_shift_reg(i), ptr[reg_shift + i*4*sizeof(float)]);
247 uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
252 uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
254 uni_vbroadcastss(xmm_shift, ptr[reg_shift]);
256 uni_vpxor(xmm_shift, xmm_shift, xmm_shift);
259 L(main_loop_label); {
260 cmp(reg_work_amount, main_loop_step-1);
261 jle(isFlat ? tail_loop_flat_label : tail_loop_label, T_NEAR);
263 int repeats = isa == sse42 ? 2 : 1;
264 for (int i = 0; i < repeats; i++) {
265 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
266 uni_vmovups(vmm_dst, get_shift_reg(i));
267 uni_vfmadd231ps(vmm_dst, vmm_src, get_scale_reg(i));
268 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
271 add(reg_from, block_size*sizeof(float));
272 add(reg_to, block_size*sizeof(float));
273 sub(reg_work_amount, main_loop_step);
275 jmp(main_loop_label, T_NEAR);
278 L(tail_loop_label); {
279 cmp(reg_work_amount, 0);
280 jle(exit_label, T_NEAR);
282 movss(xmm_src, ptr[reg_from]);
283 movss(xmm_shift, ptr[reg_shift]);
284 movss(xmm_scale, ptr[reg_scale]);
285 uni_vmovups(xmm_dst, xmm_shift);
286 uni_vfmadd231ps(xmm_dst, xmm_src, xmm_scale);
287 movss(ptr[reg_to], xmm_dst);
289 add(reg_from, 1*sizeof(float));
290 add(reg_to, 1*sizeof(float));
291 add(reg_shift, 1*sizeof(float));
292 add(reg_scale, 1*sizeof(float));
293 dec(reg_work_amount);
295 jmp(tail_loop_label, T_NEAR);
298 L(tail_loop_flat_label); {
299 cmp(reg_work_amount, 0);
300 jle(exit_label, T_NEAR);
302 movss(xmm_src, ptr[reg_from]);
303 uni_vmovups(xmm_dst, xmm_shift);
304 uni_vfmadd231ps(xmm_dst, xmm_src, xmm_scale);
305 movss(ptr[reg_to], xmm_dst);
307 add(reg_from, 1*sizeof(float));
308 add(reg_to, 1*sizeof(float));
309 dec(reg_work_amount);
311 jmp(tail_loop_flat_label, T_NEAR);
318 ker_ = (decltype(ker_))this->getCode();
322 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
323 isa == avx2, Ymm, Zmm>::type;
325 inline Vmm get_scale_reg(int idx) { return Vmm(idx + 2); }
326 inline Vmm get_shift_reg(int idx) { return Vmm(idx + 4); }
330 Reg64 reg_work_amount = r10;
331 Reg64 reg_scale = r11;
332 Reg64 reg_shift = r12;
334 Vmm vmm_src = Vmm(0);
335 Vmm vmm_dst = Vmm(1);
337 Xmm xmm_src = Xmm(0);
338 Xmm xmm_dst = Xmm(1);
339 Xmm xmm_scale = Xmm(6);
340 Xmm xmm_shift = Xmm(7);
343 template <cpu_isa_t isa>
344 struct jit_uni_prelu_kernel_f32 : public jit_uni_depthwise_kernel_f32,
347 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_prelu_kernel_f32)
348 jit_uni_prelu_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
349 : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
350 assert(desc.alg_kind == alg_kind::depthwise_prelu);
351 assert(isa == sse42 || isa == avx2 || isa == avx512_common);
353 bool isFlat = (desc.src_desc.format == nchw && desc.dst_desc.format == nchw) ||
354 (desc.src_desc.format == ncdhw && desc.dst_desc.format == ncdhw);
356 Reg64 param = abi_param1;
358 const int block_size = isa == avx512_common ? 16 : 8;
359 const int main_loop_step = (isFlat || desc.src_desc.format == nc) ? block_size : 1;
363 mov(reg_from, ptr[param + GET_OFF(from)]);
364 mov(reg_to, ptr[param + GET_OFF(to)]);
365 mov(reg_scale, ptr[param + GET_OFF(weights)]);
366 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
368 uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
370 int repeats = isa == sse42 ? 2 : 1;
371 for (int i = 0; i < repeats; i++) {
373 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
375 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
380 uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
383 Label main_loop_label;
384 Label tail_loop_label;
385 Label tail_loop_flat_label;
388 L(main_loop_label); {
389 cmp(reg_work_amount, main_loop_step-1);
390 jle(isFlat ? tail_loop_flat_label :tail_loop_label, T_NEAR);
392 for (int i = 0; i < repeats; i++) {
393 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
396 pxor(vmm_mask, vmm_mask);
397 cmpps(vmm_mask, vmm_src, _cmp_gt_os);
398 movups(vmm_dst, vmm_src);
399 mulps(vmm_src, get_scale_reg(i));
400 blendvps(vmm_dst, vmm_src);
401 } else if (isa == avx2) {
402 vcmpgtps(vmm_mask, vmm_src, vmm_zero);
403 vmulps(vmm_dst, vmm_src, get_scale_reg(i));
404 vblendvps(vmm_dst, vmm_dst, vmm_src, vmm_mask);
405 } else if (isa == avx512_common) {
406 Opmask kmask = Opmask(7);
407 vmovups(vmm_dst, vmm_src);
408 vcmpps(kmask, vmm_src, vmm_zero, _cmp_lt_os);
409 vmulps(vmm_dst | kmask, vmm_src, get_scale_reg(i));
412 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
415 add(reg_from, block_size*sizeof(float));
416 add(reg_to, block_size*sizeof(float));
417 sub(reg_work_amount, main_loop_step);
419 jmp(main_loop_label, T_NEAR);
422 L(tail_loop_label); {
423 cmp(reg_work_amount, 0);
424 jle(exit_label, T_NEAR);
426 movss(xmm_src, ptr[reg_from]);
427 movss(xmm_scale, ptr[reg_scale]);
429 pxor(xmm_mask, xmm_mask);
430 cmpps(xmm_mask, xmm_src, _cmp_gt_os);
431 movups(xmm_dst, xmm_src);
432 mulps(xmm_src, xmm_scale);
433 blendvps(xmm_dst, xmm_src);
435 movss(ptr[reg_to], xmm_dst);
437 add(reg_from, 1*sizeof(float));
438 add(reg_to, 1*sizeof(float));
439 add(reg_scale, 1*sizeof(float));
440 dec(reg_work_amount);
442 jmp(tail_loop_label, T_NEAR);
445 L(tail_loop_flat_label); {
446 cmp(reg_work_amount, 0);
447 jle(exit_label, T_NEAR);
449 movss(xmm_src, ptr[reg_from]);
451 pxor(xmm_mask, xmm_mask);
452 cmpps(xmm_mask, xmm_src, _cmp_gt_os);
453 movups(xmm_dst, xmm_src);
454 mulps(xmm_src, xmm_scale);
455 blendvps(xmm_dst, xmm_src);
457 movss(ptr[reg_to], xmm_dst);
459 add(reg_from, 1*sizeof(float));
460 add(reg_to, 1*sizeof(float));
461 dec(reg_work_amount);
463 jmp(tail_loop_flat_label, T_NEAR);
470 ker_ = (decltype(ker_))this->getCode();
474 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
475 isa == avx2, Ymm, Zmm>::type;
477 inline Vmm get_scale_reg(int idx) { return Vmm(idx + 4); }
481 Reg64 reg_work_amount = r10;
482 Reg64 reg_scale = r11;
484 Vmm vmm_mask = Vmm(0);
485 Vmm vmm_src = Vmm(1);
486 Vmm vmm_zero = Vmm(2);
487 Vmm vmm_dst = Vmm(3);
489 Xmm xmm_mask = Xmm(0);
490 Xmm xmm_src = Xmm(1);
491 Xmm xmm_dst = Xmm(3);
492 Xmm xmm_scale = Xmm(4);
494 const unsigned char _cmp_gt_os = 6;
495 const unsigned char _cmp_lt_os = 1;
500 template <cpu_isa_t isa>
501 status_t jit_uni_depthwise_fwd_t<isa>::pd_t::init() {
502 using namespace alg_kind;
504 memory_format_t desired_blk_fmt, desired_pln_fmt;
505 if (desc()->src_desc.ndims == 5) {
506 desired_blk_fmt = isa == avx512_common ? nCdhw16c : nCdhw8c;
507 desired_pln_fmt = ncdhw;
508 } else if (desc()->src_desc.ndims == 4) {
509 desired_blk_fmt = isa == avx512_common ? nChw16c : nChw8c;
510 desired_pln_fmt = nchw;
512 desired_blk_fmt = nc;
513 desired_pln_fmt = nc;
516 assert(engine()->kind() == engine_kind::cpu);
517 bool ok = true && mayiuse(isa)
518 && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
519 prop_kind::forward_inference)
520 && utils::everyone_is(data_type::f32, desc()->src_desc.data_type, desc()->dst_desc.data_type)
521 && desc()->src_desc.format == desc()->dst_desc.format
522 && utils::one_of(desc()->src_desc.format, desired_blk_fmt, desired_pln_fmt)
523 && utils::one_of(desc()->dst_desc.format, desired_blk_fmt, desired_pln_fmt)
524 && utils::one_of(desc()->weights_desc.format, x)
525 && IMPLICATION(this->with_bias(), x == desc()->bias_desc.format)
526 && attr()->has_default_values();
528 return ok ? status::success : status::unimplemented;
531 template <cpu_isa_t isa>
532 jit_uni_depthwise_fwd_t<isa>::jit_uni_depthwise_fwd_t(const pd_t *apd,
533 const input_vector &inputs, const output_vector &outputs)
534 : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr),
535 padded_weights_(nullptr), padded_bias_(nullptr) {
536 const auto &desc = *pd()->desc();
537 switch (desc.alg_kind) {
538 case alg_kind::depthwise_scale_shift:
539 kernel_ = new jit_uni_scale_shift_kernel_f32<isa>(desc, pd()->with_bias()); break;
540 case alg_kind::depthwise_prelu:
541 kernel_ = new jit_uni_prelu_kernel_f32<isa>(desc, pd()->with_bias()); break;
542 default: assert(!"unknown depthwise alg_kind");
545 const int simd_w = isa == avx512_common ? 16 : 8;
546 const memory_desc_wrapper data_d(pd()->src_pd());
547 const int c_without_padding = data_d.dims()[1];
548 const int c_padded = rnd_up(c_without_padding, simd_w);
550 if (pd()->want_padded_weights()) {
551 padded_weights_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
552 for (int oc = c_without_padding; oc < c_padded; ++oc)
553 padded_weights_[oc] = 0;
555 if (pd()->with_bias()) {
556 padded_bias_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
557 for (int oc = c_without_padding; oc < c_padded; ++oc)
558 padded_bias_[oc] = 0;
563 template <cpu_isa_t isa>
564 jit_uni_depthwise_fwd_t<isa>::~jit_uni_depthwise_fwd_t() {
566 free(padded_weights_);
570 template <cpu_isa_t isa>
571 void jit_uni_depthwise_fwd_t<isa>::execute_forward() const {
572 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
573 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
574 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
575 auto dst = reinterpret_cast<data_t *>(this->memory());
577 const memory_desc_wrapper data_d(pd()->src_pd());
578 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
579 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
581 const int MB = pd()->MB();
582 const int C = pd()->C();
583 const int D = pd()->D();
584 const int H = pd()->H();
585 const int W = pd()->W();
587 const int simd_w = isa == avx512_common ? 16 : 8;
588 const int ch_block_size = (data_d.format() == nchw) || (data_d.format() == ncdhw) ? 1 : simd_w;
589 const int CB = div_up(C, ch_block_size);
591 if (pd()->want_padded_weights()) {
592 for (int oc = 0; oc < C; ++oc)
593 padded_weights_[oc] = weights[oc];
594 weights = padded_weights_;
596 if (pd()->with_bias()) {
597 for (int oc = 0; oc < C; ++oc)
598 padded_bias_[oc] = bias[oc];
603 parallel_nd(MB, CB, D, H,
604 [&](int mb, int cb, int d, int h) {
605 auto arg = jit_args();
607 size_t data_off = data_d.ndims() == 4
608 ? data_d.blk_off(mb, cb, h)
609 : data_d.ndims() == 5
610 ? data_d.blk_off(mb, cb, d, h)
611 : data_d.blk_off(mb, cb * ch_block_size);
613 arg.from = &src[data_off];
614 arg.to = &dst[data_off];
615 arg.weights = &weights[weights_d.blk_off(cb * ch_block_size)];
617 arg.bias = &bias[bias_d.blk_off(cb * ch_block_size)];
618 arg.work_amount = data_d.format() == nc ? nstl::min(ch_block_size, C - cb * ch_block_size) : (size_t)W;
624 template struct jit_uni_depthwise_fwd_t<sse42>;
625 template struct jit_uni_depthwise_fwd_t<avx2>;
626 template struct jit_uni_depthwise_fwd_t<avx512_common>;
629 #define GET_OFF_DW(field) offsetof(jit_conv_call_s, field)
631 template <cpu_isa_t isa>
632 void jit_uni_dw_conv_row_f32<isa>::load_src(int ur_w) {
633 int repeats = isa == sse42 ? 2 : 1;
634 for (int i = 0; i < repeats; i++) {
635 for (int ow = 0; ow < ur_w; ow++) {
636 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
638 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
643 template <cpu_isa_t isa>
644 void jit_uni_dw_conv_row_f32<isa>::apply_filter(int ur_w, int kw_size) {
645 auto load_src = [=](Vmm vmm_src, const Xbyak::Address &op) {
646 if (jcp.src_dt == data_type::u8) {
647 uni_vpmovzxbd(vmm_src, op);
649 uni_vmovups(vmm_src, op);
653 auto load_ker = [=](Vmm vmm_ker, const Xbyak::Address &op) {
654 if (jcp.src_dt == data_type::u8) {
655 uni_vpmovsxbd(vmm_ker, op);
657 uni_vmovups(vmm_ker, op);
661 auto compute = [=](Vmm vmm_acc, Vmm vmm_src, Vmm vmm_ker) {
662 if (jcp.src_dt == data_type::u8) {
663 uni_vpmulld(vmm_src, vmm_src, vmm_ker);
664 uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
666 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
670 int ch_blk = jcp.ch_block;
671 int stride_w = jcp.stride_w;
675 int repeats = isa == sse42 ? 2 : 1;
678 jl(exit_label, T_NEAR);
679 for (int i = 0; i < repeats; i++) {
680 for (int kw = 0; kw < kw_size; kw++) {
681 int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
683 Vmm vmm_ker = get_ker_reg(0);
684 load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
686 for (int ow = 0; ow < ur_w; ow++) {
687 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
689 Vmm vmm_src = get_src_reg(0);
690 load_src(vmm_src, ptr[aux_reg_input0 + inp_off * jcp.typesize_in]);
692 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
693 compute(vmm_acc, vmm_src, vmm_ker);
697 add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
700 jl(exit_label, T_NEAR);
701 for (int i = 0; i < repeats; i++) {
702 for (int kw = 0; kw < kw_size; kw++) {
703 int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
705 Vmm vmm_ker = get_ker_reg(0);
706 load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
708 for (int ow = 0; ow < ur_w; ow++) {
709 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
711 Vmm vmm_src = get_src_reg(0);
712 load_src(vmm_src, ptr[aux_reg_input1 + inp_off * jcp.typesize_in]);
714 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
715 compute(vmm_acc, vmm_src, vmm_ker);
719 add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
722 jl(exit_label, T_NEAR);
723 for (int i = 0; i < repeats; i++) {
724 for (int kw = 0; kw < kw_size; kw++) {
725 int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
727 Vmm vmm_ker = get_ker_reg(0);
728 load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
730 for (int ow = 0; ow < ur_w; ow++) {
731 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
733 Vmm vmm_src = get_src_reg(0);
734 load_src(vmm_src, ptr[aux_reg_input2 + inp_off * jcp.typesize_in]);
736 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
737 compute(vmm_acc, vmm_src, vmm_ker);
745 template <cpu_isa_t isa>
746 void jit_uni_dw_conv_row_f32<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in, const Operand &op, bool scalar_load) {
747 Xmm xmm_in = Xmm(vmm_in.getIdx());
754 movq(xmm_in, reg_tmp_64);
756 uni_vmovups(vmm_in, op);
761 movsx(reg_tmp_32, op);
762 movq(xmm_in, reg_tmp_64);
764 uni_vpmovsxbd(vmm_in, op);
769 movzx(reg_tmp_32, op);
770 movq(xmm_in, reg_tmp_64);
772 uni_vpmovzxbd(vmm_in, op);
775 default: assert(!"unsupported data type");
778 if (type_in != data_type::f32)
779 uni_vcvtdq2ps(vmm_in, vmm_in);
782 template <cpu_isa_t isa>
783 void jit_uni_dw_conv_row_f32<isa>::apply_postprocessing(int ur_w, int oc_step) {
784 int repeats = isa == sse42 ? 2 : 1;
786 for (int r = 0; r < repeats; r++) {
787 for (int ow = 0; ow < ur_w; ow++) {
788 if (jcp.src_dt == data_type::u8) {
789 uni_vcvtdq2ps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow));
793 int b_off = r * (jcp.ch_block / 2);
794 cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias + b_off * jcp.typesize_bia], false);
795 uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_bias);
801 for (int r = 0; r < repeats; r++) {
802 int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - r * jcp.ch_block / 2) : oc_step;
803 bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
805 for (int ow = 0; ow < ur_w; ow++) {
806 if (is_scalar_store) {
807 if (isa == avx512_common) {
808 int o_off = ow * ow_stride_;
810 Vmm vmm_in = vmm_sum | ktail_mask | T_z;
812 cvt2ps(jcp.dst_dt, vmm_in, ptr[reg_output + o_off * jcp.typesize_out], false);
813 uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
815 for (int oc = 0; oc < tail_size; oc++) {
816 int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2) + oc;
818 uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
819 cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
821 if (oc >= jcp.ch_block / 2) {
822 vperm2i128(Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), 0x01);
824 uni_vpslldq(vmm_sum, vmm_sum, jcp.typesize_out * (oc % (jcp.ch_block / 2)));
826 uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
830 int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2);
832 uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
833 cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], false);
835 uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
841 const auto &p = attr_.post_ops_;
842 int eltwise_inj_idx = 0;
843 int depthwise_inj_idx = 0;
844 int start_idx = p.find(primitive_kind::convolution) + 1;
845 for (int i = start_idx; i < p.len_; i++) {
846 auto& post_op = p.entry_[i];
847 if (post_op.is_eltwise()) {
848 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(4, 4 + repeats * ur_w);
850 } else if (post_op.is_depthwise()) {
851 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
852 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
854 add(reg_d_weights, reg_oc_off);
855 add(reg_d_bias, reg_oc_off);
857 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4, 4 + ur_w, reg_d_weights, reg_d_bias);
860 add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
861 add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
863 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4 + ur_w, 4 + 2 * ur_w, reg_d_weights, reg_d_bias);
871 template <cpu_isa_t isa>
872 void jit_uni_dw_conv_row_f32<isa>::store_dst_typed(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
873 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
874 Xmm xmm_dst = Xmm(vmm_dst.getIdx());
876 switch (jcp.dst_dt) {
880 movq(reg_tmp_64, xmm_dst);
883 uni_vmovups(op, vmm_dst);
887 uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
889 if (isa != sse42 && !scalar_store)
890 vpermq(ymm_dst, ymm_dst, 0x08);
892 uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
895 movq(reg_tmp_64, xmm_dst);
906 uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
908 if (isa != sse42 && !scalar_store)
909 vpermq(ymm_dst, ymm_dst, 0x08);
911 uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
914 movq(reg_tmp_64, xmm_dst);
924 assert(!"unknown dst_dt");
928 template <cpu_isa_t isa>
929 void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w, int oc_step) {
931 int repeats = isa == sse42 && oc_step > (jcp.ch_block / 2) ? 2 : 1;
933 if (isa == avx512_common && oc_step != jcp.ch_block) {
934 int mask = (1 << oc_step) - 1;
935 mov(reg_tmp_32, mask);
936 kmovw(ktail_mask, reg_tmp_32);
939 for (int i = 0; i < repeats; i++) {
940 for (int ow = 0; ow < ur_w; ow++) {
941 Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
942 if (jcp.dst_dt != data_type::f32 && jcp.dst_dt != data_type::bin) {
943 if (attr_.round_mode_ == round_mode::nearest)
944 uni_vcvtps2dq(vmm_dst, vmm_dst);
945 else if (attr_.round_mode_ == round_mode::down) {
946 uni_vroundps(vmm_dst, vmm_dst, 1);
947 uni_vcvtps2dq(vmm_dst, vmm_dst);
949 assert(!"unimplemented");
954 if (jcp.with_binarization) {
955 int output_step = div_up(ow_stride_, nbits);
957 const auto &p = attr_.post_ops_;
958 int binarization_idx = p.find(primitive_kind::binarization);
962 mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
963 mov(reg_b_out_mask, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.output_mask_data));
964 add(reg_b_weights, reg_oc_off);
965 add(reg_b_out_mask, reg_oc_off);
967 for (int ow = 0; ow < ur_w; ow++) {
968 for (int i = 0; i < repeats; i++) {
969 int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
970 mov(reg_b_mask, (1 << tail_size) - 1);
971 uni_vmovups(vmm_thr, ptr[reg_b_weights + i * (jcp.ch_block / 2) * sizeof(float)]);
972 uni_vmovups(vmm_out_mask, ptr[reg_b_out_mask + i * (jcp.ch_block / 2) * sizeof(float)]);
974 Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
976 if (isa == avx512_common) {
977 vcmpps(bin_mask0, vmm_dst, vmm_thr, _cmp_gt_os);
978 vptestmd(bin_mask1, vmm_out_mask, vmm_out_mask);
979 kxnorw(bin_mask0, bin_mask0, bin_mask1);
981 uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
982 uni_vpcmpeqd(vmm_dst, vmm_dst, vmm_out_mask);
986 if (isa == avx512_common) {
987 kmovw(reg_tmp_32, bin_mask0);
989 uni_vmovmskps(reg_tmp_32, vmm_dst);
991 and_(reg_tmp_64, reg_b_mask);
993 uni_vmovmskps(reg_tmp2_32, vmm_dst);
994 and_(reg_tmp2_64, reg_b_mask);
996 or_(reg_tmp_32, reg_tmp2_32);
999 if (i == repeats - 1) {
1000 const size_t o_off = ow * output_step;
1001 if (isa == avx512_common && oc_step > nbits) {
1002 mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_16);
1004 mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
1012 for (int i = 0; i < repeats; i++) {
1013 int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
1014 bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
1015 if (is_scalar_store) {
1016 for (int ow = 0; ow < ur_w; ow++) {
1017 Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
1019 if (isa == avx512_common) {
1020 int o_off = ow * ow_stride_;
1022 store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst | ktail_mask, false);
1024 for (int oc = 0; oc < tail_size; oc++) {
1025 int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2) + oc;
1026 store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
1029 psrldq(vmm_dst, jcp.typesize_out);
1031 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
1033 vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
1034 vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
1040 for (int ow = 0; ow < ur_w; ow++) {
1041 int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2);
1042 Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
1044 store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
1051 template <cpu_isa_t isa>
1052 void jit_uni_dw_conv_row_f32<isa>::loop_body(int oc_step) {
1053 Label left_pad_label;
1054 Label right_pad_label;
1055 Label unrolled_w_label;
1059 int output_step = jcp.with_binarization ? div_up(ow_stride_, 8) : ow_stride_;
1061 L(left_pad_label); {
1063 int kw = jcp.iw == 1 ? jcp.kw - 2 : jcp.kw - 1;
1065 mov(aux_reg_input0, reg_input0);
1066 mov(aux_reg_input1, reg_input1);
1067 mov(aux_reg_input2, reg_input2);
1068 mov(aux_reg_kernel, reg_kernel);
1069 add(aux_reg_kernel, jcp.ch_block*jcp.typesize_in);
1072 apply_filter(ur_w, kw);
1073 apply_postprocessing(ur_w, oc_step);
1074 store_dst(ur_w, oc_step);
1076 add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
1077 add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
1078 add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
1079 add(reg_output, jcp.typesize_out * ur_w * output_step);
1081 sub(reg_ur_w, ur_w);
1084 L(unrolled_w_label); {
1085 int ur_w = jcp.ur_w;
1088 cmp(reg_ur_w, ur_w);
1089 jle(tail_w_label, T_NEAR);
1091 mov(aux_reg_input0, reg_input0);
1092 mov(aux_reg_input1, reg_input1);
1093 mov(aux_reg_input2, reg_input2);
1094 mov(aux_reg_kernel, reg_kernel);
1097 apply_filter(ur_w, kw);
1098 apply_postprocessing(ur_w, oc_step);
1099 store_dst(ur_w, oc_step);
1101 add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1102 add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1103 add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1104 add(reg_output, jcp.typesize_out * ur_w * output_step);
1106 sub(reg_ur_w, ur_w);
1107 jmp(unrolled_w_label, T_NEAR);
1114 cmp(reg_ur_w, ur_w);
1116 jle(right_pad_label, T_NEAR);
1118 jle(exit_label, T_NEAR);
1120 mov(aux_reg_input0, reg_input0);
1121 mov(aux_reg_input1, reg_input1);
1122 mov(aux_reg_input2, reg_input2);
1123 mov(aux_reg_kernel, reg_kernel);
1126 apply_filter(ur_w, kw);
1127 apply_postprocessing(ur_w, oc_step);
1128 store_dst(ur_w, oc_step);
1130 add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1131 add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1132 add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1133 add(reg_output, jcp.typesize_out * ur_w * output_step);
1135 sub(reg_ur_w, ur_w);
1136 jmp(tail_w_label, T_NEAR);
1140 L(right_pad_label); {
1142 int kw = jcp.kw - ((jcp.stride_w == 1) ? 1 : jcp.iw % jcp.stride_w);
1144 mov(aux_reg_input0, reg_input0);
1145 mov(aux_reg_input1, reg_input1);
1146 mov(aux_reg_input2, reg_input2);
1147 mov(aux_reg_kernel, reg_kernel);
1150 apply_filter(ur_w, kw);
1151 apply_postprocessing(ur_w, oc_step);
1152 store_dst(ur_w, oc_step);
1154 sub(reg_ur_w, ur_w);
1161 template <cpu_isa_t isa>
1162 void jit_uni_dw_conv_row_f32<isa>::generate() {
1163 const auto &p = attr_.post_ops_;
1164 int start_idx = p.find(primitive_kind::convolution) + 1;
1165 for (int i = start_idx; i < p.len_; i++) {
1166 auto &post_op = p.entry_[i];
1167 if (post_op.is_eltwise()) {
1168 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
1170 post_op.eltwise.alg,
1171 post_op.eltwise.alpha,
1172 post_op.eltwise.beta
1174 } else if (post_op.is_depthwise()) {
1175 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
1177 post_op.depthwise.alg
1184 mov(reg_input0, ptr[this->param1 + GET_OFF_DW(src_row0)]);
1185 mov(reg_input1, ptr[this->param1 + GET_OFF_DW(src_row1)]);
1186 mov(reg_input2, ptr[this->param1 + GET_OFF_DW(src_row2)]);
1187 mov(reg_output, ptr[this->param1 + GET_OFF_DW(dst)]);
1188 mov(reg_kernel, ptr[this->param1 + GET_OFF_DW(filt)]);
1190 mov(reg_bias, ptr[this->param1 + GET_OFF_DW(bias)]);
1191 mov(reg_kh, ptr[this->param1 + GET_OFF_DW(kh_padding)]);
1192 mov(reg_ur_w, ptr[this->param1 + GET_OFF_DW(ur_w)]);
1193 mov(reg_oc_work, ptr[this->param1 + GET_OFF_DW(oc_work)]);
1194 mov(reg_oc_off, ptr[this->param1 + GET_OFF_DW(oc_off)]);
1199 cmp(reg_oc_work, jcp.ch_block);
1200 jl(tail_label, T_NEAR);
1202 loop_body(jcp.ch_block);
1203 jmp(exit_label, T_NEAR);
1207 if (jcp.oc % jcp.ch_block != 0)
1208 loop_body(jcp.oc % jcp.ch_block);
1214 for (auto& inj : eltwise_injectors)
1215 inj->prepare_table();
1218 template <cpu_isa_t isa>
1219 bool jit_uni_dw_conv_row_f32<isa>::post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1220 const auto &p = attr.post_ops_;
1222 int start_idx = p.find(primitive_kind::convolution) + 1;
1224 auto all_post_ops_supported = [&]() {
1227 for (int i = start_idx; i < p.len_; i++) {
1228 ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise,
1229 primitive_kind::binarization);
1233 auto contain = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind, start_idx, -1) != -1; };
1234 auto position = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind, start_idx, -1); };
1235 auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind, start_idx, -1); };
1237 return all_post_ops_supported() &&
1238 count(primitive_kind::sum) <= 1 &&
1239 count(primitive_kind::binarization) <= 1 &&
1240 IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == start_idx) &&
1241 IMPLICATION(contain(primitive_kind::binarization), position(primitive_kind::binarization) == p.len_-1) &&
1242 IMPLICATION(contain(primitive_kind::binarization), !contain(primitive_kind::sum));
1245 template <cpu_isa_t isa>
1246 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_1x1_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1247 const primitive_attr_t &attr) {
1248 if (!mayiuse(isa)) return status::unimplemented;
1249 const int simd_w = isa == avx512_common ? 16 : 8;
1251 const auto &p = attr.post_ops_;
1253 int dw_conv_ind = p.find(primitive_kind::convolution);
1254 jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1256 jcp_dw.ch_block = simd_w;
1257 jcp_dw.with_bias = true;
1259 jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1260 jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1263 jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1264 jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1265 jcp_dw.oh = jcp.dw_conv_oh;
1266 jcp_dw.ow = jcp.dw_conv_ow;
1267 jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1268 jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1269 jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1270 jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1272 if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1273 return status::unimplemented;
1275 if (!post_ops_ok(jcp_dw, attr))
1276 return status::unimplemented;
1280 jcp_dw.src_dt = jcp.dst_dt;
1281 jcp_dw.dst_dt = jcp.dw_conv_dst_dt;
1282 jcp_dw.bia_dt = jcp.bia_dt;
1283 jcp_dw.typesize_in = (int)types::data_type_size(jcp_dw.src_dt);
1284 jcp_dw.typesize_bia = (int)types::data_type_size(jcp_dw.bia_dt);
1285 jcp_dw.typesize_out = (int)types::data_type_size(jcp_dw.dst_dt);
1287 if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1288 return status::unimplemented;
1290 return status::success;
1293 template <cpu_isa_t isa>
1294 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1295 const primitive_attr_t &attr) {
1296 if (!mayiuse(isa)) return status::unimplemented;
1297 const int simd_w = isa == avx512_common ? 16 : 8;
1299 const auto &p = attr.post_ops_;
1301 int dw_conv_ind = p.find(primitive_kind::convolution);
1302 jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1304 jcp_dw.ch_block = simd_w;
1305 jcp_dw.with_bias = true;
1307 jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1308 jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1311 jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1312 jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1313 jcp_dw.oh = jcp.dw_conv_oh;
1314 jcp_dw.ow = jcp.dw_conv_ow;
1315 jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1316 jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1317 jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1318 jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1320 if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1321 return status::unimplemented;
1323 if (!post_ops_ok(jcp_dw, attr))
1324 return status::unimplemented;
1328 jcp_dw.src_dt = jcp.dst_dt;
1329 jcp_dw.dst_dt = jcp.dw_conv_dst_dt;
1330 jcp_dw.bia_dt = jcp.bia_dt;
1331 jcp_dw.typesize_in = (int)types::data_type_size(jcp_dw.src_dt);
1332 jcp_dw.typesize_bia = (int)types::data_type_size(jcp_dw.bia_dt);
1333 jcp_dw.typesize_out = (int)types::data_type_size(jcp_dw.dst_dt);
1335 if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1336 return status::unimplemented;
1338 return status::success;
1341 template <cpu_isa_t isa>
1342 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_bin_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1343 const primitive_attr_t &attr) {
1344 if (!mayiuse(isa)) return status::unimplemented;
1345 const int simd_w = isa == avx512_common ? 16 : 8;
1347 const auto &p = attr.post_ops_;
1349 int dw_conv_ind = p.find(primitive_kind::convolution);
1350 jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1351 jcp_dw.with_binarization = p.find(primitive_kind::binarization, dw_conv_ind) != -1;
1353 jcp_dw.ch_block = simd_w;
1354 jcp_dw.with_bias = true;
1356 jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1357 jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1360 jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1361 jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1362 jcp_dw.oh = jcp.dw_conv_oh;
1363 jcp_dw.ow = jcp.dw_conv_ow;
1364 jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1365 jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1366 jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1367 jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1369 if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1370 return status::unimplemented;
1372 if (!post_ops_ok(jcp_dw, attr))
1373 return status::unimplemented;
1377 jcp_dw.src_dt = mkldnn_f32;
1378 jcp_dw.dst_dt = jcp_dw.with_binarization ? mkldnn_bin : mkldnn_f32;
1379 jcp_dw.bia_dt = mkldnn_f32;
1380 jcp_dw.typesize_in = (int)types::data_type_size(jcp_dw.src_dt);
1381 jcp_dw.typesize_bia = (int)types::data_type_size(jcp_dw.bia_dt);
1382 jcp_dw.typesize_out = (int)types::data_type_size(jcp_dw.dst_dt);
1384 if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1385 return status::unimplemented;
1387 return status::success;
1390 template struct jit_uni_dw_conv_row_f32<avx512_common>;
1391 template struct jit_uni_dw_conv_row_f32<avx2>;
1392 template struct jit_uni_dw_conv_row_f32<sse42>;