1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
18 #include "c_types_map.hpp"
19 #include "jit_uni_x8s8s32x_dw_convolution.hpp"
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::utils;
29 template <cpu_isa_t isa, data_type_t src_type, data_type_t dst_type>
30 void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, src_type, dst_type>::execute_forward() const {
31 auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
32 auto weights = reinterpret_cast<const wei_data_t*>(this->input_memory(1));
33 auto bias = reinterpret_cast<const char*>(this->input_memory(2));
34 auto dst = reinterpret_cast<dst_data_t*>(this->memory());
36 const memory_desc_wrapper src_d(pd()->src_pd());
37 const memory_desc_wrapper dst_d(pd()->dst_pd());
38 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
39 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
41 const auto &jcp = kernel_->jcp;
43 int dil_d = jcp.dilate_d + 1;
44 int dil_h = jcp.dilate_h + 1;
45 int dil_w = jcp.dilate_w + 1;
46 int str_d = jcp.stride_d;
47 int str_h = jcp.stride_h;
48 int str_w = jcp.stride_w;
50 const size_t bia_dt_size = pd()->with_bias()
51 ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
53 const auto &oscales = pd()->attr()->output_scales_;
56 int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
57 const size_t work_amount = MB * chb_work * jcp.od * jcp.oh;
59 auto kernel_params = [&](int ur_w_step, int ow, int oh, int od, int ih, int id, int kh, int kd,
60 int kh_padding, int kd_padding, int ch, int ch_num, int n) {
61 auto par_conv = jit_conv_call_s();
63 const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w));
64 const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w
65 + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw;
67 const int iw = nstl::max((ow*str_w - jcp.l_pad
68 + div_up(i_l_overflow, dil_w)*dil_w), 0);
69 const int kw = div_up(i_l_overflow, dil_w);
71 const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w)
72 - div_up(i_r_overflow, dil_w);
74 size_t src_off = (jcp.ndims == 5) ? src_d.blk_off(n, ch*jcp.ch_block, id, ih, iw)
75 : src_d.blk_off(n, ch*jcp.ch_block, ih, iw);
76 size_t dst_off = (jcp.ndims == 5) ? dst_d.blk_off(n, ch*jcp.ch_block, od, oh, ow)
77 : dst_d.blk_off(n, ch*jcp.ch_block, oh, ow);
78 size_t wei_off = (jcp.ndims == 5) ? weights_d.blk_off(ch, 0, 0, kd, kh, kw)
79 : weights_d.blk_off(ch, 0, 0, kh, kw);
81 par_conv.src = &src[src_off];
82 par_conv.dst = &dst[dst_off];
83 par_conv.filt = &weights[wei_off];
84 if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block*bia_dt_size)];
86 par_conv.kd_padding = (size_t)nstl::max(0, kd_padding);
87 par_conv.kh_padding = (size_t)nstl::max(0, kh_padding);
88 par_conv.kw_padding = (size_t)nstl::max(0, kw_padding);
90 par_conv.ur_w = (size_t)ur_w_step;
92 par_conv.ch_work = nstl::min((ch + ch_num) * jcp.ch_block, jcp.oc) - ch*jcp.ch_block;
94 par_conv.scales = &oscales.scales_[jcp.is_oc_scale * ch * jcp.ch_block];
95 par_conv.oc_off = ch * jcp.ch_block * sizeof(float);
100 auto ker = [&](const int ithr, const int nthr) {
101 size_t start{0}, end{0};
102 balance211(work_amount, nthr, ithr, start, end);
104 size_t n{0}, chb{0}, oh{0}, od{0};
105 nd_iterator_init(start, n, MB, chb, chb_work, od, jcp.od, oh, jcp.oh);
106 for (size_t iwork = start; iwork < end; ++iwork) {
107 int ch = chb * jcp.nb_ch_blocking;
108 int ch_num = jcp.nb_ch_blocking;
110 const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h));
111 const int i_b_overflow = nstl::max(jcp.ih,
112 (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih;
114 const int i_front_overflow = nstl::max(0, (int)(jcp.f_pad - od*str_d));
115 const int i_back_overflow = nstl::max(jcp.id,
116 (int)(od*str_d + (jcp.kd - 1)*dil_d - jcp.f_pad + 1)) - jcp.id;
118 const int ih = nstl::max((int)(oh*str_h - jcp.t_pad
119 + div_up(i_t_overflow, dil_h)*dil_h), 0);
120 const int kh = div_up(i_t_overflow, dil_h);
121 const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h)
122 - div_up(i_b_overflow, dil_h);
124 const int id = nstl::max((int)(od*str_d - jcp.f_pad
125 + div_up(i_front_overflow, dil_d)*dil_d), 0);
126 const int kd = div_up(i_front_overflow, dil_d);
127 const int kd_padding = jcp.kd - div_up(i_front_overflow, dil_d)
128 - div_up(i_back_overflow, dil_d);
132 int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow);
134 for (; ow < l_border; ow++) {
135 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, od, ih, id,
136 kh, kd, kh_padding, kd_padding, ch, ch_num, n);
138 kernel_->jit_ker(&par_conv);
142 ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1)
143 / jcp.stride_w - ow + 1;
145 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, od, ih, id,
146 kh, kd, kh_padding, kd_padding, ch, ch_num, n);
148 kernel_->jit_ker(&par_conv);
155 for (; ow < jcp.ow; ow++) {
156 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, od, ih, id,
157 kh, kd, kh_padding, kd_padding, ch, ch_num, n);
159 kernel_->jit_ker(&par_conv);
162 nd_iterator_step(n, MB, chb, chb_work, od, jcp.od, oh, jcp.oh);
166 parallel(0, work_amount, ker);
169 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::u8>;
170 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s8>;
171 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s32>;
172 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::f32>;
174 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::u8>;
175 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s8>;
176 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s32>;
177 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::f32>;
179 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::u8>;
180 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s8>;
181 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s32>;
182 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::f32>;
184 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::u8>;
185 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s8>;
186 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s32>;
187 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::f32>;