1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <common/primitive_attr.hpp>
18 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
22 #include "cpu_memory.hpp"
24 #include "jit_uni_planar_conv_kernel_f32.hpp"
25 #include "cpu_isa_traits.hpp"
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
33 using namespace mkldnn::impl::prop_kind;
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::utils;
37 using namespace Xbyak;
39 template <cpu_isa_t isa>
40 void jit_uni_planar_conv_fwd_kernel_f32<isa>::load_src_scalar(int ur_h) {
41 Label init_done_label;
42 Label init_first_label;
44 mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
46 mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
49 test(reg_ci_flag, FLAG_IC_FIRST);
50 jne(init_first_label, T_NEAR);
53 for (int kk = 0; kk < ur_h; kk++) {
54 size_t offt = sizeof(float) * (kk * jcp.ow * jcp.oh_block_step);
55 movss(Xmm(kk), make_safe_addr(reg_output, offt, reg_long_offt));
58 if (jcp.with_sum && jcp.with_bias) {
59 test(reg_ci_flag, FLAG_IC_FIRST);
60 je(init_done_label, T_NEAR);
62 movss(xmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt));
63 for (int kk = 0; kk < ur_h; kk++) {
64 uni_vaddps(Vmm(kk), Vmm(kk), vmm_tmp);
68 jmp(init_done_label, T_NEAR);
71 if (this->jcp.with_bias) {
72 movss(xmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt));
73 for (int kk = 0; kk < ur_h; kk++) {
74 uni_vmovups(Vmm(kk), vmm_tmp);
77 for (int kk = 0; kk < ur_h; kk++) {
78 uni_vpxor(Vmm(kk), Vmm(kk), Vmm(kk));
85 template <cpu_isa_t isa>
86 void jit_uni_planar_conv_fwd_kernel_f32<isa>::filter_scalar(int ur_h) {
87 Label iter_exit_label;
92 int dilate_w = jcp.dilate_w + 1;
93 int ic_blk = jcp.ic_block;
99 je(iter_exit_label, T_NEAR);
101 mov(aux_reg_input_w, aux_reg_input_h);
102 mov(aux_reg_kernel_w, aux_reg_kernel_h);
103 mov(kw_iter, reg_kw);
108 for (size_t ifm2 = 0; ifm2 < (size_t)ic_blk; ifm2++) {
109 for (int kk = 0; kk < ur_h; kk++) {
110 size_t inp_off = sizeof(float) * (ifm2 * id * ih * iw + kk * jcp.iw * jcp.oh_block_step);
111 movss(xmm_src, make_safe_addr(aux_reg_input_w, inp_off, reg_long_offt));
113 size_t ker_off = sizeof(float) * (ifm2 * kd * kh * kw);
114 movss(xmm_ker, ptr[aux_reg_kernel_w + ker_off]);
116 uni_vfmadd231ps(Vmm(kk), vmm_src, vmm_ker);
120 add(aux_reg_kernel_w, sizeof(float));
121 add(aux_reg_input_w, dilate_w * sizeof(float));
125 jg(kw_label, T_NEAR);
131 template <cpu_isa_t isa>
132 void jit_uni_planar_conv_fwd_kernel_f32<isa>::apply_filter_scalar(int ur_h) {
135 int dilate_h = jcp.dilate_h + 1;
136 int dilate_d = jcp.dilate_h + 1;
137 const int inp_mult_h = dilate_h;
138 const int inp_mult_d = dilate_d;
140 Label skip_kh_loop, skip_kd_loop, kd_label;
141 if (jcp.ndims == 5) {
145 mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
146 mov(aux_reg_ker_d, aux_reg_kernel_h);
147 mov(aux_reg_inp_d, aux_reg_input_h);
150 je(skip_kd_loop, T_NEAR);
153 mov(kh_iter, ptr[param1 + GET_OFF(kh_padding)]);
155 mov(kh_iter, reg_kh);
158 if (jcp.ndims == 5) {
159 mov(aux_reg_input_h, aux_reg_inp_d);
160 mov(aux_reg_kernel_h, aux_reg_ker_d);
164 je(skip_kh_loop, T_NEAR);
171 add(aux_reg_kernel_h, sizeof(float) * kw);
172 add(aux_reg_input_h, sizeof(float) * iw * inp_mult_h);
176 jg(kh_label, T_NEAR);
181 if (jcp.ndims == 5) {
182 add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh);
183 add(aux_reg_inp_d, sizeof(float) * jcp.ih * jcp.iw * inp_mult_d);
187 jg(kd_label, T_NEAR);
195 template <cpu_isa_t isa>
196 void jit_uni_planar_conv_fwd_kernel_f32<isa>::apply_postprocess_scalar(int ur_h) {
197 Label regular_store_label;
199 mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
200 test(reg_ci_flag, FLAG_IC_LAST);
201 je(regular_store_label, T_NEAR);
203 int eltwise_inj_idx = 0;
204 const auto &p = attr_.post_ops_;
207 for (int i = 0; i < p.len_; i++) {
208 auto& post_op = p.entry_[i];
209 if (post_op.is_eltwise()) {
210 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur_h);
215 L(regular_store_label);
218 template <cpu_isa_t isa>
219 void jit_uni_planar_conv_fwd_kernel_f32<isa>::store_dst_scalar(int ur_h) {
220 for (int kk = 0; kk < ur_h; kk++) {
221 size_t o_off = sizeof(float) * (kk * jcp.ow * jcp.oh_block_step);
222 movss(make_safe_addr(reg_output, o_off, reg_long_offt), Xmm(kk));
226 template <cpu_isa_t isa>
227 void jit_uni_planar_conv_fwd_kernel_f32<isa>::load_src(int ur_h, int ur_w) {
228 Label init_done_label;
229 Label init_first_label;
231 mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
233 mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
236 test(reg_ci_flag, FLAG_IC_FIRST);
237 jne(init_first_label, T_NEAR);
240 for (int kk = 0; kk < ur_h; kk++) {
241 for (int jj = 0; jj < ur_w; jj++) {
242 size_t offt = sizeof(float) * (jj * jcp.ow_block + kk * jcp.ow * jcp.oh_block_step);
243 uni_vmovups(Vmm(kk * ur_w + jj), make_safe_addr(reg_output, offt, reg_long_offt));
247 if (jcp.with_sum && jcp.with_bias) {
248 test(reg_ci_flag, FLAG_IC_FIRST);
249 je(init_done_label, T_NEAR);
251 uni_vbroadcastss(vmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt));
252 for (int kk = 0; kk < ur_h; kk++) {
253 for (int jj = 0; jj < ur_w; jj++) {
254 uni_vaddps(Vmm(kk * ur_w + jj), Vmm(kk * ur_w + jj), vmm_tmp);
259 jmp(init_done_label, T_NEAR);
262 if (this->jcp.with_bias) {
263 uni_vbroadcastss(vmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt));
264 for (int kk = 0; kk < ur_h; kk++) {
265 for (int jj = 0; jj < ur_w; jj++) {
266 uni_vmovups(Vmm(kk * ur_w + jj), vmm_tmp);
270 for (int kk = 0; kk < ur_h; kk++) {
271 for (int jj = 0; jj < ur_w; jj++) {
272 uni_vpxor(Vmm(kk * ur_w + jj), Vmm(kk * ur_w + jj), Vmm(kk * ur_w + jj));
280 template <cpu_isa_t isa>
281 void jit_uni_planar_conv_fwd_kernel_f32<isa>::filter_unrolled(int ur_h, int ur_w) {
285 int stride_w = jcp.stride_w;
286 int dilate_w = jcp.dilate_w + 1;
287 int ic_blk = jcp.ic_block;
291 int ow_blk = jcp.ow_block;
293 for (int ki = 0; ki < kw; ki++) {
294 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
295 for (int kk = 0; kk < ur_h; kk++) {
296 for (int jj = 0; jj < ur_w; jj++) {
297 size_t inp_off = sizeof(float) * ((size_t) ifm2 * id * ih * iw + ki * dilate_w +
298 jj * stride_w * ow_blk + kk * jcp.ow * jcp.oh_block_step);
299 uni_vmovups(vmm_src, make_safe_addr(aux_reg_input_h, inp_off, reg_long_offt));
301 int ker_off = sizeof(float) * ((size_t) ifm2 * kd * kh * kw + ki);
302 uni_vbroadcastss(vmm_ker, ptr[aux_reg_kernel_h + ker_off]);
304 uni_vfmadd231ps(Vmm(kk * ur_w + jj), vmm_src, vmm_ker);
311 template <cpu_isa_t isa>
312 void jit_uni_planar_conv_fwd_kernel_f32<isa>::filter(int ur_h) {
313 Label iter_exit_label;
318 int dilate_w = jcp.dilate_w + 1;
319 int ic_blk = jcp.ic_block;
325 je(iter_exit_label, T_NEAR);
327 mov(aux_reg_input_w, aux_reg_input_h);
328 mov(aux_reg_kernel_w, aux_reg_kernel_h);
329 mov(kw_iter, reg_kw);
334 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
335 for (int kk = 0; kk < ur_h; kk++) {
336 size_t inp_off = sizeof(float) * ((size_t) ifm2 * id * ih * iw + kk * jcp.ow * jcp.oh_block_step);
337 uni_vmovups(vmm_src, make_safe_addr(aux_reg_input_w, inp_off, reg_long_offt));
339 size_t ker_off = sizeof(float) * ((size_t) ifm2 * kd * kh * kw);
340 uni_vbroadcastss(vmm_ker, ptr[aux_reg_kernel_w + ker_off]);
342 uni_vfmadd231ps(Vmm(kk), vmm_src, vmm_ker);
346 add(aux_reg_kernel_w, sizeof(float));
347 add(aux_reg_input_w, dilate_w * sizeof(float));
351 jg(kw_label, T_NEAR);
357 template <cpu_isa_t isa>
358 void jit_uni_planar_conv_fwd_kernel_f32<isa>::apply_filter(int ur_h, int ur_w) {
361 int dilate_h = jcp.dilate_h + 1;
362 int dilate_d = jcp.dilate_h + 1;
363 const int inp_mult_h = dilate_h;
364 const int inp_mult_d = dilate_d;
366 Label skip_kh_loop, skip_kd_loop, kd_label;
367 if (jcp.ndims == 5) {
371 mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
372 mov(aux_reg_ker_d, aux_reg_kernel_h);
373 mov(aux_reg_inp_d, aux_reg_input_h);
376 je(skip_kd_loop, T_NEAR);
379 mov(kh_iter, ptr[param1 + GET_OFF(kh_padding)]);
381 mov(kh_iter, reg_kh);
384 if (jcp.ndims == 5) {
385 mov(aux_reg_input_h, aux_reg_inp_d);
386 mov(aux_reg_kernel_h, aux_reg_ker_d);
390 je(skip_kh_loop, T_NEAR);
395 if (ur_w == jcp.nb_ow_blocking)
396 filter_unrolled(ur_h, ur_w);
400 add(aux_reg_kernel_h, sizeof(float) * kw);
401 add(aux_reg_input_h, sizeof(float) * iw * inp_mult_h);
405 jg(kh_label, T_NEAR);
410 if (jcp.ndims == 5) {
411 add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh);
412 add(aux_reg_inp_d, sizeof(float) * jcp.ih * jcp.iw * inp_mult_d);
416 jg(kd_label, T_NEAR);
424 template <cpu_isa_t isa>
425 void jit_uni_planar_conv_fwd_kernel_f32<isa>::apply_postprocess(int ur_h, int ur_w) {
426 Label regular_store_label;
428 mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
429 test(reg_ci_flag, FLAG_IC_LAST);
430 je(regular_store_label, T_NEAR);
432 int eltwise_inj_idx = 0;
433 const auto &p = attr_.post_ops_;
435 for (int i = 0; i < p.len_; i++) {
436 auto& post_op = p.entry_[i];
437 if (post_op.is_eltwise()) {
438 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur_w * ur_h);
443 L(regular_store_label);
446 template <cpu_isa_t isa>
447 void jit_uni_planar_conv_fwd_kernel_f32<isa>::store_dst(int ur_h, int ur_w) {
448 for (int kk = 0; kk < ur_h; kk++) {
449 for (int jj = 0; jj < ur_w; jj++) {
450 size_t o_off = sizeof(float) * (jj * jcp.ow_block + kk * jcp.ow * jcp.oh_block_step);
451 uni_vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), Vmm(kk * ur_w + jj));
456 template <cpu_isa_t isa>
457 void jit_uni_planar_conv_fwd_kernel_f32<isa>::solve_common(int ur_h) {
458 auto solve_loop = [&](int ur_w, int step_w) {
465 load_src_scalar(ur_h);
466 apply_filter_scalar(ur_h);
467 apply_postprocess_scalar(ur_h);
468 store_dst_scalar(ur_h);
470 load_src(ur_h, ur_w);
471 apply_filter(ur_h, ur_w);
472 apply_postprocess(ur_h, ur_w);
473 store_dst(ur_h, ur_w);
476 add(reg_input, sizeof(float) * step_w * jcp.stride_w);
477 add(reg_output, sizeof(float) * step_w);
483 Label left_border_label;
484 Label main_loop_unrolled_label;
485 Label main_loop_label;
486 Label right_border_label;
489 xor_(reg_ow, reg_ow);
490 sub(reg_input, sizeof(float) * jcp.l_pad);
492 auto adjust_indexes_left = [&]() {
493 Label border_indexes_label;
494 Label border_indexes_exit_label;
496 mov(reg_wj, jcp.l_pad);
498 L(border_indexes_label);
501 jle(border_indexes_exit_label, T_NEAR);
503 add(aux_reg_kernel_h, sizeof(float));
504 add(aux_reg_input_h, sizeof(float) * (jcp.dilate_w + 1));
506 sub(reg_wj, jcp.dilate_w + 1);
508 jmp(border_indexes_label);
510 L(border_indexes_exit_label);
514 auto adjust_indexes_right = [&]() {
515 Label border_indexes_right_label;
516 Label border_indexes_right_exit_label;
518 imul(reg_wj, reg_ow, jcp.stride_w);
519 add(reg_wj, (jcp.kw-1) * (jcp.dilate_w+1) - jcp.l_pad+1 - jcp.iw);
521 L(border_indexes_right_label);
524 jle(border_indexes_right_exit_label, T_NEAR);
527 sub(reg_wj, jcp.dilate_w + 1);
529 jmp(border_indexes_right_label);
531 L(border_indexes_right_exit_label);
535 int left_border_end = nstl::min(div_up(jcp.l_pad, jcp.stride_w), jcp.ow);
536 L(left_border_label); {
537 cmp(reg_ow, left_border_end);
538 jge(main_loop_unrolled_label, T_NEAR);
540 mov(aux_reg_input_h, reg_input);
541 mov(aux_reg_kernel_h, reg_kernel);
544 adjust_indexes_left();
545 adjust_indexes_right();
547 solve_loop(1, 1); // scalar
550 jmp(left_border_label, T_NEAR);
553 int main_loop_end = (jcp.iw - (jcp.kw - 1)*(jcp.dilate_w + 1) + jcp.l_pad - 1) / jcp.stride_w + 1;
554 L(main_loop_unrolled_label); {
555 cmp(reg_ow, main_loop_end - jcp.nb_ow_blocking * jcp.ow_block);
556 jg(main_loop_label, T_NEAR);
558 mov(aux_reg_input_h, reg_input);
559 mov(aux_reg_kernel_h, reg_kernel);
562 solve_loop(jcp.nb_ow_blocking, jcp.nb_ow_blocking * jcp.ow_block);
564 add(reg_ow, jcp.nb_ow_blocking * jcp.ow_block);
565 jmp(main_loop_unrolled_label, T_NEAR);
568 L(main_loop_label); {
569 cmp(reg_ow, main_loop_end - jcp.ow_block);
570 jg(right_border_label, T_NEAR);
572 mov(aux_reg_input_h, reg_input);
573 mov(aux_reg_kernel_h, reg_kernel);
576 solve_loop(1, jcp.ow_block); // vectorized
578 add(reg_ow, jcp.ow_block);
579 jmp(main_loop_label, T_NEAR);
582 int right_border_end = jcp.ow;
583 L(right_border_label); {
584 cmp(reg_ow, right_border_end);
585 jge(exit_label, T_NEAR);
587 mov(aux_reg_input_h, reg_input);
588 mov(aux_reg_kernel_h, reg_kernel);
591 adjust_indexes_left();
592 adjust_indexes_right();
594 solve_loop(1, 1); // scalar
597 jmp(right_border_label, T_NEAR);
603 template <cpu_isa_t isa>
604 void jit_uni_planar_conv_fwd_kernel_f32<isa>::generate() {
605 const auto &p = attr_.post_ops_;
606 for (int i = 0; i < p.len_; i++) {
607 auto &post_op = p.entry_[i];
608 if (post_op.is_eltwise()) {
609 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
612 post_op.eltwise.alpha,
620 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
621 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
622 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
623 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
624 mov(reg_oh_blocks, ptr[this->param1 + GET_OFF(oh_blocks)]);
633 for (auto& inj : eltwise_injectors)
634 inj->prepare_table();
637 template <cpu_isa_t isa>
638 bool jit_uni_planar_conv_fwd_kernel_f32<isa>::post_ops_ok(
639 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
640 const auto &p = attr.post_ops_;
642 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
643 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
644 auto is_simple = [&](int idx) { return is_eltwise(idx); };
647 case 0: return true; // no post_ops
649 return true // sum OR eltwise OR depthwise
650 && !jcp.with_eltwise && (is_simple(0) || is_sum(0));
652 return true // sum->relu
653 && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) ||
654 (is_simple(0) && is_simple(1)));
656 return true // sum->relu
657 && !jcp.with_eltwise && (is_sum(0) && is_simple(1) && is_simple(2));
658 default: return false;
664 template <cpu_isa_t isa>
665 status_t jit_uni_planar_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
666 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
667 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
668 const primitive_attr_t &attr) {
669 if (!mayiuse(isa)) return status::unimplemented;
671 jcp.prop_kind = cd.prop_kind;
673 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
674 int ndims = src_d.ndims();
677 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
678 jcp.mb = src_d.dims()[0];
680 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
681 jcp.oc_without_padding = jcp.oc;
682 jcp.ic = src_d.dims()[1] / jcp.ngroups;
684 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
685 jcp.ih = src_d.dims()[ndims-2];
686 jcp.iw = src_d.dims()[ndims-1];
687 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
688 jcp.oh = dst_d.dims()[ndims-2];
689 jcp.ow = dst_d.dims()[ndims-1];
690 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
691 jcp.kh = weights_d.dims()[with_groups + ndims-2];
692 jcp.kw = weights_d.dims()[with_groups + ndims-1];
694 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
695 jcp.t_pad = cd.padding[0][ndims-4];
696 jcp.l_pad = cd.padding[0][ndims-3];
697 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
698 jcp.stride_h = cd.strides[ndims-4];
699 jcp.stride_w = cd.strides[ndims-3];
701 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
702 jcp.dilate_h = cd.dilates[ndims-4];
703 jcp.dilate_w = cd.dilates[ndims-3];
705 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
706 - (jcp.ih + jcp.t_pad - 1);
708 jcp.src_fmt = src_d.format();
709 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
710 jcp.with_eltwise = false;
712 if (!post_ops_ok(jcp, attr))
713 return status::unimplemented;
715 const auto &p = attr.post_ops_;
716 jcp.with_sum = p.find(primitive_kind::sum) != -1;
718 const int simd_w = isa == avx512_common ? 16 : 8;
721 && one_of(src_d.format(), nchw, ncdhw)
722 && one_of(weights_d.format(), oihw, oidhw)
723 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
724 && one_of(dst_d.format(), nchw, ncdhw);
725 if (!args_ok) return status::unimplemented;
727 // This convolution implementation was introduced as workaround to provide competitive performance on MSD topology.
728 // The conditions below are needed to bound applicability scope.
729 args_ok = jcp.ngroups == 1 &&
731 jcp.stride_d == 1 && jcp.stride_h == 1 && jcp.stride_w == 1;
733 if (!args_ok) return status::unimplemented;
737 jcp.ow_block = simd_w;
738 jcp.nb_ow_blocking = isa == avx512_common ? 3 : 3;
741 jcp.nb_oh_blocking = 1;
742 jcp.oh_block_step = 1; // (jcp.dilate_h + 1);
745 jcp.nb_oc = jcp.oc / jcp.oc_block;
746 jcp.nb_oc_blocking = 1;
749 jcp.nb_ic = jcp.ic / jcp.ic_block;
750 jcp.nb_ic_blocking = 1;
752 return status::success;
755 template struct jit_uni_planar_conv_fwd_kernel_f32<avx512_common>;
756 template struct jit_uni_planar_conv_fwd_kernel_f32<avx2>;