1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
20 #include "cpu_memory.hpp"
22 #include "jit_sse42_conv_kernel_f32.hpp"
24 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
30 using namespace mkldnn::impl::prop_kind;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
35 using namespace Xbyak;
37 void jit_sse42_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
38 int pad_l, int pad_r, int oc_blocks)
44 int nb_ic = jcp.nb_ic;
45 int stride_w = jcp.stride_w;
46 int dilate_w = jcp.dilate_w + 1;
47 int ic_blk = jcp.ic_block;
48 int oc_blk = jcp.oc_block;
50 for (int ki = 0; ki < kw; ki++) {
51 int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
53 - nstl::max(0, div_up(ki*dilate_w + pad_r - (kw-1)*dilate_w, stride_w));
54 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
55 for (int jj = jj_start; jj < jj_end; jj++) {
57 if (one_of(jcp.src_fmt, ncw, nchw))
58 inp_off = ifm2*ih*iw + (ki*dilate_w + jj*stride_w - pad_l);
60 inp_off = (ki*dilate_w + jj*stride_w - pad_l)*ic_blk + ifm2;
62 movss(Xmm(oc_blocks * ur_w + jj + 1),
63 ptr[aux_reg_input + sizeof(float) * inp_off]);
64 shufps(Xmm(oc_blocks * ur_w + jj + 1),
65 Xmm(oc_blocks * ur_w + jj + 1), 0x0);
68 for (int ii = 0; ii < oc_blocks; ii++) {
69 int ker_off = ii * nb_ic * kh * kw * ic_blk * oc_blk
70 + ki * ic_blk * oc_blk + ifm2 * oc_blk;
72 for (int jj = jj_start; jj < jj_end; jj++)
75 ptr[aux_reg_kernel + sizeof(float) * ker_off]);
76 mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
77 addps(Xmm(ur_w * ii + jj + 1), xmm0);
84 void jit_sse42_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
85 int pad_l, int pad_r, int oc_blocks)
93 int nb_ic = jcp.nb_ic;
94 int stride_w = jcp.stride_w;
95 int dilate_w = jcp.dilate_w + 1;
96 int ic_blk = jcp.ic_block;
97 int oc_blk = jcp.oc_block;
99 xor_(ki_iter, ki_iter);
104 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
105 for (int jj = jj_start; jj < jj_end; jj++) {
107 if (one_of(jcp.src_fmt, ncw, nchw))
108 inp_off = ifm2 * ih * iw + (jj * stride_w - pad_l);
110 inp_off = (jj * stride_w - pad_l) * ic_blk + ifm2;
112 movss(Xmm(oc_blocks * ur_w + jj + 1),
113 ptr[aux_reg_input + sizeof(float) * inp_off]);
114 shufps(Xmm(oc_blocks * ur_w + jj + 1),
115 Xmm(oc_blocks * ur_w + jj + 1), 0x0);
117 for (int ii = 0; ii < oc_blocks; ii++) {
118 int aux_kernel_offset = ii * nb_ic * kh * kw * ic_blk * oc_blk
120 for (int jj = jj_start; jj < jj_end; jj++) {
122 ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]);
123 mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
124 addps(Xmm(ur_w * ii + jj + 1), xmm0);
128 add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk);
129 add(aux_reg_input, sizeof(float) * (one_of(jcp.src_fmt, ncw, nchw) ?
130 dilate_w : ic_blk * dilate_w));
138 void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w,
139 int pad_l, int pad_r, int oc_blocks)
145 int dilate_h = jcp.dilate_h + 1;
146 int dilate_w = jcp.dilate_w + 1;
147 int ic_blk = jcp.ic_block;
148 int oc_blk = jcp.oc_block;
149 const int inp_mult = one_of(jcp.src_fmt, ncw, nchw)
150 ? dilate_h : ic_blk * dilate_h;
151 const int inp_off = one_of(jcp.src_fmt, ncw, nchw)
152 ? dilate_w : ic_blk * dilate_w;
154 xor_(simd_iter, simd_iter);
156 mov(aux_reg_input, reg_input);
157 mov(aux_reg_kernel, reg_kernel);
159 Label init_simd_iter_loop;
163 L(init_simd_iter_loop);
166 test(reg_ci_flag, FLAG_IC_FIRST);
167 jne(init_first, T_NEAR);
170 for (int ii = 0; ii < oc_blocks; ii++)
171 for (int jj = 0; jj < ur_w; jj++) {
173 if (jcp.with_dw_conv)
174 o_off = (ii * jcp_dw.kh * ow + jj) * oc_blk;
176 o_off = (ii * oh * ow + jj) * oc_blk;
178 movups(Xmm(ur_w * ii + jj + 1), xword[reg_output
179 + sizeof(float) * o_off]);
182 if (jcp.with_sum && jcp.with_bias) {
183 test(reg_ci_flag, FLAG_IC_FIRST);
184 je(init_done, T_NEAR);
186 for (int ii = 0; ii < oc_blocks; ii++)
187 for (int jj = 0; jj < ur_w; jj++)
188 addps(Xmm(ur_w * ii + jj + 1),
189 xword[reg_bias + sizeof(float) * ii * oc_blk]);
195 if (this->jcp.with_bias) {
196 for (int ii = 0; ii < oc_blocks; ii++)
197 for (int jj = 0; jj < ur_w; jj++)
198 movups(Xmm(ur_w * ii + jj + 1),
199 xword[reg_bias + sizeof(float) * ii * oc_blk]);
201 for (int ii = 0; ii < oc_blocks; ii++)
202 for (int jj = 0; jj < ur_w; jj++)
203 pxor(Xmm(ur_w * ii + jj + 1), Xmm(ur_w * ii + jj + 1));
210 if ((jcp.dilate_h >= jcp.ih)
211 || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
213 je(skip_kh_loop, T_NEAR);
218 if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
219 oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks);
220 sub(aux_reg_input, sizeof(float) * kw * inp_off);
221 add(aux_reg_input, sizeof(float) * iw * inp_mult);
223 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
224 add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk);
225 add(aux_reg_input, sizeof(float) * iw * inp_mult);
238 test(reg_ci_flag, FLAG_IC_LAST);
239 je(regular_store, T_NEAR);
241 int eltwise_inj_idx = 0;
242 int depthwise_inj_idx = 0;
243 const auto &p = attr_.post_ops_;
245 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
246 for (int i = 0; i < end_idx; i++) {
247 auto& post_op = p.entry_[i];
248 if (post_op.is_eltwise()) {
249 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(1, oc_blocks * ur_w + 1);
251 } else if (post_op.is_depthwise()) {
252 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
253 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
255 add(reg_d_weights, reg_oc_off);
256 add(reg_d_bias, reg_oc_off);
258 for (int ii = 0; ii < oc_blocks; ii++) {
259 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
260 ur_w * ii + 1, ur_w * ii + ur_w + 1, reg_d_weights, reg_d_bias);
262 add(reg_d_weights, oc_blk * sizeof(float));
263 add(reg_d_bias, oc_blk * sizeof(float));
272 for (int ii = 0; ii < oc_blocks; ii++) {
273 for (int jj = 0; jj < ur_w; jj++) {
275 if (jcp.with_dw_conv)
276 o_off = (ii * jcp_dw.kh * ow + jj) * oc_blk;
278 o_off = (ii * oh * ow + jj) * oc_blk;
280 Xmm reg_out = Xmm(ur_w * ii + jj + 1);
281 movups(xword[reg_output + sizeof(float) * o_off], reg_out);
285 mov(aux_reg_kernel, reg_kernel);
286 mov(aux_reg_input, reg_input);
287 add(aux_reg_kernel, sizeof(float) * 4);
288 add(reg_output, sizeof(float) * 4);
289 add(reg_bias, sizeof(float) * 4);
290 add(reg_oc_off, sizeof(float) * 4);
294 jl(init_simd_iter_loop, T_NEAR);
296 sub(reg_output, sizeof(float) * 8);
297 sub(reg_bias, sizeof(float) * 8);
298 sub(reg_oc_off, sizeof(float) * 8);
301 inline void jit_sse42_conv_fwd_kernel_f32::solve_common(int oc_blocks)
304 int ur_w_tail = jcp.ur_w_tail;
305 int n_oi = jcp.ow / ur_w;
308 int ic_blk = jcp.ic_block;
309 int oc_blk = jcp.oc_block;
310 int dilate_w = jcp.dilate_w + 1;
311 int str_w = jcp.stride_w;
312 const int inp_mult = one_of(jcp.src_fmt, ncw, nchw) ? 1 : ic_blk;
314 int l_pad = jcp.l_pad;
315 int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
317 int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
319 if (r_pad1 > 0) n_oi--;
323 if (n_oi < 0 && r_pad1 > 0)
324 width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad"
326 width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad"
327 add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult);
328 add(reg_output, sizeof(float) * ur_w * oc_blk);
332 xor_(oi_iter, oi_iter);
337 width_blk_step(ur_w, 0, 0, oc_blocks); // "middle"
338 add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
339 add(reg_output, sizeof(float) * ur_w * oc_blk);
346 if (r_pad1 > 0 && n_oi >=0) {
347 width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad"
348 add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
349 add(reg_output, sizeof(float) * ur_w * oc_blk);
353 width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail"
356 void jit_sse42_conv_fwd_kernel_f32::generate()
358 const auto &p = attr_.post_ops_;
359 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
360 for (int i = 0; i < end_idx; i++) {
361 auto &post_op = p.entry_[i];
362 if (post_op.is_eltwise()) {
363 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<sse42>(
366 post_op.eltwise.alpha,
369 } else if (post_op.is_depthwise()) {
370 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<sse42>(
372 post_op.depthwise.alg
379 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
380 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
381 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
383 mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
384 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
385 mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
386 mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
387 mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
389 int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
392 cmp(reg_oc_blocks, jcp.nb_oc_blocking);
393 jne(nb_oc_tail ? tail : exit, T_NEAR);
395 solve_common(jcp.nb_oc_blocking);
400 cmp(reg_oc_blocks, nb_oc_tail);
402 solve_common(nb_oc_tail);
409 for (auto& inj : eltwise_injectors)
410 inj->prepare_table();
413 bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok(
414 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
415 const auto &p = attr.post_ops_;
417 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
418 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
419 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
420 auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
421 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
425 case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
426 case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
427 (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
428 (is_simple(0) && is_simple(1));
429 case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
430 (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
431 (is_sum(0) && is_simple(1) && is_simple(2));
432 case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
433 default: return false;
439 status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
440 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
441 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
442 const primitive_attr_t &attr)
444 if (!mayiuse(sse42)) return status::unimplemented;
446 jcp.prop_kind = cd.prop_kind;
448 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
449 const int ndims = src_d.ndims();
452 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
453 jcp.mb = src_d.dims()[0];
455 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
456 jcp.oc_without_padding = jcp.oc;
457 jcp.ic = src_d.dims()[1] / jcp.ngroups;
459 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
460 jcp.iw = src_d.dims()[ndims - 1];
461 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
462 jcp.ow = dst_d.dims()[ndims - 1];
464 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
465 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
467 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
468 jcp.l_pad = cd.padding[0][ndims - 3];
470 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
471 jcp.stride_w = cd.strides[ndims - 3];
473 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[0];
474 jcp.dilate_w = cd.dilates[ndims - 3];
475 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
476 - (jcp.ih + jcp.t_pad - 1);
478 jcp.src_fmt = src_d.format();
479 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
481 if (!post_ops_ok(jcp, attr))
482 return status::unimplemented;
484 const auto &p = attr.post_ops_;
486 int dw_conv_ind = p.find(primitive_kind::convolution);
487 jcp.with_dw_conv = dw_conv_ind != -1;
488 if (jcp.with_dw_conv) {
489 jcp.dw_conv_oh = jcp.oh;
490 jcp.dw_conv_ow = jcp.ow;
491 jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
492 jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
495 jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
497 jcp.src_dt = cd.src_desc.data_type;
498 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
499 jcp.dst_dt = cd.dst_desc.data_type;
501 const bool flat = jcp.ic == 3 || jcp.ic == 1;
502 const bool mimo = !flat;
505 && IMPLICATION(flat, one_of(src_d.format(), ncw, nwc, nchw, nhwc)
506 && one_of(weights_d.format(), Owi8o, gOwi8o, Ohwi8o, gOhwi8o))
507 && IMPLICATION(mimo, one_of(src_d.format(), nCw8c, nChw8c)
508 && one_of(weights_d.format(), OIw8i8o, gOIw8i8o, OIhw8i8o,
510 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
511 && one_of(dst_d.format(), nCw8c, nChw8c);
512 if (!args_ok) return status::unimplemented;
514 bool ok_to_pad_channels = true
517 const int simd_w = 8; // 2 SSE vectors processing at once
518 if (ok_to_pad_channels) {
519 jcp.oc = rnd_up(jcp.oc, simd_w);
521 jcp.ic = rnd_up(jcp.ic, simd_w);
524 jcp.ur_h = 1; /* no code-unrolling by h so far */
526 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
527 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
529 jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
532 && jcp.oc % simd_w == 0
533 && jcp.l_pad <= jcp.ur_w
534 && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
535 || (jcp.stride_w == 1 && jcp.stride_h == 1))
536 && IMPLICATION(mimo, jcp.ic % simd_w == 0);
537 if (!args_ok) return status::unimplemented;
539 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
540 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
542 // kernel needs 1 temporary YMM register
543 const int num_avail_regs = 15;
544 if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) {
545 /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
546 jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail,
547 nstl::min(jcp.ow, num_avail_regs / 2));
548 jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w;
549 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
550 /* check again ... */
551 r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
552 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
553 if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail))
554 return status::unimplemented;
556 assert(jcp.nb_oc_blocking > 0);
557 assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs);
559 jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
560 jcp.nb_ic = jcp.ic / jcp.ic_block;
562 jcp.oc_block = simd_w;
563 jcp.nb_oc = jcp.oc / jcp.oc_block;
565 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
566 jcp.nb_ic_blocking = 12;
567 jcp.nb_ic_blocking_max = 16;
569 jcp.nb_ic_blocking = 1;
570 jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
573 return status::success;
576 void jit_sse42_conv_fwd_kernel_f32::init_scratchpad(
577 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
578 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
579 scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
581 if (jcp.with_dw_conv) {
582 const int nthreads = mkldnn_get_max_threads();
583 size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
584 scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
586 if (jcp.oc != jcp.oc_without_padding)
587 scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);