1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <common/primitive_attr.hpp>
18 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
22 #include "cpu_memory.hpp"
24 #include "jit_uni_bin_conv_kernel.hpp"
26 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
32 using namespace mkldnn::impl::prop_kind;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
37 using namespace Xbyak;
39 template <cpu_isa_t isa>
40 void jit_uni_bin_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in, const Operand &op, bool scalar_load) {
41 Xmm xmm_in = Xmm(vmm_in.getIdx());
48 movq(xmm_in, reg_tmp_64);
50 uni_vmovups(vmm_in, op);
55 movsx(reg_tmp_32, op);
56 movq(xmm_in, reg_tmp_64);
58 uni_vpmovsxbd(vmm_in, op);
63 movzx(reg_tmp_32, op);
64 movq(xmm_in, reg_tmp_64);
66 uni_vpmovzxbd(vmm_in, op);
69 default: assert(!"unsupported data type");
72 if (type_in != data_type::f32)
73 uni_vcvtdq2ps(vmm_in, vmm_in);
76 template <cpu_isa_t isa>
77 void jit_uni_bin_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
78 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
79 Xmm xmm_dst = Xmm(vmm_dst.getIdx());
85 movq(reg_tmp_64, xmm_dst);
88 uni_vmovups(op, vmm_dst);
92 uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
94 if (isa != sse42 && !scalar_store)
95 vpermq(ymm_dst, ymm_dst, 0x08);
97 uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
100 movq(reg_tmp_64, xmm_dst);
111 uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
113 if (isa != sse42 && !scalar_store)
114 vpermq(ymm_dst, ymm_dst, 0x08);
116 uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
119 movq(reg_tmp_64, xmm_dst);
130 assert(!"unknown dst_dt");
134 template <cpu_isa_t isa>
135 void jit_uni_bin_conv_fwd_kernel<isa>::apply_filter(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step,
136 int ic_blocks, bool last_icb, bool h_padded)
140 int stride_w = jcp.stride_w;
141 int dilate_w = jcp.dilate_w + 1;
142 int ic_blk = jcp.ic_block;
143 int oc_blk = jcp.oc_block;
145 int repeats = isa == sse42 && oc_step > (oc_blk / 2) ? 2 : 1;
148 for (int ki = 0; ki < kw; ki++) {
149 int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
150 int jj_end = ur_w - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w));
152 int _start = (!jcp.exclude_pad) ? 0 : jj_start;
153 int _end = (!jcp.exclude_pad) ? ur_w : jj_end;
155 for (int ifm2 = 0; ifm2 < ic_blocks; ifm2++) {
156 for (int jj = _start; jj < _end; jj++) {
157 int inp_off = ((ki*dilate_w + jj*stride_w - pad_l)*div_up(jcp.ic, nbits) + ifm2 * div_up(ic_blk, nbits)) * jcp.typesize_in;
159 if (h_padded || jj < jj_start || jj >= jj_end) {
160 uni_vmovups(vmm_src, ptr[reg_table + 8 * vlen]);
162 uni_vpbroadcastd(vmm_src, ptr[aux1_reg_input + inp_off]);
165 for (int r = 0; r < repeats; r++) {
166 for (int ii = 0; ii < oc_blocks; ii++) {
167 int ker_off = (ifm2 * kw * div_up(ic_blk, nbits) * oc_blk
168 + ii * jcp.nb_ic * div_up(ic_blk, nbits) * kh * kw * oc_blk
169 + ki * div_up(ic_blk, nbits) * oc_blk + r * div_up(ic_blk, nbits) * (oc_blk / 2)) * jcp.typesize_in;
171 uni_vmovups(vmm_tmp, ptr[aux1_reg_kernel + ker_off]);
173 uni_vpxor(vmm_tmp, vmm_tmp, vmm_src);
174 if (jcp.ic_padded != jcp.ic && last_icb && ifm2 == (ic_blocks - 1))
175 uni_vandps(vmm_tmp, vmm_tmp, ptr[reg_table + 7 * vlen]);
177 if (mayiuse(avx512_vpopcnt)) {
178 vpopcntd(vmm_tmp, vmm_tmp);
179 uni_vpaddd(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
180 Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), vmm_tmp);
183 movups(vmm_tmp1, vmm_tmp);
184 pand(vmm_tmp1, vmm_mask);
186 uni_vandps(vmm_tmp1, vmm_mask, vmm_tmp);
189 uni_vpsrld(vmm_tmp, vmm_tmp, 4);
190 uni_vandps(vmm_tmp, vmm_tmp, vmm_mask);
193 movups(vmm_tmp2, vmm_lookup);
194 pshufb(vmm_tmp2, vmm_tmp);
195 movups(vmm_tmp, vmm_lookup);
196 pshufb(vmm_tmp, vmm_tmp1);
197 paddb(vmm_tmp, vmm_tmp2);
199 uni_vpshufb(vmm_tmp, vmm_lookup, vmm_tmp);
200 uni_vpshufb(vmm_tmp1, vmm_lookup, vmm_tmp1);
201 uni_vpaddb(vmm_tmp, vmm_tmp, vmm_tmp1);
204 if (mayiuse(avx512_core_vnni)) {
205 vpdpbusd(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), vmm_tmp, vmm_one_u8);
207 uni_vpmaddubsw(vmm_tmp, vmm_tmp, vmm_one_u8);
208 uni_vpmaddwd(vmm_tmp, vmm_tmp, vmm_one_s16);
209 uni_vpaddd(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
210 Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), vmm_tmp);
220 template <cpu_isa_t isa>
221 void jit_uni_bin_conv_fwd_kernel<isa>::oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step, bool h_padded) {
225 int inp_mult = div_up(jcp.ic_block, nbits);
226 int out_mult = jcp.oc_block;
231 mov(aux1_reg_input, aux_reg_input);
232 mov(aux1_reg_kernel, aux_reg_kernel);
234 mov(reg_icb_iter, jcp.nb_ic);
237 cmp(reg_icb_iter, 1);
238 jle(icb_tail, T_NEAR);
240 apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, 1, false, h_padded);
242 add(aux1_reg_input, inp_mult * jcp.typesize_in);
243 add(aux1_reg_kernel, kw * inp_mult * out_mult * jcp.typesize_in);
244 sub(reg_icb_iter, 1);
245 jmp(icb_main_loop, T_NEAR);
250 apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, 1, true, h_padded);
253 template <cpu_isa_t isa>
254 void jit_uni_bin_conv_fwd_kernel<isa>::kh_loop(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step) {
257 int dilate_h = jcp.dilate_h + 1;
260 const int inp_mult = dilate_h * div_up(jcp.ic, nbits);
262 Label t_overflow_label, no_t_overflow_label,
263 b_overflow_label, no_b_overflow_label;
265 mov(aux_reg_input, reg_input);
266 mov(aux_reg_kernel, reg_kernel_base);
268 uni_vmovups(vmm_lookup, ptr[reg_table + 0 * vlen]);
269 uni_vmovups(vmm_mask, ptr[reg_table + 1 * vlen]);
270 uni_vmovups(vmm_one_u8, ptr[reg_table + 5 * vlen]);
271 uni_vmovups(vmm_one_s16, ptr[reg_table + 6 * vlen]);
273 if (!jcp.exclude_pad) {
274 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
275 cmp(reg_overflow, 0);
276 je(no_t_overflow_label, T_NEAR);
277 L(t_overflow_label); {
278 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
280 add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
282 cmp(reg_overflow, 0);
283 jg(t_overflow_label, T_NEAR);
285 L(no_t_overflow_label);
289 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
290 if (!jcp.exclude_pad || (jcp.exclude_pad &&
291 (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
293 je(skip_kh_loop, T_NEAR);
299 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, false);
301 add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
302 add(aux_reg_input, jcp.typesize_in * iw * inp_mult);
306 jg(kh_label, T_NEAR);
311 if (!jcp.exclude_pad) {
312 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
313 cmp(reg_overflow, 0);
314 je(no_b_overflow_label, T_NEAR);
315 L(b_overflow_label); {
316 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
318 add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
320 cmp(reg_overflow, 0);
321 jg(b_overflow_label, T_NEAR);
323 L(no_b_overflow_label);
327 template <cpu_isa_t isa>
328 void jit_uni_bin_conv_fwd_kernel<isa>::width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step)
331 int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
333 for (int r = 0; r < repeats; r++)
334 for (int ii = 0; ii < oc_blocks; ii++)
335 for (int jj = 0; jj < ur_w; jj++)
336 uni_vpxor(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
337 Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
338 Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj));
340 kh_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
342 if (isa == avx512_common && oc_step != jcp.oc_block) {
343 int mask = (1 << oc_step) - 1;
344 mov(reg_tmp_32, mask);
345 kmovw(ktail_mask, reg_tmp_32);
348 const auto &p = attr_.post_ops_;
349 for (int r = 0; r < repeats; r++) {
350 int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
351 bool is_scalar_store = isa == sse42 ? tail_size < jcp.oc_block / 2 : tail_size < jcp.oc_block;
354 auto kw_padding = make_vla<int>(ur_w);
356 int kw_padding[ur_w];
359 if (jcp.exclude_pad) {
360 mov(reg_tmp_32, jcp.ic);
361 imul(reg_tmp_32, ptr[param1 + GET_OFF(kh_padding)]);
363 for (int jj = 0; jj < ur_w; jj++)
366 for (int ki = 0; ki < jcp.kw; ki++) {
367 int jj_start = nstl::max(0, div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
368 int jj_end = ur_w - nstl::max(0, div_up(ki * (jcp.dilate_w + 1) + pad_r -
369 (jcp.kw - 1) * (jcp.dilate_w + 1), jcp.stride_w));
370 for (int jj = jj_start; jj < jj_end; jj++) {
375 uni_vmovups(vmm_shift, ptr[reg_table + 4 * vlen]);
377 uni_vmovups(vmm_scale, ptr[reg_table + 3 * vlen]);
379 for (int jj = 0; jj < ur_w; jj++) {
380 if (jcp.exclude_pad) {
381 mov(reg_shift, kw_padding[jj]);
382 imul(reg_shift, reg_tmp_32);
383 movq(Xmm(vmm_shift.getIdx()), reg_shift);
384 uni_vbroadcastss(vmm_shift, Xmm(vmm_shift.getIdx()));
385 uni_vcvtdq2ps(vmm_shift, vmm_shift);
388 for (int ii = 0; ii < oc_blocks; ii++) {
389 uni_vcvtdq2ps(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj));
390 uni_vfmadd213ps(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), vmm_scale, vmm_shift);
394 int eltwise_inj_idx = 0;
395 int depthwise_inj_idx = 0;
396 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
397 for (int i = 0; i < end_idx; i++) {
398 int start_idx = 1 + r * jcp.ur_w * jcp.nb_oc_blocking;
400 auto& post_op = p.entry_[i];
401 if (post_op.is_eltwise()) {
402 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + oc_blocks * ur_w);
404 } else if (post_op.is_depthwise()) {
407 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
408 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
410 add(reg_d_weights, reg_oc_off);
411 add(reg_d_bias, reg_oc_off);
414 add(reg_d_weights, (jcp.oc_block / 2) * sizeof(float));
415 add(reg_d_bias, (jcp.oc_block / 2) * sizeof(float));
418 for (int ii = 0; ii < oc_blocks; ii++) {
419 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(start_idx + ur_w * ii,
420 start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
422 add(reg_d_weights, jcp.oc_block * sizeof(float));
423 add(reg_d_bias, jcp.oc_block * sizeof(float));
429 } else if (post_op.is_sum(false)) {
430 for (int ii = 0; ii < oc_blocks; ii++) {
431 for (int jj = 0; jj < ur_w; jj++) {
432 Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
434 if (is_scalar_store) {
435 if (isa == avx512_common) {
436 int o_off = jj * jcp.oc * jcp.ngroups;
438 Vmm vmm_in = vmm_sum | ktail_mask | T_z;
440 vmovups(vmm_in, ptr[reg_output + o_off * jcp.typesize_out]);
441 uni_vaddps(vmm_dst, vmm_dst, vmm_sum);
443 for (int oc = 0; oc < tail_size; oc++) {
444 int o_off = jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2) + oc;
446 uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
447 cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
449 if (oc < jcp.oc_block / 2) {
450 uni_vpslldq(vmm_sum, vmm_sum, oc * sizeof(float));
452 Ymm ymm_prev_dst = Ymm(vmm_sum.getIdx());
453 vperm2i128(ymm_prev_dst, ymm_prev_dst, ymm_prev_dst, 0x01);
454 vpslldq(vmm_sum, vmm_sum, (oc - jcp.oc_block / 2) * sizeof(float));
457 uni_vaddps(vmm_dst, vmm_dst, vmm_sum);
461 size_t o_off = ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
463 cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], false);
464 uni_vaddps(vmm_dst, vmm_dst, vmm_sum);
472 if (jcp.with_binarization) {
473 int binarization_idx = p.find(primitive_kind::binarization);
477 mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
478 mov(reg_b_out_mask, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.output_mask_data));
479 add(reg_b_weights, reg_oc_off);
480 add(reg_b_out_mask, reg_oc_off);
484 for (int ii = 0; ii < oc_blocks; ii++) {
485 for (int jj = 0; jj < ur_w; jj++) {
486 for (int r = 0; r < repeats; r++) {
487 int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
488 mov(reg_b_mask, (1 << tail_size) - 1);
489 uni_vmovups(vmm_thr, ptr[reg_b_weights + (ii * jcp.oc_block + r * (jcp.oc_block / 2)) * sizeof(float)]);
490 uni_vmovups(vmm_out_mask, ptr[reg_b_out_mask + (ii * jcp.oc_block + r * (jcp.oc_block / 2)) * sizeof(float)]);
492 Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
494 if (isa == avx512_common) {
495 vcmpps(bin_mask0, vmm_dst, vmm_thr, _cmp_gt_os);
496 vptestmd(bin_mask1, vmm_out_mask, vmm_out_mask);
497 kxnorw(bin_mask0, bin_mask0, bin_mask1);
499 uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
500 uni_vpcmpeqd(vmm_dst, vmm_dst, vmm_out_mask);
504 if (isa == avx512_common) {
505 kmovw(reg_tmp_32, bin_mask0);
507 uni_vmovmskps(reg_tmp_32, vmm_dst);
509 and_(reg_tmp_64, reg_b_mask);
511 uni_vmovmskps(reg_tmp2_32, vmm_dst);
512 and_(reg_tmp2_64, reg_b_mask);
514 or_(reg_tmp_32, reg_tmp2_32);
517 if (r == repeats - 1) {
518 if (isa == avx512_common && oc_step > nbits) {
519 const size_t o_off = (2 * ii + jj * div_up(jcp.oc, nbits));
520 mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_16);
522 const size_t o_off = (ii + jj * div_up(jcp.oc, nbits));
523 mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
530 for (int r = 0; r < repeats; r++) {
531 int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
532 bool is_scalar_store = isa == sse42 ? tail_size < jcp.oc_block / 2 : tail_size < jcp.oc_block;
533 if (is_scalar_store) {
534 for (int jj = 0; jj < ur_w; jj++) {
535 Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + jj);
537 if (isa == avx512_common) {
539 if (jcp.with_dw_conv)
540 o_off = jj * jcp.oc_block;
542 o_off = jj * jcp.oc * jcp.ngroups;
544 uni_vmovups(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst | ktail_mask);
546 for (int oc = 0; oc < tail_size; oc++) {
548 if (jcp.with_dw_conv)
549 o_off = jj * jcp.oc_block + oc + r * (jcp.oc_block / 2);
551 o_off = jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2) + oc;
553 store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
556 psrldq(vmm_dst, jcp.typesize_out);
558 Ymm ymm_dst = Ymm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + jj);
560 vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
561 vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
567 for (int ii = 0; ii < oc_blocks; ii++) {
568 for (int jj = 0; jj < ur_w; jj++) {
569 Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
572 if (jcp.with_dw_conv)
573 o_off = ((size_t) ii * jcp_dw_conv.kh * jcp.ow + jj) * jcp.oc_block +
574 r * (jcp.oc_block / 2);
576 o_off = ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
578 store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
586 template <cpu_isa_t isa>
587 inline void jit_uni_bin_conv_fwd_kernel<isa>::solve_common(int oc_blocks, int oc_step)
590 int ur_w_tail = jcp.ur_w_tail;
591 int n_oi = jcp.ow / ur_w;
594 int dilate_w = jcp.dilate_w + 1;
595 int str_w = jcp.stride_w;
598 const int inp_mult = div_up(jcp.ic, nbits);
599 const int out_mult = jcp.with_dw_conv ? jcp.oc_block : jcp.with_binarization ? div_up(jcp.oc, nbits) : jcp.oc;
601 int l_pad = jcp.l_pad;
602 int r_pad = nstl::max(0, (jcp.ow - 1) * str_w + (kw - 1) * dilate_w
604 int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
606 if (r_pad1 > 0) n_oi--;
608 mov(reg_input, reg_input_base);
609 mov(reg_output, reg_output_base);
611 push(reg_input_base);
612 push(reg_output_base);
618 if (n_oi < 0 && r_pad1 > 0)
619 width_blk_step(ur_w, l_pad, r_pad1, oc_blocks, oc_step); // "lrpad"
621 width_blk_step(ur_w, l_pad, 0, oc_blocks, oc_step); // "lpad"
622 add(reg_input, jcp.typesize_in * (ur_w * str_w - l_pad) * inp_mult);
623 add(reg_output, jcp.typesize_out * ur_w * out_mult);
627 xor_(oi_iter, oi_iter);
632 width_blk_step(ur_w, 0, 0, oc_blocks, oc_step); // "middle"
633 add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
634 add(reg_output, jcp.typesize_out * ur_w * out_mult);
638 jl(ow_loop_label, T_NEAR);
641 if (r_pad1 > 0 && n_oi >=0) {
642 width_blk_step(ur_w, 0, r_pad1, oc_blocks, oc_step); // "rpad"
643 add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
644 add(reg_output, jcp.typesize_out * ur_w * out_mult);
648 width_blk_step(ur_w_tail, 0, r_pad, oc_blocks, oc_step); // "tail"
652 pop(reg_output_base);
656 template <cpu_isa_t isa>
657 void jit_uni_bin_conv_fwd_kernel<isa>::generate()
659 const auto &p = attr_.post_ops_;
660 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
661 for (int i = 0; i < end_idx; i++) {
662 auto &post_op = p.entry_[i];
663 if (post_op.is_eltwise()) {
664 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
667 post_op.eltwise.alpha,
670 } else if (post_op.is_depthwise()) {
671 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
673 post_op.depthwise.alg
680 mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
681 mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
682 mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
684 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
685 mov(reg_oc_work, ptr[this->param1 + GET_OFF(oc_work)]);
687 mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
688 mov(reg_table, l_table);
690 Label main_loop_label;
694 cmp(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
695 jne(main_loop_label, T_NEAR);
697 solve_common(jcp.nb_oc_blocking, jcp.oc_block);
699 sub(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
701 jmp(exit_label, T_NEAR);
705 L(main_loop_label); {
706 cmp(reg_oc_work, jcp.oc_block);
707 jl(tail_label, T_NEAR);
709 solve_common(1, jcp.oc_block);
711 sub(reg_oc_work, jcp.oc_block);
712 add(reg_kernel_base, jcp.oc_block * jcp.nb_ic * jcp.kh * jcp.kw * div_up(jcp.ic_block, nbits) * jcp.typesize_in);
714 if (jcp.with_dw_conv) {
715 add(reg_output_base, jcp.oc_block * jcp_dw_conv.kh * jcp.ow * jcp.typesize_out);
717 if (jcp.with_binarization)
718 add(reg_output_base, div_up(jcp.oc_block, nbits) * jcp.typesize_out);
720 add(reg_output_base, jcp.oc_block * jcp.typesize_out);
723 add(reg_oc_off, jcp.oc_block * sizeof(float));
725 jmp(main_loop_label, T_NEAR);
730 if (jcp.oc % jcp.oc_block != 0)
731 solve_common(1, jcp.oc % jcp.oc_block);
739 for (auto& inj : eltwise_injectors)
740 inj->prepare_table();
743 template <cpu_isa_t isa>
744 void jit_uni_bin_conv_fwd_kernel<isa>::prepare_table() {
745 const unsigned int cvals[] = {
746 0x02010100, // 0 1 1 2
747 0x03020201, // 1 2 2 3
748 0x03020201, // 1 2 2 3
749 0x04030302, // 2 3 3 4
757 size_t simd_w = vlen / sizeof(int32_t);
762 for (size_t d = 0; d < simd_w; ++d) {
766 for (size_t d = 0; d < simd_w; ++d) {
770 for (size_t d = 0; d < simd_w; ++d) {
774 for (size_t d = 0; d < simd_w; ++d) {
779 for (size_t d = 0; d < simd_w; ++d) {
780 dd(float2int(jcp.ic * jcp.kw * jcp.kh));
784 for (size_t d = 0; d < simd_w; ++d) {
788 for (size_t d = 0; d < simd_w; ++d) {
792 for (size_t d = 0; d < simd_w; ++d) {
793 uint32_t mask = 0xffffffff >> (jcp.ic_padded - jcp.ic);
797 for (size_t d = 0; d < simd_w; ++d) {
798 uint32_t val = jcp.pad_value == 1.0f ? 0xffffffff : 0x00000000;
803 template <cpu_isa_t isa>
804 bool jit_uni_bin_conv_fwd_kernel<isa>::post_ops_ok(jit_bin_conv_conf_t &jcp, const primitive_attr_t &attr) {
805 const auto &p = attr.post_ops_;
807 int dw_conv_idx = p.find(primitive_kind::convolution);
808 bool with_dw_conv = dw_conv_idx != -1;
810 auto all_post_ops_supported = [&]() {
813 int end_idx = with_dw_conv ? dw_conv_idx : p.len_;
814 for (int i = 0; i < end_idx; i++) {
815 ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise,
816 primitive_kind::binarization);
820 auto contain = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind, 0, dw_conv_idx) != -1; };
821 auto position = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind, 0, dw_conv_idx); };
822 auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind, 0, dw_conv_idx); };
824 return all_post_ops_supported() &&
825 count(primitive_kind::sum) <= 1 &&
826 count(primitive_kind::binarization) <= 1 &&
827 IMPLICATION(contain(primitive_kind::binarization), position(primitive_kind::binarization) == p.len_-1 &&
828 !contain(primitive_kind::sum)) &&
829 IMPLICATION(with_dw_conv, !contain(primitive_kind::sum) && !contain(primitive_kind::binarization));
832 template <cpu_isa_t isa>
833 status_t jit_uni_bin_conv_fwd_kernel<isa>::init_conf(jit_bin_conv_conf_t &jcp,
834 const binary_convolution_desc_t &cd, const memory_desc_wrapper &src_d,
835 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, const primitive_attr_t &attr)
837 if (!mayiuse(isa)) return status::unimplemented;
839 jcp.prop_kind = cd.prop_kind;
841 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
843 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
845 if (jcp.ngroups != 1)
846 return status::unimplemented;
848 jcp.mb = src_d.dims()[0];
850 int simd_w = isa == avx512_common ? 16 : 8;
852 jcp.ic = src_d.dims()[1] / jcp.ngroups;
853 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
855 jcp.oc_padded = rnd_up(jcp.oc, simd_w);
857 jcp.ih = src_d.dims()[2];
858 jcp.iw = src_d.dims()[3];
859 jcp.oh = dst_d.dims()[2];
860 jcp.ow = dst_d.dims()[3];
862 jcp.kh = weights_d.dims()[with_groups + 2];
863 jcp.kw = weights_d.dims()[with_groups + 3];
865 jcp.t_pad = cd.padding[0][0];
866 jcp.l_pad = cd.padding[0][1];
868 jcp.stride_h = cd.strides[0];
869 jcp.stride_w = cd.strides[1];
871 jcp.dilate_h = cd.dilates[0];
872 jcp.dilate_w = cd.dilates[1];
874 jcp.src_fmt = src_d.format();
876 if (!post_ops_ok(jcp, attr))
877 return status::unimplemented;
879 jcp.pad_value = cd.pad_value;
880 jcp.exclude_pad = jcp.pad_value == 0.0f;
882 jcp.src_dt = cd.src_desc.data_type;
883 jcp.bia_dt = mkldnn_f32;
884 jcp.dst_dt = cd.dst_desc.data_type;
886 const auto &p = attr.post_ops_;
887 int dw_conv_ind = p.find(primitive_kind::convolution);
888 jcp.with_dw_conv = dw_conv_ind != -1;
889 if (jcp.with_dw_conv) {
890 jcp.dw_conv_oh = jcp.oh;
891 jcp.dw_conv_ow = jcp.ow;
892 jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
893 jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
895 jcp.dw_conv_dst_dt = jcp.dst_dt;
896 jcp.dst_dt = p.entry_[dw_conv_ind].dw_conv.in_dt;
898 jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
899 jcp.with_binarization = p.find(primitive_kind::binarization, 0, dw_conv_ind) != -1;
902 return status::unimplemented;
904 auto desired_weights_format = isa == avx512_common ? OhIw16o32i : OhIw8o32i;
906 && src_d.format() == nhwc
907 && weights_d.format() == desired_weights_format
908 && dst_d.format() == nhwc;
909 if (!args_ok) return status::unimplemented;
912 jcp.ur_w = isa == avx512_common ? 4 : 2;
913 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
914 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
917 jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
918 jcp.ic_padded = rnd_up(jcp.ic, jcp.ic_block);
920 jcp.oc_block = simd_w;
921 jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
923 jcp.nb_ic_blocking = 1;
924 jcp.nb_oc_blocking = nstl::min(isa == sse42 ? 2 : isa == avx2 ? 4 : 6, jcp.nb_oc);
926 jcp.typesize_in = types::data_type_size(jcp.src_dt);
927 jcp.typesize_out = types::data_type_size(jcp.dst_dt);
928 jcp.typesize_acc = sizeof(int32_t);
931 && jcp.l_pad <= jcp.ur_w
932 && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
933 || (jcp.stride_w == 1 && jcp.stride_h == 1));
934 if (!args_ok) return status::unimplemented;
936 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
937 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
938 if (r_pad_no_tail > jcp.ur_w)
939 return status::unimplemented;
941 if (jcp.l_pad > jcp.ur_w) return status::unimplemented;
943 return status::success;
946 template <cpu_isa_t isa>
947 void jit_uni_bin_conv_fwd_kernel<isa>::init_scratchpad(
948 memory_tracking::registrar_t &scratchpad, const jit_bin_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw_conv) {
949 if (jcp.with_dw_conv) {
950 const int nthreads = mkldnn_get_max_threads();
951 size_t dw_conv_buffer_size_ = (size_t)jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block * jcp.nb_oc_blocking;
952 scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
954 if (jcp.oc != jcp.oc_padded)
955 scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc_padded);
959 template struct jit_uni_bin_conv_fwd_kernel<sse42>;
960 template struct jit_uni_bin_conv_fwd_kernel<avx2>;
961 template struct jit_uni_bin_conv_fwd_kernel<avx512_common>;