Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21
22 #include "jit_avx512_core_x8s8s32x_convolution.hpp"
23
24 namespace mkldnn {
25 namespace impl {
26 namespace cpu {
27
28 using namespace mkldnn::impl::status;
29 using namespace mkldnn::impl::memory_format;
30 using namespace mkldnn::impl::memory_tracking::names;
31 using namespace mkldnn::impl::utils;
32
33 using namespace nstl;
34
35 using jit_conv_ker_t = void (*)(jit_conv_call_s *);
36
37 #define wht_blk_off(d, g, ...) \
38         (pd()->with_groups() \
39          ? (d).blk_off((g), __VA_ARGS__) \
40          : (d).blk_off(__VA_ARGS__))
41
42 template <data_type_t src_type, data_type_t dst_type>
43 void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
44 execute_forward() const
45 {
46     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
47     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
48     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
49     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
50
51     const memory_desc_wrapper src_d(pd()->src_pd());
52     const memory_desc_wrapper dst_d(pd()->dst_pd());
53     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
54     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
55
56     const size_t bia_dt_size = pd()->with_bias()
57         ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
58
59     const auto &jcp = pd()->jcp_;
60     assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
61     assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
62
63     const float *oscales = pd()->attr()->output_scales_.scales_;
64     if (jcp.signed_input && jcp.ver != ver_vnni) {
65         auto local_scales = scratchpad().template get<float>(
66                 key_conv_adjusted_scales);
67         size_t count = pd()->attr()->output_scales_.count_;
68         float factor = 1.f / pd()->jcp_.wei_adj_scale;
69         if (count == 1) {
70             utils::array_set(local_scales, oscales[0] * factor, 16);
71         } else {
72             for (size_t c = 0; c < count; c++)
73                 local_scales[c] = oscales[c] * factor;
74         }
75         oscales = local_scales;
76     }
77
78     size_t offset = weights_d.size() - weights_d.additional_buffer_size();
79     auto w = const_cast<wei_data_t *>(weights);
80     int32_t* compensation = (jcp.signed_input)
81                                 ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
82     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
83     int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
84     int group_block = jcp.ch_block;
85     int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
86
87     parallel(0, [&](const int ithr, const int nthr) {
88
89         int start{0}, end{0};
90         balance211(work_amount, nthr, ithr, start, end);
91
92         auto p = jit_conv_call_s();
93
94         size_t src_h_stride = src_d.blk_off(0, 0, 1);
95         size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
96         size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
97
98         int n{ 0 }, gg{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 };
99         if (jcp.loop_order == loop_cwgn)
100             nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
101                     nb_groups, n, jcp.mb, oh_s, jcp.oh);
102         else if (jcp.loop_order == loop_gncw)
103             nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks,
104                     owb, jcp.nb_ow, oh_s, jcp.oh);
105         else if (jcp.loop_order == loop_ngcw)
106             nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks,
107                     owb, jcp.nb_ow, oh_s, jcp.oh);
108         else if (jcp.loop_order == loop_nhwcg)
109             nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
110                     occ, oc_chunks, gg, nb_groups);
111         else
112             assert(!"unsupported loop order");
113         while (start < end) {
114             int ocb = occ * jcp.nb_oc_blocking;
115             int gb = gg * jcp.nb_ch_blocking;
116             int g = gb * group_block;
117             int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
118
119             int g_ic = g * jcp.nb_ic * jcp.ic_block;
120
121             int work_rem = end - start;
122             int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
123             int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
124             if (jcp.loop_order == loop_nhwcg) oh_e = oh_s + 1; // step instead
125             int ow_s = owb * jcp.ow_block;
126             int iw_s = ow_s * jcp.stride_w;
127
128             auto bias_w = bias
129                 ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
130                 : 0;
131             int32_t *compensation_w = (jcp.signed_input)
132                                                     ? compensation + g_oc : 0;
133
134             auto dst_w = dst + dst_d.blk_off(n, g_oc, oh_s, ow_s);
135             auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s);
136             auto wht_w = weights + wht_blk_off(weights_d, gb, ocb, 0);
137
138             auto scales = &oscales[jcp.is_oc_scale * g_oc];
139
140             for (int oj = oh_s, ij = ih_s; oj < oh_e;
141                 ++oj, ij += jcp.stride_h) {
142                 int dilate_h = jcp.dilate_h + 1;
143                 int i_t_overflow = nstl::min(jcp.kh,
144                                                 div_up(max(0, -ij), dilate_h));
145                 int i_b_overflow = nstl::min(jcp.kh, div_up(
146                         max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
147                         dilate_h));
148                 int kh_padding = nstl::max(0,
149                     jcp.kh - i_t_overflow - i_b_overflow);
150
151                 size_t wei_stride = (!jcp.signed_input)
152                                             ? i_t_overflow * wht_h_stride : 0;
153                 p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
154                 p.dst = dst_w;
155                 p.filt = wht_w + wei_stride;
156                 p.bias = bias_w;
157                 p.compensation = compensation_w;
158                 p.oc_blocks = jcp.is_depthwise ? gb : ocb;
159                 p.kh_padding = kh_padding;
160                 p.scales = scales;
161                 p.t_overflow = i_t_overflow;
162                 p.b_overflow = i_b_overflow;
163                 p.owb = owb;
164
165                 p.oc_off = g_oc * sizeof(float);
166
167                 kernel_->jit_ker(&p);
168
169                 src_w += src_h_stride * jcp.stride_h;
170                 dst_w += dst_h_stride;
171             }
172             if (jcp.loop_order == loop_cwgn)
173                 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg,
174                         nb_groups, n, jcp.mb, oh_s, jcp.oh);
175             else if (jcp.loop_order == loop_gncw)
176                 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ,
177                         oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
178             else if (jcp.loop_order == loop_ngcw)
179                 nd_iterator_jump(start, end, n, jcp.mb, gg, nb_groups, occ,
180                         oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
181             else if (jcp.loop_order == loop_nhwcg) {
182                 ++start;
183                 nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ,
184                         oc_chunks, gg, nb_groups);
185             }
186             else
187                 assert(!"unsupported loop order");
188         }
189     });
190 }
191
192 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
193                                                 data_type::s8, data_type::u8>;
194 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
195                                                 data_type::u8, data_type::u8>;
196 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
197                                                 data_type::s8, data_type::s8>;
198 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
199                                                 data_type::u8, data_type::s8>;
200 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
201                                                 data_type::s8, data_type::s32>;
202 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
203                                                 data_type::u8, data_type::s32>;
204 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
205                                                 data_type::s8, data_type::f32>;
206 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
207                                                 data_type::u8, data_type::f32>;
208 }
209 }
210 }
211
212 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s