Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_dw_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18 #include "c_types_map.hpp"
19 #include "jit_uni_x8s8s32x_dw_convolution.hpp"
20
21 namespace mkldnn {
22 namespace impl {
23 namespace cpu {
24
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::utils;
28
29 template <cpu_isa_t isa, data_type_t src_type, data_type_t dst_type>
30 void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, src_type, dst_type>::execute_forward() const {
31     auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
32     auto weights = reinterpret_cast<const wei_data_t*>(this->input_memory(1));
33     auto bias = reinterpret_cast<const char*>(this->input_memory(2));
34     auto dst = reinterpret_cast<dst_data_t*>(this->memory());
35
36     const memory_desc_wrapper src_d(pd()->src_pd());
37     const memory_desc_wrapper dst_d(pd()->dst_pd());
38     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
39     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
40
41     const auto &jcp = kernel_->jcp;
42
43     int dil_h = jcp.dilate_h + 1;
44     int dil_w = jcp.dilate_w + 1;
45     int str_h = jcp.stride_h;
46     int str_w = jcp.stride_w;
47
48     const size_t bia_dt_size = pd()->with_bias()
49         ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
50
51     const auto &oscales = pd()->attr()->output_scales_;
52
53     int MB = jcp.mb;
54     int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
55     const size_t work_amount = MB * chb_work * jcp.oh;
56
57     auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh,
58             int kh_padding, int ch, int ch_num, int n) {
59         auto par_conv = jit_conv_call_s();
60
61         const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w));
62         const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w
63             + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw;
64
65         const int iw = nstl::max((ow*str_w - jcp.l_pad
66             + div_up(i_l_overflow, dil_w)*dil_w), 0);
67         const int kw = div_up(i_l_overflow, dil_w);
68
69         const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w)
70             - div_up(i_r_overflow, dil_w);
71
72         int src_off = src_d.blk_off(n, ch*jcp.ch_block, ih, iw);
73         int dst_off = dst_d.blk_off(n, ch*jcp.ch_block, oh, ow);
74
75         par_conv.src = &src[src_off];
76         par_conv.dst = &dst[dst_off];
77
78         par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)];
79         if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block*bia_dt_size)];
80
81         par_conv.kh_padding = (size_t)nstl::max(0, kh_padding);
82         par_conv.kw_padding = (size_t)nstl::max(0, kw_padding);
83
84         par_conv.ur_w = (size_t)ur_w_step;
85
86         par_conv.ch_work = nstl::min((ch + ch_num) * jcp.ch_block, jcp.oc) - ch*jcp.ch_block;
87
88         par_conv.scales = &oscales.scales_[jcp.is_oc_scale * ch * jcp.ch_block];
89         par_conv.oc_off = ch * jcp.ch_block * sizeof(float);
90
91         return par_conv;
92     };
93
94     auto ker = [&](const int ithr, const int nthr) {
95         size_t start{0}, end{0};
96         balance211(work_amount, nthr, ithr, start, end);
97
98         size_t n{0}, chb{0}, oh{0};
99         nd_iterator_init(start, n, MB, chb, chb_work, oh, jcp.oh);
100         for (size_t iwork = start; iwork < end; ++iwork) {
101             int ch = chb * jcp.nb_ch_blocking;
102             int ch_num = jcp.nb_ch_blocking;
103
104             const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h));
105             const int i_b_overflow = nstl::max(jcp.ih,
106                 (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih;
107
108             const int ih = nstl::max((int)(oh*str_h - jcp.t_pad
109                 + div_up(i_t_overflow, dil_h)*dil_h), 0);
110             const int kh = div_up(i_t_overflow, dil_h);
111             const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h)
112                 - div_up(i_b_overflow, dil_h);
113
114             // left border
115             int ow = 0;
116             int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow);
117             int ur_w_step = 1;
118             for (; ow < l_border; ow++) {
119                 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
120                                             kh, kh_padding, ch, ch_num, n);
121
122                 kernel_->jit_ker(&par_conv);
123             }
124
125             // main loop
126             ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1)
127                 / jcp.stride_w - ow + 1;
128             if (ur_w_step > 0) {
129                 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
130                                             kh, kh_padding, ch, ch_num, n);
131
132                 kernel_->jit_ker(&par_conv);
133
134                 ow += ur_w_step;
135             }
136
137             // right border
138             ur_w_step = 1;
139             for (; ow < jcp.ow; ow++) {
140                 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
141                                             kh, kh_padding, ch, ch_num, n);
142
143                 kernel_->jit_ker(&par_conv);
144             }
145
146             nd_iterator_step(n, MB, chb, chb_work, oh, jcp.oh);
147         }
148     };
149
150     parallel(0, ker);
151 }
152
153 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::u8>;
154 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s8>;
155 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s32>;
156 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::f32>;
157
158 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::u8>;
159 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s8>;
160 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s32>;
161 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::f32>;
162
163 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::u8>;
164 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s8>;
165 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s32>;
166 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::f32>;
167
168 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::u8>;
169 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s8>;
170 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s32>;
171 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::f32>;
172
173 }
174 }
175 }