1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
20 #include "cpu_memory.hpp"
22 #include "jit_sse42_conv_kernel_f32.hpp"
24 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
30 using namespace mkldnn::impl::prop_kind;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::utils;
34 using namespace Xbyak;
36 void jit_sse42_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
37 int pad_l, int pad_r, int oc_blocks)
43 int nb_ic = jcp.nb_ic;
44 int stride_w = jcp.stride_w;
45 int dilate_w = jcp.dilate_w + 1;
46 int ic_blk = jcp.ic_block;
47 int oc_blk = jcp.oc_block;
49 for (int ki = 0; ki < kw; ki++) {
50 int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
52 - nstl::max(0, div_up(ki*dilate_w + pad_r - (kw-1)*dilate_w, stride_w));
53 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
54 for (int jj = jj_start; jj < jj_end; jj++) {
56 if (jcp.src_fmt == nchw)
57 inp_off = ifm2*ih*iw + (ki*dilate_w + jj*stride_w - pad_l);
59 inp_off = (ki*dilate_w + jj*stride_w - pad_l)*ic_blk + ifm2;
61 movss(Xmm(oc_blocks * ur_w + jj + 1),
62 ptr[aux_reg_input + sizeof(float) * inp_off]);
63 shufps(Xmm(oc_blocks * ur_w + jj + 1),
64 Xmm(oc_blocks * ur_w + jj + 1), 0x0);
67 for (int ii = 0; ii < oc_blocks; ii++) {
68 int ker_off = ii * nb_ic * kh * kw * ic_blk * oc_blk
69 + ki * ic_blk * oc_blk + ifm2 * oc_blk;
71 for (int jj = jj_start; jj < jj_end; jj++)
74 ptr[aux_reg_kernel + sizeof(float) * ker_off]);
75 mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
76 addps(Xmm(ur_w * ii + jj + 1), xmm0);
83 void jit_sse42_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
84 int pad_l, int pad_r, char pad_tag,
85 int oc_blocks, char oc_blocks_tag)
87 jit_tagged_label kw_label("kw", pad_tag, oc_blocks_tag);
93 int nb_ic = jcp.nb_ic;
94 int stride_w = jcp.stride_w;
95 int dilate_w = jcp.dilate_w + 1;
96 int ic_blk = jcp.ic_block;
97 int oc_blk = jcp.oc_block;
99 xor_(ki_iter, ki_iter);
104 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
105 for (int jj = jj_start; jj < jj_end; jj++) {
107 if (jcp.src_fmt == nchw)
108 inp_off = ifm2 * ih * iw + (jj * stride_w - pad_l);
110 inp_off = (jj * stride_w - pad_l) * ic_blk + ifm2;
112 movss(Xmm(oc_blocks * ur_w + jj + 1),
113 ptr[aux_reg_input + sizeof(float) * inp_off]);
114 shufps(Xmm(oc_blocks * ur_w + jj + 1),
115 Xmm(oc_blocks * ur_w + jj + 1), 0x0);
117 for (int ii = 0; ii < oc_blocks; ii++) {
118 int aux_kernel_offset = ii * nb_ic * kh * kw * ic_blk * oc_blk
120 for (int jj = jj_start; jj < jj_end; jj++) {
122 ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]);
123 mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
124 addps(Xmm(ur_w * ii + jj + 1), xmm0);
128 add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk);
129 add(aux_reg_input, sizeof(float) * (jcp.src_fmt == nchw ?
130 dilate_w : ic_blk * dilate_w));
134 jl(kw_label, T_NEAR);
138 void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w,
139 int pad_l, int pad_r, char pad_tag,
140 int oc_blocks, char oc_blocks_tag)
146 int dilate_h = jcp.dilate_h + 1;
147 int dilate_w = jcp.dilate_w + 1;
148 int ic_blk = jcp.ic_block;
149 int oc_blk = jcp.oc_block;
150 const int inp_mult = jcp.src_fmt == nchw ? dilate_h : ic_blk * dilate_h;
151 const int inp_off = jcp.src_fmt == nchw ? dilate_w : ic_blk * dilate_w;
153 xor_(simd_iter, simd_iter);
155 mov(aux_reg_input, reg_input);
156 mov(aux_reg_kernel, reg_kernel);
158 jit_tagged_label init_simd_iter_label("simd_iter", pad_tag, oc_blocks_tag);
159 jit_tagged_label init_done_label("init", pad_tag, oc_blocks_tag);
160 jit_tagged_label init_first_label("first", pad_tag, oc_blocks_tag);
162 L(init_simd_iter_label);
165 test(reg_ci_flag, FLAG_IC_FIRST);
166 jne(init_first_label, T_NEAR);
169 for (int ii = 0; ii < oc_blocks; ii++)
170 for (int jj = 0; jj < ur_w; jj++) {
172 if (jcp.with_dw_conv)
173 o_off = (ii * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
175 o_off = (ii * oh * ow + jj) * oc_blk;
177 movups(Xmm(ur_w * ii + jj + 1), xword[reg_output
178 + sizeof(float) * o_off]);
181 if (jcp.with_sum && jcp.with_bias) {
182 test(reg_ci_flag, FLAG_IC_FIRST);
183 je(init_done_label, T_NEAR);
185 for (int ii = 0; ii < oc_blocks; ii++)
186 for (int jj = 0; jj < ur_w; jj++)
187 addps(Xmm(ur_w * ii + jj + 1),
188 xword[reg_bias + sizeof(float) * ii * oc_blk]);
191 jmp(init_done_label);
194 if (this->jcp.with_bias) {
195 for (int ii = 0; ii < oc_blocks; ii++)
196 for (int jj = 0; jj < ur_w; jj++)
197 movups(Xmm(ur_w * ii + jj + 1),
198 xword[reg_bias + sizeof(float) * ii * oc_blk]);
200 for (int ii = 0; ii < oc_blocks; ii++)
201 for (int jj = 0; jj < ur_w; jj++)
202 pxor(Xmm(ur_w * ii + jj + 1), Xmm(ur_w * ii + jj + 1));
209 if (jcp.kh <= jcp.t_pad) {
211 je(skip_kh_loop, T_NEAR);
213 jit_tagged_label kh_label("kh", pad_tag, oc_blocks_tag);
216 if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
217 oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks,
219 sub(aux_reg_input, sizeof(float) * kw * inp_off);
220 add(aux_reg_input, sizeof(float) * iw * inp_mult);
222 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
223 add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk);
224 add(aux_reg_input, sizeof(float) * iw * inp_mult);
229 jg(kh_label, T_NEAR);
234 jit_tagged_label done_label("done", pad_tag, oc_blocks_tag);
235 jit_tagged_label regular_store_label("store", pad_tag, oc_blocks_tag);
237 if (jcp.with_eltwise) {
238 assert(oc_blocks * ur_w < 15);
239 test(reg_ci_flag, FLAG_IC_LAST);
240 je(regular_store_label, T_NEAR);
242 inject(eltwise_generator.prepareConstants(jcp.eltwise_alpha, jcp.eltwise_beta));
244 // TODO (dmitrygo): need to find appropriate way to share labels.
245 mov(imm_addr64, l_table);
246 for (int ii = 0; ii < oc_blocks; ii++) {
247 for (int jj = 0; jj < ur_w; jj++) {
248 Xmm reg_out = Xmm(ur_w * ii + jj + 1);
250 inject(eltwise_generator.computeVector(reg_out, reg_out));
254 L(regular_store_label);
257 for (int ii = 0; ii < oc_blocks; ii++) {
258 for (int jj = 0; jj < ur_w; jj++) {
260 if (jcp.with_dw_conv)
261 o_off = (ii * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
263 o_off = (ii * oh * ow + jj) * oc_blk;
265 Xmm reg_out = Xmm(ur_w * ii + jj + 1);
266 movups(xword[reg_output + sizeof(float) * o_off], reg_out);
272 mov(aux_reg_kernel, reg_kernel);
273 mov(aux_reg_input, reg_input);
274 add(aux_reg_kernel, sizeof(float) * 4);
275 add(reg_output, sizeof(float) * 4);
276 add(reg_bias, sizeof(float) * 4);
280 jl(init_simd_iter_label, T_NEAR);
282 sub(reg_output, sizeof(float) * 8);
283 sub(reg_bias, sizeof(float) * 8);
286 inline void jit_sse42_conv_fwd_kernel_f32::solve_common(
287 int oc_blocks, char oc_blocks_tag)
290 int ur_w_tail = jcp.ur_w_tail;
291 int n_oi = jcp.ow / ur_w;
294 int ic_blk = jcp.ic_block;
295 int oc_blk = jcp.oc_block;
296 int dilate_w = jcp.dilate_w + 1;
297 int str_w = jcp.stride_w;
298 const int inp_mult = jcp.src_fmt == nchw ? 1 : ic_blk;
300 int l_pad = jcp.l_pad;
301 int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
303 int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
305 if (r_pad1 > 0) n_oi--;
309 if (n_oi < 0 && r_pad1 > 0)
310 width_blk_step(ur_w, l_pad, r_pad1,
311 'l', oc_blocks, oc_blocks_tag); // "lrpad"
313 width_blk_step(ur_w, l_pad, 0,
314 'l', oc_blocks, oc_blocks_tag); // "lpad"
315 add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult);
316 add(reg_output, sizeof(float) * ur_w * oc_blk);
319 jit_tagged_label ow_loop_label("ow", oc_blocks_tag);
320 xor_(oi_iter, oi_iter);
325 width_blk_step(ur_w, 0, 0,
326 'm', oc_blocks, oc_blocks_tag); // "middle"
327 add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
328 add(reg_output, sizeof(float) * ur_w * oc_blk);
332 jl(ow_loop_label, T_NEAR);
335 if (r_pad1 > 0 && n_oi >=0) {
336 width_blk_step(ur_w, 0, r_pad1,
337 'r', oc_blocks, oc_blocks_tag); // "rpad"
338 add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
339 add(reg_output, sizeof(float) * ur_w * oc_blk);
343 width_blk_step(ur_w_tail, 0, r_pad,
344 't', oc_blocks, oc_blocks_tag); // "tail"
347 void jit_sse42_conv_fwd_kernel_f32::generate()
349 if (jcp.with_eltwise) {
350 nstl::vector<int> shared_vecs;
351 shared_vecs.push_back(0);
352 shared_vecs.push_back(13);
353 shared_vecs.push_back(14);
354 shared_vecs.push_back(15);
356 nstl::vector<Reg64> shared_regs;
357 shared_regs.push_back(imm_addr64);
359 eltwise_generator.init(jcp.eltwise_alg, shared_vecs, shared_regs);
364 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
365 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
366 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
368 mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
369 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
370 mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
371 mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
373 int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
374 const char *tail_label = ".tail";
375 const char *exit_label = ".exit";
377 cmp(reg_oc_blocks, jcp.nb_oc_blocking);
378 jne(nb_oc_tail ? tail_label : exit_label, T_NEAR);
380 solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
381 jmp(exit_label, T_NEAR);
385 cmp(reg_oc_blocks, nb_oc_tail);
386 jne(exit_label, T_NEAR);
387 solve_common(nb_oc_tail, '0' + nb_oc_tail);
394 if (jcp.with_eltwise) {
395 // TODO (dmitrygo): need to find appropriate way to share labels.
398 inject(eltwise_generator.prepareTable());
399 eltwise_generator.release();
403 bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok(
404 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
405 const auto &p = attr.post_ops_;
407 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
408 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
409 auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
412 case 0: return true; // no post_ops
414 return true // sum OR eltwise OR dw_conv
415 && !jcp.with_eltwise && (is_eltwise(0) || is_sum(0) || is_dw_conv(0));
417 return true // sum->eltwise or dw_conv->eltwise or eltwise->dw_conv or dw_conv->sum
418 && !jcp.with_eltwise && ((is_sum(0) && is_eltwise(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
419 (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)));
421 return true // eltwise->dw_conv->eltwise or dw_conv->sum->eltwise
422 && !jcp.with_eltwise && ((is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
423 (is_dw_conv(0) && is_sum(1) && is_eltwise(2)));
424 case 4: return true // eltwise->dw_conv->sum->eltwise
425 && !jcp.with_eltwise && (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
426 default: return false;
432 status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
433 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
434 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
435 const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
437 if (!mayiuse(sse42)) return status::unimplemented;
439 jcp.prop_kind = cd.prop_kind;
441 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
443 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
444 jcp.mb = src_d.dims()[0];
446 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
447 jcp.ic = src_d.dims()[1] / jcp.ngroups;
449 jcp.ih = src_d.dims()[2];
450 jcp.iw = src_d.dims()[3];
451 jcp.oh = dst_d.dims()[2];
452 jcp.ow = dst_d.dims()[3];
454 jcp.kh = weights_d.dims()[with_groups + 2];
455 jcp.kw = weights_d.dims()[with_groups + 3];
457 jcp.t_pad = cd.padding[0][0];
458 jcp.l_pad = cd.padding[0][1];
460 jcp.stride_h = cd.strides[0];
461 jcp.stride_w = cd.strides[1];
463 jcp.dilate_h = cd.dilates[0];
464 jcp.dilate_w = cd.dilates[1];
466 jcp.src_fmt = src_d.format();
467 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
468 jcp.with_eltwise = with_relu;
469 jcp.eltwise_alg = mkldnn_eltwise_relu;
470 jcp.eltwise_alpha = relu_negative_slope;
472 if (!post_ops_ok(jcp, attr))
473 return status::unimplemented;
475 const auto &p = attr.post_ops_;
476 jcp.with_dw_conv = false;
477 int dw_conv_ind = p.find(primitive_kind::convolution);
478 if (dw_conv_ind != -1) {
479 jcp.with_dw_conv = true;
480 jcp.dw_conv_in_h = p.entry_[dw_conv_ind].dw_conv.in_h;
481 jcp.dw_conv_in_w = p.entry_[dw_conv_ind].dw_conv.in_w;
482 jcp.dw_conv_ker_h = p.entry_[dw_conv_ind].dw_conv.ker_h;
483 jcp.dw_conv_ker_w = p.entry_[dw_conv_ind].dw_conv.ker_w;
484 jcp.dw_conv_str_h = p.entry_[dw_conv_ind].dw_conv.str_h;
485 jcp.dw_conv_str_w = p.entry_[dw_conv_ind].dw_conv.str_w;
486 jcp.dw_conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
487 jcp.dw_conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
490 if (!jcp.with_eltwise) {
491 int eltwise_ind = p.find(primitive_kind::eltwise, 0, dw_conv_ind);
492 if (eltwise_ind != -1) {
493 jcp.with_eltwise = true;
494 jcp.eltwise_alg = p.entry_[eltwise_ind].eltwise.alg;
495 jcp.eltwise_alpha = p.entry_[eltwise_ind].eltwise.alpha;
496 jcp.eltwise_beta = p.entry_[eltwise_ind].eltwise.beta;
497 jcp.eltwise_scale = p.entry_[eltwise_ind].eltwise.scale;
501 if (jcp.with_dw_conv) {
502 int dw_conv_eltwise_ind = p.find(primitive_kind::eltwise, dw_conv_ind);
503 if (dw_conv_eltwise_ind != -1) {
504 jcp.dw_conv_with_eltwise = true;
505 jcp.dw_conv_eltwise_alg = p.entry_[dw_conv_eltwise_ind].eltwise.alg;
506 jcp.dw_conv_eltwise_alpha = p.entry_[dw_conv_eltwise_ind].eltwise.alpha;
507 jcp.dw_conv_eltwise_beta = p.entry_[dw_conv_eltwise_ind].eltwise.beta;
511 jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
512 if (jcp.with_dw_conv) {
513 jcp.dw_conv_with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
516 if (jcp.with_dw_conv) {
517 jcp.oh = jcp.dw_conv_in_h;
518 jcp.ow = jcp.dw_conv_in_w;
521 const bool flat = jcp.ic == 3 || jcp.ic == 1;
522 const bool mimo = !flat;
525 && implication(flat, one_of(src_d.format(), nchw, nhwc)
526 && one_of(weights_d.format(), Ohwi8o, gOhwi8o))
527 && implication(mimo, src_d.format() == nChw8c
528 && one_of(weights_d.format(), OIhw8i8o, gOIhw8i8o))
529 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
530 && dst_d.format() == nChw8c;
531 if (!args_ok) return status::unimplemented;
533 const int simd_w = 8; // 2 SSE vectors processing at once
535 jcp.ur_h = 1; /* no code-unrolling by h so far */
537 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
538 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
540 jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
543 && jcp.oc % simd_w == 0
544 && jcp.l_pad <= jcp.ur_w
545 && implication(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
546 || (jcp.stride_w == 1 && jcp.stride_h == 1))
547 && implication(mimo, jcp.ic % simd_w == 0);
548 if (!args_ok) return status::unimplemented;
550 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
551 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
553 if (r_pad_no_tail > jcp.ur_w) {
554 /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
555 jcp.ur_w = r_pad_no_tail + 1;
556 jcp.nb_oc_blocking = ((16 - 1)-jcp.ur_w)/jcp.ur_w;
557 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
558 /* check again ... */
559 r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
560 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
561 if ((r_pad_no_tail > jcp.ur_w) || (jcp.ow < jcp.ur_w))
562 return status::unimplemented;
564 if (jcp.l_pad > jcp.ur_w) return status::unimplemented;
566 jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
567 jcp.nb_ic = jcp.ic / jcp.ic_block;
569 jcp.oc_block = simd_w;
570 jcp.nb_oc = jcp.oc / jcp.oc_block;
572 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
573 jcp.nb_ic_blocking = 12;
574 jcp.nb_ic_blocking_max = 16;
576 jcp.nb_ic_blocking = 1;
577 jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
580 if (jcp.with_eltwise) {
581 int nvecs_elt = jit_uni_eltwise_vector_f32<sse42>::sharedVecsCount(jcp.eltwise_alg);
582 int nvecs_conv = 16 - nvecs_elt;
583 while (jcp.ur_w * jcp.nb_oc_blocking > nvecs_conv) {
584 if (jcp.nb_oc_blocking <= 1) {
588 jcp.nb_oc_blocking -= 1;
591 if (jcp.ur_w * jcp.nb_oc_blocking > nvecs_conv)
592 return status::unimplemented;
595 return status::success;