1 /*******************************************************************************
2 * Copyright 2016-2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "mkldnn_thread.hpp"
23 #include "cpu_isa_traits.hpp"
25 #include "gemm_convolution_utils.hpp"
26 #include "jit_generator.hpp"
32 using namespace mkldnn::impl::status;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
35 using namespace prop_kind;
36 using namespace data_type;
38 namespace jit_gemm_convolution_utils {
40 template <typename data_type_t>
41 void im2col_3d(const jit_gemm_conv_conf_t &jcp, const data_type_t *im,
42 data_type_t *col, int od)
44 const size_t OHW = jcp.oh * jcp.ow;
45 const size_t im_step = jcp.ih * jcp.iw * jcp.id;
46 const size_t col_step = jcp.ks * OHW;
48 parallel_nd(jcp.ic, [&](int ic) {
49 const data_type_t *__restrict im_loc = im + ic * im_step;
50 data_type_t *__restrict col_loc = col + ic * col_step;
51 int id = od * jcp.stride_d - jcp.f_pad;
52 for (int kd = 0; kd < jcp.kd; ++kd) {
53 data_type_t *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW;
54 if (id < 0 || id >= jcp.id) {
56 for (int kh = 0; kh < jcp.kh; ++kh) {
58 for (int oh = 0; oh < jcp.oh; ++oh) {
59 if (ih < 0 || ih >= jcp.ih) {
64 for (int kw = 0; kw < jcp.kw; ++kw) {
66 for (int ow = 0; ow < jcp.ow; ++ow) {
67 if (iw < 0 || iw >= jcp.iw) {
72 const size_t col_idx = kw * OHW + oh * jcp.ow
78 iw_ += (1 + jcp.dilate_w);
82 ih_ += (1 + jcp.dilate_h);
86 const data_type_t *__restrict im_ =
87 im_loc + id * jcp.ih * jcp.iw;
89 for (int kh = 0; kh < jcp.kh; ++kh) {
91 for (int oh = 0; oh < jcp.oh; ++oh) {
92 if (ih < 0 || ih >= jcp.ih) {
97 for (int kw = 0; kw < jcp.kw; ++kw) {
99 for (int ow = 0; ow < jcp.ow; ++ow) {
100 if (iw < 0 || iw >= jcp.iw) {
105 const size_t col_idx = kw * OHW + oh * jcp.ow
107 const size_t im_idx = ih * jcp.iw + iw;
109 col_[col_idx] = im_[im_idx];
112 iw_ += (1 + jcp.dilate_w);
116 ih_ += (1 + jcp.dilate_h);
117 col_ += jcp.kw * OHW;
120 id += (1 + jcp.dilate_d);
126 void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col,
130 void im2col_3d(const jit_gemm_conv_conf_t &jcp, const mkldnn_bfloat16_t *im,
131 mkldnn_bfloat16_t *col, int od);
133 /* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */
134 template <typename data_type_t>
135 void im2col(const jit_gemm_conv_conf_t &jcp, const data_type_t *__restrict im,
136 data_type_t *__restrict col, int hs, int hb, int ws, int wb) {
137 const size_t im_step = jcp.is;
138 const size_t col_step = jcp.ks * hb * wb;
139 if (jcp.stride_w == 1) {
140 // Generated code is more optimized for stride_w == 1
141 // because innermost loop is by width
142 auto ker = [&](int ic, int kh, int kw, int oh) {
143 const data_type_t *__restrict im_ = im + ic * im_step;
144 data_type_t *__restrict col_
145 = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb;
147 const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
148 + kh * (1 + jcp.dilate_h);
149 if (ih < 0 || ih >= jcp.ih) {
150 for (int ow = 0; ow < wb; ++ow)
151 col_[ow] = (data_type_t)0;
153 for (int ow = 0; ow < wb; ++ow) {
154 const int iw = ow + ws - jcp.l_pad
155 + kw * (1 + jcp.dilate_w);
156 if (iw < 0 || iw >= jcp.iw)
157 col_[ow] = (data_type_t)0;
159 const size_t im_idx = ih * jcp.iw + iw;
160 col_[ow] = im_[im_idx];
166 if (jcp.outer_threading) {
167 for (int ic = 0; ic < jcp.ic; ic++)
168 for (int kh = 0; kh < jcp.kh; kh++)
169 for (int kw = 0; kw < jcp.kw; kw++)
170 for (int oh = 0; oh < hb; oh++)
174 parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker);
176 } else if (jcp.ic == 1) {
177 parallel_nd(jcp.kh, hb, [&](int kh, int oh) {
178 const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
179 + kh * (1 + jcp.dilate_h);
180 if (ih < 0 || ih >= jcp.ih)
181 for (int kw = 0; kw < jcp.kw; ++kw) {
182 for (int ow = 0; ow < wb; ++ow) {
184 = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
185 col[col_idx] = (data_type_t)0;
189 for (int kw = 0; kw < jcp.kw; ++kw) {
190 for (int ow = 0; ow < wb; ++ow) {
191 const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
192 + kw * (1 + jcp.dilate_w);
194 = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
195 const size_t im_idx = ih * jcp.iw + iw;
196 if (iw < 0 || iw >= jcp.iw)
197 col[col_idx] = (data_type_t)0;
199 col[col_idx] = im[im_idx];
205 parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb,
206 [&](int ic, int kh, int kw, int oh) {
207 const data_type_t *__restrict im_ = im + ic * im_step;
208 data_type_t *__restrict col_ = col + ic * col_step
209 + ((kh * jcp.kw + kw) * hb + oh) * wb;
211 const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
212 + kh * (1 + jcp.dilate_h);
213 if (ih < 0 || ih >= jcp.ih) {
214 for (int ow = 0; ow < wb; ++ow)
215 col_[ow] = (data_type_t)0;
217 for (int ow = 0; ow < wb; ++ow) {
218 const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
219 + kw * (1 + jcp.dilate_w);
220 const size_t im_idx = ih * jcp.iw + iw;
221 if (iw < 0 || iw >= jcp.iw)
222 col_[ow] = (data_type_t)0;
224 col_[ow] = im_[im_idx];
232 void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im,
233 float *__restrict col, int hs, int hb, int ws, int wb);
236 void im2col(const jit_gemm_conv_conf_t &jcp,
237 const mkldnn_bfloat16_t *__restrict im,
238 mkldnn_bfloat16_t *__restrict col, int hs, int hb, int ws, int wb);
240 inline int limit(int low, int upper, int value) {
241 return nstl::max(low, nstl::min(upper, value));
244 /* col[kh][kw][ic][oh][ow] <-- im2col_u8(im[ih][iw][ic]) */
245 template <typename T>
246 void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
247 T *__restrict imtr, uint8_t *__restrict col, int hs, int hb, int ws,
249 uint8_t shift = jcp.signed_input ? 128 : 0;
250 const int dh = 1 + jcp.dilate_h;
251 const int dw = 1 + jcp.dilate_w;
252 const int sh = jcp.stride_h;
253 const int sw = jcp.stride_w;
254 const int im_iw_stride = jcp.ic * jcp.ngroups;
255 const int im_ih_stride = jcp.iw * im_iw_stride;
256 const int tp = jcp.t_pad;
257 const int lp = jcp.l_pad;
259 if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) {
260 /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */
261 const int hp = hs - tp;
262 const int wp = ws - lp;
263 const int ih_start = limit(0, jcp.ih, hp);
264 const int ih_end = limit(0, jcp.ih, hp + hb + jcp.kh);
265 const int iw_start = limit(0, jcp.iw, wp);
266 const int iw_end = limit(0, jcp.iw, wp + wb + jcp.kw);
268 const int ihb = ih_end - ih_start;
269 const int iwb = iw_end - iw_start;
271 const int imtr_ic_stride = ihb * iwb;
272 const ptrdiff_t imtr_idx_shift = ih_start * iwb + iw_start;
273 for (int ic = 0; ic < jcp.ic; ic++) {
274 const ptrdiff_t imtr_idx_ic = ic * imtr_ic_stride - imtr_idx_shift;
275 for (int ih = ih_start; ih < ih_end; ih++) {
276 const ptrdiff_t im_idx_ih = ic + ih * im_ih_stride;
277 const ptrdiff_t imtr_idx_ih = imtr_idx_ic + ih * iwb;
278 for (int iw = iw_start; iw < iw_end; iw++)
279 imtr[imtr_idx_ih + iw] = im[im_idx_ih + iw * im_iw_stride];
283 const int col_ic_str = hb * wb;
284 const int col_kw_stride = jcp.ic * col_ic_str;
285 const int col_kh_stride = jcp.kw * col_kw_stride;
287 const int oh_init = ih_start - hp;
288 const int ow_init = iw_start - wp;
289 for (int kh = 0; kh < jcp.kh; kh++) {
290 const ptrdiff_t col_idx_kh = kh * col_kh_stride;
291 const int oh_kh = oh_init - kh;
292 const int oh_start = limit(0, hb, oh_kh);
293 const int oh_end = limit(0, hb, oh_kh + ihb);
294 for (int kw = 0; kw < jcp.kw; kw++) {
295 const ptrdiff_t col_idx_kw
296 = col_idx_kh + kw * jcp.ic * col_ic_str;
297 const int ow_kw = ow_init - kw;
298 const int imtr_shift = oh_kh * iwb + ow_kw;
299 const int ow_start = limit(0, wb, ow_kw);
300 const int ow_end = limit(0, wb, ow_kw + iwb);
301 for (int ic = 0; ic < jcp.ic; ic++) {
302 const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str;
303 const int imtr_idx_ic = ic * imtr_ic_stride - imtr_shift;
304 for (int oh = 0; oh < oh_start; oh++) {
305 const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
306 for (int ow = 0; ow < wb; ++ow)
307 col[col_idx_oh + ow] = shift;
309 for (int oh = oh_start; oh < oh_end; oh++) {
310 const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
311 const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb;
312 for (int ow = 0; ow < ow_start; ++ow)
313 col[col_idx_oh + ow] = shift;
314 for (int ow = ow_start; ow < ow_end; ++ow)
316 = imtr[imtr_idx_oh + ow] + shift;
317 for (int ow = ow_end; ow < wb; ++ow)
318 col[col_idx_oh + ow] = shift;
320 for (int oh = oh_end; oh < hb; oh++) {
321 const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
322 for (int ow = 0; ow < wb; ++ow)
323 col[col_idx_oh + ow] = shift;
329 parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb,
330 [&](int kh, int kw, int ic, int oh) {
331 const int hp = tp - kh * dh;
332 const int ih = (oh + hs) * sh - hp;
333 const ptrdiff_t col_idx_base
334 = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) * wb;
335 if (ih < 0 || ih >= jcp.ih)
336 for (int ow = 0; ow < wb; ow++)
337 col[col_idx_base + ow] = shift;
339 const int wp = lp - kw * dw;
340 const int ow_start = limit(0, wb, div_up(wp, sw) - ws);
342 = limit(0, wb, div_up(jcp.iw + wp, sw) - ws);
343 for (int ow = 0; ow < ow_start; ow++)
344 col[col_idx_base + ow] = shift;
345 const int iw_base = ws * sw - wp;
346 const ptrdiff_t im_idx_base = ih * im_ih_stride + ic;
347 for (int ow = ow_start; ow < ow_end; ow++) {
348 const int iw = iw_base + ow * sw;
349 const ptrdiff_t im_idx
350 = im_idx_base + iw * im_iw_stride;
351 col[col_idx_base + ow] = im[im_idx] + shift;
353 for (int ow = ow_end; ow < wb; ow++)
354 col[col_idx_base + ow] = shift;
360 template void im2col_u8<int8_t>(const jit_gemm_conv_conf_t &jcp,
361 const int8_t *__restrict im, int8_t *__restrict imtr,
362 uint8_t *__restrict col, int hs, int hb, int ws, int wb);
363 template void im2col_u8<uint8_t>(const jit_gemm_conv_conf_t &jcp,
364 const uint8_t *__restrict im, uint8_t *__restrict imtr,
365 uint8_t *__restrict col, int hs, int hb, int ws, int wb);
367 template <typename T>
368 void im2col_u8_3d(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
369 uint8_t *__restrict col, int od) {
370 uint8_t shift = jcp.signed_input ? 128 : 0;
371 const int dh = 1 + jcp.dilate_h;
372 const int dw = 1 + jcp.dilate_w;
373 const int dd = 1 + jcp.dilate_d;
374 const int sh = jcp.stride_h;
375 const int sw = jcp.stride_w;
376 const int sd = jcp.stride_d;
377 const int im_iw_stride = jcp.ic * jcp.ngroups;
378 const int im_ih_stride = jcp.iw * im_iw_stride;
379 const int im_id_stride = jcp.ih * im_ih_stride;
380 const int tp = jcp.t_pad;
381 const int lp = jcp.l_pad;
382 const int fp = jcp.f_pad;
384 const T* im_loc = im + od * sd * im_id_stride;
386 parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic, jcp.oh, jcp.ow,
387 [&](int kd, int kh, int kw, int ic, int oh, int ow) {
388 int im_idx = (kd * dd - fp) * im_id_stride
389 + (kh * dh - tp + oh * sh) * im_ih_stride
390 + (kw * dw - lp + ow * sw) * im_iw_stride
393 int col_idx = kd * jcp.kh * jcp.kw * jcp.ic * jcp.oh * jcp.ow
394 + kh * jcp.kw * jcp.ic * jcp.oh * jcp.ow
395 + kw * jcp.ic * jcp.oh * jcp.ow
396 + ic * jcp.oh * jcp.ow
400 int id = od * sd + kd * dd - fp;
401 int ih = oh * sh + kh * dh - tp;
402 int iw = ow * sw + kw * dw - lp;
404 if (id < 0 || id >= jcp.id || ih < 0 || ih >= jcp.ih || iw < 0 || iw >= jcp.iw)
405 col[col_idx] = shift;
407 col[col_idx] = im_loc[im_idx] + shift;
411 template void im2col_u8_3d<int8_t>(const jit_gemm_conv_conf_t &jcp, const int8_t *__restrict im,
412 uint8_t *__restrict col, int od);
414 template void im2col_u8_3d<uint8_t>(const jit_gemm_conv_conf_t &jcp, const uint8_t *__restrict im,
415 uint8_t *__restrict col, int od);
417 /* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */
418 void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col,
419 int32_t *__restrict im)
421 parallel(0, (size_t)mkldnn_get_max_threads(), [&](const int ithr, const int nthr) {
422 int h_nthr = nstl::min(jcp.ih, nthr);
423 int w_nthr = nstl::min(jcp.iw, nthr / h_nthr);
424 int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0;
425 if (ithr < h_nthr * w_nthr) {
426 h_ithr = ithr / w_nthr;
427 w_ithr = ithr % w_nthr;
428 balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e);
429 balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e);
431 h_ithr = w_ithr = -ithr;
432 h_s = h_e = w_s = w_e = -1;
435 for (int ih = h_s; ih < h_e; ++ih) {
436 for (int iw = w_s; iw < w_e; ++iw) {
438 for (int ic = 0; ic < jcp.ic; ++ic) {
439 im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0;
444 // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh]
445 for (int oh = 0; oh < jcp.oh; ++oh) {
446 for (int ow = 0; ow < jcp.ow; ++ow) {
447 for (int kh = 0; kh < jcp.kh; ++kh) {
448 const int ih = oh * jcp.stride_h
449 - jcp.t_pad + kh * (1 + jcp.dilate_h);
450 if (ih < h_s || ih >= h_e) continue;
452 for (int kw = 0; kw < jcp.kw; ++kw) {
453 const int iw = ow * jcp.stride_w
454 - jcp.l_pad + kw * (1 + jcp.dilate_w);
455 if (iw < w_s || iw >= w_e) continue;
457 const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh
458 + kh) * jcp.kw + kw) * jcp.ic;
460 = (ih * jcp.iw + iw) * jcp.ic;
462 for (int ic = 0; ic < jcp.ic; ++ic) {
463 im[im_idx + ic] += col[col_idx + ic];
472 void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im,
475 parallel_nd(jcp.ic, [&](int ic) {
476 const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os;
477 float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id;
479 int id = od * jcp.stride_d - jcp.f_pad;
480 for (int kd = 0; kd < jcp.kd; ++kd) {
481 if (id < 0 || id >= jcp.id) {
482 col_ += jcp.kh * jcp.kw * jcp.os;
483 id += (1 + jcp.dilate_d);
487 float *__restrict im_ = im_ic + id * jcp.ih * jcp.iw;
489 for (int oh = 0; oh < jcp.oh; ++oh) {
490 for (int kh = 0; kh < jcp.kh; ++kh) {
491 const int ih = oh * jcp.stride_h - jcp.t_pad
492 + kh * (1 + jcp.dilate_h);
493 if (ih < 0 || ih >= jcp.ih) continue;
495 for (int ow = 0; ow < jcp.ow; ++ow) {
496 for (int kw = 0; kw < jcp.kw; ++kw) {
497 const int iw = ow * jcp.stride_w - jcp.l_pad
498 + kw * (1 + jcp.dilate_w);
499 if (iw < 0 || iw >= jcp.iw) continue;
501 const size_t col_idx =
502 ((kh * jcp.kw + kw) * jcp.oh + oh) * jcp.ow + ow;
503 const size_t im_idx = ih*jcp.iw + iw;
504 im_[im_idx] += col_[col_idx];
508 col_ += jcp.kh * jcp.kw * jcp.os;
509 id += (1 + jcp.dilate_d);
514 void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im) {
515 const size_t col_step = jcp.ks * jcp.os;
516 const size_t im_step = jcp.ih * jcp.iw;
517 const int iS = jcp.ih * jcp.iw;
519 parallel_nd(jcp.ic, [&](int ic) {
520 float *__restrict im_ = im + ic * im_step;
521 const float *__restrict col_ = col + ic * col_step;
523 for (int is = 0; is < iS; ++is) im_[is] = 0.;
525 for (int kh = 0; kh < jcp.kh; ++kh) {
526 for (int oh = 0; oh < jcp.oh; ++oh) {
528 oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h);
529 if (ih < 0 || ih >= jcp.ih) continue;
531 for (int kw = 0; kw < jcp.kw; ++kw) {
532 for (int ow = 0; ow < jcp.ow; ++ow) {
534 ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w);
535 if (iw < 0 || iw >= jcp.iw) continue;
537 const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
538 const size_t im_idx = ih*jcp.iw + iw;
539 im_[im_idx] += col_[col_idx];
547 status_t init_conf(jit_gemm_conv_conf_t &jcp,
548 memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
549 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
550 const memory_desc_wrapper &dst_d, int max_threads) {
551 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
552 const int ndims = src_d.ndims();
553 const int is_1d = ndims == 3;
554 const int is_3d = ndims == 5;
556 jcp.prop_kind = cd.prop_kind;
558 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
559 jcp.mb = src_d.dims()[0];
561 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
562 jcp.ic = src_d.dims()[1] / jcp.ngroups;
563 jcp.id = is_3d ? src_d.dims()[2] : 1;
564 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
565 jcp.iw = src_d.dims()[ndims - 1];
566 jcp.od = is_3d ? dst_d.dims()[2] : 1;
567 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
568 jcp.ow = dst_d.dims()[ndims - 1];
570 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
571 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
572 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
574 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
575 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
576 jcp.l_pad = cd.padding[0][ndims - 3];
578 jcp.stride_d = is_3d ? cd.strides[0] : 1;
579 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
580 jcp.stride_w = cd.strides[ndims - 3];
582 jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
583 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
584 jcp.dilate_w = cd.dilates[ndims - 3];
586 jcp.src_fmt = src_d.format();
587 jcp.with_bias = cd.bias_desc.format != memory_format::undef
588 || cd.diff_bias_desc.format != memory_format::undef;
590 jcp.is = jcp.ih * jcp.iw;
591 jcp.os = jcp.oh * jcp.ow;
592 jcp.ks = jcp.kh * jcp.kw * jcp.kd;
594 jcp.signed_input = src_d.data_type() == data_type::s8;
596 !jcp.signed_input || mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
598 jcp.im2col_sz = !everyone_is(true,
599 jcp.ow == jcp.iw, jcp.oh == jcp.ih, jcp.od == jcp.id,
600 jcp.stride_w == 1, jcp.stride_h == 1, jcp.stride_d == 1,
601 jcp.ks == 1, !jcp.signed_input)
602 ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os : 0;
604 jcp.outer_threading = false;
606 bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8)
607 && weights_d.data_type() == s8;
609 const bool is_bwd_d = jcp.prop_kind == backward_data;
610 const bool is_bwd_w = jcp.prop_kind == backward_weights;
611 const bool is_fwd = !is_bwd_d && !is_bwd_w;
613 bool is_bf16_conv = false
614 || (is_fwd && utils::everyone_is(bf16,
615 src_d.data_type(), weights_d.data_type()))
616 || (is_bwd_d && utils::everyone_is(bf16,
617 dst_d.data_type(), weights_d.data_type()))
618 || (is_bwd_w && utils::everyone_is(bf16,
619 src_d.data_type(), dst_d.data_type()));
620 if (is_bf16_conv && !mayiuse(avx512_core))
621 return status::unimplemented;
623 bool is_bf16_to_bf16_conv = is_bf16_conv
624 && ((is_fwd && bf16 == dst_d.data_type())
625 || (is_bwd_d && bf16 == src_d.data_type())
626 || (is_bwd_w && bf16 == weights_d.data_type()));
628 const int vlen = mayiuse(avx512_common)
629 ? cpu_isa_traits<avx512_common>::vlen
631 ? cpu_isa_traits<avx>::vlen
632 : mayiuse(sse42) ? cpu_isa_traits<sse42>::vlen : 4;
633 const int data_size = (is_int8_conv ? 1 : (is_bf16_conv ? 2 : 4));
634 const int simd_w = vlen / data_size;
636 jcp.oh_block = is_fwd ? jcp.oh : jcp.ih;
637 jcp.ow_block = is_fwd ? jcp.ow : jcp.iw;
639 using namespace memory_tracking::names;
640 bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
642 // TODO: maybe mitigate blocking restriction
643 const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
644 const int L2 = get_cache_size(2, true)
646 bool is_blocking_applicable = true
647 && is_fwd && jcp.im2col_sz
648 && jcp.id == 1 && jcp.od == 1
649 // This condition was relaxed to support old behaviour
650 // && jcp.dilate_h == 0 && jcp.dilate_w == 0
653 if (is_blocking_applicable) {
654 // looking for oh and ow blocking
655 int h_block{ jcp.oh_block }, w_block{ jcp.ow_block };
656 const int ic = jcp.ic;
657 const int oc = jcp.oc;
658 const int iw = jcp.iw;
659 const int ow = jcp.ow;
660 const int oh = jcp.oh;
661 const int os = oh * ow;
663 // 1. cache requirement
664 int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow);
666 // Heuristic rule: gemm needed a lot of memory for internal usage
668 // memory for accumulators
669 row_size += oc * ow * sizeof(uint32_t);
670 // memory for transposition
674 h_block = nstl::max(1, nstl::min(oh, div_up(L2 - wei_size, row_size)));
676 int col_size = ic * jcp.ks + 2 * (ic + oc);
679 col_size += oc * sizeof(uint32_t);
682 w_block = nstl::max(1, nstl::min(ow, div_up(L2 - wei_size, col_size)));
685 // 2. threading requirement
687 h_block = nstl::max(1, rnd_dn(h_block, 4));
689 w_block = nstl::max(1, rnd_dn(w_block, simd_w));
692 float thr_eff_treshold = 0.9f;
695 int nb_h = div_up(oh, h_block);
696 size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h;
697 float disb = (float)oh / rnd_up(oh, h_block);
698 thr_eff = (float)work / rnd_up(work, max_threads);
699 thr_eff = (thr_eff + disb) / 2.f;
700 if (thr_eff >= thr_eff_treshold)
702 h_block = rnd_dn(h_block - 4, 4);
703 } while (h_block > 0);
705 if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block
710 int nb_w = div_up(ow, w_block);
711 size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w;
712 float disb = (float)ow / rnd_up(ow, w_block);
713 thr_eff = (float)work_amount / rnd_up(work_amount, max_threads);
714 thr_eff = (thr_eff + disb) / 2.f;
715 if (thr_eff > thr_eff_treshold)
717 w_block = rnd_dn(w_block - simd_w, simd_w);
718 } while (w_block > 0);
720 h_block = nstl::max(1, h_block);
721 w_block = nstl::max(1, w_block);
722 const size_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w);
723 const float inner_thr_eff
724 = (float)inner_work / rnd_up(inner_work, max_threads);
725 if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) {
726 jcp.oh_block = h_block;
727 jcp.ow_block = w_block;
728 jcp.outer_threading = true;
730 // updating jcp.im2col_sz
731 if (jcp.oh_block != 1)
733 jcp.im2col_sz = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block;
735 // For threading selection in bwd_d we do:
736 // 1. Rough estimation of efficiency for inner and outer threading.
737 // 2. Gemm size estimation in assumption that it does not work
738 // so effectively for small sizes.
739 // 64K - this is heuristic gemm size per thread threshold.
740 const int gemm_thrld = 64 * 1024;
744 if (!jcp.outer_threading) {
745 bool is_depthwise = jcp.ic == 1 && jcp.oc == 1
747 const size_t outer_work = jcp.ngroups * jcp.mb;
748 const float outer_thr_eff
749 = (float)outer_work / rnd_up(outer_work, max_threads);
750 const size_t inner_work
751 = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
752 const float inner_thr_eff
753 = (float)inner_work / rnd_up(inner_work, max_threads);
754 jcp.outer_threading = (is_depthwise
755 || (jcp.is / max_threads < 64 && jcp.mb != 1))
756 && (outer_thr_eff / inner_thr_eff >= 1.f
757 || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
759 jcp.nthr = jcp.outer_threading ? max_threads : 1;
760 scratchpad.book(key_conv_gemm_col,
761 sizeof(int8_t) * jcp.nthr * jcp.im2col_sz);
762 scratchpad.book(key_conv_int_dat_in_acc_dt,
763 sizeof(int32_t) * jcp.nthr * jcp.oh_block
764 * jcp.ow_block * jcp.oc);
765 scratchpad.book(key_conv_gemm_imtr,
766 sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic);
767 } else if (is_bwd_d) {
768 bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
769 const size_t outer_work = jcp.ngroups * jcp.mb;
770 const float outer_thr_eff
771 = (float)outer_work / rnd_up(outer_work, max_threads);
772 const size_t inner_work
773 = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
774 const float inner_thr_eff
775 = (float)inner_work / rnd_up(inner_work, max_threads);
776 jcp.outer_threading = (is_depthwise
777 || (jcp.is / max_threads < 64 && jcp.mb != 1))
778 && (outer_thr_eff / inner_thr_eff >= 1.f
779 || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
781 jcp.nthr = jcp.outer_threading ? max_threads : 1;
782 scratchpad.book(key_conv_gemm_col,
783 sizeof(int32_t) * jcp.nthr * jcp.im2col_sz);
784 scratchpad.book(key_conv_int_dat_in_acc_dt,
785 sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic);
786 } else if (is_bwd_w) {
787 assert(!"unimplemented prop_kind");
788 return status::unimplemented;
792 if (!jcp.outer_threading) {
793 const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od;
794 const float outer_thr_eff = (float)outer_work_amount
795 / rnd_up(outer_work_amount, max_threads);
796 const size_t inner_work_amount
797 = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w);
798 const float inner_thr_eff = (float)inner_work_amount
799 / rnd_up(inner_work_amount, max_threads);
800 jcp.outer_threading = jcp.os / max_threads < 512
801 && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2)
802 && (outer_thr_eff / inner_thr_eff >= 1.f
803 || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
805 } else if (is_bwd_d) {
806 const size_t outer_work_amount = jcp.ngroups * jcp.mb;
807 const float outer_thr_eff = (float)outer_work_amount
808 / rnd_up(outer_work_amount, max_threads);
809 const size_t inner_work
810 = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
811 const float inner_thr_eff = (float)inner_work
812 / rnd_up(inner_work, max_threads);
813 jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64)
814 && (jcp.mb != 1 || jcp.ngroups > 2)
815 && (outer_thr_eff / inner_thr_eff >= 1.f
816 || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
818 jcp.outer_threading = jcp.os / max_threads < 256
819 && (jcp.mb != 1 || jcp.ngroups > 2);
821 jcp.nthr = jcp.outer_threading ? max_threads : 1;
822 const size_t gemm_col_datatype_size = is_bf16_conv && !is_bwd_d
823 ? sizeof(mkldnn_bfloat16_t)
825 scratchpad.book(key_conv_gemm_col,
826 gemm_col_datatype_size * jcp.nthr * jcp.im2col_sz);
828 const int sizeof_cacheline_float = 16;
830 jcp.need_wei_reduction = mkldnn_thr_syncable()
831 ? jcp.mb != 1 && jcp.nthr != 1 : false;
832 scratchpad.book(key_conv_wei_reduction,
833 sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size());
835 if (is_bf16_conv && jcp.with_bias) {
836 const size_t ws_size = sizeof(float)
837 * max_threads * rnd_up(jcp.ow, sizeof_cacheline_float);
838 scratchpad.book(key_conv_dst_bf16_convert_wsp, ws_size);
842 if (is_bf16_to_bf16_conv) {
843 size_t conv_acc_buffer_size = 0;
845 conv_acc_buffer_size = sizeof(float) * jcp.nthr
846 * rnd_up(jcp.oc * jcp.oh_block * jcp.ow_block,
847 sizeof_cacheline_float);
849 conv_acc_buffer_size = sizeof(float) * jcp.nthr
850 * rnd_up(jcp.ic * jcp.ih * jcp.iw * jcp.id,
851 sizeof_cacheline_float);
853 conv_acc_buffer_size = sizeof(float) * weights_d.size();
854 scratchpad.book(key_conv_int_dat_in_acc_dt, conv_acc_buffer_size);
858 return status::success;
861 void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g,
862 int &nthr_g, int &ithr_mb, int &nthr_mb) {
863 nthr_g = nstl::min(ngroups, nthr);
864 nthr_mb = nstl::min(mb, nthr / nthr_g);
865 if (ithr / nthr_mb >= ngroups) {
866 ithr_g = ithr_mb = -1;
868 ithr_g = ithr / nthr_mb;
869 ithr_mb = ithr % nthr_mb;
873 void bwd_weights_reduction_par(int ithr, int nthr,
874 const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws,
876 const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
878 size_t weights_start{0}, weights_end{0};
879 balance211(weights_g_size, nthr, ithr, weights_start, weights_end);
881 for (int i = 0; i < nthr; ++i) {
882 const float *ws_i = weights_reduce_ws + i * weights_g_size;
883 for (size_t s = weights_start; s < weights_end; ++s)
884 weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s];