1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef JIT_UNI_1x1_CONV_UTILS_HPP
18 #define JIT_UNI_1x1_CONV_UTILS_HPP
20 #include "memory_tracking.hpp"
21 #include "mkldnn_thread.hpp"
23 #include "type_helpers.hpp"
26 #include "jit_generator.hpp"
32 using namespace mkldnn::impl::utils;
34 struct reduce_to_unit_stride_t {
35 convolution_desc_t conv_d_;
37 size_t space_per_thread_;
40 /* 1x1-kernel does not support non-unit strides so far, so the idea is:
41 * - for fwd or bwd_weights: to copy src to a scratch memory (with strides
42 * equal to 1) and then call the kernel
43 * - for bwd_data: reduce the problem to the one with unit stride by
44 * performing computations in a scratch memory (with strides equal to 1)
45 * and then copy the result to diff_src */
46 template <typename conv_pd_t>
47 inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d,
48 const memory_desc_t *&src_d, const memory_desc_t *dst_d) {
49 const bool is_bwd_data = self->desc()->prop_kind
50 == prop_kind::backward_data;
52 const int ndims = src_d->ndims;
53 bool rtus_applicable = true
54 && utils::pick(ndims - 3,
55 (conv_d->strides[0] != 1 && !one_of(conv_d->src_desc.data_type,
56 data_type::s16, data_type::s32)),
57 (conv_d->strides[0] != 1 || conv_d->strides[1] != 1))
58 && utils::one_of(src_d->format, memory_format::nCw8c,
59 memory_format::nCw16c, memory_format::nChw8c,
60 memory_format::nChw16c);
61 for (int d = 2; d < ndims; ++d) {
62 /* TODO: relax these conditions (by improving reducer) */
63 rtus_applicable = rtus_applicable
64 && conv_d->padding[0][d - 2] == 0
65 && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d];
68 if (rtus_applicable) {
69 self->rtus_.reduce_src_ = true;
70 conv_d = &(self->rtus_.conv_d_ = *conv_d);
71 self->rtus_.conv_d_.strides[0] = 1;
73 self->rtus_.conv_d_.strides[1] = 1;
74 utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2);
76 utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2);
77 const int ic = src_d->dims[1];
79 src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d);
80 self->rtus_.conv_d_.diff_src_desc.dims[1] = ic;
81 memory_desc_wrapper::compute_blocking(
82 self->rtus_.conv_d_.diff_src_desc);
84 data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type;
85 src_d = &(self->rtus_.conv_d_.src_desc = *dst_d);
86 self->rtus_.conv_d_.src_desc.dims[1] = ic;
87 self->rtus_.conv_d_.src_desc.data_type = data_type;
88 memory_desc_wrapper::compute_blocking(
89 self->rtus_.conv_d_.src_desc);
94 template <typename conv_pd_t>
95 inline void rtus_prepare_space_info(conv_pd_t *self,
96 memory_tracking::registrar_t &scratchpad) {
97 const auto &jcp = self->jcp_;
99 const int max_threads = mkldnn_get_max_threads();
100 const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind,
101 jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking);
102 size_t typesize = types::data_type_size(
103 conv_prop_agnostic_src_d(self->desc())->data_type);
105 self->rtus_.space_per_thread_ = factor * jcp.is * jcp.ic_block;
106 scratchpad.book(memory_tracking::names::key_conv_rtus_space,
107 typesize * max_threads * self->rtus_.space_per_thread_);
110 template <cpu_isa_t isa>
111 struct rtus_driver_t: public jit_generator {
113 struct call_params_t {
114 const void *ws; /* reduced image (w/ strides = 1) */
115 const void *src; /* source image (w/ non-unit strides) */
121 void (*ker_)(const call_params_t *p);
123 DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t)
125 /* cpu specific part */
126 using Vmm = typename utils::conditional<isa == avx2, Xbyak::Ymm,
129 Xbyak::Reg64 reg_ws = abi_param1;
130 Xbyak::Reg64 reg_src = abi_not_param1;
131 Xbyak::Reg64 reg_icb = rdx;
132 Xbyak::Reg64 reg_os = r11;
133 Xbyak::Reg64 reg_iw_start = r8;
135 Xbyak::Reg64 reg_cur_os = rax;
136 Xbyak::Reg64 reg_cur_iw = r9;
137 Xbyak::Reg64 reg_cur_src = r10;
140 int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_;
146 rtus_driver_t(int iw, int stride_w, int src_step_h,
147 int src_step_icb, int ws_step_icb, bool src_to_ws, size_t typesize)
148 : iw_(iw), stride_w_(stride_w), src_step_h_(src_step_h)
149 , src_step_icb_(src_step_icb), ws_step_icb_(ws_step_icb)
150 , src_to_ws_(src_to_ws), typesize_(typesize)
152 using namespace Xbyak;
153 vlen_ = cpu_isa_traits<isa>::vlen;
154 vlen_shift_ = cpu_isa_traits<isa>::vlen_shift;
155 if (typesize_ == 2) {
167 using namespace Xbyak;
169 mov(reg_cur_src, reg_src);
170 mov(reg_cur_iw, reg_iw_start);
171 mov(reg_cur_os, reg_os);
173 Label is_loop, skip_h_step;
177 vmovups(reg_v, ptr[reg_cur_src]);
178 vmovups(ptr[reg_ws], reg_v);
180 vmovups(reg_v, ptr[reg_ws]);
181 vmovups(ptr[reg_cur_src], reg_v);
182 for (int w = 1; w < stride_w_; ++w)
183 vmovups(ptr[reg_cur_src + w * vlen_], reg_zero);
188 add(reg_cur_iw, stride_w_);
189 add(reg_cur_src, stride_w_ * vlen_);
191 cmp(reg_cur_iw, iw_);
193 /* for 1d convolution the loop over h should be skipped */
194 if (src_step_icb_ == iw_) jmp(skip_h_step);
197 add(reg_cur_src, (src_step_h_ - iw_) * vlen_);
199 Xbyak::Reg64 reg_cur_src_fin = reg_cur_iw; /* just reuse */
200 mov(reg_cur_src_fin, reg_cur_src);
201 add(reg_cur_src_fin, (src_step_h_ - iw_) * vlen_);
205 for (int w = 0; w < stride_w_; ++w)
206 vmovups(ptr[reg_cur_src + w * vlen_], reg_zero);
208 add(reg_cur_src, stride_w_ * vlen_);
209 cmp(reg_cur_src, reg_cur_src_fin);
212 xor_(reg_cur_iw, reg_cur_iw);
216 sub(reg_cur_os, vlen_);
224 using namespace Xbyak;
225 assert(isa == avx2 || isa == avx512_common
226 || isa == avx512_core || isa == avx512_mic);
229 assert(reg_src == abi_not_param1 && abi_not_param1 == rdi);
233 #define READ_PARAM(what) \
234 mov(reg_ ## what, ptr[abi_param1 + offsetof(call_params_t, what)])
238 READ_PARAM(iw_start);
240 assert(reg_ws == abi_param1);
241 READ_PARAM(ws); /* reg_ws should always be read the last */
244 shl(reg_os, vlen_shift_);
247 uni_vpxor(reg_zero, reg_zero, reg_zero);
254 add(reg_ws, ws_step_icb_ * vlen_);
255 add(reg_src, src_step_icb_ * vlen_);
258 jnz(icb_loop, T_NEAR);
266 this->ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
271 template <cpu_isa_t isa, typename conv_t>
272 inline void init_rtus_driver(conv_t *self) {
273 const auto &conf = *self->pd();
274 if (!conf.rtus_.reduce_src_) return;
276 const auto &cd = *conf.desc();
277 const int ndims = conf.ndims();
278 const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0];
279 const int stride_w = cd.strides[ndims - 3];
281 const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data;
282 const auto &src_d = is_bwd_data ? *conf.diff_src_pd()->desc()
283 : *conf.src_pd()->desc();
284 assert((isa == avx2 && utils::one_of(src_d.format, memory_format::nCw8c,
285 memory_format::nChw8c)) || (isa == avx512_common && utils::one_of(
286 src_d.format, memory_format::nCw16c, memory_format::nChw16c)));
288 const int ih = ndims == 3 ? 1 : src_d.dims[2];
289 const int iw = src_d.dims[ndims - 1];
291 const int src_step_h = stride_h * iw;
292 const int src_step_icb = ih * iw;
293 const int ws_step_icb = conf.jcp_.is;
294 const bool src_to_ws = !is_bwd_data;
295 const size_t typesize = types::data_type_size(
296 conv_prop_agnostic_src_d(self->pd()->desc())->data_type);
298 self->rtus_driver_ = new rtus_driver_t<isa>(iw, stride_w, src_step_h,
299 src_step_icb, ws_step_icb, src_to_ws, typesize);
302 inline int best_divider(int value, int min_divider, int max_divider,
303 bool find_max, int step = 1)
305 max_divider = nstl::max(1, nstl::min(max_divider, value));
306 min_divider = nstl::max(1, nstl::min(min_divider, max_divider));
308 auto loss_ratio = [](int total, int chunk)
309 { return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk); };
311 float min_loss = FLT_MAX;
312 int x_divider = max_divider;
313 for (int divider = max_divider; divider >= min_divider; divider -= step) {
314 const float loss = loss_ratio(value, divider);
315 if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) {