1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <mkldnn_types.h>
18 #include "mkldnn_types.h"
19 #include "mkldnn_thread.hpp"
22 #include "jit_generator.hpp"
24 #include "jit_uni_depthwise.hpp"
26 #define GET_OFF(field) offsetof(jit_args, field)
32 using namespace Xbyak;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
44 struct jit_uni_depthwise_kernel_f32 : public c_compatible {
45 const depthwise_desc_t &desc_;
46 void (*ker_)(const jit_args *);
49 void operator()(const jit_args *args) { assert(ker_); ker_(args); }
51 jit_uni_depthwise_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
52 : desc_(desc), ker_(nullptr), with_bias_(with_bias) {}
53 virtual ~jit_uni_depthwise_kernel_f32() {}
56 template <cpu_isa_t isa>
57 int jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(alg_kind_t depthwise_alg) {
58 switch (depthwise_alg) {
59 case alg_kind::depthwise_scale_shift: return isa == sse42 ? 1 : 0;
60 case alg_kind::depthwise_prelu: return 2;
61 default: assert(!"unsupported depthwise algorithm");
67 template <cpu_isa_t isa>
68 void jit_uni_depthwise_injector_f32<isa>::injector_preamble(size_t start_idx, size_t end_idx) {
69 preserved_vecs_count = 0;
70 vecs_to_preserve = (size_t)jit_uni_depthwise_injector_f32<isa>::aux_vecs_count(depthwise_alg);
72 for (size_t i = 0; i < vecs_count; i++) {
73 if (preserved_vecs_count >= vecs_to_preserve)
76 if (i < start_idx || i >= end_idx) {
77 preserved_vec_idxs[preserved_vecs_count] = i;
78 preserved_vecs_count++;
82 start_idx_tail = start_idx;
83 size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
84 for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
85 preserved_vec_idxs[preserved_vecs_count] = start_idx + i;
86 preserved_vecs_count++;
87 start_idx_tail = start_idx + i + 1;
90 h->sub(h->rsp, preserved_vecs_count * vlen);
91 for (size_t i = 0; i < preserved_vecs_count; ++i)
92 h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[i]));
97 template <cpu_isa_t isa>
98 void jit_uni_depthwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx, size_t end_idx) {
99 size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
100 int idx_off = (vecs_to_preserve - tail_vecs_to_preserve);
102 if (tail_vecs_to_preserve > 0) {
103 h->add(h->rsp, idx_off * vlen);
104 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
105 h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), h->ptr[h->rsp + i * vlen]);
107 for (size_t i = 0; i < tail_vecs_to_preserve; ++i) {
108 preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
111 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
112 h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[idx_off + i]));
113 h->sub(h->rsp, idx_off * vlen);
119 template <cpu_isa_t isa>
120 void jit_uni_depthwise_injector_f32<isa>::injector_postamble() {
121 for (size_t i = 0; i < preserved_vecs_count; ++i)
122 h->uni_vmovups(Vmm(preserved_vec_idxs[i]), h->ptr[h->rsp + i * vlen]);
123 h->add(h->rsp, preserved_vecs_count * vlen);
126 template <cpu_isa_t isa>
127 void jit_uni_depthwise_injector_f32<isa>::assign_regs() {
128 vmm_mask = Vmm(preserved_vec_idxs[0]);
129 vmm_aux0 = Vmm(preserved_vec_idxs[1]);
132 template <cpu_isa_t isa>
133 void jit_uni_depthwise_injector_f32<isa>::scale_shift_compute_vector(const Vmm &vmm_src,
134 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
136 h->movups(vmm_mask, h->ptr[p_weights]);
137 h->mulps(vmm_src, vmm_mask);
138 h->movups(vmm_mask, h->ptr[p_bias]);
139 h->addps(vmm_src, vmm_mask);
141 h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_weights]);
142 h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_bias]);
146 template <cpu_isa_t isa>
147 void jit_uni_depthwise_injector_f32<isa>::prelu_compute_vector(const Vmm &vmm_src,
148 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
149 const unsigned char _cmp_gt_os = 6;
150 const unsigned char _cmp_lt_os = 1;
153 h->pxor(vmm_mask, vmm_mask);
154 h->cmpps(vmm_mask, vmm_src, _cmp_gt_os);
155 h->movups(vmm_aux0, h->ptr[p_weights]);
156 h->mulps(vmm_aux0, vmm_src);
157 h->blendvps(vmm_src, vmm_aux0);
158 } else if (isa == avx2) {
159 h->vxorps(vmm_mask, vmm_mask, vmm_mask);
160 h->vcmpgtps(vmm_mask, vmm_src, vmm_mask);
161 h->vmulps(vmm_aux0, vmm_src, h->ptr[p_weights]);
162 h->vblendvps(vmm_src, vmm_aux0, vmm_src, vmm_mask);
163 } else if (isa == avx512_common) {
164 h->vxorpd(vmm_mask, vmm_mask, vmm_mask);
165 h->vmovups(vmm_aux0, vmm_src);
166 h->vcmpps(k_mask, vmm_src, vmm_mask, _cmp_lt_os);
167 h->vmulps(vmm_src | k_mask, vmm_aux0, h->ptr[p_weights]);
171 template <cpu_isa_t isa>
172 void jit_uni_depthwise_injector_f32<isa>::compute_body(size_t start_idx, size_t end_idx,
173 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
174 for (size_t idx = start_idx; idx < end_idx; idx++) {
175 switch (depthwise_alg) {
176 case alg_kind::depthwise_scale_shift:
177 scale_shift_compute_vector(Vmm(idx), p_weights, p_bias); break;
178 case alg_kind::depthwise_prelu:
179 prelu_compute_vector(Vmm(idx), p_weights, p_bias); break;
180 default: assert(!"unsupported depthwise algorithm");
185 template <cpu_isa_t isa>
186 void jit_uni_depthwise_injector_f32<isa>::compute_vector_range(int start_idx, int end_idx,
187 const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias) {
188 injector_preamble(start_idx, end_idx);
189 compute_body(start_idx_tail, end_idx, p_weights, p_bias);
190 injector_preamble_tail(start_idx, end_idx);
191 compute_body(start_idx, start_idx_tail, p_weights, p_bias);
192 injector_postamble();
195 template struct jit_uni_depthwise_injector_f32<avx512_common>;
196 template struct jit_uni_depthwise_injector_f32<avx2>;
197 template struct jit_uni_depthwise_injector_f32<sse42>;
202 template <cpu_isa_t isa>
203 struct jit_uni_scale_shift_kernel_f32 : public jit_uni_depthwise_kernel_f32,
206 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_scale_shift_kernel_f32)
207 jit_uni_scale_shift_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
208 : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
209 assert(desc.alg_kind == alg_kind::depthwise_scale_shift);
210 assert(isa == sse42 || isa == avx2 || isa == avx512_common);
212 bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw;
214 Reg64 param = abi_param1;
216 const int block_size = isa == avx512_common ? 16 : 8;
217 const int main_loop_step = isFlat ? block_size : 1;
221 mov(reg_from, ptr[param + GET_OFF(from)]);
222 mov(reg_to, ptr[param + GET_OFF(to)]);
223 mov(reg_scale, ptr[param + GET_OFF(weights)]);
224 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
226 mov(reg_shift, ptr[param + GET_OFF(bias)]);
228 Label main_loop_label;
229 Label tail_loop_label;
232 int repeats = isa == sse42 ? 2 : 1;
233 for (int i = 0; i < repeats; i++) {
235 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
237 uni_vbroadcastss(get_shift_reg(i), ptr[reg_shift]);
239 uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
241 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
243 uni_vmovups(get_shift_reg(i), ptr[reg_shift + i*4*sizeof(float)]);
245 uni_vpxor(get_shift_reg(i), get_shift_reg(i), get_shift_reg(i));
250 uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
252 uni_vbroadcastss(xmm_shift, ptr[reg_shift]);
254 uni_vpxor(xmm_shift, xmm_shift, xmm_shift);
257 L(main_loop_label); {
258 cmp(reg_work_amount, main_loop_step-1);
259 jle(tail_loop_label, T_NEAR);
261 int repeats = isa == sse42 ? 2 : 1;
262 for (int i = 0; i < repeats; i++) {
263 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
264 uni_vmovups(vmm_dst, get_shift_reg(i));
265 uni_vfmadd231ps(vmm_dst, vmm_src, get_scale_reg(i));
266 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
269 add(reg_from, block_size*sizeof(float));
270 add(reg_to, block_size*sizeof(float));
271 sub(reg_work_amount, main_loop_step);
273 jmp(main_loop_label, T_NEAR);
276 L(tail_loop_label); {
277 cmp(reg_work_amount, 0);
278 jle(exit_label, T_NEAR);
280 movss(xmm_src, ptr[reg_from]);
281 uni_vmovups(xmm_dst, xmm_shift);
282 uni_vfmadd231ps(xmm_dst, xmm_src, xmm_scale);
283 movss(ptr[reg_to], xmm_dst);
285 add(reg_from, 1*sizeof(float));
286 add(reg_to, 1*sizeof(float));
287 dec(reg_work_amount);
289 jmp(tail_loop_label, T_NEAR);
296 ker_ = (decltype(ker_))this->getCode();
300 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
301 isa == avx2, Ymm, Zmm>::type;
303 inline Vmm get_scale_reg(int idx) { return Vmm(idx + 2); }
304 inline Vmm get_shift_reg(int idx) { return Vmm(idx + 4); }
308 Reg64 reg_work_amount = r10;
309 Reg64 reg_scale = r11;
310 Reg64 reg_shift = r12;
312 Vmm vmm_src = Vmm(0);
313 Vmm vmm_dst = Vmm(1);
315 Xmm xmm_src = Xmm(0);
316 Xmm xmm_dst = Xmm(1);
317 Xmm xmm_scale = Xmm(6);
318 Xmm xmm_shift = Xmm(7);
321 template <cpu_isa_t isa>
322 struct jit_uni_prelu_kernel_f32 : public jit_uni_depthwise_kernel_f32,
325 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_prelu_kernel_f32)
326 jit_uni_prelu_kernel_f32(const depthwise_desc_t &desc, bool with_bias)
327 : jit_uni_depthwise_kernel_f32(desc, with_bias), jit_generator() {
328 assert(desc.alg_kind == alg_kind::depthwise_prelu);
329 assert(isa == sse42 || isa == avx2 || isa == avx512_common);
331 bool isFlat = desc.src_desc.format == nchw && desc.dst_desc.format == nchw;
333 Reg64 param = abi_param1;
335 const int block_size = isa == avx512_common ? 16 : 8;
336 const int main_loop_step = isFlat ? block_size : 1;
340 mov(reg_from, ptr[param + GET_OFF(from)]);
341 mov(reg_to, ptr[param + GET_OFF(to)]);
342 mov(reg_scale, ptr[param + GET_OFF(weights)]);
343 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
345 uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
347 int repeats = isa == sse42 ? 2 : 1;
348 for (int i = 0; i < repeats; i++) {
350 uni_vbroadcastss(get_scale_reg(i), ptr[reg_scale]);
352 uni_vmovups(get_scale_reg(i), ptr[reg_scale + i*4*sizeof(float)]);
357 uni_vbroadcastss(xmm_scale, ptr[reg_scale]);
360 Label main_loop_label;
361 Label tail_loop_label;
364 L(main_loop_label); {
365 cmp(reg_work_amount, main_loop_step-1);
366 jle(tail_loop_label, T_NEAR);
368 for (int i = 0; i < repeats; i++) {
369 uni_vmovups(vmm_src, ptr[reg_from + i*4*sizeof(float)]);
372 pxor(vmm_mask, vmm_mask);
373 cmpps(vmm_mask, vmm_src, _cmp_gt_os);
374 movups(vmm_dst, vmm_src);
375 mulps(vmm_src, get_scale_reg(i));
376 blendvps(vmm_dst, vmm_src);
377 } else if (isa == avx2) {
378 vcmpgtps(vmm_mask, vmm_src, vmm_zero);
379 vmulps(vmm_dst, vmm_src, get_scale_reg(i));
380 vblendvps(vmm_dst, vmm_dst, vmm_src, vmm_mask);
381 } else if (isa == avx512_common) {
382 Opmask kmask = Opmask(7);
383 vmovups(vmm_dst, vmm_src);
384 vcmpps(kmask, vmm_src, vmm_zero, _cmp_lt_os);
385 vmulps(vmm_dst | kmask, vmm_src, get_scale_reg(i));
388 uni_vmovups(ptr[reg_to + i*4*sizeof(float)], vmm_dst);
391 add(reg_from, block_size*sizeof(float));
392 add(reg_to, block_size*sizeof(float));
393 sub(reg_work_amount, main_loop_step);
395 jmp(main_loop_label, T_NEAR);
398 L(tail_loop_label); {
399 cmp(reg_work_amount, 0);
400 jle(exit_label, T_NEAR);
402 movss(xmm_src, ptr[reg_from]);
404 pxor(xmm_mask, xmm_mask);
405 cmpps(xmm_mask, xmm_src, _cmp_gt_os);
406 movups(xmm_dst, xmm_src);
407 mulps(xmm_src, xmm_scale);
408 blendvps(xmm_dst, xmm_src);
410 movss(ptr[reg_to], xmm_dst);
412 add(reg_from, 1*sizeof(float));
413 add(reg_to, 1*sizeof(float));
414 dec(reg_work_amount);
416 jmp(tail_loop_label, T_NEAR);
423 ker_ = (decltype(ker_))this->getCode();
427 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
428 isa == avx2, Ymm, Zmm>::type;
430 inline Vmm get_scale_reg(int idx) { return Vmm(idx + 4); }
434 Reg64 reg_work_amount = r10;
435 Reg64 reg_scale = r11;
437 Vmm vmm_mask = Vmm(0);
438 Vmm vmm_src = Vmm(1);
439 Vmm vmm_zero = Vmm(2);
440 Vmm vmm_dst = Vmm(3);
442 Xmm xmm_mask = Xmm(0);
443 Xmm xmm_src = Xmm(1);
444 Xmm xmm_dst = Xmm(3);
445 Xmm xmm_scale = Xmm(4);
447 const unsigned char _cmp_gt_os = 6;
448 const unsigned char _cmp_lt_os = 1;
453 template <cpu_isa_t isa>
454 status_t jit_uni_depthwise_fwd_t<isa>::pd_t::init() {
455 using namespace alg_kind;
457 auto desired_blk_fmt = isa == avx512_common ? nChw16c : nChw8c;
459 assert(engine()->kind() == engine_kind::cpu);
460 bool ok = true && mayiuse(isa)
461 && utils::one_of(desc()->prop_kind, prop_kind::forward_training,
462 prop_kind::forward_inference)
463 && utils::everyone_is(data_type::f32, desc()->src_desc.data_type, desc()->dst_desc.data_type)
464 && desc()->src_desc.format == desc()->dst_desc.format
465 && utils::one_of(desc()->src_desc.format, desired_blk_fmt, nchw)
466 && utils::one_of(desc()->dst_desc.format, desired_blk_fmt, nchw)
467 && utils::one_of(desc()->weights_desc.format, x)
468 && IMPLICATION(this->with_bias(), x == desc()->bias_desc.format)
469 && attr()->has_default_values();
471 return ok ? status::success : status::unimplemented;
474 template <cpu_isa_t isa>
475 jit_uni_depthwise_fwd_t<isa>::jit_uni_depthwise_fwd_t(const pd_t *apd,
476 const input_vector &inputs, const output_vector &outputs)
477 : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr),
478 padded_weights_(nullptr), padded_bias_(nullptr) {
479 const auto &desc = *pd()->desc();
480 switch (desc.alg_kind) {
481 case alg_kind::depthwise_scale_shift:
482 kernel_ = new jit_uni_scale_shift_kernel_f32<isa>(desc, pd()->with_bias()); break;
483 case alg_kind::depthwise_prelu:
484 kernel_ = new jit_uni_prelu_kernel_f32<isa>(desc, pd()->with_bias()); break;
485 default: assert(!"unknown depthwise alg_kind");
488 const int simd_w = isa == avx512_common ? 16 : 8;
489 const memory_desc_wrapper data_d(pd()->src_pd());
490 const int c_without_padding = data_d.dims()[1];
491 const int c_padded = rnd_up(c_without_padding, simd_w);
493 if (pd()->want_padded_weights()) {
494 padded_weights_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
495 for (int oc = c_without_padding; oc < c_padded; ++oc)
496 padded_weights_[oc] = 0;
498 if (pd()->with_bias()) {
499 padded_bias_ = (data_t *)malloc(sizeof(data_t) * c_padded, 64);
500 for (int oc = c_without_padding; oc < c_padded; ++oc)
501 padded_bias_[oc] = 0;
506 template <cpu_isa_t isa>
507 jit_uni_depthwise_fwd_t<isa>::~jit_uni_depthwise_fwd_t() {
509 free(padded_weights_);
513 template <cpu_isa_t isa>
514 void jit_uni_depthwise_fwd_t<isa>::execute_forward() const {
515 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
516 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
517 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
518 auto dst = reinterpret_cast<data_t *>(this->memory());
520 const memory_desc_wrapper data_d(pd()->src_pd());
521 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
522 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
524 const int N = data_d.dims()[0];
525 const int C = data_d.dims()[1];
526 const int H = data_d.dims()[2];
527 const int W = data_d.dims()[3];
529 const int simd_w = isa == avx512_common ? 16 : 8;
530 const int ch_block_size = data_d.format() == nchw ? 1 : simd_w;
531 const int CB = div_up(C, ch_block_size);
533 if (pd()->want_padded_weights()) {
534 for (int oc = 0; oc < C; ++oc)
535 padded_weights_[oc] = weights[oc];
536 weights = padded_weights_;
538 if (pd()->with_bias()) {
539 for (int oc = 0; oc < C; ++oc)
540 padded_bias_[oc] = bias[oc];
545 parallel_nd(N, CB, H,
546 [&](int n, int cb, int h) {
547 auto arg = jit_args();
549 arg.from = &src[data_d.blk_off(n, cb, h)];
550 arg.to = &dst[data_d.blk_off(n, cb, h)];
551 arg.weights = &weights[weights_d.blk_off(cb * ch_block_size)];
553 arg.bias = &bias[bias_d.blk_off(cb * ch_block_size)];
554 arg.work_amount = (size_t)W;
560 template struct jit_uni_depthwise_fwd_t<sse42>;
561 template struct jit_uni_depthwise_fwd_t<avx2>;
562 template struct jit_uni_depthwise_fwd_t<avx512_common>;
565 #define GET_OFF_DW(field) offsetof(jit_conv_call_s, field)
567 template <cpu_isa_t isa>
568 void jit_uni_dw_conv_row_f32<isa>::load_src(int ur_w) {
569 int repeats = isa == sse42 ? 2 : 1;
570 for (int i = 0; i < repeats; i++) {
571 for (int ow = 0; ow < ur_w; ow++) {
572 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
574 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
579 template <cpu_isa_t isa>
580 void jit_uni_dw_conv_row_f32<isa>::apply_filter(int ur_w, int kw_size) {
581 auto load_src = [=](Vmm vmm_src, const Xbyak::Address &op) {
582 if (jcp.src_dt == data_type::u8) {
583 uni_vpmovzxbd(vmm_src, op);
585 uni_vmovups(vmm_src, op);
589 auto load_ker = [=](Vmm vmm_ker, const Xbyak::Address &op) {
590 if (jcp.src_dt == data_type::u8) {
591 uni_vpmovsxbd(vmm_ker, op);
593 uni_vmovups(vmm_ker, op);
597 auto compute = [=](Vmm vmm_acc, Vmm vmm_src, Vmm vmm_ker) {
598 if (jcp.src_dt == data_type::u8) {
599 uni_vpmulld(vmm_src, vmm_src, vmm_ker);
600 uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
602 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
606 int ch_blk = jcp.ch_block;
607 int stride_w = jcp.stride_w;
611 int repeats = isa == sse42 ? 2 : 1;
614 jl(exit_label, T_NEAR);
615 for (int i = 0; i < repeats; i++) {
616 for (int kw = 0; kw < kw_size; kw++) {
617 int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
619 Vmm vmm_ker = get_ker_reg(0);
620 load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
622 for (int ow = 0; ow < ur_w; ow++) {
623 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
625 Vmm vmm_src = get_src_reg(0);
626 load_src(vmm_src, ptr[aux_reg_input0 + inp_off * jcp.typesize_in]);
628 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
629 compute(vmm_acc, vmm_src, vmm_ker);
633 add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
636 jl(exit_label, T_NEAR);
637 for (int i = 0; i < repeats; i++) {
638 for (int kw = 0; kw < kw_size; kw++) {
639 int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
641 Vmm vmm_ker = get_ker_reg(0);
642 load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
644 for (int ow = 0; ow < ur_w; ow++) {
645 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
647 Vmm vmm_src = get_src_reg(0);
648 load_src(vmm_src, ptr[aux_reg_input1 + inp_off * jcp.typesize_in]);
650 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
651 compute(vmm_acc, vmm_src, vmm_ker);
655 add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
658 jl(exit_label, T_NEAR);
659 for (int i = 0; i < repeats; i++) {
660 for (int kw = 0; kw < kw_size; kw++) {
661 int ker_off = kw * ch_blk + i*(jcp.ch_block / 2);
663 Vmm vmm_ker = get_ker_reg(0);
664 load_ker(vmm_ker, ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
666 for (int ow = 0; ow < ur_w; ow++) {
667 int inp_off = ow * stride_w * ch_blk + kw * ch_blk + i*(jcp.ch_block / 2);
669 Vmm vmm_src = get_src_reg(0);
670 load_src(vmm_src, ptr[aux_reg_input2 + inp_off * jcp.typesize_in]);
672 Vmm vmm_acc = get_acc_reg(i*ur_w + ow);
673 compute(vmm_acc, vmm_src, vmm_ker);
681 template <cpu_isa_t isa>
682 void jit_uni_dw_conv_row_f32<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in, const Operand &op, bool scalar_load) {
683 Xmm xmm_in = Xmm(vmm_in.getIdx());
690 movq(xmm_in, reg_tmp_64);
692 uni_vmovups(vmm_in, op);
697 movsx(reg_tmp_32, op);
698 movq(xmm_in, reg_tmp_64);
700 uni_vpmovsxbd(vmm_in, op);
705 movzx(reg_tmp_32, op);
706 movq(xmm_in, reg_tmp_64);
708 uni_vpmovzxbd(vmm_in, op);
711 default: assert(!"unsupported data type");
714 if (type_in != data_type::f32)
715 uni_vcvtdq2ps(vmm_in, vmm_in);
718 template <cpu_isa_t isa>
719 void jit_uni_dw_conv_row_f32<isa>::apply_postprocessing(int ur_w, int oc_step) {
720 int repeats = isa == sse42 ? 2 : 1;
722 for (int r = 0; r < repeats; r++) {
723 for (int ow = 0; ow < ur_w; ow++) {
724 if (jcp.src_dt == data_type::u8) {
725 uni_vcvtdq2ps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow));
729 int b_off = r * (jcp.ch_block / 2);
730 cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias + b_off * jcp.typesize_bia], false);
731 uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_bias);
737 for (int r = 0; r < repeats; r++) {
738 int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - r * jcp.ch_block / 2) : oc_step;
739 bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
741 for (int ow = 0; ow < ur_w; ow++) {
742 if (is_scalar_store) {
743 for (int oc = 0; oc < tail_size; oc++) {
744 int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2) + oc;
746 uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
747 cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
749 if (oc >= jcp.ch_block / 2) {
750 vperm2i128(Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), Ymm(vmm_sum.getIdx()), 0x01);
752 uni_vpslldq(vmm_sum, vmm_sum, jcp.typesize_out * (oc % (jcp.ch_block / 2)));
754 uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
757 int o_off = ow * ow_stride_ + r * (jcp.ch_block / 2);
759 uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
760 cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], false);
762 uni_vaddps(get_acc_reg(r * ur_w + ow), get_acc_reg(r * ur_w + ow), vmm_sum);
768 const auto &p = attr_.post_ops_;
769 int eltwise_inj_idx = 0;
770 int depthwise_inj_idx = 0;
771 int start_idx = p.find(primitive_kind::convolution) + 1;
772 for (int i = start_idx; i < p.len_; i++) {
773 auto& post_op = p.entry_[i];
774 if (post_op.is_eltwise()) {
775 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(4, 4 + repeats * ur_w);
777 } else if (post_op.is_depthwise()) {
778 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
779 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
781 add(reg_d_weights, reg_oc_off);
782 add(reg_d_bias, reg_oc_off);
784 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4, 4 + ur_w, reg_d_weights, reg_d_bias);
787 add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
788 add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
790 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(4 + ur_w, 4 + 2 * ur_w, reg_d_weights, reg_d_bias);
798 template <cpu_isa_t isa>
799 void jit_uni_dw_conv_row_f32<isa>::store_dst_typed(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
800 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
801 Xmm xmm_dst = Xmm(vmm_dst.getIdx());
803 switch (jcp.dst_dt) {
807 movq(reg_tmp_64, xmm_dst);
810 uni_vmovups(op, vmm_dst);
814 uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
816 if (isa != sse42 && !scalar_store)
817 vpermq(ymm_dst, ymm_dst, 0x08);
819 uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
822 movq(reg_tmp_64, xmm_dst);
833 uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
835 if (isa != sse42 && !scalar_store)
836 vpermq(ymm_dst, ymm_dst, 0x08);
838 uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
841 movq(reg_tmp_64, xmm_dst);
851 assert(!"unknown dst_dt");
855 template <cpu_isa_t isa>
856 void jit_uni_dw_conv_row_f32<isa>::store_dst(int ur_w, int oc_step) {
857 int repeats = isa == sse42 && oc_step > (jcp.ch_block / 2) ? 2 : 1;
859 for (int i = 0; i < repeats; i++) {
860 for (int ow = 0; ow < ur_w; ow++) {
861 Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
862 if (jcp.dst_dt != data_type::f32 && jcp.dst_dt != data_type::bin) {
863 if (attr_.round_mode_ == round_mode::nearest)
864 uni_vcvtps2dq(vmm_dst, vmm_dst);
865 else if (attr_.round_mode_ == round_mode::down) {
866 uni_vroundps(vmm_dst, vmm_dst, 1);
867 uni_vcvtps2dq(vmm_dst, vmm_dst);
869 assert(!"unimplemented");
874 if (jcp.with_binarization) {
875 int output_step = div_up(ow_stride_, 8);
877 const auto &p = attr_.post_ops_;
878 int binarization_idx = p.find(primitive_kind::binarization);
880 mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
881 add(reg_b_weights, reg_oc_off);
883 for (int ow = 0; ow < ur_w; ow++) {
884 for (int i = 0; i < repeats; i++) {
885 int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
886 mov(reg_b_mask, (1 << tail_size) - 1);
887 uni_vmovups(vmm_thr, ptr[reg_b_weights + i * (jcp.ch_block / 2) * sizeof(float)]);
889 Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
891 uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
894 uni_vmovmskps(reg_tmp_32, vmm_dst);
895 and_(reg_tmp_64, reg_b_mask);
897 uni_vmovmskps(reg_tmp2_32, vmm_dst);
898 and_(reg_tmp2_64, reg_b_mask);
900 or_(reg_tmp_32, reg_tmp2_32);
903 if (i == repeats - 1) {
904 const size_t o_off = ow * output_step;
905 mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
910 for (int i = 0; i < repeats; i++) {
911 int tail_size = isa == sse42 ? nstl::min(jcp.ch_block / 2, oc_step - i * jcp.ch_block / 2) : oc_step;
912 bool is_scalar_store = isa == sse42 ? tail_size < jcp.ch_block / 2 : tail_size < jcp.ch_block;
913 if (is_scalar_store) {
914 for (int ow = 0; ow < ur_w; ow++) {
915 Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
916 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
918 for (int oc = 0; oc < tail_size; oc++) {
919 int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2) + oc;
920 store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
923 psrldq(vmm_dst, jcp.typesize_out);
925 vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
926 vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
931 for (int ow = 0; ow < ur_w; ow++) {
932 int o_off = ow * ow_stride_ + i * (jcp.ch_block / 2);
933 Vmm vmm_dst = get_acc_reg(i * ur_w + ow);
935 store_dst_typed(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
942 template <cpu_isa_t isa>
943 void jit_uni_dw_conv_row_f32<isa>::loop_body(int oc_step) {
944 Label left_pad_label;
945 Label right_pad_label;
946 Label unrolled_w_label;
950 int output_step = jcp.with_binarization ? div_up(ow_stride_, 8) : ow_stride_;
954 int kw = jcp.iw == 1 ? jcp.kw - 2 : jcp.kw - 1;
956 mov(aux_reg_input0, reg_input0);
957 mov(aux_reg_input1, reg_input1);
958 mov(aux_reg_input2, reg_input2);
959 mov(aux_reg_kernel, reg_kernel);
960 add(aux_reg_kernel, jcp.ch_block*jcp.typesize_in);
963 apply_filter(ur_w, kw);
964 apply_postprocessing(ur_w, oc_step);
965 store_dst(ur_w, oc_step);
967 add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
968 add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
969 add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * (jcp.stride_w-1));
970 add(reg_output, jcp.typesize_out * ur_w * output_step);
975 L(unrolled_w_label); {
980 jle(tail_w_label, T_NEAR);
982 mov(aux_reg_input0, reg_input0);
983 mov(aux_reg_input1, reg_input1);
984 mov(aux_reg_input2, reg_input2);
985 mov(aux_reg_kernel, reg_kernel);
988 apply_filter(ur_w, kw);
989 apply_postprocessing(ur_w, oc_step);
990 store_dst(ur_w, oc_step);
992 add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
993 add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
994 add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
995 add(reg_output, jcp.typesize_out * ur_w * output_step);
998 jmp(unrolled_w_label, T_NEAR);
1005 cmp(reg_ur_w, ur_w);
1007 jle(right_pad_label, T_NEAR);
1009 jle(exit_label, T_NEAR);
1011 mov(aux_reg_input0, reg_input0);
1012 mov(aux_reg_input1, reg_input1);
1013 mov(aux_reg_input2, reg_input2);
1014 mov(aux_reg_kernel, reg_kernel);
1017 apply_filter(ur_w, kw);
1018 apply_postprocessing(ur_w, oc_step);
1019 store_dst(ur_w, oc_step);
1021 add(reg_input0, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1022 add(reg_input1, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1023 add(reg_input2, jcp.typesize_in * ur_w * jcp.ch_block * jcp.stride_w);
1024 add(reg_output, jcp.typesize_out * ur_w * output_step);
1026 sub(reg_ur_w, ur_w);
1027 jmp(tail_w_label, T_NEAR);
1031 L(right_pad_label); {
1033 int kw = jcp.kw - ((jcp.stride_w == 1) ? 1 : jcp.iw % jcp.stride_w);
1035 mov(aux_reg_input0, reg_input0);
1036 mov(aux_reg_input1, reg_input1);
1037 mov(aux_reg_input2, reg_input2);
1038 mov(aux_reg_kernel, reg_kernel);
1041 apply_filter(ur_w, kw);
1042 apply_postprocessing(ur_w, oc_step);
1043 store_dst(ur_w, oc_step);
1045 sub(reg_ur_w, ur_w);
1052 template <cpu_isa_t isa>
1053 void jit_uni_dw_conv_row_f32<isa>::generate() {
1054 const auto &p = attr_.post_ops_;
1055 int start_idx = p.find(primitive_kind::convolution) + 1;
1056 for (int i = start_idx; i < p.len_; i++) {
1057 auto &post_op = p.entry_[i];
1058 if (post_op.is_eltwise()) {
1059 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
1061 post_op.eltwise.alg,
1062 post_op.eltwise.alpha,
1063 post_op.eltwise.beta
1065 } else if (post_op.is_depthwise()) {
1066 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
1068 post_op.depthwise.alg
1075 mov(reg_input0, ptr[this->param1 + GET_OFF_DW(src_row0)]);
1076 mov(reg_input1, ptr[this->param1 + GET_OFF_DW(src_row1)]);
1077 mov(reg_input2, ptr[this->param1 + GET_OFF_DW(src_row2)]);
1078 mov(reg_output, ptr[this->param1 + GET_OFF_DW(dst)]);
1079 mov(reg_kernel, ptr[this->param1 + GET_OFF_DW(filt)]);
1081 mov(reg_bias, ptr[this->param1 + GET_OFF_DW(bias)]);
1082 mov(reg_kh, ptr[this->param1 + GET_OFF_DW(kh_padding)]);
1083 mov(reg_ur_w, ptr[this->param1 + GET_OFF_DW(ur_w)]);
1084 mov(reg_oc_work, ptr[this->param1 + GET_OFF_DW(oc_work)]);
1085 mov(reg_oc_off, ptr[this->param1 + GET_OFF_DW(oc_off)]);
1090 cmp(reg_oc_work, jcp.ch_block);
1091 jl(tail_label, T_NEAR);
1093 loop_body(jcp.ch_block);
1094 jmp(exit_label, T_NEAR);
1098 if (jcp.oc % jcp.ch_block != 0)
1099 loop_body(jcp.oc % jcp.ch_block);
1105 for (auto& inj : eltwise_injectors)
1106 inj->prepare_table();
1109 template <cpu_isa_t isa>
1110 bool jit_uni_dw_conv_row_f32<isa>::post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1111 const auto &p = attr.post_ops_;
1113 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
1114 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
1115 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
1116 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
1117 auto is_binarization = [&](int idx) { return p.entry_[idx].is_binarization(); };
1119 int start_idx = p.find(primitive_kind::convolution) + 1;
1121 switch (p.len_ - start_idx) {
1122 case 0: return true; // no post_ops
1123 case 1: return is_simple(start_idx) || is_sum(start_idx) || is_binarization(start_idx);
1124 case 2: return (is_sum(start_idx) && is_simple(start_idx+1)) || (is_simple(start_idx) && is_simple(start_idx+1)) ||
1125 (is_simple(start_idx) && is_binarization(start_idx+1));
1126 case 3: return (is_sum(start_idx) && is_simple(start_idx+1) && is_simple(start_idx+2));
1127 default: return false;
1133 template <cpu_isa_t isa>
1134 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_1x1_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1135 const primitive_attr_t &attr) {
1136 if (!mayiuse(isa)) return status::unimplemented;
1137 const int simd_w = isa == avx512_common ? 16 : 8;
1139 const auto &p = attr.post_ops_;
1141 int dw_conv_ind = p.find(primitive_kind::convolution);
1142 jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1144 jcp_dw.ch_block = simd_w;
1145 jcp_dw.with_bias = true;
1147 jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1148 jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1151 jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1152 jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1153 jcp_dw.oh = jcp.dw_conv_oh;
1154 jcp_dw.ow = jcp.dw_conv_ow;
1155 jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1156 jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1157 jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1158 jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1160 if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1161 return status::unimplemented;
1163 if (!post_ops_ok(jcp_dw, attr))
1164 return status::unimplemented;
1168 jcp_dw.src_dt = jcp.src_dt;
1169 jcp_dw.dst_dt = jcp.dst_dt;
1170 jcp_dw.bia_dt = jcp.bia_dt;
1171 jcp_dw.typesize_in = (int)types::data_type_size(jcp.src_dt);
1172 jcp_dw.typesize_bia = (int)types::data_type_size(jcp.bia_dt);
1173 jcp_dw.typesize_out = (int)types::data_type_size(jcp.dst_dt);
1175 if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1176 return status::unimplemented;
1178 return status::success;
1181 template <cpu_isa_t isa>
1182 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1183 const primitive_attr_t &attr) {
1184 if (!mayiuse(isa)) return status::unimplemented;
1185 const int simd_w = isa == avx512_common ? 16 : 8;
1187 const auto &p = attr.post_ops_;
1189 int dw_conv_ind = p.find(primitive_kind::convolution);
1190 jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1192 jcp_dw.ch_block = simd_w;
1193 jcp_dw.with_bias = true;
1195 jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1196 jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1199 jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1200 jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1201 jcp_dw.oh = jcp.dw_conv_oh;
1202 jcp_dw.ow = jcp.dw_conv_ow;
1203 jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1204 jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1205 jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1206 jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1208 if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1209 return status::unimplemented;
1211 if (!post_ops_ok(jcp_dw, attr))
1212 return status::unimplemented;
1216 jcp_dw.src_dt = jcp.dst_dt;
1217 jcp_dw.dst_dt = jcp.dst_dt;
1218 jcp_dw.bia_dt = jcp.bia_dt;
1219 jcp_dw.typesize_in = (int)types::data_type_size(jcp.src_dt);
1220 jcp_dw.typesize_bia = (int)types::data_type_size(jcp.bia_dt);
1221 jcp_dw.typesize_out = (int)types::data_type_size(jcp.dst_dt);
1223 if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1224 return status::unimplemented;
1226 return status::success;
1229 template <cpu_isa_t isa>
1230 status_t jit_uni_dw_conv_row_f32<isa>::init_conf(jit_bin_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw,
1231 const primitive_attr_t &attr) {
1232 if (!mayiuse(isa)) return status::unimplemented;
1233 const int simd_w = isa == avx512_common ? 16 : 8;
1235 const auto &p = attr.post_ops_;
1237 int dw_conv_ind = p.find(primitive_kind::convolution);
1238 jcp_dw.with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
1239 jcp_dw.with_binarization = p.find(primitive_kind::binarization, dw_conv_ind) != -1;
1241 jcp_dw.ch_block = simd_w;
1242 jcp_dw.with_bias = true;
1244 jcp_dw.kh = p.entry_[dw_conv_ind].dw_conv.ker_h;
1245 jcp_dw.kw = p.entry_[dw_conv_ind].dw_conv.ker_w;
1248 jcp_dw.ih = p.entry_[dw_conv_ind].dw_conv.in_h;
1249 jcp_dw.iw = p.entry_[dw_conv_ind].dw_conv.in_w;
1250 jcp_dw.oh = jcp.dw_conv_oh;
1251 jcp_dw.ow = jcp.dw_conv_ow;
1252 jcp_dw.stride_h = p.entry_[dw_conv_ind].dw_conv.str_h;
1253 jcp_dw.stride_w = p.entry_[dw_conv_ind].dw_conv.str_w;
1254 jcp_dw.conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
1255 jcp_dw.conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
1257 if (jcp_dw.kh != 3 || jcp_dw.kw != 3)
1258 return status::unimplemented;
1260 if (!post_ops_ok(jcp_dw, attr))
1261 return status::unimplemented;
1265 jcp_dw.src_dt = mkldnn_f32;
1266 jcp_dw.dst_dt = jcp_dw.with_binarization ? mkldnn_bin : mkldnn_f32;
1267 jcp_dw.bia_dt = mkldnn_f32;
1268 jcp_dw.typesize_in = (int)types::data_type_size(jcp_dw.src_dt);
1269 jcp_dw.typesize_bia = (int)types::data_type_size(jcp_dw.bia_dt);
1270 jcp_dw.typesize_out = (int)types::data_type_size(jcp_dw.dst_dt);
1272 if (jcp_dw.src_dt != mkldnn_f32 && jcp_dw.src_dt != mkldnn_u8)
1273 return status::unimplemented;
1275 return status::success;
1278 template struct jit_uni_dw_conv_row_f32<avx512_common>;
1279 template struct jit_uni_dw_conv_row_f32<avx2>;
1280 template struct jit_uni_dw_conv_row_f32<sse42>;