1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "mkldnn_thread.hpp"
23 #include "cpu_isa_traits.hpp"
25 #include "gemm_convolution_utils.hpp"
26 #include "jit_generator.hpp"
32 using namespace mkldnn::impl::status;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
35 using namespace prop_kind;
36 using namespace data_type;
38 namespace jit_gemm_convolution_utils {
40 void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col,
43 const size_t OHW = jcp.oh * jcp.ow;
44 const size_t im_step = jcp.ih * jcp.iw * jcp.id;
45 const size_t col_step = jcp.ks * OHW;
47 parallel_nd(jcp.ic, [&](int ic) {
48 const float *__restrict im_loc = im + ic * im_step;
49 float *__restrict col_loc = col + ic * col_step;
50 int id = od * jcp.stride_d - jcp.f_pad;
51 for (int kd = 0; kd < jcp.kd; ++kd) {
52 float *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW;
53 if (id < 0 || id >= jcp.id) {
55 for (int kh = 0; kh < jcp.kh; ++kh) {
57 for (int oh = 0; oh < jcp.oh; ++oh) {
58 if (ih < 0 || ih >= jcp.ih) {
63 for (int kw = 0; kw < jcp.kw; ++kw) {
65 for (int ow = 0; ow < jcp.ow; ++ow) {
66 if (iw < 0 || iw >= jcp.iw) {
71 const size_t col_idx = kw * OHW + oh * jcp.ow
77 iw_ += (1 + jcp.dilate_w);
81 ih_ += (1 + jcp.dilate_h);
85 const float *__restrict im_ = im_loc + id * jcp.ih * jcp.iw;
87 for (int kh = 0; kh < jcp.kh; ++kh) {
89 for (int oh = 0; oh < jcp.oh; ++oh) {
90 if (ih < 0 || ih >= jcp.ih) {
95 for (int kw = 0; kw < jcp.kw; ++kw) {
97 for (int ow = 0; ow < jcp.ow; ++ow) {
98 if (iw < 0 || iw >= jcp.iw) {
103 const size_t col_idx = kw * OHW + oh * jcp.ow
105 const size_t im_idx = ih * jcp.iw + iw;
107 col_[col_idx] = im_[im_idx];
110 iw_ += (1 + jcp.dilate_w);
114 ih_ += (1 + jcp.dilate_h);
115 col_ += jcp.kw * OHW;
118 id += (1 + jcp.dilate_d);
123 /* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */
124 void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im,
125 float *__restrict col, int hs, int hb, int ws, int wb) {
126 const size_t im_step = jcp.is;
127 const size_t col_step = jcp.ks * hb * wb;
128 if (jcp.stride_w == 1) {
129 // Generated code is more optimized for stride_w == 1
130 // because innermost loop is by width
131 auto ker = [&](int ic, int kh, int kw, int oh) {
132 const float *__restrict im_ = im + ic * im_step;
133 float *__restrict col_
134 = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb;
136 const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
137 + kh * (1 + jcp.dilate_h);
138 if (ih < 0 || ih >= jcp.ih) {
139 for (int ow = 0; ow < wb; ++ow)
142 for (int ow = 0; ow < wb; ++ow) {
143 const int iw = ow + ws - jcp.l_pad + kw * (1 + jcp.dilate_w);
144 if (iw < 0 || iw >= jcp.iw)
147 const size_t im_idx = ih * jcp.iw + iw;
148 col_[ow] = im_[im_idx];
154 if (jcp.outer_threading) {
155 for (int ic = 0; ic < jcp.ic; ic++)
156 for (int kh = 0; kh < jcp.kh; kh++)
157 for (int kw = 0; kw < jcp.kw; kw++)
158 for (int oh = 0; oh < hb; oh++)
162 parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker);
164 } else if (jcp.ic == 1) {
165 parallel_nd(jcp.kh, hb, [&](int kh, int oh) {
166 const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
167 + kh * (1 + jcp.dilate_h);
168 if (ih < 0 || ih >= jcp.ih)
169 for (int kw = 0; kw < jcp.kw; ++kw) {
170 for (int ow = 0; ow < wb; ++ow) {
172 = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
177 for (int kw = 0; kw < jcp.kw; ++kw) {
178 for (int ow = 0; ow < wb; ++ow) {
179 const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
180 + kw * (1 + jcp.dilate_w);
182 = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
183 const size_t im_idx = ih * jcp.iw + iw;
184 if (iw < 0 || iw >= jcp.iw)
187 col[col_idx] = im[im_idx];
193 parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb,
194 [&](int ic, int kh, int kw, int oh) {
195 const float *__restrict im_ = im + ic * im_step;
196 float *__restrict col_ = col + ic * col_step
197 + ((kh * jcp.kw + kw) * hb + oh) * wb;
199 const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
200 + kh * (1 + jcp.dilate_h);
201 if (ih < 0 || ih >= jcp.ih) {
202 for (int ow = 0; ow < wb; ++ow)
205 for (int ow = 0; ow < wb; ++ow) {
206 const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
207 + kw * (1 + jcp.dilate_w);
208 const size_t im_idx = ih * jcp.iw + iw;
209 if (iw < 0 || iw >= jcp.iw)
212 col_[ow] = im_[im_idx];
219 /* col[oh][ow][kh][kw][ic] <-- im2col_u8(im[ih][iw][ic]) */
220 template <typename T>
221 void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
222 uint8_t *__restrict col) {
223 uint8_t shift = jcp.signed_input ? 128 : 0;
224 const int dh = 1 + jcp.dilate_h;
225 const int dw = 1 + jcp.dilate_w;
226 const int sh = jcp.stride_h;
227 const int sw = jcp.stride_w;
228 if (sh == 1 && sw == 1 && jcp.oh > 2 * mkldnn_get_max_threads()) {
229 const int ihp = jcp.ih + jcp.t_pad;
230 const int iwp = jcp.iw + jcp.l_pad;
231 const int col_kw_step = jcp.ic;
232 const int col_kh_step = jcp.kw * col_kw_step;
233 const int col_ow_step = jcp.kh * col_kh_step;
234 const int col_oh_step = jcp.ow * col_ow_step;
235 const int im_iw_step = jcp.ngroups * jcp.ic;
236 const int im_ih_step = jcp.iw * im_iw_step;
238 const int nb_ic = jcp.ic / 4;
239 const int ic_blocked = nb_ic * 4;
241 parallel_nd(jcp.oh, [&](int oh) {
242 const int kh_start = nstl::max(div_up(jcp.t_pad - oh, dh), 0);
243 const int kh_end = nstl::min(div_up(ihp - oh, dh), jcp.kh);
244 const int ih_start = oh - jcp.t_pad + kh_start * dh;
245 const int col_oh_idx = oh * col_oh_step;
247 for (int kh = kh_start, ih = ih_start; kh < kh_end; ++kh, ih += dh)
249 const int col_kh_idx = col_oh_idx + kh * col_kh_step;
250 const int im_kh_idx = ih * im_ih_step;
252 for (int kw = 0; kw < jcp.kw; ++kw) {
253 const int ow_start = nstl::max(jcp.l_pad - kw * dw, 0);
254 const int ow_end = nstl::min(iwp - kw * dw, jcp.ow);
255 const int iw_start = ow_start - jcp.l_pad + kw * dw;
256 const int col_kw_idx = col_kh_idx + kw * col_kw_step;
258 const int col_idx_start
259 = col_kw_idx + ow_start * col_ow_step;
260 const int im_idx_start = im_kh_idx + iw_start * im_iw_step;
261 const int col_idx_end = col_kw_idx + ow_end * col_ow_step;
265 for (int col_idx = col_idx_start, im_idx = im_idx_start;
266 col_idx < col_idx_end;
267 col_idx += col_ow_step, im_idx += im_iw_step) {
268 for (int icb = 0; icb < 4 * nb_ic; icb += 4) {
270 for (int ic = 0; ic < 4; ++ic) {
271 col[col_idx + icb + ic]
272 = im[im_idx + icb + ic] + shift;
277 if (ic_blocked != jcp.ic) {
278 for (int col_idx = col_idx_start, im_idx = im_idx_start;
279 col_idx < col_idx_end;
280 col_idx += col_ow_step, im_idx += im_iw_step) {
282 for (int ic = ic_blocked; ic < jcp.ic; ++ic) {
283 col[col_idx + ic] = im[im_idx + ic] + shift;
292 const size_t col_kh_step = jcp.kw * jcp.ic;
293 const size_t col_ow_step = jcp.kh * col_kh_step;
294 const size_t col_oh_step = jcp.ow * col_ow_step;
295 const size_t im_ih_step = jcp.iw * jcp.ngroups * jcp.ic;
296 const size_t im_iw_step = jcp.ngroups * jcp.ic;
297 const int ih_pad = jcp.ih + jcp.t_pad;
298 const int iw_pad = jcp.iw + jcp.l_pad;
299 parallel_nd(jcp.oh, jcp.ow, [&](int oh, int ow) {
300 const int ihs = oh * sh;
301 const int ihsp = jcp.t_pad - ihs;
302 const int kh_start = nstl::max(div_up(ihsp, dh), 0);
303 const int kh_end = nstl::min(div_up(ih_pad - ihs, dh), jcp.kh);
304 const int ih_start = kh_start * dh - ihsp;
305 const int iws = ow * sw;
306 const int iwsp = jcp.l_pad - iws;
307 const int kw_start = nstl::max(div_up(iwsp, dw), 0);
308 const int kw_end = nstl::min(div_up(iw_pad - iws, dw), jcp.kw);
309 const int iw_start = kw_start * dw - iwsp;
311 uint8_t *__restrict col_base
312 = col + oh * col_oh_step + ow * col_ow_step;
313 for (int kh = kh_start, ih = ih_start; kh < kh_end;
315 uint8_t *__restrict col_ = col_base + kh * col_kh_step;
316 const T *__restrict im_ = im + ih * im_ih_step;
318 for (int kw = kw_start, iw = iw_start; kw < kw_end;
321 const size_t col_idx = kw * jcp.ic;
322 const size_t im_idx = iw * im_iw_step;
324 for (int ic = 0; ic < jcp.ic; ++ic) {
325 col_[col_idx + ic] = im_[im_idx + ic] + shift;
334 template void im2col_u8<int8_t>(const jit_gemm_conv_conf_t &jcp,
335 const int8_t *__restrict im, uint8_t *__restrict col);
336 template void im2col_u8<uint8_t>(const jit_gemm_conv_conf_t &jcp,
337 const uint8_t *__restrict im, uint8_t *__restrict col);
339 /* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */
340 void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col,
341 int32_t *__restrict im)
343 parallel(0, [&](const int ithr, const int nthr) {
344 int h_nthr = nstl::min(jcp.ih, nthr);
345 int w_nthr = nstl::min(jcp.iw, nthr / h_nthr);
346 int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0;
347 if (ithr < h_nthr * w_nthr) {
348 h_ithr = ithr / w_nthr;
349 w_ithr = ithr % w_nthr;
350 balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e);
351 balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e);
353 h_ithr = w_ithr = -ithr;
354 h_s = h_e = w_s = w_e = -1;
357 for (int ih = h_s; ih < h_e; ++ih) {
358 for (int iw = w_s; iw < w_e; ++iw) {
360 for (int ic = 0; ic < jcp.ic; ++ic) {
361 im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0;
366 // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh]
367 for (int oh = 0; oh < jcp.oh; ++oh) {
368 for (int ow = 0; ow < jcp.ow; ++ow) {
369 for (int kh = 0; kh < jcp.kh; ++kh) {
370 const int ih = oh * jcp.stride_h
371 - jcp.t_pad + kh * (1 + jcp.dilate_h);
372 if (ih < h_s || ih >= h_e) continue;
374 for (int kw = 0; kw < jcp.kw; ++kw) {
375 const int iw = ow * jcp.stride_w
376 - jcp.l_pad + kw * (1 + jcp.dilate_w);
377 if (iw < w_s || iw >= w_e) continue;
379 const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh
380 + kh) * jcp.kw + kw) * jcp.ic;
382 = (ih * jcp.iw + iw) * jcp.ic;
384 for (int ic = 0; ic < jcp.ic; ++ic) {
385 im[im_idx + ic] += col[col_idx + ic];
394 void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im,
397 parallel_nd(jcp.ic, [&](int ic) {
398 const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os;
399 float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id;
401 int id = od * jcp.stride_d - jcp.f_pad;
402 for (int kd = 0; kd < jcp.kd; ++kd) {
403 if (id < 0 || id >= jcp.id) {
404 col_ += jcp.kh * jcp.kw * jcp.os;
405 id += (1 + jcp.dilate_d);
409 float *__restrict im_ = im_ic + id * jcp.ih * jcp.iw;
411 for (int oh = 0; oh < jcp.oh; ++oh) {
412 for (int kh = 0; kh < jcp.kh; ++kh) {
413 const int ih = oh * jcp.stride_h - jcp.t_pad
414 + kh * (1 + jcp.dilate_h);
415 if (ih < 0 || ih >= jcp.ih) continue;
417 for (int ow = 0; ow < jcp.ow; ++ow) {
418 for (int kw = 0; kw < jcp.kw; ++kw) {
419 const int iw = ow * jcp.stride_w - jcp.l_pad
420 + kw * (1 + jcp.dilate_w);
421 if (iw < 0 || iw >= jcp.iw) continue;
423 const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
424 const size_t im_idx = ih*jcp.iw + iw;
425 im_[im_idx] += col_[col_idx];
429 col_ += jcp.kh * jcp.kw * jcp.os;
430 id += (1 + jcp.dilate_d);
435 void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im) {
436 const size_t col_step = jcp.ks * jcp.os;
437 const size_t im_step = jcp.ih * jcp.iw;
438 const int iS = jcp.ih * jcp.iw;
440 parallel_nd(jcp.ic, [&](int ic) {
441 float *__restrict im_ = im + ic * im_step;
442 const float *__restrict col_ = col + ic * col_step;
444 for (int is = 0; is < iS; ++is) im_[is] = 0.;
446 for (int kh = 0; kh < jcp.kh; ++kh) {
447 for (int oh = 0; oh < jcp.oh; ++oh) {
448 const int ih = oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h);
449 if (ih < 0 || ih >= jcp.ih) continue;
451 for (int kw = 0; kw < jcp.kw; ++kw) {
452 for (int ow = 0; ow < jcp.ow; ++ow) {
453 const int iw = ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w);
454 if (iw < 0 || iw >= jcp.iw) continue;
456 const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
457 const size_t im_idx = ih*jcp.iw + iw;
458 im_[im_idx] += col_[col_idx];
466 status_t init_conf(jit_gemm_conv_conf_t &jcp,
467 memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
468 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
469 const memory_desc_wrapper &dst_d, int max_threads) {
470 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
471 const int ndims = src_d.ndims();
472 const int is_1d = ndims == 3;
473 const int is_3d = ndims == 5;
475 jcp.prop_kind = cd.prop_kind;
477 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
478 jcp.mb = src_d.dims()[0];
480 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
481 jcp.ic = src_d.dims()[1] / jcp.ngroups;
482 jcp.id = is_3d ? src_d.dims()[2] : 1;
483 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
484 jcp.iw = src_d.dims()[ndims - 1];
485 jcp.od = is_3d ? dst_d.dims()[2] : 1;
486 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
487 jcp.ow = dst_d.dims()[ndims - 1];
489 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
490 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
491 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
493 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
494 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
495 jcp.l_pad = cd.padding[0][ndims - 3];
497 jcp.stride_d = is_3d ? cd.strides[0] : 1;
498 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
499 jcp.stride_w = cd.strides[ndims - 3];
501 jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
502 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
503 jcp.dilate_w = cd.dilates[ndims - 3];
505 jcp.src_fmt = src_d.format();
506 jcp.with_bias = cd.bias_desc.format != memory_format::undef
507 || cd.diff_bias_desc.format != memory_format::undef;
509 jcp.is = jcp.ih * jcp.iw;
510 jcp.os = jcp.oh * jcp.ow;
511 jcp.ks = jcp.kh * jcp.kw * jcp.kd;
513 jcp.signed_input = src_d.data_type() == data_type::s8;
515 !jcp.signed_input || mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
517 jcp.im2col_sz = !everyone_is(true,
518 jcp.ow == jcp.iw, jcp.oh == jcp.ih, jcp.od == jcp.id,
519 jcp.stride_w == 1, jcp.stride_h == 1, jcp.stride_d == 1,
520 jcp.ks == 1, !jcp.signed_input)
521 ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os : 0;
523 jcp.outer_threading = false;
524 jcp.oh_block = jcp.oh;
525 jcp.ow_block = jcp.ow;
527 bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8)
528 && weights_d.data_type() == s8;
530 const int vlen = mayiuse(avx512_common)
531 ? cpu_isa_traits<avx512_common>::vlen
533 ? cpu_isa_traits<avx>::vlen
534 : mayiuse(sse42) ? cpu_isa_traits<sse42>::vlen : 4;
535 const int simd_w = vlen / (is_int8_conv ? 1 : 4);
537 const bool is_bwd_d = jcp.prop_kind == backward_data;
538 const bool is_bwd_w = jcp.prop_kind == backward_weights;
539 const bool is_fwd = !is_bwd_d && !is_bwd_w;
541 using namespace memory_tracking::names;
542 // For threading selection we do:
543 // 1. Rough estimation of efficiency for inner and outer threading.
544 // 2. Gemm size estimation in assumption that it does not work
545 // so effectively for small sizes.
546 // 64K - this is heuristic gemm size per thread threshold.
547 const int gemm_threshold = 64 * 1024;
549 bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
551 const int bs = is_fwd ? jcp.os : jcp.is;
552 const int ls = is_fwd ? jcp.oc : jcp.ic;
553 const size_t outer_work_amount = jcp.ngroups * jcp.mb;
554 const float outer_thr_eff = (float)outer_work_amount
555 / rnd_up(outer_work_amount, max_threads);
556 const size_t inner_work_amount
557 = div_up(bs, simd_w) * div_up(ls, simd_w);
558 const float inner_thr_eff = (float)inner_work_amount
559 / rnd_up(inner_work_amount, max_threads);
560 jcp.outer_threading = (is_depthwise
561 || (bs / max_threads < 64 && jcp.mb != 1))
562 && (outer_thr_eff / inner_thr_eff >= 1.f
563 || (bs * jcp.ic * jcp.oc) / max_threads < gemm_threshold);
564 jcp.nthr = jcp.outer_threading ? max_threads : 1;
567 scratchpad.book(key_conv_gemm_col,
568 sizeof(int8_t) * jcp.nthr * jcp.im2col_sz);
569 scratchpad.book(key_conv_int_dat_in_acc_dt,
570 sizeof(int32_t) * jcp.nthr * jcp.os * jcp.oc);
571 } else if (is_bwd_d) {
572 scratchpad.book(key_conv_gemm_col,
573 sizeof(int32_t) * jcp.nthr * jcp.im2col_sz);
574 scratchpad.book(key_conv_int_dat_in_acc_dt,
575 sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic);
576 } else if (is_bwd_w) {
577 assert(!"unimplemented prop_kind");
578 return status::unimplemented;
582 const int L2 = get_cache_size(2, true) / sizeof(float);
583 const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
585 // It makes sense to try blocking for some special cases:
586 // when weights size is small and we have to do im2col
587 if (wei_size < L2/2 && jcp.im2col_sz && jcp.id == 1 && jcp.od == 1) {
588 // looking for oh and ow blocking
589 int h_block{ jcp.oh }, w_block{ jcp.ow };
590 // 1. cache requirement
591 // !!! used memory (assuming strides = 1 and dilate = 0 etc):
592 const int row_size = jcp.ic * jcp.kh * jcp.kw * jcp.ow
593 + 2 * jcp.ic * jcp.iw + 2 * jcp.oc * jcp.ow;
595 1, nstl::min(jcp.oh, div_up(L2 - wei_size, row_size)));
597 const int col_size = jcp.ic * jcp.kh * jcp.kw + 2 * jcp.ic
600 1, nstl::min(jcp.ow, div_up(L2 - wei_size, col_size)));
603 // 2. threading requirement
604 if (h_block != jcp.oh)
605 h_block = nstl::max(1, rnd_dn(h_block, 4));
606 if (w_block != jcp.ow)
607 w_block = nstl::max(1, rnd_dn(w_block, simd_w));
610 float thr_eff_treshold = 0.9f;
611 if (w_block == jcp.ow) {
613 int nb_oh = div_up(jcp.oh, h_block);
614 size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_oh;
615 float disb = (float)jcp.oh / rnd_up(jcp.oh, h_block);
616 thr_eff = (float)work
617 / rnd_up(work, max_threads);
618 thr_eff = (thr_eff + disb) / 2.f;
619 if (thr_eff >= thr_eff_treshold)
621 h_block = rnd_dn(h_block - 4, 4);
622 } while (h_block > 0);
624 if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block
629 int nb_ow = div_up(jcp.ow, w_block);
631 = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow;
632 float disb = (float)jcp.ow / rnd_up(jcp.ow, w_block);
633 thr_eff = (float)work_amount
634 / rnd_up(work_amount, max_threads);
635 thr_eff = (thr_eff + disb) / 2.f;
636 if (thr_eff > thr_eff_treshold)
638 w_block = rnd_dn(w_block - simd_w, simd_w);
639 } while (w_block > 0);
641 const size_t inner_work_amount
642 = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w);
643 const float inner_thr_eff = (float)inner_work_amount
644 / rnd_up(inner_work_amount, max_threads);
645 if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) {
646 jcp.oh_block = h_block;
647 jcp.ow_block = w_block;
648 jcp.outer_threading = true;
650 // updating jcp.im2col_sz
651 if (jcp.oh_block != 1)
652 jcp.ow_block = jcp.ow;
654 = (ptrdiff_t)jcp.ic * jcp.ks * jcp.oh_block * jcp.ow_block;
656 const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od;
657 const float outer_thr_eff = (float)outer_work_amount
658 / rnd_up(outer_work_amount, max_threads);
659 const size_t inner_work_amount
660 = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w);
661 const float inner_thr_eff = (float)inner_work_amount
662 / rnd_up(inner_work_amount, max_threads);
663 jcp.outer_threading = jcp.os / max_threads < 512
664 && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2)
665 && (outer_thr_eff / inner_thr_eff >= 1.f
666 || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_threshold);
668 } else if (is_bwd_d) {
669 const size_t outer_work_amount = jcp.ngroups * jcp.mb;
670 const float outer_thr_eff = (float)outer_work_amount
671 / rnd_up(outer_work_amount, max_threads);
672 const size_t inner_work_amount
673 = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
674 const float inner_thr_eff = (float)inner_work_amount
675 / rnd_up(inner_work_amount, max_threads);
676 jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64)
677 && (jcp.mb != 1 || jcp.ngroups > 2)
678 && (outer_thr_eff / inner_thr_eff >= 1.f
679 || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_threshold);
681 jcp.outer_threading = jcp.os / max_threads < 256
682 && (jcp.mb != 1 || jcp.ngroups > 2);
684 jcp.nthr = jcp.outer_threading ? max_threads : 1;
686 scratchpad.book(key_conv_gemm_col,
687 sizeof(float) * jcp.nthr * jcp.im2col_sz);
690 jcp.need_wei_reduction = mkldnn_thr_syncable()
691 ? jcp.mb != 1 && jcp.nthr != 1 : false;
693 scratchpad.book(key_conv_wei_reduction,
694 sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size());
698 return status::success;
701 void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g,
702 int &nthr_g, int &ithr_mb, int &nthr_mb) {
703 nthr_g = nstl::min(ngroups, nthr);
704 nthr_mb = nstl::min(mb, nthr / nthr_g);
705 if (ithr / nthr_mb >= ngroups) {
706 ithr_g = ithr_mb = -1;
708 ithr_g = ithr / nthr_mb;
709 ithr_mb = ithr % nthr_mb;
713 void bwd_weights_reduction_par(int ithr, int nthr,
714 const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws,
716 const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
718 size_t weights_start{0}, weights_end{0};
719 balance211(weights_g_size, nthr, ithr, weights_start, weights_end);
721 for (int i = 0; i < nthr; ++i) {
722 const float *ws_i = weights_reduce_ws + i * weights_g_size;
723 for (size_t s = weights_start; s < weights_end; ++s)
724 weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s];