1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 * Copyright 2018 YANDEX LLC
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
18 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
22 #include "cpu_memory.hpp"
24 #include "jit_avx2_conv_kernel_f32.hpp"
26 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
32 using namespace mkldnn::impl::prop_kind;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
37 using namespace Xbyak;
39 void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
40 int pad_l, int pad_r, int oc_blocks)
48 int nb_ic = jcp.nb_ic;
49 int stride_w = jcp.stride_w;
50 int dilate_w = jcp.dilate_w + 1;
51 int ic_blk = jcp.ic_block;
52 int oc_blk = jcp.oc_block;
54 for (int ki = 0; ki < kw; ki++) {
55 int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
57 - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w));
58 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
59 for (int jj = jj_start; jj < jj_end; jj++) {
61 if (one_of(jcp.src_fmt, ncw, nchw, ncdhw))
62 inp_off = sizeof(float)*((size_t)ifm2*id*ih*iw
63 + (ki*dilate_w + jj*stride_w - pad_l));
65 inp_off = sizeof(float)*((ki*dilate_w + jj*stride_w
66 - pad_l)*ic_blk + ifm2);
67 vbroadcastss(Ymm(oc_blocks * ur_w + jj),
68 make_safe_addr(aux_reg_input, inp_off, reg_long_offt));
71 for (int ii = 0; ii < oc_blocks; ii++) {
72 int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk
73 + ki * ic_blk * oc_blk + ifm2 * oc_blk;
74 vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * ker_off]);
75 for (int jj = jj_start; jj < jj_end; jj++)
77 vfmadd231ps(Ymm(ur_w * ii + jj),
78 Ymm(oc_blocks * ur_w + jj), ymm15);
79 else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
80 vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
81 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
88 void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
89 int pad_l, int pad_r, char pad_tag,
90 int oc_blocks, char oc_blocks_tag)
100 int nb_ic = jcp.nb_ic;
101 int stride_w = jcp.stride_w;
102 int dilate_w = jcp.dilate_w + 1;
103 int ic_blk = jcp.ic_block;
104 int oc_blk = jcp.oc_block;
106 xor_(ki_iter, ki_iter);
111 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
112 for (int jj = jj_start; jj < jj_end; jj++) {
114 if (one_of(jcp.src_fmt, ncw, nchw, ncdhw))
115 inp_off = sizeof(float)*((size_t)ifm2 * id * ih * iw
116 + (jj * stride_w - pad_l));
118 inp_off = sizeof(float)*((jj * stride_w - pad_l) * ic_blk
120 vbroadcastss(Ymm(oc_blocks * ur_w + jj),
121 make_safe_addr(aux_reg_input, inp_off, reg_long_offt));
123 for (int ii = 0; ii < oc_blocks; ii++) {
124 int aux_kernel_offset =
125 ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ifm2 * oc_blk;
126 vmovups(ymm15, ptr[aux_reg_kernel
127 + sizeof(float) * aux_kernel_offset]);
128 for (int jj = jj_start; jj < jj_end; jj++)
130 vfmadd231ps(Ymm(ur_w * ii + jj),
131 Ymm(oc_blocks * ur_w + jj), ymm15);
132 else { // Intel AVX support
133 vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
134 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
138 add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk);
139 add(aux_reg_input, sizeof(float) * (one_of(jcp.src_fmt, ncw, nchw, ncdhw)
140 ? dilate_w : ic_blk * dilate_w));
148 void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w,
149 int pad_l, int pad_r, char pad_tag,
150 int oc_blocks, char oc_blocks_tag)
157 int dilate_h = jcp.dilate_h + 1;
158 int dilate_w = jcp.dilate_w + 1;
159 int ic_blk = jcp.ic_block;
160 int oc_blk = jcp.oc_block;
161 const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw)
163 const int inp_off = one_of(jcp.src_fmt, ncw, nchw, ncdhw)
164 ? dilate_w : ic_blk * dilate_w;
166 Label init_done, init_first;
169 test(reg_ci_flag, FLAG_IC_FIRST);
170 jne(init_first, T_NEAR);
173 for (int ii = 0; ii < oc_blocks; ii++) {
174 for (int jj = 0; jj < ur_w; jj++) {
176 if (jcp.with_dw_conv)
177 offt = sizeof(float) * ((size_t)ii * od * jcp_dw.kh * ow + jj) * oc_blk;
179 offt = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
180 vmovups(Ymm(ur_w * ii + jj),
181 make_safe_addr(reg_output, offt, reg_long_offt));
185 if (jcp.with_sum && jcp.with_bias) {
186 test(reg_ci_flag, FLAG_IC_FIRST);
187 je(init_done, T_NEAR);
189 for (int ii = 0; ii < oc_blocks; ii++)
190 for (int jj = 0; jj < ur_w; jj++)
191 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
192 yword[reg_bias + sizeof(float) * ii * oc_blk]);
198 if (this->jcp.with_bias) {
199 for (int ii = 0; ii < oc_blocks; ii++)
200 for (int jj = 0; jj < ur_w; jj++)
201 vmovups(Ymm(ur_w * ii + jj),
202 yword[reg_bias + sizeof(float) * ii * oc_blk]);
204 for (int ii = 0; ii < oc_blocks; ii++)
205 for (int jj = 0; jj < ur_w; jj++)
206 uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj));
211 if (one_of(jcp.ndims, 3, 4)) {
212 mov(aux_reg_input, reg_input);
213 mov(aux_reg_kernel, reg_kernel);
216 Label skip_kh_loop, skip_kd_loop, kd_loop;
217 if (jcp.ndims == 5) {
221 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
222 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
223 mov(aux_reg_inp_d, reg_input);
225 if ((jcp.dilate_d >= jcp.id)
226 || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
228 je(skip_kd_loop, T_NEAR);
231 mov(kj, ptr[param1 + GET_OFF(kh_padding)]);
236 if (jcp.ndims == 5) {
237 mov(aux_reg_input, aux_reg_inp_d);
238 mov(aux_reg_kernel, aux_reg_ker_d);
241 if ((jcp.dilate_h >= jcp.ih)
242 || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
244 je(skip_kh_loop, T_NEAR);
249 if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
250 oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks,
252 sub(aux_reg_input, sizeof(float) * kw * inp_off);
253 add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult);
255 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
256 add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk);
257 add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult);
267 if (jcp.ndims == 5) {
269 sizeof(float) * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mult);
270 add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh * jcp.oc_block
284 test(reg_ci_flag, FLAG_IC_LAST);
285 je(regular_store, T_NEAR);
287 int eltwise_inj_idx = 0;
288 int depthwise_inj_idx = 0;
289 const auto &p = attr_.post_ops_;
291 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
292 for (int i = 0; i < end_idx; i++) {
293 auto& post_op = p.entry_[i];
294 if (post_op.is_eltwise()) {
295 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, oc_blocks * ur_w);
297 } else if (post_op.is_depthwise()) {
298 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
299 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
301 add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]);
302 add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]);
304 for (int ii = 0; ii < oc_blocks; ii++) {
305 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
306 ur_w * ii, ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
308 add(reg_d_weights, jcp.oc_block * sizeof(float));
309 add(reg_d_bias, jcp.oc_block * sizeof(float));
318 for (int ii = 0; ii < oc_blocks; ii++) {
319 for (int jj = 0; jj < ur_w; jj++) {
321 if (jcp.with_dw_conv)
322 o_off = sizeof(float) * ((size_t)ii * od * jcp_dw.kh * ow + jj) * oc_blk;
324 o_off = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
325 Ymm reg_out = Ymm(ur_w * ii + jj);
326 vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out);
331 inline void jit_avx2_conv_fwd_kernel_f32::solve_common(
332 int oc_blocks, char oc_blocks_tag)
335 int ur_w_tail = jcp.ur_w_tail;
336 int n_oi = jcp.ow / ur_w;
339 int ic_blk = jcp.ic_block;
340 int oc_blk = jcp.oc_block;
341 int dilate_w = jcp.dilate_w + 1;
342 int str_w = jcp.stride_w;
343 const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw) ? 1 : ic_blk;
345 int l_pad = jcp.l_pad;
346 int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
348 int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
350 if (r_pad1 > 0) n_oi--;
354 if (n_oi < 0 && r_pad1 > 0)
355 width_blk_step(ur_w, l_pad, r_pad1,
356 'l', oc_blocks, oc_blocks_tag); // "lrpad"
358 width_blk_step(ur_w, l_pad, 0,
359 'l', oc_blocks, oc_blocks_tag); // "lpad"
360 add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult);
361 add(reg_output, sizeof(float) * ur_w * oc_blk);
365 xor_(oi_iter, oi_iter);
370 width_blk_step(ur_w, 0, 0,
371 'm', oc_blocks, oc_blocks_tag); // "middle"
372 add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
373 add(reg_output, sizeof(float) * ur_w * oc_blk);
380 if (r_pad1 > 0 && n_oi >=0) {
381 width_blk_step(ur_w, 0, r_pad1,
382 'r', oc_blocks, oc_blocks_tag); // "rpad"
383 add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
384 add(reg_output, sizeof(float) * ur_w * oc_blk);
388 width_blk_step(ur_w_tail, 0, r_pad,
389 't', oc_blocks, oc_blocks_tag); // "tail"
392 void jit_avx2_conv_fwd_kernel_f32::generate()
394 const auto &p = attr_.post_ops_;
395 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
396 for (int i = 0; i < end_idx; i++) {
397 auto &post_op = p.entry_[i];
398 if (post_op.is_eltwise()) {
399 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx2>(
402 post_op.eltwise.alpha,
405 } else if (post_op.is_depthwise()) {
406 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx2>(
408 post_op.depthwise.alg
415 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
416 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
417 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
419 mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
420 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
421 mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
422 mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
424 int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
427 if (jcp.nb_oc > jcp.nb_oc_blocking) {
428 cmp(reg_oc_blocks, jcp.nb_oc_blocking);
429 jne(nb_oc_tail ? tail : exit, T_NEAR);
431 solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
436 cmp(reg_oc_blocks, nb_oc_tail);
438 solve_common(nb_oc_tail, '0' + nb_oc_tail);
442 } else if (jcp.nb_oc == jcp.nb_oc_blocking) {
443 solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
445 solve_common(nb_oc_tail, '0' + nb_oc_tail);
450 for (auto& inj : eltwise_injectors)
451 inj->prepare_table();
454 bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok(
455 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
456 const auto &p = attr.post_ops_;
458 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
459 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
460 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
461 auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
462 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
466 case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
467 case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
468 (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
469 (is_simple(0) && is_simple(1));
470 case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
471 (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
472 (is_sum(0) && is_simple(1) && is_simple(2));
473 case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
474 default: return false;
480 status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
481 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
482 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
483 const primitive_attr_t &attr)
485 if (!mayiuse(avx)) return status::unimplemented;
487 jcp.prop_kind = cd.prop_kind;
489 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
490 int ndims = src_d.ndims();
493 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
494 jcp.mb = src_d.dims()[0];
496 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
497 jcp.oc_without_padding = jcp.oc;
498 jcp.ic = src_d.dims()[1] / jcp.ngroups;
500 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
501 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
502 jcp.iw = src_d.dims()[ndims-1];
503 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
504 jcp.oh = (ndims == 3) ? 1 :dst_d.dims()[ndims-2];
505 jcp.ow = dst_d.dims()[ndims-1];
506 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
507 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
508 jcp.kw = weights_d.dims()[with_groups + ndims-1];
510 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
511 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
512 jcp.l_pad = cd.padding[0][ndims-3];
513 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
514 jcp.stride_h = (ndims == 3) ? 1 :cd.strides[ndims-4];
515 jcp.stride_w = cd.strides[ndims-3];
517 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
518 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
519 jcp.dilate_w = cd.dilates[ndims-3];
521 jcp.src_fmt = src_d.format();
522 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
524 if (!post_ops_ok(jcp, attr))
525 return status::unimplemented;
527 const auto &p = attr.post_ops_;
529 int dw_conv_ind = p.find(primitive_kind::convolution);
530 jcp.with_dw_conv = dw_conv_ind != -1;
531 if (jcp.with_dw_conv) {
532 jcp.dw_conv_oh = jcp.oh;
533 jcp.dw_conv_ow = jcp.ow;
534 jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
535 jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
538 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
539 - (jcp.ih + jcp.t_pad - 1);
541 if (jcp.with_dw_conv && !mayiuse(avx2))
542 return status::unimplemented;
544 if (jcp.with_dw_conv && jcp.ndims == 5)
545 return status::unimplemented;
547 if (!mayiuse(avx2)) {
548 for (int i = 0; i < p.len_; i++) {
549 auto &post_op = p.entry_[i];
550 if (post_op.is_eltwise()) {
551 if (post_op.eltwise.alg != alg_kind::eltwise_relu)
552 return status::unimplemented;
553 } else if (post_op.is_depthwise()) {
554 return status::unimplemented;
559 jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
561 jcp.src_dt = cd.src_desc.data_type;
562 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
563 jcp.dst_dt = cd.dst_desc.data_type;
565 const int simd_w = 8;
566 const bool flat = jcp.ic < simd_w;
567 const bool mimo = !flat;
570 /* Grouped channel offset to support 'non-blocked data' format for
571 * convolution sizes with '(input_channel / ngroups) < simd' */
573 = (one_of(src_d.format(), ncw, nchw, ncdhw) && jcp.ngroups > 1) ?
577 bool ok_to_pad_channels = true
580 if (ok_to_pad_channels) {
581 jcp.oc = rnd_up(jcp.oc, simd_w);
583 jcp.ic = rnd_up(jcp.ic, simd_w);
587 && IMPLICATION(flat, one_of(src_d.format(), ncw, nwc, nchw, nhwc,
589 && one_of(weights_d.format(), Owi8o, gOwi8o, Ohwi8o, gOhwi8o,
591 && IMPLICATION(mimo, one_of(src_d.format(), nCw8c, nChw8c, nCdhw8c)
592 && one_of(weights_d.format(), OIw8i8o, gOIw8i8o, OIhw8i8o,
593 gOIhw8i8o, OIdhw8i8o, gOIdhw8i8o))
594 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
595 && one_of(dst_d.format(), nCw8c, nChw8c, nCdhw8c);
596 if (!args_ok) return status::unimplemented;
598 jcp.ur_h = 1; /* no code-unrolling by h so far */
601 jcp.oc_block = simd_w;
602 jcp.nb_oc = jcp.oc / jcp.oc_block;
604 jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
606 // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively
607 // Thus, we can only assign 14 or 15 YMMs for data storage
608 const int num_avail_regs = mayiuse(avx2) ? 15 : 14;
609 if (!mayiuse(avx2)) {
610 if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) {
611 // current register assignment requires more YMMs than available
612 // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad
613 if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1)
616 for (int b = 3; b > 1; b--)
617 if (jcp.nb_oc % b == 0) {
618 jcp.nb_oc_blocking = b;
624 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
625 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
628 && jcp.oc % simd_w == 0
629 && jcp.l_pad <= jcp.ur_w
630 && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
631 || (jcp.stride_w == 1 && jcp.stride_h == 1))
632 && IMPLICATION(mimo, jcp.ic % simd_w == 0);
633 if (!args_ok) return status::unimplemented;
635 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
636 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
638 if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) {
639 /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
640 jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail,
641 nstl::min(jcp.ow, num_avail_regs / 2));
642 jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w;
643 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
644 /* check again ... */
645 r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
646 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
647 if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail))
648 return status::unimplemented;
650 assert(jcp.nb_oc_blocking > 0);
651 assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs);
653 jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
654 jcp.nb_ic = jcp.ic / jcp.ic_block;
656 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
657 jcp.nb_ic_blocking = 12;
658 jcp.nb_ic_blocking_max = 16;
660 jcp.nb_ic_blocking = 1;
661 jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
664 return status::success;
667 void jit_avx2_conv_fwd_kernel_f32::init_scratchpad(
668 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
669 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
670 scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
672 if (jcp.with_dw_conv) {
673 const int nthreads = mkldnn_get_max_threads();
674 size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
675 scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
677 if (jcp.oc != jcp.oc_without_padding)
678 scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
682 void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow,
693 int ic_block = jcp.ic_block;
694 int oc_block = jcp.oc_block;
695 int nb_ic_block = jcp.nb_ic_blocking;
696 int stride_w = jcp.stride_w;
697 int stride_h = jcp.stride_h;
699 Label kd_loop, skip_kd_loop;
700 Label oc_loop, skip_oc_loop;
702 for (int ii = 0; ii < nb_ic_block; ii++)
703 for (int jj = 0; jj < ur_w; jj++) {
704 uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
705 Ymm(ur_w * ii + jj));
708 if (one_of(jcp.ndims, 3, 4)) {
709 cmp(reg_channel_work, 0);
710 jle(skip_oc_loop, T_NEAR);
711 xor_(reg_channel, reg_channel);
713 mov(aux_reg_ddst_oc_loop, reg_ddst);
714 mov(aux_reg_kernel_oc_loop, reg_kernel);
717 mov(aux_reg_ddst, aux_reg_ddst_oc_loop);
718 mov(aux_reg_kernel, aux_reg_kernel_oc_loop);
721 if (jcp.ndims == 5) {
722 assert(jcp.nb_oc_blocking == 1);
725 mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]);
726 mov(aux_reg_dst_d, reg_ddst);
727 mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]);
730 mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]);
735 if (jcp.ndims == 5) {
736 mov(aux_reg_ddst, aux_reg_dst_d);
737 mov(aux_reg_kernel, aux_reg_ker_d);
740 Label kh_loop, skip_kh_loop;
742 jle(skip_kh_loop, T_NEAR);
744 for (int ki = 0; ki < kw; ki++) {
745 int jj_start = get_iw_start(ki, l_overflow); // 0;
746 int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w;
747 for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
749 for (int jj = jj_start ; jj < jj_end; jj += stride_w) {
750 int aux_output_offset
751 = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2;
752 vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w),
754 + sizeof(float) * aux_output_offset]);
757 for (int ii = 0; ii < nb_ic_block; ii++) {
758 int aux_kernel_offset
759 = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block
760 + ki * jcp.ic_block * jcp.oc_block
761 + ofm2 * jcp.ic_block;
764 + sizeof(float) * aux_kernel_offset]);
765 for (int jj = jj_start; jj < jj_end; jj += stride_w)
766 vfmadd231ps(Ymm(ur_w * ii + jj),
767 Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15);
771 add(aux_reg_kernel, sizeof(float) * stride_h * kw * oc_block
773 sub(aux_reg_ddst, sizeof(float) * ow * oc_block);
781 if (jcp.ndims == 5) {
783 sizeof(float) * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
785 sizeof(float) * jcp.kw * jcp.kh * oc_block * ic_block);
795 if (one_of(jcp.ndims, 3, 4)) {
796 int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow
798 int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw
799 * jcp.ic * jcp.oc_block;
801 add(aux_reg_ddst_oc_loop, ddst_oc_shift);
802 add(aux_reg_kernel_oc_loop, kernel_oc_shift);
805 cmp(reg_channel, reg_channel_work);
809 mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
812 Label no_update_label;
814 je(no_update_label, T_NEAR);
815 for (int ii = 0; ii < nb_ic_block; ii++) {
816 for (int jj = 0; jj < ur_w; jj++) {
818 sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
820 make_safe_addr(reg_dsrc, offt, reg_long_offt));
821 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
828 for (int ii = 0; ii < nb_ic_block; ii++)
829 for (int jj = 0; jj < ur_w; jj++) {
831 sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
832 vmovups(make_safe_addr(reg_dsrc, offt, reg_long_offt),
833 Ymm(ur_w * ii + jj));
837 void jit_avx2_conv_bwd_data_kernel_f32::generate() {
840 mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
841 mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
842 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
843 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
844 mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
845 mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]);
847 int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block;
848 int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block;
850 int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
851 int r_overflow = nstl::max(0, (jcp.kw - 1
852 - nstl::max(0, jcp.r_pad)) / jcp.stride_w);
853 int r_overflow1 = nstl::max(0, (jcp.kw - 1
854 - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
856 int n_oi = jcp.iw / jcp.ur_w;
860 if (jcp.ur_w == jcp.iw) {
861 compute_loop(jcp.ur_w, l_overflow, r_overflow);
862 } else if (n_oi == 0) {
863 compute_loop(jcp.ur_w, l_overflow, r_overflow1);
864 add(reg_dsrc, dsrc_shift);
865 add(reg_ddst, ddst_shift);
866 if (jcp.ur_w_tail != 0)
867 compute_loop(jcp.ur_w_tail, 0, r_overflow);
869 xor_(oi_iter, oi_iter);
870 if (l_overflow > 0) {
871 compute_loop(jcp.ur_w, l_overflow, 0);
872 add(reg_dsrc, dsrc_shift);
873 add(reg_ddst, ddst_shift);
877 if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) {
880 compute_loop(jcp.ur_w, 0, 0);
881 add(reg_dsrc, dsrc_shift);
882 add(reg_ddst, ddst_shift);
884 cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR);
888 if (r_overflow1 > 0 ) {
889 compute_loop(jcp.ur_w, 0, r_overflow1);
890 add(reg_dsrc, dsrc_shift);
891 add(reg_ddst, ddst_shift);
894 if (jcp.ur_w_tail != 0)
895 compute_loop(jcp.ur_w_tail, 0, r_overflow);
901 status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
902 const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
903 const memory_desc_wrapper &weights_d,
904 const memory_desc_wrapper &diff_dst_d)
906 if (!mayiuse(avx2)) return status::unimplemented;
908 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
910 int ndims = diff_src_d.ndims();
913 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
914 jcp.mb = diff_src_d.dims()[0];
916 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
917 jcp.oc_without_padding = jcp.oc;
918 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
920 jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
921 jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
922 jcp.iw = diff_src_d.dims()[ndims-1];
923 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
924 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
925 jcp.ow = diff_dst_d.dims()[ndims-1];
927 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
928 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
929 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
931 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
932 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
933 jcp.l_pad = cd.padding[0][ndims-3];
935 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
936 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
937 jcp.stride_w = cd.strides[ndims-3];
939 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
940 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
941 jcp.dilate_w = cd.dilates[ndims-3];
943 const int simd_w = 8;
946 jcp.idp = jcp.id + 2 * jcp.f_pad;
947 jcp.ihp = jcp.ih + 2 * jcp.t_pad;
948 jcp.iwp = jcp.iw + 2 * jcp.l_pad;
949 jcp.ohp = jcp.oh; /* do we really need */
950 jcp.owp = jcp.ow; /* padded output ??? */
952 bool ok_to_pad_channels = true
955 /* gemm-based convolution performs better in these cases */
956 if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1)
957 return status::unimplemented;
959 if (ok_to_pad_channels) {
960 jcp.oc = rnd_up(jcp.oc, simd_w);
961 jcp.ic = rnd_up(jcp.ic, simd_w);
964 jcp.ic_block = (jcp.ic % simd_w) ? 1 : simd_w;
965 jcp.nb_ic = jcp.ic / jcp.ic_block;
967 jcp.oc_block = simd_w;
968 if (jcp.oc % jcp.oc_block) return status::unimplemented;
969 jcp.nb_oc = jcp.oc / jcp.oc_block;
971 jcp.ur_h = 1; /* no code-unrolling by h so far */
972 jcp.nb_ic_blocking = 1;
973 jcp.nb_oc_blocking = 1;
976 if(one_of(ndims, 3, 4) && jcp.ow < 40)
977 jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2;
979 jcp.src_fmt = diff_src_d.format();
982 && one_of(diff_src_d.format(), nCw8c, nChw8c, nCdhw8c)
983 && one_of(weights_d.format(), gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i,
984 gOIdhw8o8i, OIdhw8o8i)
985 && one_of(diff_dst_d.format(), nCw8c, nChw8c, nCdhw8c)
986 && jcp.stride_w == jcp.stride_h
991 && jcp.ic % simd_w == 0
992 && jcp.oc % simd_w == 0
993 && jcp.od == (jcp.idp - jcp.kd) / jcp.stride_d + 1
994 && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
995 && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
996 if (!args_ok) return status::unimplemented;
997 jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad;
998 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad;
999 int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
1001 const int max_regs = 15; /* Maximun number of registers available for
1002 result accumulation and delta dst data.
1003 One additional register is reserved for weights
1006 /* Find the best blocking with maximum number of fma instructions
1007 per ur_w * nb_ic_blocking compute loops. Number of required registers
1008 is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
1009 ur_w must be divisible by stride_w */
1010 if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers
1011 distribution exceeds max_regs */
1012 return status::unimplemented;
1015 for (int b = 1; b <= 4; b++)
1017 if (jcp.nb_ic % b != 0)
1020 for (int u = jcp.stride_w;
1021 u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w;
1024 int ur_w = nstl::min(u, jcp.iw);
1025 /* maximum 1 step with l_overflow so far */
1026 if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
1028 int nfmas = utils::div_up(ur_w, jcp.stride_w) * b;
1029 if (nfmas > best_nfmas
1030 || (nfmas == best_nfmas && jcp.ur_w < ur_w)) {
1032 jcp.nb_ic_blocking = b;
1037 if (best_nfmas == 0) /* can't find appropriate blocking */
1038 return status::unimplemented;
1040 jcp.ur_w_tail = jcp.iw % jcp.ur_w;
1042 int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail
1043 - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
1044 /* maximum 1 ur_w block with r_overflow so far */
1045 if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
1046 return status::unimplemented;
1048 if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
1049 return status::unimplemented;
1051 return status::success;
1054 void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(
1055 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1060 void jit_avx2_conv_bwd_weights_kernel_f32::generate() {
1063 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
1064 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
1065 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
1066 compute_oh_loop_common();
1070 status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp,
1071 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
1072 const memory_desc_wrapper &diff_weights_d,
1073 const memory_desc_wrapper &diff_dst_d) {
1074 if (!mayiuse(avx2)) return status::unimplemented;
1076 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
1077 int ndims = src_d.ndims();
1080 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
1081 jcp.mb = src_d.dims()[0];
1083 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1084 jcp.oc_without_padding = jcp.oc;
1085 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1087 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1088 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
1089 jcp.iw = src_d.dims()[ndims-1];
1090 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1091 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
1092 jcp.ow = diff_dst_d.dims()[ndims-1];
1094 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
1095 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
1096 jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
1098 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1099 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
1100 jcp.l_pad = cd.padding[0][ndims-3];
1102 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1103 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
1104 jcp.stride_w = cd.strides[ndims-3];
1106 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1107 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
1108 jcp.dilate_w = cd.dilates[ndims-3];
1110 jcp.src_fmt = src_d.format();
1111 jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
1113 const bool flat = jcp.ic == 3;
1114 const bool mimo = !flat;
1116 const int simd_w = 8;
1118 int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id
1121 if (jcp.f_pad != 0 || back_pad != 0)
1122 return status::unimplemented;
1124 bool ok_to_pad_channels = true
1125 && jcp.ngroups == 1;
1127 if (ok_to_pad_channels) {
1128 jcp.oc = rnd_up(jcp.oc, simd_w);
1130 jcp.ic = rnd_up(jcp.ic, simd_w);
1134 && IMPLICATION(flat, one_of(src_d.format(), ncw, nwc, nchw, nhwc, ncdhw,
1136 && one_of(diff_weights_d.format(), Owi8o, gOwi8o, Ohwi8o,
1137 gOhwi8o, Odhwi8o, gOdhwi8o))
1138 && IMPLICATION(mimo, one_of(src_d.format(), nCw8c, nChw8c, nCdhw8c)
1139 && one_of(diff_weights_d.format(), OIw8i8o, gOIw8i8o, OIhw8i8o,
1140 gOIhw8i8o, OIdhw8i8o, gOIdhw8i8o))
1141 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
1142 && one_of(diff_dst_d.format(), nCw8c, nChw8c, nCdhw8c)
1143 && IMPLICATION(mimo, jcp.ic % simd_w == 0)
1144 && jcp.oc % simd_w == 0
1146 && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */
1147 && jcp.kh <= jcp.ih /* [bwd_w:r2] */
1148 && jcp.kd <= jcp.f_pad + jcp.id
1150 && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */
1151 && jcp.dilate_d == 0
1152 && jcp.dilate_h == 0
1153 && jcp.dilate_w == 0;
1154 if (!args_ok) return status::unimplemented;
1156 jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
1157 jcp.nb_ic = jcp.ic / jcp.ic_block;
1159 jcp.oc_block = simd_w;
1160 jcp.nb_oc = jcp.oc / jcp.oc_block;
1161 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1163 return status::success;
1166 void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(
1167 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1168 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
1169 scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
1172 inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
1174 Label kd_comeback_loop;
1175 mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0
1176 L(kd_comeback_loop); {
1177 const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1179 sub(aux_reg_input, sizeof(float) * jcp.iw * jcp.ih * inp_mult);
1180 sub(aux_reg_kernel, sizeof(float) * jcp.kw * jcp.kh * jcp.ic_block
1184 jg(kd_comeback_loop, T_NEAR);
1188 inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
1191 Label kh_comeback_loop;
1192 L(kh_comeback_loop); {
1193 const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1195 sub(reg_input, sizeof(float) * jcp.iw * inp_mult);
1196 sub(reg_kernel, sizeof(float) * jcp.kw * jcp.ic_block * jcp.oc_block);
1199 jg(kh_comeback_loop, T_NEAR);
1203 inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step(
1204 int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset,
1205 int kernel_offset, int output_offset)
1207 const int kw = jcp.kw;
1208 const int ic_block = jcp.ic_block;
1209 const int oc_block = jcp.oc_block;
1210 for (int i_kw = 0; i_kw < kw; i_kw++)
1211 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1213 = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block
1215 vmovups(Ymm(i_kw * ic_block_step + i_ic), yword[reg_kernel + off]);
1218 for (int i_ur = 0; i_ur < ur_w; i_ur++) {
1219 vmovups(Ymm(kw * ic_block_step + 0),
1221 + sizeof(float) * i_ur * oc_block + output_offset]);
1223 for (int i_kw = 0; i_kw < kw; i_kw++) {
1224 int i_iw = i_ur * jcp.stride_w + i_kw;
1225 if (i_iw - pad_l < 0
1226 || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r)
1228 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1229 size_t i_off = (size_t)input_offset + sizeof(float)*(
1230 one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1231 ? (i_iw - pad_l) + i_ic
1232 * ((size_t)jcp.id * jcp.ih * jcp.iw)
1233 : (i_iw - pad_l) * ic_block + i_ic);
1234 vbroadcastss(Ymm(kw * ic_block_step + 1),
1235 make_safe_addr(reg_input, i_off, reg_long_offt));
1236 vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic),
1237 Ymm(kw * ic_block_step + 0),
1238 Ymm(kw * ic_block_step + 1));
1243 for (int i_kw = 0; i_kw < kw; i_kw++)
1244 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1246 = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block
1248 vmovups(yword[reg_kernel + off],
1249 Ymm(i_kw * ic_block_step + i_ic));
1253 inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp()
1256 if (one_of(jcp.src_fmt, ncw, nchw, ncdhw)) {
1257 ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block;
1259 ic_block_step = jcp.kw > 7 ? 1
1261 : jcp.kw > 1 ? 4 : 8;
1264 const int max_ur_w = jcp.ow > 56 ? 14 : 28;
1266 if (jcp.ow <= max_ur_w)
1267 compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
1269 compute_oh_step_common(ic_block_step, max_ur_w);
1271 if (jcp.ndims == 5) {
1272 od_step_comeback_pointers();
1273 mov(reg_input, aux_reg_input);
1274 mov(reg_kernel, aux_reg_kernel);
1276 oh_step_comeback_pointers();
1280 inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow(
1281 int ic_block_step, int max_ur_w)
1285 const int ic_block = jcp.ic_block;
1286 const int oc_block = jcp.oc_block;
1287 int inp_mul = one_of(jcp.src_fmt, ncw, nchw, ncdhw) ? 1 : jcp.ic_block;
1292 (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1294 if (jcp.ndims == 5) {
1295 mov(aux_reg_input, reg_input);
1296 mov(aux_reg_kernel, reg_kernel);
1299 mov(reg_input, aux_reg_input);
1300 mov(reg_kernel, aux_reg_kernel);
1307 Label ic_block_loop;
1309 compute_ic_block_step(jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0,
1311 size_t inp_icblk_stride = sizeof(float) * ic_block_step
1312 * (one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1313 ? jcp.id*jcp.ih*jcp.iw : 1);
1314 safe_add(reg_input, inp_icblk_stride, reg_long_offt);
1315 add(reg_kernel, sizeof(float) * ic_block_step * oc_block);
1316 add(b_ic, ic_block_step);
1317 cmp(b_ic, ic_block);
1318 jl(ic_block_loop, T_NEAR);
1320 if(one_of(jcp.src_fmt, ncw, nchw, ncdhw)) {
1321 size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block;
1322 safe_sub(reg_input, offt, reg_long_offt);
1323 add(reg_input, sizeof(float) * jcp.iw);
1325 add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block);
1327 add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block);
1330 jg(kh_loop, T_NEAR);
1333 if (jcp.ndims == 5) {
1334 add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul);
1335 add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block
1339 jg(kd_loop, T_NEAR);
1344 inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common(
1345 int ic_block_step, int max_ur_w)
1347 const int ic_block = jcp.ic_block;
1348 const int oc_block = jcp.oc_block;
1349 const int stride_w = jcp.stride_w;
1350 int inp_mul = one_of(jcp.src_fmt, ncw, nchw, ncdhw) ? 1 : jcp.ic_block;
1355 (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1357 int ur_w = nstl::min(jcp.ow, max_ur_w);
1358 int ur_w_trips = jcp.ow / ur_w;
1359 int ur_w_tail = jcp.ow % ur_w;
1360 if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) {
1361 if (ur_w_trips > 1) {
1365 ur_w_tail += (ur_w - ur_w / 2);
1369 const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw) ? 1 : ic_block;
1371 int input_comeback = (ur_w_trips * ur_w * stride_w - jcp.l_pad) * inp_mult;
1372 int output_comeback = ur_w_trips * ur_w * oc_block;
1374 if (jcp.ndims == 5) {
1375 mov(aux_reg_input, reg_input);
1376 mov(aux_reg_kernel, reg_kernel);
1379 mov(reg_input, aux_reg_input);
1380 mov(reg_kernel, aux_reg_kernel);
1387 Label ic_block_loop;
1389 if (jcp.l_pad != 0) {
1391 compute_ic_block_step(ur_w,
1392 jcp.l_pad, 0, ic_block_step, 0, 0, 0);
1393 add(reg_input, sizeof(float)
1394 * (ur_w * stride_w - jcp.l_pad) * inp_mult);
1395 add(reg_output, sizeof(float) * ur_w * oc_block);
1398 if (ur_w_trips > 0) {
1399 xor_(reg_ur_w_trips, reg_ur_w_trips);
1400 Label ow_block_loop;
1402 compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
1403 add(reg_input, sizeof(float) * ur_w * stride_w * inp_mult);
1404 add(reg_output, sizeof(float) * ur_w * oc_block);
1406 inc(reg_ur_w_trips);
1407 cmp(reg_ur_w_trips, ur_w_trips);
1408 jl(ow_block_loop, T_NEAR);
1413 compute_ic_block_step(ur_w_tail,
1414 0, r_pad, ic_block_step, 0, 0, 0);
1416 sub(reg_input, sizeof(float) * input_comeback);
1417 sub(reg_output, sizeof(float) * output_comeback);
1419 size_t inp_icblk_stride = sizeof(float) * ic_block_step
1420 * (one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1421 ? jcp.id*jcp.ih*jcp.iw : 1);
1422 safe_add(reg_input, inp_icblk_stride, reg_long_offt);
1423 add(reg_kernel, sizeof(float) * ic_block_step * oc_block);
1425 add(b_ic, ic_block_step);
1426 cmp(b_ic, jcp.ic_block);
1427 jl(ic_block_loop, T_NEAR);
1429 if (one_of(jcp.src_fmt, ncw, nchw, ncdhw)) {
1430 size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block;
1431 safe_sub(reg_input, offt, reg_long_offt);
1432 add(reg_input, sizeof(float) * jcp.iw);
1434 add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block);
1436 add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block);
1439 jg(kh_loop, T_NEAR);
1442 if (jcp.ndims == 5) {
1443 add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul);
1444 add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block
1448 jg(kd_loop, T_NEAR);
1453 inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common()
1455 const int icoc_block = jcp.ic_block * jcp.oc_block;
1456 const int t_pad = jcp.t_pad;
1457 const int stride_h = jcp.stride_h;
1458 const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1461 = nstl::max(0, (jcp.oh - 1) * stride_h + jcp.kh - jcp.ih - t_pad);
1463 Label oh_tpad_loop, oh_loop, oh_loop_end;
1465 mov(reg_kh, jcp.kh);
1466 xor_(reg_ih_count, reg_ih_count);
1467 xor_(reg_oj, reg_oj);
1469 assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */
1470 mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih);
1471 add(reg_kernel, sizeof(float) * t_pad * jcp.kw * icoc_block);
1474 compute_oh_step_disp();
1475 add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
1476 sub(reg_kernel, sizeof(float) * stride_h * jcp.kw * icoc_block);
1479 add(reg_ih_count, stride_h);
1480 add(reg_kh, stride_h);
1482 /* the overlap between input and kernel may not reach kernel size.
1483 * so far we do not support that (until we put constant here) */
1484 const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */
1485 cmp(reg_kh, final_inp_ker_overlap);
1486 jl(oh_tpad_loop, T_NEAR);
1489 if (t_pad % stride_h != 0) {
1490 int inp_corr = stride_h - t_pad % stride_h;
1491 add(reg_kernel, sizeof(float) * inp_corr * jcp.kw * icoc_block);
1492 add(reg_input, sizeof(float) * inp_corr * jcp.iw * inp_mult);
1495 cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
1496 jge(oh_loop_end, T_NEAR);
1497 cmp(reg_oj, jcp.oh);
1498 jge(oh_loop, T_NEAR);
1500 mov(reg_kh, jcp.kh);
1502 compute_oh_step_disp();
1503 add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult);
1504 add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
1507 add(reg_ih_count, stride_h);
1509 cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
1510 jge(oh_loop_end, T_NEAR);
1512 cmp(reg_oj, jcp.oh);
1513 jl(oh_loop, T_NEAR);
1517 Label oh_bpad_loop, oh_bpad_loop_end;
1518 cmp(reg_oj, jcp.oh);
1519 jge(oh_bpad_loop_end, T_NEAR);
1521 mov(reg_kh, jcp.ih + t_pad);
1522 sub(reg_kh, reg_ih_count);
1524 compute_oh_step_disp();
1525 add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult);
1526 add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
1528 sub(reg_kh, stride_h);
1530 jle(oh_bpad_loop_end, T_NEAR);
1533 cmp(reg_oj, jcp.oh);
1534 jl(oh_bpad_loop, T_NEAR);
1536 L(oh_bpad_loop_end);
1544 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s