updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_dw_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18 #include "c_types_map.hpp"
19 #include "jit_uni_x8s8s32x_dw_convolution.hpp"
20
21 namespace mkldnn {
22 namespace impl {
23 namespace cpu {
24
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::utils;
28
29 template <cpu_isa_t isa, data_type_t src_type, data_type_t dst_type>
30 void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, src_type, dst_type>::execute_forward() const {
31     auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
32     auto weights = reinterpret_cast<const wei_data_t*>(this->input_memory(1));
33     auto bias = reinterpret_cast<const char*>(this->input_memory(2));
34     auto dst = reinterpret_cast<dst_data_t*>(this->memory());
35
36     const memory_desc_wrapper src_d(pd()->src_pd());
37     const memory_desc_wrapper dst_d(pd()->dst_pd());
38     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
39     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
40
41     const auto &jcp = kernel_->jcp;
42
43     int dil_d = jcp.dilate_d + 1;
44     int dil_h = jcp.dilate_h + 1;
45     int dil_w = jcp.dilate_w + 1;
46     int str_d = jcp.stride_d;
47     int str_h = jcp.stride_h;
48     int str_w = jcp.stride_w;
49
50     const size_t bia_dt_size = pd()->with_bias()
51         ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
52
53     const auto &oscales = pd()->attr()->output_scales_;
54
55     int MB = jcp.mb;
56     int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
57     const size_t work_amount = MB * chb_work * jcp.od * jcp.oh;
58
59     auto kernel_params = [&](int ur_w_step, int ow, int oh, int od, int ih, int id, int kh, int kd,
60             int kh_padding, int kd_padding, int ch, int ch_num, int n) {
61         auto par_conv = jit_conv_call_s();
62
63         const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w));
64         const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w
65             + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw;
66
67         const int iw = nstl::max((ow*str_w - jcp.l_pad
68             + div_up(i_l_overflow, dil_w)*dil_w), 0);
69         const int kw = div_up(i_l_overflow, dil_w);
70
71         const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w)
72             - div_up(i_r_overflow, dil_w);
73
74         size_t src_off = (jcp.ndims == 5) ? src_d.blk_off(n, ch*jcp.ch_block, id, ih, iw)
75                                        : src_d.blk_off(n, ch*jcp.ch_block, ih, iw);
76         size_t dst_off = (jcp.ndims == 5) ? dst_d.blk_off(n, ch*jcp.ch_block, od, oh, ow)
77                                        : dst_d.blk_off(n, ch*jcp.ch_block, oh, ow);
78         size_t wei_off = (jcp.ndims == 5) ? weights_d.blk_off(ch, 0, 0, kd, kh, kw)
79                                           : weights_d.blk_off(ch, 0, 0, kh, kw);
80
81         par_conv.src = &src[src_off];
82         par_conv.dst = &dst[dst_off];
83         par_conv.filt = &weights[wei_off];
84         if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block*bia_dt_size)];
85
86         par_conv.kd_padding = (size_t)nstl::max(0, kd_padding);
87         par_conv.kh_padding = (size_t)nstl::max(0, kh_padding);
88         par_conv.kw_padding = (size_t)nstl::max(0, kw_padding);
89
90         par_conv.ur_w = (size_t)ur_w_step;
91
92         par_conv.ch_work = nstl::min((ch + ch_num) * jcp.ch_block, jcp.oc) - ch*jcp.ch_block;
93
94         par_conv.scales = &oscales.scales_[jcp.is_oc_scale * ch * jcp.ch_block];
95         par_conv.oc_off = ch * jcp.ch_block * sizeof(float);
96
97         return par_conv;
98     };
99
100     auto ker = [&](const int ithr, const int nthr) {
101         size_t start{0}, end{0};
102         balance211(work_amount, nthr, ithr, start, end);
103
104         size_t n{0}, chb{0}, oh{0}, od{0};
105         nd_iterator_init(start, n, MB, chb, chb_work, od, jcp.od, oh, jcp.oh);
106         for (size_t iwork = start; iwork < end; ++iwork) {
107             int ch = chb * jcp.nb_ch_blocking;
108             int ch_num = jcp.nb_ch_blocking;
109
110             const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h));
111             const int i_b_overflow = nstl::max(jcp.ih,
112                 (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih;
113
114             const int i_front_overflow = nstl::max(0, (int)(jcp.f_pad - od*str_d));
115             const int i_back_overflow = nstl::max(jcp.id,
116                 (int)(od*str_d + (jcp.kd - 1)*dil_d - jcp.f_pad + 1)) - jcp.id;
117
118             const int ih = nstl::max((int)(oh*str_h - jcp.t_pad
119                 + div_up(i_t_overflow, dil_h)*dil_h), 0);
120             const int kh = div_up(i_t_overflow, dil_h);
121             const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h)
122                 - div_up(i_b_overflow, dil_h);
123
124             const int id = nstl::max((int)(od*str_d - jcp.f_pad
125                 + div_up(i_front_overflow, dil_d)*dil_d), 0);
126             const int kd = div_up(i_front_overflow, dil_d);
127             const int kd_padding = jcp.kd - div_up(i_front_overflow, dil_d)
128                 - div_up(i_back_overflow, dil_d);
129
130             // left border
131             int ow = 0;
132             int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow);
133             int ur_w_step = 1;
134             for (; ow < l_border; ow++) {
135                 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, od, ih, id,
136                                             kh, kd, kh_padding, kd_padding, ch, ch_num, n);
137
138                 kernel_->jit_ker(&par_conv);
139             }
140
141             // main loop
142             ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1)
143                 / jcp.stride_w - ow + 1;
144             if (ur_w_step > 0) {
145                 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, od, ih, id,
146                                             kh, kd, kh_padding, kd_padding, ch, ch_num, n);
147
148                 kernel_->jit_ker(&par_conv);
149
150                 ow += ur_w_step;
151             }
152
153             // right border
154             ur_w_step = 1;
155             for (; ow < jcp.ow; ow++) {
156                 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, od, ih, id,
157                                             kh, kd, kh_padding, kd_padding, ch, ch_num, n);
158
159                 kernel_->jit_ker(&par_conv);
160             }
161
162             nd_iterator_step(n, MB, chb, chb_work, od, jcp.od, oh, jcp.oh);
163         }
164     };
165
166     parallel(0, work_amount, ker);
167 }
168
169 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::u8>;
170 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s8>;
171 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s32>;
172 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::f32>;
173
174 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::u8>;
175 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s8>;
176 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s32>;
177 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::f32>;
178
179 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::u8>;
180 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s8>;
181 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s32>;
182 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::f32>;
183
184 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::u8>;
185 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s8>;
186 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s32>;
187 template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::f32>;
188
189 }
190 }
191 }