1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "mkldnn_thread.hpp"
24 #include "gemm_convolution_utils.hpp"
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::utils;
33 using namespace prop_kind;
34 using namespace data_type;
36 namespace jit_gemm_convolution_utils {
38 void im2col_3d(jit_gemm_conv_conf_t &jcp, const float *im, float *col, int od) {
39 const size_t OHW = jcp.oh * jcp.ow;
40 const size_t im_step = jcp.ih * jcp.iw * jcp.id;
41 const size_t col_step = jcp.ks * OHW;
43 parallel_nd(jcp.ic, [&](int ic) {
44 const float *im_loc = im + ic * im_step;
45 float *col_loc = col + ic * col_step;
46 int id = od * jcp.stride_d - jcp.f_pad;
47 for (int kd = 0; kd < jcp.kd; ++kd) {
48 float *col_ = col_loc + kd * jcp.kh * jcp.kw * OHW;
49 if (id < 0 || id >= jcp.id) {
51 for (int kh = 0; kh < jcp.kh; ++kh) {
53 for (int oh = 0; oh < jcp.oh; ++oh) {
54 if (ih < 0 || ih >= jcp.ih) {
59 for (int kw = 0; kw < jcp.kw; ++kw) {
61 for (int ow = 0; ow < jcp.ow; ++ow) {
62 if (iw < 0 || iw >= jcp.iw) {
67 const size_t col_idx = kw * OHW + oh * jcp.ow
73 iw_ += (1 + jcp.dilate_w);
77 ih_ += (1 + jcp.dilate_h);
81 const float *im_ = im_loc + id * jcp.ih * jcp.iw;
83 for (int kh = 0; kh < jcp.kh; ++kh) {
85 for (int oh = 0; oh < jcp.oh; ++oh) {
86 if (ih < 0 || ih >= jcp.ih) {
91 for (int kw = 0; kw < jcp.kw; ++kw) {
93 for (int ow = 0; ow < jcp.ow; ++ow) {
94 if (iw < 0 || iw >= jcp.iw) {
99 const size_t col_idx = kw * OHW + oh * jcp.ow
101 const size_t im_idx = ih * jcp.iw + iw;
103 col_[col_idx] = im_[im_idx];
106 iw_ += (1 + jcp.dilate_w);
110 ih_ += (1 + jcp.dilate_h);
111 col_ += jcp.kw * OHW;
114 id += (1 + jcp.dilate_d);
119 void im2col(jit_gemm_conv_conf_t &jcp, const float *im, float *col) {
121 parallel_nd(jcp.kh, jcp.oh, [&](int kh, int oh) {
122 const int ih = oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h);
123 if (ih < 0 || ih >= jcp.ih) return;
125 for (int kw = 0; kw < jcp.kw; ++kw) {
126 for (int ow = 0; ow < jcp.ow; ++ow) {
127 const int iw = ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w);
128 if (iw < 0 || iw >= jcp.iw) continue;
130 const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
131 const size_t im_idx = ih*jcp.iw + iw;
132 col[col_idx] = im[im_idx];
136 const size_t im_step = jcp.ih * jcp.iw;
137 const size_t col_step = jcp.ks * jcp.os;
139 parallel_nd(jcp.ic, [&](int ic) {
140 const float *im_ = im + ic * im_step;
141 float *col_ = col + ic * col_step;
143 for (int kh = 0; kh < jcp.kh; ++kh) {
144 for (int oh = 0; oh < jcp.oh; ++oh) {
145 const int ih = oh * jcp.stride_h
146 - jcp.t_pad + kh * (1 + jcp.dilate_h);
147 if (ih < 0 || ih >= jcp.ih) continue;
149 for (int kw = 0; kw < jcp.kw; ++kw) {
150 for (int ow = 0; ow < jcp.ow; ++ow) {
151 const int iw = ow * jcp.stride_w
152 - jcp.l_pad + kw * (1 + jcp.dilate_w);
153 if (iw < 0 || iw >= jcp.iw) continue;
155 const size_t col_idx = ((kh * jcp.kw + kw) * jcp.oh+oh)
157 const size_t im_idx = ih*jcp.iw + iw;
158 col_[col_idx] = im_[im_idx];
165 /* col[oh][ow][kh][kw][ic] <-- im2col_u8(im[ih][iw][ic]) */
166 void im2col_u8(jit_gemm_conv_conf_t &jcp, const uint8_t *im, uint8_t *col) {
167 parallel_nd(jcp.oh, jcp.ow, [&](int oh, int ow) {
168 for (int kh = 0; kh < jcp.kh; ++kh) {
169 const int ih = oh * jcp.stride_h
170 - jcp.t_pad + kh * (1 + jcp.dilate_h);
171 if (ih < 0 || ih >= jcp.ih) continue;
173 for (int kw = 0; kw < jcp.kw; ++kw) {
174 const int iw = ow * jcp.stride_w
175 - jcp.l_pad + kw * (1 + jcp.dilate_w);
176 if (iw < 0 || iw >= jcp.iw) continue;
178 const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh + kh)
179 * jcp.kw + kw) * jcp.ic;
181 = (ih * jcp.iw + iw) * jcp.ngroups * jcp.ic;
183 for (int ic = 0; ic < jcp.ic; ++ic) {
184 col[col_idx + ic] = im[im_idx + ic];
192 /* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */
193 void col2im_s32(jit_gemm_conv_conf_t &jcp, const int32_t *col, int32_t *im) {
194 parallel(0, [&](const int ithr, const int nthr) {
195 int h_nthr = nstl::min(jcp.ih, nthr);
196 int w_nthr = nstl::min(jcp.iw, nthr / h_nthr);
197 int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0;
198 if (ithr < h_nthr * w_nthr) {
199 h_ithr = ithr / w_nthr;
200 w_ithr = ithr % w_nthr;
201 balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e);
202 balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e);
204 h_ithr = w_ithr = -ithr;
205 h_s = h_e = w_s = w_e = -1;
208 for (int ih = h_s; ih < h_e; ++ih) {
209 for (int iw = w_s; iw < w_e; ++iw) {
211 for (int ic = 0; ic < jcp.ic; ++ic) {
212 im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0;
217 // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh]
218 for (int oh = 0; oh < jcp.oh; ++oh) {
219 for (int ow = 0; ow < jcp.ow; ++ow) {
220 for (int kh = 0; kh < jcp.kh; ++kh) {
221 const int ih = oh * jcp.stride_h
222 - jcp.t_pad + kh * (1 + jcp.dilate_h);
223 if (ih < h_s || ih >= h_e) continue;
225 for (int kw = 0; kw < jcp.kw; ++kw) {
226 const int iw = ow * jcp.stride_w
227 - jcp.l_pad + kw * (1 + jcp.dilate_w);
228 if (iw < w_s || iw >= w_e) continue;
230 const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh
231 + kh) * jcp.kw + kw) * jcp.ic;
233 = (ih * jcp.iw + iw) * jcp.ic;
235 for (int ic = 0; ic < jcp.ic; ++ic) {
236 im[im_idx + ic] += col[col_idx + ic];
245 void col2im_3d(jit_gemm_conv_conf_t &jcp, const float *col, float *im, int od) {
246 parallel_nd(jcp.ic, [&](int ic) {
247 const float *col_ = col + (size_t)ic * jcp.ks * jcp.os;
248 float *im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id;
250 int id = od * jcp.stride_d - jcp.f_pad;
251 for (int kd = 0; kd < jcp.kd; ++kd) {
252 if (id < 0 || id >= jcp.id) {
253 col_ += jcp.kh * jcp.kw * jcp.os;
254 id += (1 + jcp.dilate_d);
258 float *im_ = im_ic + id * jcp.ih * jcp.iw;
260 for (int oh = 0; oh < jcp.oh; ++oh) {
261 for (int kh = 0; kh < jcp.kh; ++kh) {
262 const int ih = oh * jcp.stride_h - jcp.t_pad
263 + kh * (1 + jcp.dilate_h);
264 if (ih < 0 || ih >= jcp.ih) continue;
266 for (int ow = 0; ow < jcp.ow; ++ow) {
267 for (int kw = 0; kw < jcp.kw; ++kw) {
268 const int iw = ow * jcp.stride_w - jcp.l_pad
269 + kw * (1 + jcp.dilate_w);
270 if (iw < 0 || iw >= jcp.iw) continue;
272 const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
273 const size_t im_idx = ih*jcp.iw + iw;
274 im_[im_idx] += col_[col_idx];
278 col_ += jcp.kh * jcp.kw * jcp.os;
279 id += (1 + jcp.dilate_d);
285 jit_gemm_conv_conf_t &jcp, const float *col, float *im) {
287 const size_t col_step = jcp.ks * jcp.os;
288 const size_t im_step = jcp.ih * jcp.iw;
289 const int iS = jcp.ih * jcp.iw;
291 parallel_nd(jcp.ic, [&](int ic) {
292 float *im_ = im + ic * im_step;
293 const float *col_ = col + ic * col_step;
295 for (int is = 0; is < iS; ++is) im_[is] = 0.;
297 for (int kh = 0; kh < jcp.kh; ++kh) {
298 for (int oh = 0; oh < jcp.oh; ++oh) {
299 const int ih = oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h);
300 if (ih < 0 || ih >= jcp.ih) continue;
302 for (int kw = 0; kw < jcp.kw; ++kw) {
303 for (int ow = 0; ow < jcp.ow; ++ow) {
304 const int iw = ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w);
305 if (iw < 0 || iw >= jcp.iw) continue;
307 const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
308 const size_t im_idx = ih*jcp.iw + iw;
309 im_[im_idx] += col_[col_idx];
318 jit_gemm_conv_conf_t &jcp, const convolution_desc_t &cd,
319 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
320 const memory_desc_wrapper &dst_d, int max_threads,
321 bool with_relu, float relu_negative_slope) {
323 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
324 jcp.prop_kind = cd.prop_kind;
325 const int ndims = src_d.ndims();
327 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
328 jcp.mb = src_d.dims()[0];
330 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
331 jcp.ic = src_d.dims()[1] / jcp.ngroups;
333 jcp.id = (ndims == 4) ? 1 : src_d.dims()[2];
334 jcp.ih = src_d.dims()[ndims - 2];
335 jcp.iw = src_d.dims()[ndims - 1];
336 jcp.od = (ndims == 4) ? 1 : dst_d.dims()[2];
337 jcp.oh = dst_d.dims()[ndims - 2];
338 jcp.ow = dst_d.dims()[ndims - 1];
340 jcp.kd = (ndims == 4) ? 1 : weights_d.dims()[with_groups + 2];
341 jcp.kh = weights_d.dims()[with_groups + ndims - 2];
342 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
344 jcp.f_pad = (ndims == 4) ? 0 : cd.padding[0][0];
345 jcp.t_pad = cd.padding[0][ndims - 4];
346 jcp.l_pad = cd.padding[0][ndims - 3];
348 jcp.stride_d = (ndims == 4) ? 1 : cd.strides[0];
349 jcp.stride_h = cd.strides[ndims - 4];
350 jcp.stride_w = cd.strides[ndims - 3];
352 jcp.dilate_d = (ndims == 4) ? 0 : cd.dilates[0];
353 jcp.dilate_h = cd.dilates[ndims - 4];
354 jcp.dilate_w = cd.dilates[ndims - 3];
356 jcp.src_fmt = src_d.format();
358 = cd.bias_desc.format != memory_format::undef
359 || cd.diff_bias_desc.format != memory_format::undef;
360 jcp.with_relu = with_relu;
361 jcp.relu_negative_slope = relu_negative_slope;
363 jcp.is = jcp.ih * jcp.iw;
364 jcp.os = jcp.oh * jcp.ow;
365 jcp.ks = jcp.kh * jcp.kw * jcp.kd;
366 jcp.im2col_sz = !(jcp.oh == jcp.ih && jcp.ow == jcp.iw
367 && jcp.od == jcp.id && jcp.ks == 1)
368 ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os
371 bool do_outer_threading = false;
372 bool is_int8_conv = (cd.src_desc.data_type == u8
373 && cd.weights_desc.data_type == s8);
376 utils::everyone_is(1, jcp.ic, jcp.oc) && jcp.ngroups != 1;
378 = (is_depthwise || (jcp.os / max_threads < 64 && jcp.mb != 1));
380 if (utils::one_of(jcp.prop_kind, forward_training, forward_inference))
381 do_outer_threading = jcp.os / max_threads < 512
382 && utils::implication(jcp.od == 1, (jcp.mb != 1 || jcp.ngroups > 2));
383 else if (jcp.prop_kind == backward_data)
384 do_outer_threading = (jcp.mb != 1 || jcp.ngroups > 2);
385 else //(jcp.prop_kind == backward_weights)
386 do_outer_threading = jcp.os / max_threads < 256
387 && (jcp.mb != 1 || jcp.ngroups > 2);
389 jcp.nthr = do_outer_threading ? max_threads : 1;
390 jcp.need_wei_reduction = mkldnn_thr_syncable()
391 ? (jcp.mb != 1 && jcp.nthr != 1) : false;
394 status_t prepare_scratchpad(jit_gemm_conv_conf_t &jcp,
395 scratchpad_t **scratchpad_, size_t size, const int nthr) {
397 *scratchpad_ = create_scratchpad(nthr * size);
398 if (*scratchpad_ == nullptr) return status::out_of_memory;
400 *scratchpad_ = nullptr;
402 return status::success;
405 void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g,
406 int &nthr_g, int &ithr_mb, int &nthr_mb) {
407 nthr_g = nstl::min(ngroups, nthr);
408 nthr_mb = nstl::min(mb, nthr / nthr_g);
409 if (ithr / nthr_mb >= ngroups) {
410 ithr_g = ithr_mb = -1;
412 ithr_g = ithr / nthr_mb;
413 ithr_mb = ithr % nthr_mb;
417 void bwd_weights_reduction_par(int ithr, int nthr, const jit_gemm_conv_conf_t &jcp,
418 const float *weights_reduce_ws, float *weights) {
419 const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
421 size_t weights_start{0}, weights_end{0};
422 balance211(weights_g_size, nthr, ithr, weights_start, weights_end);
424 for (int i = 0; i < nthr; ++i) {
425 const float *ws_i = weights_reduce_ws + i * weights_g_size;
426 for (size_t s = weights_start; s < weights_end; ++s)
427 weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s];