1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
18 #include "c_types_map.hpp"
19 #include "jit_uni_x8s8s32x_dw_convolution.hpp"
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::utils;
29 template <cpu_isa_t isa, data_type_t src_type, data_type_t dst_type>
30 void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, src_type, dst_type>::execute_forward() const {
31 auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
32 auto weights = reinterpret_cast<const wei_data_t*>(this->input_memory(1));
33 auto bias = reinterpret_cast<const char*>(this->input_memory(2));
34 auto dst = reinterpret_cast<dst_data_t*>(this->memory());
36 const memory_desc_wrapper src_d(pd()->src_pd());
37 const memory_desc_wrapper dst_d(pd()->dst_pd());
38 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
39 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
41 const auto &jcp = kernel_->jcp;
43 int dil_h = jcp.dilate_h + 1;
44 int dil_w = jcp.dilate_w + 1;
45 int str_h = jcp.stride_h;
46 int str_w = jcp.stride_w;
48 const size_t bia_dt_size = pd()->with_bias()
49 ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
51 const auto &oscales = pd()->attr()->output_scales_;
54 int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
55 const size_t work_amount = MB * chb_work * jcp.oh;
57 auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh,
58 int kh_padding, int ch, int ch_num, int n) {
59 auto par_conv = jit_conv_call_s();
61 const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w));
62 const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w
63 + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw;
65 const int iw = nstl::max((ow*str_w - jcp.l_pad
66 + div_up(i_l_overflow, dil_w)*dil_w), 0);
67 const int kw = div_up(i_l_overflow, dil_w);
69 const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w)
70 - div_up(i_r_overflow, dil_w);
72 int src_off = src_d.blk_off(n, ch*jcp.ch_block, ih, iw);
73 int dst_off = dst_d.blk_off(n, ch*jcp.ch_block, oh, ow);
75 par_conv.src = &src[src_off];
76 par_conv.dst = &dst[dst_off];
78 par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)];
79 if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block*bia_dt_size)];
81 par_conv.kh_padding = (size_t)nstl::max(0, kh_padding);
82 par_conv.kw_padding = (size_t)nstl::max(0, kw_padding);
84 par_conv.ur_w = (size_t)ur_w_step;
86 par_conv.ch_work = nstl::min((ch + ch_num) * jcp.ch_block, jcp.oc) - ch*jcp.ch_block;
88 par_conv.scales = &oscales.scales_[jcp.is_oc_scale * ch * jcp.ch_block];
89 par_conv.oc_off = ch * jcp.ch_block * sizeof(float);
94 auto ker = [&](const int ithr, const int nthr) {
95 size_t start{0}, end{0};
96 balance211(work_amount, nthr, ithr, start, end);
98 size_t n{0}, chb{0}, oh{0};
99 nd_iterator_init(start, n, MB, chb, chb_work, oh, jcp.oh);
100 for (size_t iwork = start; iwork < end; ++iwork) {
101 int ch = chb * jcp.nb_ch_blocking;
102 int ch_num = jcp.nb_ch_blocking;
104 const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h));
105 const int i_b_overflow = nstl::max(jcp.ih,
106 (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih;
108 const int ih = nstl::max((int)(oh*str_h - jcp.t_pad
109 + div_up(i_t_overflow, dil_h)*dil_h), 0);
110 const int kh = div_up(i_t_overflow, dil_h);
111 const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h)
112 - div_up(i_b_overflow, dil_h);
116 int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow);
118 for (; ow < l_border; ow++) {
119 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
120 kh, kh_padding, ch, ch_num, n);
122 kernel_->jit_ker(&par_conv);
126 ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1)
127 / jcp.stride_w - ow + 1;
129 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
130 kh, kh_padding, ch, ch_num, n);
132 kernel_->jit_ker(&par_conv);
139 for (; ow < jcp.ow; ow++) {
140 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
141 kh, kh_padding, ch, ch_num, n);
143 kernel_->jit_ker(&par_conv);
146 nd_iterator_step(n, MB, chb, chb_work, oh, jcp.oh);
153 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::u8>;
154 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s8>;
155 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s32>;
156 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::f32>;
158 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::u8>;
159 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s8>;
160 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s32>;
161 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::f32>;
163 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::u8>;
164 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s8>;
165 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s32>;
166 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::f32>;
168 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::u8>;
169 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s8>;
170 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s32>;
171 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::f32>;