1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
21 #include "cpu_memory.hpp"
23 #include "jit_sse42_1x1_conv_kernel_f32.hpp"
25 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
35 using namespace Xbyak;
37 void jit_sse42_1x1_conv_kernel_f32::bcast_loop(int load_loop_blk,
40 mov(aux1_reg_bcast_data, reg_bcast_data);
41 mov(aux_reg_output_data, reg_output_data);
42 mov(bcast_loop_iter, reg_bcast_loop_work);
44 jit_tagged_label bcast_loop("bcast_loop", load_loop_tag);
45 jit_tagged_label bcast_loop_tail("bcast_loop_tail", load_loop_tag);
47 cmp(bcast_loop_iter, jcp.ur);
48 jl(bcast_loop_tail, T_NEAR);
51 assert(jcp.bcast_block % jcp.ur == 0);
52 int num_substeps = jcp.bcast_block / jcp.ur;
53 assert(num_substeps > 0 && num_substeps < 10);
54 for (int i = 0; i < num_substeps; i++) {
55 reduce_loop(load_loop_blk, jcp.ur, load_loop_tag, '0' + i);
56 if (i < num_substeps - 1) {
57 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
58 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
60 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
61 - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
62 add(aux_reg_output_data, jcp.bcast_loop_output_step
63 - (num_substeps - 1) * jcp.bcast_loop_output_substep);
66 sub(bcast_loop_iter, jcp.bcast_block);
67 cmp(bcast_loop_iter, jcp.bcast_block);
68 jge(bcast_loop, T_NEAR);
73 jit_tagged_label bcast_loop_tail_out(
74 "bcast_loop_tail_out", load_loop_tag);
75 cmp(bcast_loop_iter, 0);
76 jz(bcast_loop_tail_out, T_NEAR);
77 reduce_loop(load_loop_blk, jcp.ur_tail, load_loop_tag, '1');
78 L(bcast_loop_tail_out);
82 void jit_sse42_1x1_conv_kernel_f32::reduce_loop(int load_loop_blk, int ur,
83 char load_loop_tag, char bcast_loop_tag)
85 auto reg_load = [=](int i, int n) {
86 return Xmm(2*ur * load_loop_blk + 2*i + n + 1);
89 auto reg_accum = [=](int i, int j, int n) {
90 return Xmm(n*load_loop_blk*ur + i*ur + j + 1);
93 auto bias_ptr = [=](int i, int n) {
94 return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i + n*4*sizeof(float)];
97 auto bcast_ptr = [=](int u, int j) {
99 assert(u <= jcp.reduce_loop_unroll);
101 if (one_of(jcp.prop_kind,
102 forward_training, forward_inference, backward_data)) {
103 assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data)
104 ? jcp.oc_block : jcp.ic_block);
105 auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is;
106 offt = (u == jcp.reduce_loop_unroll)
107 ? (height + j) * jcp.reduce_loop_unroll
108 : j * jcp.reduce_loop_unroll + u;
110 offt = u * jcp.ic_block + j;
111 return ptr[aux_reg_bcast_data + sizeof(float) * offt];
114 auto load_ptr = [=](int u, int i, int n) {
116 size_t u0 = u % jcp.reduce_loop_unroll;
117 size_t u1 = u / jcp.reduce_loop_unroll;
118 switch (jcp.prop_kind) {
120 offt = (i * jcp.oc_block + u0) * jcp.ic_block;
122 case backward_weights:
123 offt = (i * jcp.os + u0) * jcp.oc_block;
126 offt = (i * jcp.ic + u0) * jcp.oc_block;
128 return ptr[aux_reg_load_data
129 + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt + n * 4 * sizeof(float)];
132 auto output_ptr = [=](int i, int j, int n) {
133 switch (jcp.prop_kind) {
135 return ptr[aux_reg_output_data +
136 (i * jcp.is + j) * jcp.ic_block * sizeof(float) + n * 4 * sizeof(float)];
137 case backward_weights:
138 return ptr[aux_reg_output_data
139 + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale
140 + sizeof(float) * jcp.oc_block * j + n * 4 * sizeof(float)];
142 if (jcp.with_dw_conv)
143 return ptr[aux_reg_output_data +
144 (i * jcp.dw_conv_ker_h * jcp.ow + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
146 return ptr[aux_reg_output_data +
147 (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
152 jit_tagged_label init_done("init_done", load_loop_tag, bcast_loop_tag);
153 jit_tagged_label init_zero("init_zero", load_loop_tag, bcast_loop_tag);
155 if (jcp.with_bias && one_of(jcp.prop_kind, forward_training,
156 forward_inference)) {
157 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
160 for (int i = 0; i < load_loop_blk; i++)
161 for (int j = 0; j < ur; ++j) {
162 movups(reg_accum(i, j, 0), bias_ptr(i, 0));
163 movups(reg_accum(i, j, 1), bias_ptr(i, 1));
169 for (int i = 0; i < load_loop_blk; ++i)
170 for (int j = 0; j < ur; ++j) {
171 auto r0 = reg_accum(i, j, 0);
172 auto r1 = reg_accum(i, j, 1);
180 for (int i = 0; i < load_loop_blk; ++i) {
181 movups(reg_load(i, 0), load_ptr(0, i, 0));
182 movups(reg_load(i, 1), load_ptr(0, i, 1));
185 movss(reg_bcast, bcast_ptr(0, 0));
186 shufps(reg_bcast, reg_bcast, 0);
190 jit_tagged_label store_done(
191 "store_done", load_loop_tag, bcast_loop_tag);
192 jit_tagged_label store_noadd(
193 "store_noadd", load_loop_tag, bcast_loop_tag);
196 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
197 jnz(store_noadd, T_NEAR);
200 for (int j = 0; j < ur; ++j)
201 for (int i = 0; i < load_loop_blk; ++i) {
202 auto r0 = reg_accum(i, j, 0);
203 auto r1 = reg_accum(i, j, 1);
204 addps(r0, output_ptr(i, j, 0));
205 addps(r1, output_ptr(i, j, 1));
210 jit_tagged_label store_norelu(
211 "store_norelu", load_loop_tag, bcast_loop_tag);
212 test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
213 jz(store_norelu, T_NEAR);
215 int eltwise_inj_idx = 0;
216 int depthwise_inj_idx = 0;
217 const auto &p = attr_.post_ops_;
219 if (p.len_ == 0 && eltwise_injectors.size() == 1) {
220 eltwise_injectors[0]->compute_vector_range(1, 2 * ur * load_loop_blk + 1);
223 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
224 for (int i = 0; i < end_idx; i++) {
225 auto& post_op = p.entry_[i];
226 if (post_op.is_eltwise()) {
227 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(1, 2 * ur * load_loop_blk + 1);
229 } else if (post_op.is_depthwise()) {
230 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
231 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
233 add(reg_d_weights, reg_oc_off);
234 add(reg_d_bias, reg_oc_off);
236 for (int j = 0; j < load_loop_blk; ++j) {
237 for (int k = 0; k < 2; k++) {
238 int start_idx = reg_accum(j, 0, k).getIdx();
239 int end_idx = reg_accum(j, ur, k).getIdx();
241 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
242 start_idx, end_idx, reg_d_weights, reg_d_bias);
244 add(reg_d_weights, 4 * sizeof(float));
245 add(reg_d_bias, 4 * sizeof(float));
255 for (int j = 0; j < ur; ++j)
256 for (int i = 0; i < load_loop_blk; ++i) {
257 movups(output_ptr(i, j, 0), reg_accum(i, j, 0));
258 movups(output_ptr(i, j, 1), reg_accum(i, j, 1));
264 auto fma_block = [=](bool last_block) {
265 for (int u = 0; u < jcp.reduce_loop_unroll; ++u) {
266 for (int j = 0; j < ur; ++j) {
267 for (int i = 0; i < load_loop_blk; ++i) {
268 mulps(reg_load(i, 0), reg_bcast);
269 mulps(reg_load(i, 1), reg_bcast);
270 addps(reg_accum(i, j, 0), reg_load(i, 0));
271 addps(reg_accum(i, j, 1), reg_load(i, 1));
273 if (j == ur - 1 && !(last_block && u == jcp.reduce_loop_unroll - 1)) {
274 movups(reg_load(i, 0), load_ptr(u + 1, i, 0));
275 movups(reg_load(i, 1), load_ptr(u + 1, i, 1));
279 movss(reg_bcast, bcast_ptr(u, j + 1));
280 shufps(reg_bcast, reg_bcast, 0);
283 if (!last_block || u < jcp.reduce_loop_unroll - 1) {
284 movss(reg_bcast, bcast_ptr(u + 1, 0));
285 shufps(reg_bcast, reg_bcast, 0);
287 } // for reduce_loop_unroll
290 jit_tagged_label reduce_loop("reduce_loop", load_loop_tag, bcast_loop_tag);
291 jit_tagged_label reduce_loop_tail(
292 "reduce_loop_tail", load_loop_tag, bcast_loop_tag);
294 mov(aux_reg_load_data, reg_load_data);
295 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
299 mov(reduce_loop_iter, reg_reduce_loop_work);
300 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
301 jle(reduce_loop_tail, T_NEAR);
305 add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
306 add(aux_reg_load_data, jcp.reduce_loop_load_step);
307 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
308 jg(reduce_loop, T_NEAR);
317 void jit_sse42_1x1_conv_kernel_f32::diff_bias_loop(int load_loop_blk,
320 if (!jcp.with_bias || jcp.prop_kind != backward_weights)
323 jit_tagged_label diff_bias_loop("diff_bias_loop", load_loop_tag);
324 jit_tagged_label diff_bias_loop_out("diff_bias_loop_out", load_loop_tag);
325 jit_tagged_label diff_bias_init_out("diff_bias_init_out", load_loop_tag);
326 jit_tagged_label diff_bias_load("diff_bias_load", load_loop_tag);
328 auto diff_bias_ptr = [=](int i, int n) {
329 return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)+ 4*n*sizeof(float)];
332 auto load_ptr = [=](int u, int i, int n) {
333 return ptr[aux_reg_load_data
334 + (i * jcp.os + u) * jcp.oc_block * sizeof(float) + 4*n*sizeof(float)];
337 auto diff_bias_reg = [=](int i, int n) { return Xmm(2*i + n + 1); };
339 mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
340 cmp(reg_diff_bias_data, 0);
341 je(diff_bias_loop_out, T_NEAR);
343 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
344 jz(diff_bias_load, T_NEAR);
346 for (int i = 0; i < load_loop_blk; ++i) {
347 auto r0 = diff_bias_reg(i, 0);
348 auto r1 = diff_bias_reg(i, 1);
352 jmp(diff_bias_init_out, T_NEAR);
355 for (int i = 0; i < load_loop_blk; ++i) {
356 movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0));
357 movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1));
360 L(diff_bias_init_out);
361 mov(aux_reg_load_data, reg_load_data);
362 mov(reduce_loop_iter, reg_reduce_loop_work);
364 for(int u = 0; u < jcp.reduce_loop_unroll; ++u)
365 for (int i = 0; i < load_loop_blk; ++i) {
366 addps(diff_bias_reg(i, 0), load_ptr(u, i, 0));
367 addps(diff_bias_reg(i, 1), load_ptr(u, i, 1));
369 assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
370 add(aux_reg_load_data, jcp.reduce_loop_load_step);
371 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
372 jnz(diff_bias_loop, T_NEAR);
375 for (int i = 0; i < load_loop_blk; i++) {
376 movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0));
377 movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1));
380 add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
381 mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
383 L(diff_bias_loop_out);
386 void jit_sse42_1x1_conv_kernel_f32::generate()
388 if (jcp.with_eltwise) {
389 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<sse42>(
390 this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
394 const auto &p = attr_.post_ops_;
395 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
396 for (int i = 0; i < end_idx; i++) {
397 auto &post_op = p.entry_[i];
398 if (post_op.is_eltwise()) {
399 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<sse42>(
402 post_op.eltwise.alpha,
405 } else if (post_op.is_depthwise()) {
406 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<sse42>(
408 post_op.depthwise.alg
415 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
416 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
417 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
419 if (jcp.prop_kind == backward_weights) {
420 sub(rsp, stack_space_needed);
421 mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
422 mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
424 mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
427 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
428 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
429 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
430 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
431 if (jcp.prop_kind == backward_weights)
432 mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
433 mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
435 auto load_loop_body = [=] (int load_loop_blk, char bcast_loop_tag) {
436 bcast_loop(load_loop_blk, bcast_loop_tag);
437 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
438 switch (jcp.prop_kind) {
439 case forward_training:
440 case forward_inference:
441 add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
442 if (jcp.with_dw_conv)
444 load_loop_blk * jcp.ow * jcp.oc_block * sizeof(float));
447 load_loop_blk * jcp.os * jcp.oc_block * sizeof(float));
451 load_loop_blk * jcp.is * jcp.ic_block * sizeof(float));
453 case backward_weights:
454 for (int i = 0; i < load_loop_blk; i++)
455 add(reg_output_data, reg_output_stride);
458 assert(!"invalid prop_kind");
460 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
461 add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float));
464 const char *load_loop_blk_8 = "load_loop_blk_8";
465 const char *load_loop_blk_16 = "load_loop_blk_16";
466 const char *load_loop_blk_24 = "load_loop_blk_24";
467 const char *load_loop_blk_end = "load_loop_blk_end";
469 cmp(reg_load_loop_work, 8);
470 jle(load_loop_blk_8, T_NEAR);
472 cmp(reg_load_loop_work, 32);
473 je(load_loop_blk_16, T_NEAR);
475 cmp(reg_load_loop_work, 16);
476 jle(load_loop_blk_16, T_NEAR);
478 L(load_loop_blk_24); {
479 diff_bias_loop(3, '3');
480 load_loop_body(3, '3');
481 cmp(reg_load_loop_work, 32);
482 je(load_loop_blk_16);
483 cmp(reg_load_loop_work, 24);
484 jge(load_loop_blk_24);
487 cmp(reg_load_loop_work, 8);
488 jle(load_loop_blk_8, T_NEAR);
490 L(load_loop_blk_16); {
491 diff_bias_loop(2, '2');
492 load_loop_body(2, '2');
493 cmp(reg_load_loop_work, 16);
494 jge(load_loop_blk_16);
497 L(load_loop_blk_8); {
498 cmp(reg_load_loop_work, 0);
499 je(load_loop_blk_end, T_NEAR);
500 diff_bias_loop(1, '1');
501 load_loop_body(1, '1');
504 L(load_loop_blk_end);
506 if (jcp.with_bias && jcp.prop_kind == backward_weights)
507 add(rsp, stack_space_needed);
511 for (auto& inj : eltwise_injectors)
512 inj->prepare_table();
515 bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok(
516 jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
517 const auto &p = attr.post_ops_;
519 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
520 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
521 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
522 auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
523 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
526 case 0: return true; // no post_ops
528 return true // sum OR eltwise OR dw_conv
529 && !jcp.with_eltwise && (is_simple(0) || is_sum(0) || is_dw_conv(0));
531 return true // sum->eltwise OR dw_conv->eltwise OR eltwise->dw_conv OR dw_conv->sum OR sum->depthwise OR
532 // eltwise->depthwise OR depthwise->depthwise
533 && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
534 (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
535 (is_simple(0) && is_simple(1)));
537 return true // eltwise->dw_conv->eltwise OR dw_conv->sum->eltwise OR sum->eltwise->depthwise OR
538 // sum->depthwise->eltwise OR sum->depthwise->depthwise
539 && !jcp.with_eltwise && ((is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
540 (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
541 (is_sum(0) && is_simple(1) && is_simple(2)));
542 case 4: return true // eltwise->dw_conv->sum->eltwise
543 && !jcp.with_eltwise && (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
544 default: return false;
550 status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
551 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
552 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
553 const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
556 return status::unimplemented;
558 // TODO (Roma): this code is duplicated from the generic kernel; maybe the
559 // configuration struct could do some stuff below
560 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
562 jcp.prop_kind = cd.prop_kind;
564 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
565 jcp.mb = src_d.dims()[0];
567 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
568 jcp.oc_without_padding = jcp.oc;
569 jcp.ic = src_d.dims()[1] / jcp.ngroups;
571 jcp.ih = src_d.dims()[2];
572 jcp.iw = src_d.dims()[3];
573 jcp.oh = dst_d.dims()[2];
574 jcp.ow = dst_d.dims()[3];
576 jcp.kh = weights_d.dims()[with_groups + 2];
577 jcp.kw = weights_d.dims()[with_groups + 3];
579 jcp.t_pad = cd.padding[0][0];
580 jcp.l_pad = cd.padding[0][1];
582 jcp.stride_h = cd.strides[0];
583 jcp.stride_w = cd.strides[1];
585 jcp.src_fmt = src_d.format();
586 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
588 jcp.with_eltwise = with_relu;
589 jcp.eltwise_alg = mkldnn_eltwise_relu;
590 jcp.eltwise_alpha = relu_negative_slope;
592 if (!post_ops_ok(jcp, attr))
593 return status::unimplemented;
595 const auto &p = attr.post_ops_;
596 jcp.with_dw_conv = false;
597 int dw_conv_ind = p.find(primitive_kind::convolution);
598 if (dw_conv_ind != -1) {
599 jcp.with_dw_conv = true;
600 jcp.dw_conv_in_h = p.entry_[dw_conv_ind].dw_conv.in_h;
601 jcp.dw_conv_in_w = p.entry_[dw_conv_ind].dw_conv.in_w;
602 jcp.dw_conv_ker_h = p.entry_[dw_conv_ind].dw_conv.ker_h;
603 jcp.dw_conv_ker_w = p.entry_[dw_conv_ind].dw_conv.ker_w;
604 jcp.dw_conv_str_h = p.entry_[dw_conv_ind].dw_conv.str_h;
605 jcp.dw_conv_str_w = p.entry_[dw_conv_ind].dw_conv.str_w;
606 jcp.dw_conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
607 jcp.dw_conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
610 if (jcp.with_dw_conv) {
611 int dw_conv_eltwise_ind = p.find(primitive_kind::eltwise, dw_conv_ind);
612 if (dw_conv_eltwise_ind != -1) {
613 jcp.dw_conv_with_eltwise = true;
614 jcp.dw_conv_eltwise_alg = p.entry_[dw_conv_eltwise_ind].eltwise.alg;
615 jcp.dw_conv_eltwise_alpha = p.entry_[dw_conv_eltwise_ind].eltwise.alpha;
616 jcp.dw_conv_eltwise_beta = p.entry_[dw_conv_eltwise_ind].eltwise.beta;
620 jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
621 if (jcp.with_dw_conv) {
622 jcp.dw_conv_with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
625 if (jcp.with_dw_conv) {
626 jcp.oh = jcp.dw_conv_in_h;
627 jcp.ow = jcp.dw_conv_in_w;
630 jcp.os = jcp.oh * jcp.ow;
631 jcp.is = jcp.ih * jcp.iw;
633 constexpr memory_format_t weights_formats[2][2] = {
634 { OIhw8i8o, OIhw8o8i },
635 { gOIhw8i8o, gOIhw8o8i }
637 memory_format_t weights_format
638 = weights_formats[with_groups][jcp.prop_kind == backward_data];
642 && src_d.format() == nChw8c
643 && weights_d.format() == weights_format
644 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
645 && dst_d.format() == nChw8c;
646 if (!args_ok) return status::unimplemented;
648 const int simd_w = 4;
650 jcp.oc = rnd_up(jcp.oc, simd_w*2);
651 jcp.ic = rnd_up(jcp.ic, simd_w*2);
653 jcp.ic_block = jcp.oc_block = simd_w*2;
656 && jcp.oc % jcp.oc_block == 0
657 && jcp.ic % jcp.ic_block == 0
658 && jcp.t_pad == 0 && jcp.l_pad == 0
659 && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
660 && jcp.kh == 1 && jcp.kw == 1;
661 if (!args_ok) return status::unimplemented;
665 int load_blocking{ 0 };
666 int load_blocking_max{ 0 };
667 int bcast_blocking{ 0 };
668 int bcast_blocking_max{ 0 };
669 int reduce_blocking{ 0 };
671 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
672 jcp.reduce_dim = jcp.ic;
673 jcp.reduce_block = jcp.ic_block;
675 jcp.load_dim = jcp.oc;
676 jcp.load_block = jcp.oc_block;
678 jcp.bcast_dim = jcp.with_dw_conv ? jcp.iw : jcp.is;
679 jcp.bcast_block = jcp.ur;
681 jcp.reduce_loop_unroll = jcp.reduce_block;
682 jcp.reduce_loop_bcast_step
683 = jcp.reduce_loop_unroll * jcp.is * sizeof(float);
684 jcp.reduce_loop_load_step
685 = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
687 jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float);
688 jcp.bcast_loop_output_substep = -1; // unused
689 jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float);
690 jcp.bcast_loop_bcast_substep = -1; // unused
692 jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float);
693 jcp.load_loop_iter_step = jcp.oc_block;
695 load_blocking = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 120; // assumes the kernel is jcp.ur x 3
696 load_blocking_max = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 144;
697 bcast_blocking = 128; // affects load balancing across threads
698 bcast_blocking_max = 192;
699 reduce_blocking = 128; // affects L1$ utilization
700 } else if (jcp.prop_kind == backward_data) {
701 jcp.reduce_dim = jcp.oc;
702 jcp.reduce_block = jcp.oc_block;
704 jcp.load_dim = jcp.ic;
705 jcp.load_block = jcp.oc_block;
707 jcp.bcast_dim = jcp.os;
708 jcp.bcast_block = jcp.ur;
710 jcp.reduce_loop_unroll = jcp.reduce_block;
711 jcp.reduce_loop_bcast_step
712 = jcp.reduce_loop_unroll * jcp.os * sizeof(float);
713 jcp.reduce_loop_load_step
714 = jcp.reduce_loop_unroll * jcp.ic * sizeof(float);
716 jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float);
717 jcp.bcast_loop_output_substep = -1; // unused
718 jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float);
719 jcp.bcast_loop_bcast_substep = -1; // unused
721 jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
722 jcp.load_loop_iter_step = jcp.ic_block;
724 load_blocking = 96; // assumes the kernel is jcp.ur x 3
725 load_blocking_max = 144;
726 bcast_blocking = 128; // affects load balancing across threads
727 bcast_blocking_max = 196;
728 reduce_blocking = 64; // affects L1$ utilization
729 } else if (jcp.prop_kind == backward_weights) {
730 jcp.reduce_dim = jcp.os;
731 jcp.reduce_block = 1;
733 jcp.load_dim = jcp.oc;
734 jcp.load_block = jcp.oc_block;
736 jcp.bcast_dim = jcp.ic;
737 jcp.bcast_block = jcp.ic_block;
739 jcp.reduce_loop_unroll = jcp.reduce_block;
740 jcp.reduce_loop_bcast_step
741 = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float);
742 jcp.reduce_loop_load_step
743 = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
745 jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float);
746 jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
747 jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float);
748 jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
750 jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float);
751 jcp.load_loop_iter_step = jcp.oc_block;
755 load_blocking = div_up(jcp.load_dim, jcp.load_block);
757 if (load_blocking <= 32) break;
758 else if (load_blocking % 2 == 0) load_blocking /= 2;
759 else if (load_blocking % 3 == 0) load_blocking /= 3;
762 load_blocking *= jcp.load_block;
763 load_blocking_max = load_blocking;
764 assert(jcp.load_dim % load_blocking == 0);
766 bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
768 if (bcast_blocking <= 9) break;
769 else if (bcast_blocking % 2 == 0) bcast_blocking /= 2;
770 else if (bcast_blocking % 3 == 0) bcast_blocking /= 3;
773 bcast_blocking *= jcp.bcast_block;
774 bcast_blocking_max = bcast_blocking;
775 assert(jcp.bcast_dim % bcast_blocking == 0);
777 reduce_blocking = 128; // affects L1$ utilization
779 return status::unimplemented;
781 assert(load_blocking);
782 assert(load_blocking_max);
783 assert(bcast_blocking);
784 assert(bcast_blocking_max);
785 assert(reduce_blocking);
787 assert(jcp.bcast_block % jcp.ur == 0);
788 jcp.ur_tail = jcp.bcast_dim % jcp.ur;
790 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
791 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
792 jcp.nb_load_blocking = load_blocking / jcp.load_block;
793 jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
794 jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
796 jcp.nb_bcast = jcp.with_dw_conv ? jcp.ih : div_up(jcp.bcast_dim, jcp.bcast_block);
797 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
798 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
800 return status::success;