1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <common/primitive_attr.hpp>
18 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
22 #include "cpu_memory.hpp"
24 #include "jit_uni_bin_conv_kernel.hpp"
26 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
32 using namespace mkldnn::impl::prop_kind;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
37 using namespace Xbyak;
39 template <cpu_isa_t isa>
40 void jit_uni_bin_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in, const Operand &op, bool scalar_load) {
41 Xmm xmm_in = Xmm(vmm_in.getIdx());
48 movq(xmm_in, reg_tmp_64);
50 uni_vmovups(vmm_in, op);
55 movsx(reg_tmp_32, op);
56 movq(xmm_in, reg_tmp_64);
58 uni_vpmovsxbd(vmm_in, op);
63 movzx(reg_tmp_32, op);
64 movq(xmm_in, reg_tmp_64);
66 uni_vpmovzxbd(vmm_in, op);
69 default: assert(!"unsupported data type");
72 if (type_in != data_type::f32)
73 uni_vcvtdq2ps(vmm_in, vmm_in);
76 template <cpu_isa_t isa>
77 void jit_uni_bin_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
78 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
79 Xmm xmm_dst = Xmm(vmm_dst.getIdx());
85 movq(reg_tmp_64, xmm_dst);
88 uni_vmovups(op, vmm_dst);
92 uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
94 if (isa != sse42 && !scalar_store)
95 vpermq(ymm_dst, ymm_dst, 0x08);
97 uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
100 movq(reg_tmp_64, xmm_dst);
111 uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
113 if (isa != sse42 && !scalar_store)
114 vpermq(ymm_dst, ymm_dst, 0x08);
116 uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
119 movq(reg_tmp_64, xmm_dst);
130 assert(!"unknown dst_dt");
134 template <cpu_isa_t isa>
135 void jit_uni_bin_conv_fwd_kernel<isa>::apply_filter(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step,
136 int ic_blocks, bool last_icb, bool h_padded)
140 int stride_w = jcp.stride_w;
141 int dilate_w = jcp.dilate_w + 1;
142 int ic_blk = jcp.ic_block;
143 int oc_blk = jcp.oc_block;
145 int repeats = isa == sse42 && oc_step > (oc_blk / 2) ? 2 : 1;
148 for (int ki = 0; ki < kw; ki++) {
149 int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
150 int jj_end = ur_w - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w));
152 int _start = (!jcp.exclude_pad) ? 0 : jj_start;
153 int _end = (!jcp.exclude_pad) ? ur_w : jj_end;
155 for (int ifm2 = 0; ifm2 < ic_blocks; ifm2++) {
156 for (int jj = _start; jj < _end; jj++) {
157 int inp_off = ((ki*dilate_w + jj*stride_w - pad_l)*div_up(jcp.ic, nbits) + ifm2 * div_up(ic_blk, nbits)) * jcp.typesize_in;
159 if (h_padded || jj < jj_start || jj >= jj_end) {
160 uni_vmovups(vmm_src, ptr[reg_table + 256]);
162 uni_vpbroadcastd(vmm_src, ptr[aux1_reg_input + inp_off]);
165 for (int r = 0; r < repeats; r++) {
166 for (int ii = 0; ii < oc_blocks; ii++) {
167 int ker_off = (ifm2 * kw * div_up(ic_blk, nbits) * oc_blk
168 + ii * jcp.nb_ic * div_up(ic_blk, nbits) * kh * kw * oc_blk
169 + ki * div_up(ic_blk, nbits) * oc_blk + r * div_up(ic_blk, nbits) * (oc_blk / 2)) * jcp.typesize_in;
171 uni_vmovups(vmm_tmp, ptr[aux1_reg_kernel + ker_off]);
173 uni_vpxor(vmm_tmp, vmm_tmp, vmm_src);
174 if (jcp.ic_padded != jcp.ic && last_icb && ifm2 == (ic_blocks - 1))
175 uni_vandps(vmm_tmp, vmm_tmp, ptr[reg_table + 224]);
178 movups(vmm_tmp1, vmm_tmp);
179 pand(vmm_tmp1, vmm_mask);
181 uni_vandps(vmm_tmp1, vmm_mask, vmm_tmp);
184 uni_vpsrld(vmm_tmp, vmm_tmp, 4);
185 uni_vandps(vmm_tmp, vmm_tmp, vmm_mask);
188 movups(vmm_tmp2, vmm_lookup);
189 pshufb(vmm_tmp2, vmm_tmp);
190 movups(vmm_tmp, vmm_lookup);
191 pshufb(vmm_tmp, vmm_tmp1);
192 paddb(vmm_tmp, vmm_tmp2);
194 uni_vpshufb(vmm_tmp, vmm_lookup, vmm_tmp);
195 uni_vpshufb(vmm_tmp1, vmm_lookup, vmm_tmp1);
196 uni_vpaddb(vmm_tmp, vmm_tmp, vmm_tmp1);
199 uni_vpmaddubsw(vmm_tmp, vmm_tmp, vmm_one_u8);
200 uni_vpmaddwd(vmm_tmp, vmm_tmp, vmm_one_s16);
201 uni_vpaddd(Vmm(1 + r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj),
202 Vmm(1 + r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj), vmm_tmp);
210 template <cpu_isa_t isa>
211 void jit_uni_bin_conv_fwd_kernel<isa>::oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step, bool h_padded) {
215 int inp_mult = div_up(jcp.ic_block, nbits);
216 int out_mult = jcp.oc_block;
221 mov(aux1_reg_input, aux_reg_input);
222 mov(aux1_reg_kernel, aux_reg_kernel);
224 mov(reg_icb_iter, jcp.nb_ic);
227 cmp(reg_icb_iter, 1);
228 jle(icb_tail, T_NEAR);
230 apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, 1, false, h_padded);
232 add(aux1_reg_input, inp_mult * jcp.typesize_in);
233 add(aux1_reg_kernel, kw * inp_mult * out_mult * jcp.typesize_in);
234 sub(reg_icb_iter, 1);
235 jmp(icb_main_loop, T_NEAR);
240 apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, 1, true, h_padded);
243 template <cpu_isa_t isa>
244 void jit_uni_bin_conv_fwd_kernel<isa>::kh_loop(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step) {
247 int dilate_h = jcp.dilate_h + 1;
250 const int inp_mult = dilate_h * div_up(jcp.ic, nbits);
252 Label t_overflow_label, no_t_overflow_label,
253 b_overflow_label, no_b_overflow_label;
255 mov(aux_reg_input, reg_input);
256 mov(aux_reg_kernel, reg_kernel_base);
258 uni_vmovups(vmm_lookup, ptr[reg_table]);
259 uni_vmovups(vmm_mask, ptr[reg_table + 32]);
260 uni_vmovups(vmm_one_u8, ptr[reg_table + 160]);
261 uni_vmovups(vmm_one_s16, ptr[reg_table + 192]);
263 if (!jcp.exclude_pad) {
264 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
265 cmp(reg_overflow, 0);
266 je(no_t_overflow_label, T_NEAR);
267 L(t_overflow_label); {
268 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
270 add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
272 cmp(reg_overflow, 0);
273 jg(t_overflow_label, T_NEAR);
275 L(no_t_overflow_label);
279 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
280 if (!jcp.exclude_pad || (jcp.exclude_pad &&
281 (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
283 je(skip_kh_loop, T_NEAR);
289 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, false);
291 add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
292 add(aux_reg_input, jcp.typesize_in * iw * inp_mult);
296 jg(kh_label, T_NEAR);
301 if (!jcp.exclude_pad) {
302 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
303 cmp(reg_overflow, 0);
304 je(no_b_overflow_label, T_NEAR);
305 L(b_overflow_label); {
306 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
308 add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
310 cmp(reg_overflow, 0);
311 jg(b_overflow_label, T_NEAR);
313 L(no_b_overflow_label);
317 template <cpu_isa_t isa>
318 void jit_uni_bin_conv_fwd_kernel<isa>::width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step)
321 int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
323 for (int r = 0; r < repeats; r++)
324 for (int ii = 0; ii < oc_blocks; ii++)
325 for (int jj = 0; jj < ur_w; jj++)
326 uni_vpxor(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
327 Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
328 Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj));
330 kh_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
332 const auto &p = attr_.post_ops_;
333 for (int r = 0; r < repeats; r++) {
334 int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
335 bool is_scalar_store = isa == sse42 ? tail_size < jcp.oc_block / 2 : tail_size < jcp.oc_block;
337 int kw_padding[ur_w];
338 if (jcp.exclude_pad) {
339 mov(reg_tmp_32, jcp.ic);
340 imul(reg_tmp_32, ptr[param1 + GET_OFF(kh_padding)]);
342 for (int jj = 0; jj < ur_w; jj++)
345 for (int ki = 0; ki < jcp.kw; ki++) {
346 int jj_start = nstl::max(0, div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
347 int jj_end = ur_w - nstl::max(0, div_up(ki * (jcp.dilate_w + 1) + pad_r -
348 (jcp.kw - 1) * (jcp.dilate_w + 1), jcp.stride_w));
349 for (int jj = jj_start; jj < jj_end; jj++) {
354 uni_vmovups(vmm_shift, ptr[reg_table + 128]);
356 uni_vmovups(vmm_scale, ptr[reg_table + 96]);
358 for (int jj = 0; jj < ur_w; jj++) {
359 if (jcp.exclude_pad) {
360 mov(reg_shift, kw_padding[jj]);
361 imul(reg_shift, reg_tmp_32);
362 movq(Xmm(vmm_shift.getIdx()), reg_shift);
363 uni_vbroadcastss(vmm_shift, Xmm(vmm_shift.getIdx()));
364 uni_vcvtdq2ps(vmm_shift, vmm_shift);
367 for (int ii = 0; ii < oc_blocks; ii++) {
368 uni_vcvtdq2ps(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj));
369 uni_vfmadd213ps(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), vmm_scale, vmm_shift);
373 int eltwise_inj_idx = 0;
374 int depthwise_inj_idx = 0;
375 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
376 for (int i = 0; i < end_idx; i++) {
377 int start_idx = 1 + r * jcp.ur_w * jcp.nb_oc_blocking;
379 auto& post_op = p.entry_[i];
380 if (post_op.is_eltwise()) {
381 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + oc_blocks * ur_w);
383 } else if (post_op.is_depthwise()) {
386 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
387 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
389 add(reg_d_weights, reg_oc_off);
390 add(reg_d_bias, reg_oc_off);
393 add(reg_d_weights, (jcp.oc_block / 2) * sizeof(float));
394 add(reg_d_bias, (jcp.oc_block / 2) * sizeof(float));
397 for (int ii = 0; ii < oc_blocks; ii++) {
398 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(start_idx + ur_w * ii,
399 start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
401 add(reg_d_weights, jcp.oc_block * sizeof(float));
402 add(reg_d_bias, jcp.oc_block * sizeof(float));
408 } else if (post_op.is_sum(false)) {
409 for (int ii = 0; ii < oc_blocks; ii++) {
410 for (int jj = 0; jj < ur_w; jj++) {
411 Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
413 if (is_scalar_store) {
414 for (int oc = 0; oc < tail_size; oc++) {
415 int o_off = jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2) + oc;
417 uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
418 cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
420 if (oc < jcp.oc_block / 2) {
421 uni_vpslldq(vmm_sum, vmm_sum, oc * sizeof(float));
423 Ymm ymm_prev_dst = Ymm(vmm_sum.getIdx());
424 vperm2i128(ymm_prev_dst, ymm_prev_dst, ymm_prev_dst, 0x01);
425 vpslldq(vmm_sum, vmm_sum, (oc - jcp.oc_block / 2) * sizeof(float));
428 uni_vaddps(vmm_dst, vmm_dst, vmm_sum);
431 size_t o_off = ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
433 cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], false);
434 uni_vaddps(vmm_dst, vmm_dst, vmm_sum);
442 if (jcp.with_binarization) {
443 int binarization_idx = p.find(primitive_kind::binarization);
447 mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
448 add(reg_b_weights, reg_oc_off);
452 for (int ii = 0; ii < oc_blocks; ii++) {
453 for (int jj = 0; jj < ur_w; jj++) {
454 for (int r = 0; r < repeats; r++) {
455 int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
456 mov(reg_b_mask, (1 << tail_size) - 1);
457 uni_vmovups(vmm_thr, ptr[reg_b_weights + (ii * jcp.oc_block + r * (jcp.oc_block / 2)) * sizeof(float)]);
459 Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
461 uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
464 uni_vmovmskps(reg_tmp_32, vmm_dst);
465 and_(reg_tmp_64, reg_b_mask);
467 uni_vmovmskps(reg_tmp2_32, vmm_dst);
468 and_(reg_tmp2_64, reg_b_mask);
470 or_(reg_tmp_32, reg_tmp2_32);
473 if (r == repeats - 1) {
474 const size_t o_off = (ii + jj * div_up(jcp.oc, nbits));
475 mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
481 for (int r = 0; r < repeats; r++) {
482 int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
483 bool is_scalar_store = isa == sse42 ? tail_size < jcp.oc_block / 2 : tail_size < jcp.oc_block;
484 if (is_scalar_store) {
485 for (int jj = 0; jj < ur_w; jj++) {
486 Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + jj);
487 Ymm ymm_dst = Ymm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + jj);
489 for (int oc = 0; oc < tail_size; oc++) {
491 if (jcp.with_dw_conv)
492 o_off = jj * jcp.oc_block + oc + r * (jcp.oc_block / 2);
494 o_off = jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2) + oc;
496 store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
499 psrldq(vmm_dst, jcp.typesize_out);
501 vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
502 vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
507 for (int ii = 0; ii < oc_blocks; ii++) {
508 for (int jj = 0; jj < ur_w; jj++) {
509 Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
512 if (jcp.with_dw_conv)
513 o_off = ((size_t) ii * jcp_dw_conv.kh * jcp.ow + jj) * jcp.oc_block +
514 r * (jcp.oc_block / 2);
516 o_off = ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
518 store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
526 template <cpu_isa_t isa>
527 inline void jit_uni_bin_conv_fwd_kernel<isa>::solve_common(int oc_blocks, int oc_step)
530 int ur_w_tail = jcp.ur_w_tail;
531 int n_oi = jcp.ow / ur_w;
534 int dilate_w = jcp.dilate_w + 1;
535 int str_w = jcp.stride_w;
538 const int inp_mult = div_up(jcp.ic, nbits);
539 const int out_mult = jcp.with_dw_conv ? jcp.oc_block : jcp.with_binarization ? div_up(jcp.oc, nbits) : jcp.oc;
541 int l_pad = jcp.l_pad;
542 int r_pad = nstl::max(0, (jcp.ow - 1) * str_w + (kw - 1) * dilate_w
544 int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
546 if (r_pad1 > 0) n_oi--;
548 mov(reg_input, reg_input_base);
549 mov(reg_output, reg_output_base);
551 push(reg_input_base);
552 push(reg_output_base);
558 if (n_oi < 0 && r_pad1 > 0)
559 width_blk_step(ur_w, l_pad, r_pad1, oc_blocks, oc_step); // "lrpad"
561 width_blk_step(ur_w, l_pad, 0, oc_blocks, oc_step); // "lpad"
562 add(reg_input, jcp.typesize_in * (ur_w * str_w - l_pad) * inp_mult);
563 add(reg_output, jcp.typesize_out * ur_w * out_mult);
567 xor_(oi_iter, oi_iter);
572 width_blk_step(ur_w, 0, 0, oc_blocks, oc_step); // "middle"
573 add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
574 add(reg_output, jcp.typesize_out * ur_w * out_mult);
578 jl(ow_loop_label, T_NEAR);
581 if (r_pad1 > 0 && n_oi >=0) {
582 width_blk_step(ur_w, 0, r_pad1, oc_blocks, oc_step); // "rpad"
583 add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
584 add(reg_output, jcp.typesize_out * ur_w * out_mult);
588 width_blk_step(ur_w_tail, 0, r_pad, oc_blocks, oc_step); // "tail"
592 pop(reg_output_base);
596 template <cpu_isa_t isa>
597 void jit_uni_bin_conv_fwd_kernel<isa>::generate()
599 const auto &p = attr_.post_ops_;
600 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
601 for (int i = 0; i < end_idx; i++) {
602 auto &post_op = p.entry_[i];
603 if (post_op.is_eltwise()) {
604 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
607 post_op.eltwise.alpha,
610 } else if (post_op.is_depthwise()) {
611 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
613 post_op.depthwise.alg
620 mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
621 mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
622 mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
624 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
625 mov(reg_oc_work, ptr[this->param1 + GET_OFF(oc_work)]);
627 mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
628 mov(reg_table, l_table);
630 Label main_loop_label;
634 cmp(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
635 jne(main_loop_label, T_NEAR);
637 solve_common(jcp.nb_oc_blocking, jcp.oc_block);
639 sub(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
641 jmp(exit_label, T_NEAR);
645 L(main_loop_label); {
646 cmp(reg_oc_work, jcp.oc_block);
647 jl(tail_label, T_NEAR);
649 solve_common(1, jcp.oc_block);
651 sub(reg_oc_work, jcp.oc_block);
652 add(reg_kernel_base, jcp.oc_block * jcp.nb_ic * jcp.kh * jcp.kw * div_up(jcp.ic_block, nbits) * jcp.typesize_in);
654 if (jcp.with_dw_conv) {
655 add(reg_output_base, jcp.oc_block * jcp_dw_conv.kh * jcp.ow * jcp.typesize_out);
657 if (jcp.with_binarization)
658 add(reg_output_base, jcp.typesize_out);
660 add(reg_output_base, jcp.oc_block * jcp.typesize_out);
663 add(reg_oc_off, jcp.oc_block * sizeof(float));
665 jmp(main_loop_label, T_NEAR);
670 if (jcp.oc % jcp.oc_block != 0)
671 solve_common(1, jcp.oc % jcp.oc_block);
679 for (auto& inj : eltwise_injectors)
680 inj->prepare_table();
683 template <cpu_isa_t isa>
684 void jit_uni_bin_conv_fwd_kernel<isa>::prepare_table() {
685 const unsigned int cvals[] = {
686 0x02010100, // 0 1 1 2
687 0x03020201, // 1 2 2 3
688 0x03020201, // 1 2 2 3
689 0x04030302, // 2 3 3 4
690 0x02010100, // 0 1 1 2
691 0x03020201, // 1 2 2 3
692 0x03020201, // 1 2 2 3
693 0x04030302, // 2 3 3 4
704 for (size_t d = 0; d < 8; ++d) {
708 for (size_t d = 0; d < 8; ++d) {
712 for (size_t d = 0; d < 8; ++d) {
716 for (size_t d = 0; d < 8; ++d) {
721 for (size_t d = 0; d < 8; ++d) {
722 dd(float2int(jcp.ic * jcp.kw * jcp.kh));
726 for (size_t d = 0; d < 8; ++d) {
730 for (size_t d = 0; d < 8; ++d) {
734 for (size_t d = 0; d < 8; ++d) {
735 uint32_t mask = 0xffffffff >> (jcp.ic_padded - jcp.ic);
739 for (size_t d = 0; d < 8; ++d) {
740 uint32_t val = jcp.pad_value == 1.0f ? 0xffffffff : 0x00000000;
745 template <cpu_isa_t isa>
746 bool jit_uni_bin_conv_fwd_kernel<isa>::post_ops_ok(jit_bin_conv_conf_t &jcp, const primitive_attr_t &attr) {
747 const auto &p = attr.post_ops_;
749 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
750 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
751 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
752 auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
753 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
754 auto is_binarization = [&](int idx) { return p.entry_[idx].is_binarization(); };
757 case 0: return true; // no post_ops
759 return (is_simple(0) || is_sum(0) || is_dw_conv(0) || is_binarization(0));
761 return ((is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_simple(1)) ||
762 (is_simple(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
763 (is_simple(0) && is_simple(1)) || (is_simple(0) && is_binarization(1)) ||
764 (is_dw_conv(0) && is_binarization(1)) || (is_simple(0) && is_sum(1)));
766 return ((is_simple(0) && is_dw_conv(1) && is_simple(2)) ||
767 (is_dw_conv(0) && is_sum(1) && is_simple(2)) ||
768 (is_sum(0) && is_simple(1) && is_simple(2)) ||
769 (is_simple(0) && is_sum(1) && is_simple(2)) ||
770 (is_simple(0) && is_dw_conv(1) && is_binarization(2)) ||
771 (is_simple(0) && is_simple(1) && is_dw_conv(2)));
772 case 4: return ((is_simple(0) && is_dw_conv(1) && is_sum(2) && is_simple(3)) ||
773 (is_simple(0) && is_dw_conv(1) && is_simple(2) && is_binarization(3)) ||
774 (is_simple(0) && is_simple(1) && is_dw_conv(2) && is_binarization(3)) ||
775 (is_simple(0) && is_simple(1) && is_simple(2) && is_binarization(3)) ||
776 (is_simple(0) && is_simple(1) && is_dw_conv(2) && is_simple(3)));
777 default: return false;
783 template <cpu_isa_t isa>
784 status_t jit_uni_bin_conv_fwd_kernel<isa>::init_conf(jit_bin_conv_conf_t &jcp,
785 const binary_convolution_desc_t &cd, const memory_desc_wrapper &src_d,
786 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, const primitive_attr_t &attr)
788 if (!mayiuse(isa)) return status::unimplemented;
790 jcp.prop_kind = cd.prop_kind;
792 jcp.dst_dt = cd.dst_desc.data_type;
794 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
796 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
798 if (jcp.ngroups != 1)
799 return status::unimplemented;
801 jcp.mb = src_d.dims()[0];
803 int simd_w = isa == avx512_common ? 16 : 8;
805 jcp.ic = src_d.dims()[1] / jcp.ngroups;
806 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
808 jcp.oc_padded = rnd_up(jcp.oc, simd_w);
810 jcp.ih = src_d.dims()[2];
811 jcp.iw = src_d.dims()[3];
812 jcp.oh = dst_d.dims()[2];
813 jcp.ow = dst_d.dims()[3];
815 jcp.kh = weights_d.dims()[with_groups + 2];
816 jcp.kw = weights_d.dims()[with_groups + 3];
818 jcp.t_pad = cd.padding[0][0];
819 jcp.l_pad = cd.padding[0][1];
821 jcp.stride_h = cd.strides[0];
822 jcp.stride_w = cd.strides[1];
824 jcp.dilate_h = cd.dilates[0];
825 jcp.dilate_w = cd.dilates[1];
827 jcp.src_fmt = src_d.format();
829 if (!post_ops_ok(jcp, attr))
830 return status::unimplemented;
832 jcp.pad_value = cd.pad_value;
833 jcp.exclude_pad = jcp.pad_value == 0.0f;
835 const auto &p = attr.post_ops_;
836 int dw_conv_ind = p.find(primitive_kind::convolution);
837 jcp.with_dw_conv = dw_conv_ind != -1;
838 if (jcp.with_dw_conv) {
839 jcp.dw_conv_oh = jcp.oh;
840 jcp.dw_conv_ow = jcp.ow;
841 jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
842 jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
844 jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
845 jcp.with_binarization = p.find(primitive_kind::binarization, 0, dw_conv_ind) != -1;
848 return status::unimplemented;
850 auto desired_weights_format = isa == avx512_common ? OhIw16o32i : OhIw8o32i;
852 && src_d.format() == nhwc
853 && weights_d.format() == desired_weights_format
854 && dst_d.format() == nhwc;
855 if (!args_ok) return status::unimplemented;
857 jcp.ur_h = 1; /* no code-unrolling by h so far */
859 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
860 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
862 jcp.nb_oc_blocking = isa == sse42 ? 2 : 4; /* the optimal value for the kernel */
865 && jcp.l_pad <= jcp.ur_w
866 && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
867 || (jcp.stride_w == 1 && jcp.stride_h == 1));
868 if (!args_ok) return status::unimplemented;
870 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
871 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
873 if (r_pad_no_tail > jcp.ur_w) {
874 /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
875 jcp.ur_w = r_pad_no_tail + 1;
876 jcp.nb_oc_blocking = ((16 - 1)-jcp.ur_w)/jcp.ur_w;
877 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
878 /* check again ... */
879 r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
880 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
881 if ((r_pad_no_tail > jcp.ur_w) || (jcp.ow < jcp.ur_w))
882 return status::unimplemented;
884 if (jcp.l_pad > jcp.ur_w) return status::unimplemented;
887 jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
888 jcp.ic_padded = rnd_up(jcp.ic, jcp.ic_block);
890 jcp.oc_block = simd_w;
891 jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
893 jcp.nb_ic_blocking = 1;
895 jcp.src_dt = cd.src_desc.data_type;
896 jcp.bia_dt = mkldnn_f32;
897 jcp.dst_dt = jcp.with_binarization ? mkldnn_bin : mkldnn_f32;
899 jcp.typesize_in = types::data_type_size(jcp.src_dt);
900 jcp.typesize_out = types::data_type_size(jcp.dst_dt);
901 jcp.typesize_acc = sizeof(int32_t);
903 return status::success;
906 template <cpu_isa_t isa>
907 void jit_uni_bin_conv_fwd_kernel<isa>::init_scratchpad(
908 memory_tracking::registrar_t &scratchpad, const jit_bin_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw_conv) {
909 if (jcp.with_dw_conv) {
910 const int nthreads = mkldnn_get_max_threads();
911 size_t dw_conv_buffer_size_ = (size_t)jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block * jcp.nb_oc_blocking;
912 scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
914 if (jcp.oc != jcp.oc_padded)
915 scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc_padded);
919 template struct jit_uni_bin_conv_fwd_kernel<sse42>;
920 template struct jit_uni_bin_conv_fwd_kernel<avx2>;
921 template struct jit_uni_bin_conv_fwd_kernel<avx512_common>;