Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_1x1_conv_utils.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_UNI_1x1_CONV_UTILS_HPP
18 #define JIT_UNI_1x1_CONV_UTILS_HPP
19
20 #include "memory_tracking.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "nstl.hpp"
23 #include "type_helpers.hpp"
24 #include "utils.hpp"
25
26 #include "jit_generator.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace mkldnn::impl::utils;
33
34 struct reduce_to_unit_stride_t {
35     convolution_desc_t conv_d_;
36     bool reduce_src_;
37     size_t space_per_thread_;
38 };
39
40 /* 1x1-kernel does not support non-unit strides so far, so the idea is:
41  *  - for fwd or bwd_weights: to copy src to a scratch memory (with strides
42  *    equal to 1) and then call the kernel
43  *  - for bwd_data: reduce the problem to the one with unit stride by
44  *    performing computations in a scratch memory (with strides equal to 1)
45  *    and then copy the result to diff_src */
46 template <typename conv_pd_t>
47 inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d,
48         const memory_desc_t *&src_d, const memory_desc_t *dst_d) {
49     const bool is_bwd_data = self->desc()->prop_kind
50         == prop_kind::backward_data;
51
52     const int ndims = src_d->ndims;
53     bool rtus_applicable = true
54         && utils::pick(ndims - 3,
55             (conv_d->strides[0] != 1 && !one_of(conv_d->src_desc.data_type,
56                 data_type::s16, data_type::s32)),
57             (conv_d->strides[0] != 1 || conv_d->strides[1] != 1))
58         && utils::one_of(src_d->format, memory_format::nCw8c,
59             memory_format::nCw16c, memory_format::nChw8c,
60             memory_format::nChw16c);
61     for (int d = 2; d < ndims; ++d) {
62         /* TODO: relax these conditions (by improving reducer) */
63         rtus_applicable = rtus_applicable
64             && conv_d->padding[0][d - 2] == 0
65             && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d];
66     }
67
68     if (rtus_applicable) {
69         self->rtus_.reduce_src_ = true;
70         conv_d = &(self->rtus_.conv_d_ = *conv_d);
71         self->rtus_.conv_d_.strides[0] = 1;
72         if (ndims == 4)
73             self->rtus_.conv_d_.strides[1] = 1;
74         utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2);
75         if (ndims == 4)
76             utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2);
77         const int ic = src_d->dims[1];
78         if (is_bwd_data) {
79             src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d);
80             self->rtus_.conv_d_.diff_src_desc.dims[1] = ic;
81             memory_desc_wrapper::compute_blocking(
82                     self->rtus_.conv_d_.diff_src_desc);
83         } else {
84             data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type;
85             src_d = &(self->rtus_.conv_d_.src_desc = *dst_d);
86             self->rtus_.conv_d_.src_desc.dims[1] = ic;
87             self->rtus_.conv_d_.src_desc.data_type = data_type;
88             memory_desc_wrapper::compute_blocking(
89                     self->rtus_.conv_d_.src_desc);
90         }
91     }
92 }
93
94 template <typename conv_pd_t>
95 inline void rtus_prepare_space_info(conv_pd_t *self,
96         memory_tracking::registrar_t &scratchpad) {
97     const auto &jcp = self->jcp_;
98
99     const int max_threads = mkldnn_get_max_threads();
100     const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind,
101             jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking);
102     size_t typesize = types::data_type_size(
103             conv_prop_agnostic_src_d(self->desc())->data_type);
104
105     self->rtus_.space_per_thread_ = factor * jcp.is * jcp.ic_block;
106     scratchpad.book(memory_tracking::names::key_conv_rtus_space,
107             typesize * max_threads * self->rtus_.space_per_thread_);
108 }
109
110 template <cpu_isa_t isa>
111 struct rtus_driver_t: public jit_generator {
112
113     struct call_params_t {
114         const void *ws; /* reduced image (w/ strides = 1) */
115         const void *src; /* source image (w/ non-unit strides) */
116         size_t icb;
117         size_t os;
118         size_t iw_start;
119     };
120
121     void (*ker_)(const call_params_t *p);
122
123     DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t)
124
125     /* cpu specific part */
126     using Vmm = typename utils::conditional<isa == avx2, Xbyak::Ymm,
127           Xbyak::Zmm>::type;
128
129     Xbyak::Reg64 reg_ws = abi_param1;
130     Xbyak::Reg64 reg_src = abi_not_param1;
131     Xbyak::Reg64 reg_icb = rdx;
132     Xbyak::Reg64 reg_os = r11;
133     Xbyak::Reg64 reg_iw_start = r8;
134
135     Xbyak::Reg64 reg_cur_os = rax;
136     Xbyak::Reg64 reg_cur_iw = r9;
137     Xbyak::Reg64 reg_cur_src = r10;
138
139     int iw_, stride_w_;
140     int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_;
141     bool src_to_ws_;
142     size_t typesize_;
143     Vmm reg_zero;
144     Vmm reg_v;
145
146     rtus_driver_t(int iw, int stride_w, int src_step_h,
147             int src_step_icb, int ws_step_icb, bool src_to_ws, size_t typesize)
148         : iw_(iw), stride_w_(stride_w), src_step_h_(src_step_h)
149         , src_step_icb_(src_step_icb), ws_step_icb_(ws_step_icb)
150         , src_to_ws_(src_to_ws), typesize_(typesize)
151     {
152         using namespace Xbyak;
153         vlen_ = cpu_isa_traits<isa>::vlen;
154         vlen_shift_ = cpu_isa_traits<isa>::vlen_shift;
155         if (typesize_ == 2) {
156             vlen_ /= 2;
157             vlen_shift_--;
158         }
159
160         reg_zero = Vmm(0);
161         reg_v = Vmm(1);
162
163         generate();
164     }
165
166     void loop_is() {
167         using namespace Xbyak;
168
169         mov(reg_cur_src, reg_src);
170         mov(reg_cur_iw, reg_iw_start);
171         mov(reg_cur_os, reg_os);
172
173         Label is_loop, skip_h_step;
174         L(is_loop);
175
176         if (src_to_ws_) {
177             vmovups(reg_v, ptr[reg_cur_src]);
178             vmovups(ptr[reg_ws], reg_v);
179         } else {
180             vmovups(reg_v, ptr[reg_ws]);
181             vmovups(ptr[reg_cur_src], reg_v);
182             for (int w = 1; w < stride_w_; ++w)
183                 vmovups(ptr[reg_cur_src + w * vlen_], reg_zero);
184         }
185
186         add(reg_ws, vlen_);
187
188         add(reg_cur_iw, stride_w_);
189         add(reg_cur_src, stride_w_ * vlen_);
190
191         cmp(reg_cur_iw, iw_);
192         jl(skip_h_step);
193         /* for 1d convolution the loop over h should be skipped */
194         if (src_step_icb_ == iw_) jmp(skip_h_step);
195
196         if (src_to_ws_) {
197             add(reg_cur_src, (src_step_h_ - iw_) * vlen_);
198         } else {
199             Xbyak::Reg64 reg_cur_src_fin = reg_cur_iw; /* just reuse */
200             mov(reg_cur_src_fin, reg_cur_src);
201             add(reg_cur_src_fin, (src_step_h_ - iw_) * vlen_);
202             Label ih_loop;
203             L(ih_loop);
204
205             for (int w = 0; w < stride_w_; ++w)
206                 vmovups(ptr[reg_cur_src + w * vlen_], reg_zero);
207
208             add(reg_cur_src, stride_w_ * vlen_);
209             cmp(reg_cur_src, reg_cur_src_fin);
210             jl(ih_loop);
211         }
212         xor_(reg_cur_iw, reg_cur_iw);
213
214         L(skip_h_step);
215
216         sub(reg_cur_os, vlen_);
217         jnz(is_loop);
218
219         /* restore dst */
220         sub(reg_ws, reg_os);
221     }
222
223     void generate() {
224         using namespace Xbyak;
225         assert(isa == avx2 || isa == avx512_common
226                 || isa == avx512_core || isa == avx512_mic);
227
228 #if defined(_WIN32)
229         assert(reg_src == abi_not_param1 && abi_not_param1 == rdi);
230         push(rdi);
231 #endif
232
233 #define READ_PARAM(what) \
234         mov(reg_ ## what, ptr[abi_param1 + offsetof(call_params_t, what)])
235         READ_PARAM(src);
236         READ_PARAM(icb);
237         READ_PARAM(os);
238         READ_PARAM(iw_start);
239
240         assert(reg_ws == abi_param1);
241         READ_PARAM(ws); /* reg_ws should always be read the last */
242 #undef  READ_PARAM
243
244         shl(reg_os, vlen_shift_);
245
246         if (!src_to_ws_)
247             uni_vpxor(reg_zero, reg_zero, reg_zero);
248
249         Label icb_loop;
250         L(icb_loop);
251
252         loop_is();
253
254         add(reg_ws, ws_step_icb_ * vlen_);
255         add(reg_src, src_step_icb_ * vlen_);
256
257         dec(reg_icb);
258         jnz(icb_loop, T_NEAR);
259
260 #if defined(_WIN32)
261         pop(rdi);
262 #endif
263
264         uni_vzeroupper();
265         ret();
266         this->ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
267                     this->getCode()));
268     }
269 };
270
271 template <cpu_isa_t isa, typename conv_t>
272 inline void init_rtus_driver(conv_t *self) {
273     const auto &conf = *self->pd();
274     if (!conf.rtus_.reduce_src_) return;
275
276     const auto &cd = *conf.desc();
277     const int ndims = conf.ndims();
278     const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0];
279     const int stride_w = cd.strides[ndims - 3];
280
281     const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data;
282     const auto &src_d = is_bwd_data ? *conf.diff_src_pd()->desc()
283                                     : *conf.src_pd()->desc();
284     assert((isa == avx2 && utils::one_of(src_d.format, memory_format::nCw8c,
285         memory_format::nChw8c)) || (isa == avx512_common && utils::one_of(
286             src_d.format, memory_format::nCw16c, memory_format::nChw16c)));
287
288     const int ih = ndims == 3 ? 1 : src_d.dims[2];
289     const int iw = src_d.dims[ndims - 1];
290
291     const int src_step_h = stride_h * iw;
292     const int src_step_icb = ih * iw;
293     const int ws_step_icb = conf.jcp_.is;
294     const bool src_to_ws = !is_bwd_data;
295     const size_t typesize = types::data_type_size(
296             conv_prop_agnostic_src_d(self->pd()->desc())->data_type);
297
298     self->rtus_driver_ = new rtus_driver_t<isa>(iw, stride_w, src_step_h,
299             src_step_icb, ws_step_icb, src_to_ws, typesize);
300 }
301
302 inline int best_divider(int value, int min_divider, int max_divider,
303         bool find_max, int step = 1)
304 {
305     max_divider = nstl::max(1, nstl::min(max_divider, value));
306     min_divider = nstl::max(1, nstl::min(min_divider, max_divider));
307
308     auto loss_ratio = [](int total, int chunk)
309     { return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk); };
310
311     float min_loss = FLT_MAX;
312     int x_divider = max_divider;
313     for (int divider = max_divider; divider >= min_divider; divider -= step) {
314         const float loss = loss_ratio(value, divider);
315         if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) {
316             min_loss = loss;
317             x_divider = divider;
318         }
319     }
320     return x_divider;
321 }
322
323 }
324 }
325 }
326
327 #endif