1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
21 #include "cpu_memory.hpp"
23 #include "jit_sse42_1x1_conv_kernel_f32.hpp"
25 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
35 using namespace Xbyak;
37 void jit_sse42_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk)
39 mov(aux1_reg_bcast_data, reg_bcast_data);
40 mov(aux_reg_output_data, reg_output_data);
41 mov(bcast_loop_iter, reg_bcast_loop_work);
44 Label bcast_loop_tail;
46 cmp(bcast_loop_iter, jcp.ur);
47 jl(bcast_loop_tail, T_NEAR);
50 assert(jcp.bcast_block % jcp.ur == 0);
51 int num_substeps = jcp.bcast_block / jcp.ur;
52 assert(num_substeps > 0 && num_substeps < 10);
53 for (int i = 0; i < num_substeps; i++) {
54 generate_reduce_loop(load_loop_blk, jcp.ur);
55 if (i < num_substeps - 1) {
56 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
57 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
59 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
60 - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
61 add(aux_reg_output_data, jcp.bcast_loop_output_step
62 - (num_substeps - 1) * jcp.bcast_loop_output_substep);
65 sub(bcast_loop_iter, jcp.bcast_block);
66 cmp(bcast_loop_iter, jcp.bcast_block);
67 jge(bcast_loop, T_NEAR);
72 Label bcast_loop_tail_out;
73 cmp(bcast_loop_iter, 0);
74 jz(bcast_loop_tail_out, T_NEAR);
75 generate_reduce_loop(load_loop_blk, jcp.ur_tail);
76 L(bcast_loop_tail_out);
80 void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop(
81 int load_loop_blk, int ur)
83 auto reg_load = [=](int i, int n) {
84 return Xmm(2*ur * load_loop_blk + 2*i + n + 1);
87 auto reg_accum = [=](int i, int j, int n) {
88 return Xmm(n*load_loop_blk*ur + i*ur + j + 1);
91 auto bias_ptr = [=](int i, int n) {
92 return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i + n*4*sizeof(float)];
95 auto bcast_ptr = [=](int u, int j) {
97 assert(u <= jcp.reduce_loop_unroll);
99 if (one_of(jcp.prop_kind,
100 forward_training, forward_inference, backward_data)) {
101 assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data)
102 ? jcp.oc_block : jcp.ic_block);
103 auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is;
104 offt = (u == jcp.reduce_loop_unroll)
105 ? (height + j) * jcp.reduce_loop_unroll
106 : j * jcp.reduce_loop_unroll + u;
108 offt = u * jcp.ic_block + j;
109 return ptr[aux_reg_bcast_data + sizeof(float) * offt];
112 auto load_ptr = [=](int u, int i, int n) {
114 size_t u0 = u % jcp.reduce_loop_unroll;
115 size_t u1 = u / jcp.reduce_loop_unroll;
116 switch (jcp.prop_kind) {
118 offt = (i * jcp.oc_block + u0) * jcp.ic_block;
120 case backward_weights:
121 offt = (i * jcp.os + u0) * jcp.oc_block;
124 offt = (i * jcp.ic + u0) * jcp.oc_block;
126 return ptr[aux_reg_load_data
127 + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt + n * 4 * sizeof(float)];
130 auto output_ptr = [=](int i, int j, int n) {
131 switch (jcp.prop_kind) {
133 return ptr[aux_reg_output_data +
134 (i * jcp.is + j) * jcp.ic_block * sizeof(float) + n * 4 * sizeof(float)];
135 case backward_weights:
136 return ptr[aux_reg_output_data
137 + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale
138 + sizeof(float) * jcp.oc_block * j + n * 4 * sizeof(float)];
140 if (jcp.with_dw_conv)
141 return ptr[aux_reg_output_data +
142 (i * jcp_dw.kh * jcp.ow + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
144 return ptr[aux_reg_output_data +
145 (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
153 if (jcp.with_bias && one_of(jcp.prop_kind, forward_training,
154 forward_inference)) {
155 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
158 for (int i = 0; i < load_loop_blk; i++)
159 for (int j = 0; j < ur; ++j) {
160 movups(reg_accum(i, j, 0), bias_ptr(i, 0));
161 movups(reg_accum(i, j, 1), bias_ptr(i, 1));
167 for (int i = 0; i < load_loop_blk; ++i)
168 for (int j = 0; j < ur; ++j) {
169 auto r0 = reg_accum(i, j, 0);
170 auto r1 = reg_accum(i, j, 1);
178 for (int i = 0; i < load_loop_blk; ++i) {
179 movups(reg_load(i, 0), load_ptr(0, i, 0));
180 movups(reg_load(i, 1), load_ptr(0, i, 1));
183 movss(reg_bcast, bcast_ptr(0, 0));
184 shufps(reg_bcast, reg_bcast, 0);
191 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
192 jnz(store_noadd, T_NEAR);
195 for (int j = 0; j < ur; ++j)
196 for (int i = 0; i < load_loop_blk; ++i) {
197 auto r0 = reg_accum(i, j, 0);
198 auto r1 = reg_accum(i, j, 1);
199 addps(r0, output_ptr(i, j, 0));
200 addps(r1, output_ptr(i, j, 1));
205 Label store_no_postops;
206 test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
207 jz(store_no_postops, T_NEAR);
209 int eltwise_inj_idx = 0;
210 int depthwise_inj_idx = 0;
211 const auto &p = attr_.post_ops_;
213 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
214 for (int i = 0; i < end_idx; i++) {
215 auto& post_op = p.entry_[i];
216 if (post_op.is_eltwise()) {
217 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(1, 2 * ur * load_loop_blk + 1);
219 } else if (post_op.is_depthwise()) {
220 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
221 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
223 add(reg_d_weights, reg_oc_off);
224 add(reg_d_bias, reg_oc_off);
226 for (int j = 0; j < load_loop_blk; ++j) {
227 for (int k = 0; k < 2; k++) {
228 int start_idx = reg_accum(j, 0, k).getIdx();
229 int end_idx = reg_accum(j, ur, k).getIdx();
231 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
232 start_idx, end_idx, reg_d_weights, reg_d_bias);
234 add(reg_d_weights, 4 * sizeof(float));
235 add(reg_d_bias, 4 * sizeof(float));
245 for (int j = 0; j < ur; ++j)
246 for (int i = 0; i < load_loop_blk; ++i) {
247 movups(output_ptr(i, j, 0), reg_accum(i, j, 0));
248 movups(output_ptr(i, j, 1), reg_accum(i, j, 1));
252 auto fma_block = [=](bool last_block) {
253 for (int u = 0; u < jcp.reduce_loop_unroll; ++u) {
254 for (int j = 0; j < ur; ++j) {
255 for (int i = 0; i < load_loop_blk; ++i) {
256 mulps(reg_load(i, 0), reg_bcast);
257 mulps(reg_load(i, 1), reg_bcast);
258 addps(reg_accum(i, j, 0), reg_load(i, 0));
259 addps(reg_accum(i, j, 1), reg_load(i, 1));
261 if (j == ur - 1 && !(last_block && u == jcp.reduce_loop_unroll - 1)) {
262 movups(reg_load(i, 0), load_ptr(u + 1, i, 0));
263 movups(reg_load(i, 1), load_ptr(u + 1, i, 1));
267 movss(reg_bcast, bcast_ptr(u, j + 1));
268 shufps(reg_bcast, reg_bcast, 0);
271 if (!last_block || u < jcp.reduce_loop_unroll - 1) {
272 movss(reg_bcast, bcast_ptr(u + 1, 0));
273 shufps(reg_bcast, reg_bcast, 0);
275 } // for reduce_loop_unroll
279 Label reduce_loop_tail;
281 mov(aux_reg_load_data, reg_load_data);
282 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
286 mov(reduce_loop_iter, reg_reduce_loop_work);
287 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
288 jle(reduce_loop_tail, T_NEAR);
292 add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
293 add(aux_reg_load_data, jcp.reduce_loop_load_step);
294 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
295 jg(reduce_loop, T_NEAR);
304 void jit_sse42_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk)
306 if (!jcp.with_bias || jcp.prop_kind != backward_weights)
309 Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out;
310 Label diff_bias_load;
312 auto diff_bias_ptr = [=](int i, int n) {
313 return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)+ 4*n*sizeof(float)];
316 auto load_ptr = [=](int u, int i, int n) {
317 return ptr[aux_reg_load_data
318 + (i * jcp.os + u) * jcp.oc_block * sizeof(float) + 4*n*sizeof(float)];
321 auto diff_bias_reg = [=](int i, int n) { return Xmm(2*i + n + 1); };
323 mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
324 cmp(reg_diff_bias_data, 0);
325 je(diff_bias_loop_out, T_NEAR);
327 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
328 jz(diff_bias_load, T_NEAR);
330 for (int i = 0; i < load_loop_blk; ++i) {
331 auto r0 = diff_bias_reg(i, 0);
332 auto r1 = diff_bias_reg(i, 1);
336 jmp(diff_bias_init_out, T_NEAR);
339 for (int i = 0; i < load_loop_blk; ++i) {
340 movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0));
341 movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1));
344 L(diff_bias_init_out);
345 mov(aux_reg_load_data, reg_load_data);
346 mov(reduce_loop_iter, reg_reduce_loop_work);
348 for(int u = 0; u < jcp.reduce_loop_unroll; ++u)
349 for (int i = 0; i < load_loop_blk; ++i) {
350 addps(diff_bias_reg(i, 0), load_ptr(u, i, 0));
351 addps(diff_bias_reg(i, 1), load_ptr(u, i, 1));
353 assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
354 add(aux_reg_load_data, jcp.reduce_loop_load_step);
355 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
356 jnz(diff_bias_loop, T_NEAR);
359 for (int i = 0; i < load_loop_blk; i++) {
360 movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0));
361 movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1));
364 add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
365 mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
367 L(diff_bias_loop_out);
370 void jit_sse42_1x1_conv_kernel_f32::generate()
372 const auto &p = attr_.post_ops_;
373 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
374 for (int i = 0; i < end_idx; i++) {
375 auto &post_op = p.entry_[i];
376 if (post_op.is_eltwise()) {
377 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<sse42>(
380 post_op.eltwise.alpha,
383 } else if (post_op.is_depthwise()) {
384 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<sse42>(
386 post_op.depthwise.alg
393 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
394 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
395 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
397 if (jcp.prop_kind == backward_weights) {
398 sub(rsp, stack_space_needed);
399 mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
400 mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
402 mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
405 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
406 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
407 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
408 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
409 if (jcp.prop_kind == backward_weights)
410 mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
411 mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
413 auto generate_load_loop_body = [=] (int load_loop_blk) {
414 generate_bcast_loop(load_loop_blk);
415 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
416 switch (jcp.prop_kind) {
417 case forward_training:
418 case forward_inference:
419 add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
420 if (jcp.with_dw_conv)
422 load_loop_blk * jcp.ow * jcp.oc_block * sizeof(float));
425 load_loop_blk * jcp.os * jcp.oc_block * sizeof(float));
429 load_loop_blk * jcp.is * jcp.ic_block * sizeof(float));
431 case backward_weights:
432 for (int i = 0; i < load_loop_blk; i++)
433 add(reg_output_data, reg_output_stride);
436 assert(!"invalid prop_kind");
438 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
439 add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float));
442 Label load_loop_blk_8;
443 Label load_loop_blk_16;
444 Label load_loop_blk_24;
445 Label load_loop_blk_end;
447 cmp(reg_load_loop_work, 8);
448 jle(load_loop_blk_8, T_NEAR);
450 cmp(reg_load_loop_work, 32);
451 je(load_loop_blk_16, T_NEAR);
453 cmp(reg_load_loop_work, 16);
454 jle(load_loop_blk_16, T_NEAR);
456 L(load_loop_blk_24); {
457 generate_diff_bias_loop(3);
458 generate_load_loop_body(3);
459 cmp(reg_load_loop_work, 32);
460 je(load_loop_blk_16);
461 cmp(reg_load_loop_work, 24);
462 jge(load_loop_blk_24);
465 cmp(reg_load_loop_work, 8);
466 jle(load_loop_blk_8, T_NEAR);
468 L(load_loop_blk_16); {
469 generate_diff_bias_loop(2);
470 generate_load_loop_body(2);
471 cmp(reg_load_loop_work, 16);
472 jge(load_loop_blk_16);
475 L(load_loop_blk_8); {
476 cmp(reg_load_loop_work, 0);
477 je(load_loop_blk_end, T_NEAR);
478 generate_diff_bias_loop(1);
479 generate_load_loop_body(1);
482 L(load_loop_blk_end);
484 if (jcp.with_bias && jcp.prop_kind == backward_weights)
485 add(rsp, stack_space_needed);
489 for (auto& inj : eltwise_injectors)
490 inj->prepare_table();
493 bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok(
494 jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
495 const auto &p = attr.post_ops_;
497 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
498 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
499 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
500 auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
501 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
505 case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
506 case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
507 (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
508 (is_simple(0) && is_simple(1));
509 case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
510 (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
511 (is_sum(0) && is_simple(1) && is_simple(2));
512 case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
513 default: return false;
519 status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
520 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
521 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
522 const primitive_attr_t &attr)
525 return status::unimplemented;
527 // TODO (Roma): this code is duplicated from the generic kernel; maybe the
528 // configuration struct could do some stuff below
529 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
530 const int ndims = src_d.ndims();
532 jcp.prop_kind = cd.prop_kind;
534 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
535 jcp.mb = src_d.dims()[0];
537 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
538 jcp.oc_without_padding = jcp.oc;
539 jcp.ic = src_d.dims()[1] / jcp.ngroups;
541 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
542 jcp.iw = src_d.dims()[ndims - 1];
543 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
544 jcp.ow = dst_d.dims()[ndims - 1];
546 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
547 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
549 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
550 jcp.l_pad = cd.padding[0][ndims - 3];
552 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
553 jcp.stride_w = cd.strides[ndims - 3];
555 jcp.src_fmt = src_d.format();
556 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
558 if (!post_ops_ok(jcp, attr))
559 return status::unimplemented;
561 const auto &p = attr.post_ops_;
563 int dw_conv_ind = p.find(primitive_kind::convolution);
564 jcp.with_dw_conv = dw_conv_ind != -1;
565 if (jcp.with_dw_conv) {
566 jcp.dw_conv_oh = jcp.oh;
567 jcp.dw_conv_ow = jcp.ow;
568 jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
569 jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
572 jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
574 jcp.src_dt = cd.src_desc.data_type;
575 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
576 jcp.dst_dt = cd.dst_desc.data_type;
578 jcp.os = jcp.oh * jcp.ow;
579 jcp.is = jcp.ih * jcp.iw;
581 const int is_bwd_d = jcp.prop_kind == backward_data;
582 memory_format_t weights_format = with_groups
583 ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o,
585 : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o,
590 && one_of(src_d.format(), nCw8c, nChw8c)
591 && weights_d.format() == weights_format
592 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
593 && one_of(dst_d.format(), nCw8c, nChw8c);
594 if (!args_ok) return status::unimplemented;
596 const int simd_w = 4;
598 jcp.oc = rnd_up(jcp.oc, simd_w*2);
599 jcp.ic = rnd_up(jcp.ic, simd_w*2);
601 jcp.ic_block = jcp.oc_block = simd_w*2;
604 && jcp.oc % jcp.oc_block == 0
605 && jcp.ic % jcp.ic_block == 0
606 && jcp.t_pad == 0 && jcp.l_pad == 0
607 && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
608 && jcp.kh == 1 && jcp.kw == 1;
609 if (!args_ok) return status::unimplemented;
613 int load_blocking{ 0 };
614 int load_blocking_max{ 0 };
615 int bcast_blocking{ 0 };
616 int bcast_blocking_max{ 0 };
617 int reduce_blocking{ 0 };
619 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
620 jcp.reduce_dim = jcp.ic;
621 jcp.reduce_block = jcp.ic_block;
623 jcp.load_dim = jcp.oc;
624 jcp.load_block = jcp.oc_block;
626 jcp.bcast_dim = jcp.with_dw_conv ? jcp.iw : jcp.is;
627 jcp.bcast_block = jcp.ur;
629 jcp.reduce_loop_unroll = jcp.reduce_block;
630 jcp.reduce_loop_bcast_step
631 = jcp.reduce_loop_unroll * jcp.is * sizeof(float);
632 jcp.reduce_loop_load_step
633 = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
635 jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float);
636 jcp.bcast_loop_output_substep = -1; // unused
637 jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float);
638 jcp.bcast_loop_bcast_substep = -1; // unused
640 jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float);
641 jcp.load_loop_iter_step = jcp.oc_block;
643 load_blocking = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 120; // assumes the kernel is jcp.ur x 3
644 load_blocking_max = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 144;
645 bcast_blocking = 128; // affects load balancing across threads
646 bcast_blocking_max = 192;
647 reduce_blocking = 128; // affects L1$ utilization
648 } else if (jcp.prop_kind == backward_data) {
649 jcp.reduce_dim = jcp.oc;
650 jcp.reduce_block = jcp.oc_block;
652 jcp.load_dim = jcp.ic;
653 jcp.load_block = jcp.oc_block;
655 jcp.bcast_dim = jcp.os;
656 jcp.bcast_block = jcp.ur;
658 jcp.reduce_loop_unroll = jcp.reduce_block;
659 jcp.reduce_loop_bcast_step
660 = jcp.reduce_loop_unroll * jcp.os * sizeof(float);
661 jcp.reduce_loop_load_step
662 = jcp.reduce_loop_unroll * jcp.ic * sizeof(float);
664 jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float);
665 jcp.bcast_loop_output_substep = -1; // unused
666 jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float);
667 jcp.bcast_loop_bcast_substep = -1; // unused
669 jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
670 jcp.load_loop_iter_step = jcp.ic_block;
672 load_blocking = 96; // assumes the kernel is jcp.ur x 3
673 load_blocking_max = 144;
674 bcast_blocking = 128; // affects load balancing across threads
675 bcast_blocking_max = 196;
676 reduce_blocking = 64; // affects L1$ utilization
677 } else if (jcp.prop_kind == backward_weights) {
678 jcp.reduce_dim = jcp.os;
679 jcp.reduce_block = 1;
681 jcp.load_dim = jcp.oc;
682 jcp.load_block = jcp.oc_block;
684 jcp.bcast_dim = jcp.ic;
685 jcp.bcast_block = jcp.ic_block;
687 jcp.reduce_loop_unroll = jcp.reduce_block;
688 jcp.reduce_loop_bcast_step
689 = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float);
690 jcp.reduce_loop_load_step
691 = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
693 jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float);
694 jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
695 jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float);
696 jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
698 jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float);
699 jcp.load_loop_iter_step = jcp.oc_block;
703 load_blocking = div_up(jcp.load_dim, jcp.load_block);
705 if (load_blocking <= 32) break;
706 else if (load_blocking % 2 == 0) load_blocking /= 2;
707 else if (load_blocking % 3 == 0) load_blocking /= 3;
710 load_blocking *= jcp.load_block;
711 load_blocking_max = load_blocking;
712 assert(jcp.load_dim % load_blocking == 0);
714 bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
716 if (bcast_blocking <= 9) break;
717 else if (bcast_blocking % 2 == 0) bcast_blocking /= 2;
718 else if (bcast_blocking % 3 == 0) bcast_blocking /= 3;
721 bcast_blocking *= jcp.bcast_block;
722 bcast_blocking_max = bcast_blocking;
723 assert(jcp.bcast_dim % bcast_blocking == 0);
725 reduce_blocking = 128; // affects L1$ utilization
727 return status::unimplemented;
729 assert(load_blocking);
730 assert(load_blocking_max);
731 assert(bcast_blocking);
732 assert(bcast_blocking_max);
733 assert(reduce_blocking);
735 assert(jcp.bcast_block % jcp.ur == 0);
736 jcp.ur_tail = jcp.bcast_dim % jcp.ur;
738 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
739 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
740 jcp.nb_load_blocking = load_blocking / jcp.load_block;
741 jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
742 jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
744 jcp.nb_bcast = jcp.with_dw_conv ? jcp.ih : div_up(jcp.bcast_dim, jcp.bcast_block);
745 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
746 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
748 return status::success;
751 void jit_sse42_1x1_conv_kernel_f32::init_scratchpad(
752 memory_tracking::registrar_t &scratchpad,
753 const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
754 using namespace mkldnn::impl::memory_tracking::names;
756 if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
757 scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
759 if (jcp.with_dw_conv) {
760 const int nthreads = mkldnn_get_max_threads();
761 size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
762 scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
764 if (jcp.oc != jcp.oc_without_padding)
765 scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);