1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 * Copyright 2018 YANDEX LLC
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
23 #include "type_helpers.hpp"
26 #include "cpu_memory.hpp"
28 #include "jit_avx2_1x1_conv_kernel_f32.hpp"
30 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
36 using namespace mkldnn::impl::prop_kind;
37 using namespace mkldnn::impl::memory_format;
38 using namespace mkldnn::impl::utils;
40 using namespace Xbyak;
42 void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk)
44 mov(aux1_reg_bcast_data, reg_bcast_data);
45 mov(aux_reg_output_data, reg_output_data);
46 mov(bcast_loop_iter, reg_bcast_loop_work);
48 Label bcast_loop, bcast_loop_tail;
50 cmp(bcast_loop_iter, jcp.ur);
51 jl(bcast_loop_tail, T_NEAR);
54 assert(jcp.bcast_block % jcp.ur == 0);
55 int num_substeps = jcp.bcast_block / jcp.ur;
56 assert(num_substeps > 0 && num_substeps < 10);
57 for (int i = 0; i < num_substeps; i++) {
58 generate_reduce_loop(load_loop_blk, jcp.ur);
59 if (i < num_substeps - 1) {
60 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
61 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
63 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
64 - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
65 add(aux_reg_output_data, jcp.bcast_loop_output_step
66 - (num_substeps - 1) * jcp.bcast_loop_output_substep);
69 sub(bcast_loop_iter, jcp.bcast_block);
70 cmp(bcast_loop_iter, jcp.bcast_block);
71 jge(bcast_loop, T_NEAR);
76 Label bcast_loop_tail_out;
77 cmp(bcast_loop_iter, 0);
78 jz(bcast_loop_tail_out, T_NEAR);
79 generate_reduce_loop(load_loop_blk, jcp.ur_tail);
80 L(bcast_loop_tail_out);
84 void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
85 int load_loop_blk, int ur)
87 auto vreg_load = [=](int i) {
88 return Ymm(ur * load_loop_blk + i);
91 auto vreg_accum = [=](int i, int j) {
95 auto bias_ptr = [=](int i) {
96 return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i];
99 auto bcast_ptr = [=](int u, int j) {
101 assert(u <= jcp.reduce_loop_unroll);
103 if (one_of(jcp.prop_kind,
104 forward_training, forward_inference, backward_data))
106 assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data)
107 ? jcp.oc_block : jcp.ic_block);
108 auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is;
109 offt = (u == jcp.reduce_loop_unroll)
110 ? (height + j) * jcp.reduce_loop_unroll
111 : j * jcp.reduce_loop_unroll + u;
113 offt = u * jcp.ic_block + j;
114 return ptr[aux_reg_bcast_data + sizeof(float) * offt];
117 auto load_ptr = [=](int u, int i) {
119 size_t u0 = u % jcp.reduce_loop_unroll;
120 size_t u1 = u / jcp.reduce_loop_unroll;
121 switch (jcp.prop_kind) {
123 offt = (i * jcp.oc_block + u0) * jcp.ic_block;
125 case backward_weights:
126 offt = (i * jcp.os + u0) * jcp.oc_block;
129 offt = (i * jcp.ic + u0) * jcp.oc_block;
131 return ptr[aux_reg_load_data
132 + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt];
135 auto output_ptr = [=](int i, int j) {
136 switch (jcp.prop_kind) {
138 return ptr[aux_reg_output_data +
139 (i * jcp.is + j) * jcp.ic_block * sizeof(float)];
140 case backward_weights:
141 return ptr[aux_reg_output_data
142 + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale
143 + sizeof(float) * jcp.oc_block * j];
145 if (jcp.with_dw_conv) {
146 return ptr[aux_reg_output_data +
147 (i * jcp_dw.kh * jcp.ow + j) * jcp.oc_block * sizeof(float)];
149 return ptr[aux_reg_output_data +
150 (i * jcp.os + j) * jcp.oc_block * sizeof(float)];
156 Label init_done, init_zero;
158 if (jcp.with_bias && one_of(jcp.prop_kind, forward_training,
159 forward_inference)) {
160 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
163 for (int i = 0; i < load_loop_blk; i++)
164 for (int j = 0; j < ur; ++j)
165 vmovups(vreg_accum(i, j), bias_ptr(i));
170 for (int i = 0; i < load_loop_blk; ++i)
171 for (int j = 0; j < ur; ++j) {
172 auto r = vreg_accum(i, j);
177 for (int i = 0; i < load_loop_blk; ++i)
178 vmovups(vreg_load(i), load_ptr(0, i));
179 vbroadcastss(vreg_bcast, bcast_ptr(0, 0));
186 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
187 jnz(store_noadd, T_NEAR);
190 for (int j = 0; j < ur; ++j)
191 for (int i = 0; i < load_loop_blk; ++i) {
192 auto r = vreg_accum(i, j);
193 vaddps(r, r, output_ptr(i, j));
199 test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
200 jz(store_norelu, T_NEAR);
202 int eltwise_inj_idx = 0;
203 int depthwise_inj_idx = 0;
204 const auto &p = attr_.post_ops_;
206 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
207 for (int i = 0; i < end_idx; i++) {
208 auto& post_op = p.entry_[i];
209 if (post_op.is_eltwise()) {
210 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk);
212 } else if (post_op.is_depthwise()) {
213 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
214 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
216 add(reg_d_weights, reg_oc_off);
217 add(reg_d_bias, reg_oc_off);
219 for (int j = 0; j < load_loop_blk; ++j) {
220 int start_idx = vreg_accum(j, 0).getIdx();
221 int end_idx = start_idx + ur;
223 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
224 start_idx, end_idx, reg_d_weights, reg_d_bias);
226 add(reg_d_weights, jcp.oc_block * sizeof(float));
227 add(reg_d_bias, jcp.oc_block * sizeof(float));
236 for (int j = 0; j < ur; ++j)
237 for (int i = 0; i < load_loop_blk; ++i) {
238 vmovups(output_ptr(i, j), vreg_accum(i, j));
242 auto fma_block = [=](bool last_block) {
243 for (int u = 0; u < jcp.reduce_loop_unroll; ++u) {
244 for (int j = 0; j < ur; ++j) {
245 for (int i = 0; i < load_loop_blk; ++i) {
247 vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast);
248 else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
249 vmulps(vtmp, vreg_bcast, vreg_load(i));
250 vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp);
252 if (j == ur - 1 && !(last_block
253 && u == jcp.reduce_loop_unroll - 1))
254 vmovups(vreg_load(i), load_ptr(u + 1, i));
257 vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1));
259 if (!last_block || u < jcp.reduce_loop_unroll - 1)
260 vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0));
264 Label reduce_loop, reduce_loop_tail;
266 mov(aux_reg_load_data, reg_load_data);
267 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
271 mov(reduce_loop_iter, reg_reduce_loop_work);
272 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
273 jle(reduce_loop_tail, T_NEAR);
277 add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
278 add(aux_reg_load_data, jcp.reduce_loop_load_step);
279 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
280 jg(reduce_loop, T_NEAR);
289 void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk)
291 if (!jcp.with_bias || jcp.prop_kind != backward_weights)
294 Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out;
295 Label diff_bias_load;
297 auto diff_bias_ptr = [=](int i) {
298 return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)];
301 auto load_ptr = [=](int u, int i) {
302 return ptr[aux_reg_load_data
303 + (i * jcp.os + u) * jcp.oc_block * sizeof(float)];
306 auto diff_bias_reg = [=](int i) { return Ymm(i); };
308 mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
309 cmp(reg_diff_bias_data, 0);
310 je(diff_bias_loop_out, T_NEAR);
312 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
313 jz(diff_bias_load, T_NEAR);
315 for (int i = 0; i < load_loop_blk; ++i) {
316 auto r = diff_bias_reg(i);
319 jmp(diff_bias_init_out, T_NEAR);
322 for (int i = 0; i < load_loop_blk; ++i)
323 vmovups(diff_bias_reg(i), diff_bias_ptr(i));
325 L(diff_bias_init_out);
326 mov(aux_reg_load_data, reg_load_data);
327 mov(reduce_loop_iter, reg_reduce_loop_work);
329 for(int u = 0; u < jcp.reduce_loop_unroll; ++u)
330 for (int i = 0; i < load_loop_blk; ++i)
331 vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i));
332 assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
333 add(aux_reg_load_data, jcp.reduce_loop_load_step);
334 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
335 jnz(diff_bias_loop, T_NEAR);
338 for (int i = 0; i < load_loop_blk; i++)
339 vmovups(diff_bias_ptr(i), diff_bias_reg(i));
340 add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
341 mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
343 L(diff_bias_loop_out);
346 void jit_avx2_1x1_conv_kernel_f32::generate()
348 const auto &p = attr_.post_ops_;
349 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
350 for (int i = 0; i < end_idx; i++) {
351 auto &post_op = p.entry_[i];
352 if (post_op.is_eltwise()) {
353 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx2>(
356 post_op.eltwise.alpha,
359 } else if (post_op.is_depthwise()) {
360 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx2>(
362 post_op.depthwise.alg
369 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
370 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
371 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
373 if (jcp.prop_kind == backward_weights) {
374 sub(rsp, stack_space_needed);
375 mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
376 mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
378 mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
381 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
382 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
383 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
384 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
385 if (jcp.prop_kind == backward_weights)
386 mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
387 mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
389 auto generate_load_loop_body = [=] (int load_loop_blk) {
390 generate_bcast_loop(load_loop_blk);
391 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
392 switch (jcp.prop_kind) {
393 case forward_training:
394 case forward_inference:
395 add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
396 if (jcp.with_dw_conv)
398 load_loop_blk * jcp.ow * jcp.oc_block * sizeof(float));
401 load_loop_blk * jcp.os * jcp.oc_block * sizeof(float));
405 load_loop_blk * jcp.is * jcp.ic_block * sizeof(float));
407 case backward_weights:
408 for (int i = 0; i < load_loop_blk; i++)
409 add(reg_output_data, reg_output_stride);
412 assert(!"invalid prop_kind");
414 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
415 add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float));
418 Label load_loop_blk_8;
419 Label load_loop_blk_16;
420 Label load_loop_blk_24;
421 Label load_loop_blk_end;
423 cmp(reg_load_loop_work, 8);
424 jle(load_loop_blk_8, T_NEAR);
426 cmp(reg_load_loop_work, 32);
427 je(load_loop_blk_16, T_NEAR);
429 cmp(reg_load_loop_work, 16);
430 jle(load_loop_blk_16, T_NEAR);
432 L(load_loop_blk_24); {
433 generate_diff_bias_loop(3);
434 generate_load_loop_body(3);
435 cmp(reg_load_loop_work, 32);
436 je(load_loop_blk_16);
437 cmp(reg_load_loop_work, 24);
438 jge(load_loop_blk_24);
441 cmp(reg_load_loop_work, 8);
442 jle(load_loop_blk_8, T_NEAR);
444 L(load_loop_blk_16); {
445 generate_diff_bias_loop(2);
446 generate_load_loop_body(2);
447 cmp(reg_load_loop_work, 16);
448 jge(load_loop_blk_16);
451 L(load_loop_blk_8); {
452 cmp(reg_load_loop_work, 0);
453 je(load_loop_blk_end, T_NEAR);
454 generate_diff_bias_loop(1);
455 generate_load_loop_body(1);
458 L(load_loop_blk_end);
460 if (jcp.with_bias && jcp.prop_kind == backward_weights)
465 for (auto& inj : eltwise_injectors)
466 inj->prepare_table();
469 bool jit_avx2_1x1_conv_kernel_f32::post_ops_ok(
470 jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
471 const auto &p = attr.post_ops_;
473 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
474 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
475 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
476 auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
477 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
481 case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
482 case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
483 (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
484 (is_simple(0) && is_simple(1));
485 case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
486 (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
487 (is_sum(0) && is_simple(1) && is_simple(2));
488 case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
489 default: return false;
495 status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
496 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
497 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
498 const primitive_attr_t &attr)
500 if (!mayiuse(avx)) return status::unimplemented;
502 // TODO (Roma): this code is duplicated from the generic kernel; maybe the
503 // configuration struct could do some stuff below
504 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
505 const int ndims = src_d.ndims();
507 jcp.prop_kind = cd.prop_kind;
509 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
510 jcp.mb = src_d.dims()[0];
512 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
513 jcp.oc_without_padding = jcp.oc;
514 jcp.ic = src_d.dims()[1] / jcp.ngroups;
516 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
517 jcp.iw = src_d.dims()[ndims - 1];
518 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
519 jcp.ow = dst_d.dims()[ndims - 1];
521 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
522 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
524 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
525 jcp.l_pad = cd.padding[0][ndims - 3];
527 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
528 jcp.stride_w = cd.strides[ndims - 3];
530 jcp.src_fmt = src_d.format();
531 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
533 if (!post_ops_ok(jcp, attr))
534 return status::unimplemented;
536 const auto &p = attr.post_ops_;
538 int dw_conv_ind = p.find(primitive_kind::convolution);
539 jcp.with_dw_conv = dw_conv_ind != -1;
540 jcp.with_dw_conv = dw_conv_ind != -1;
541 if (jcp.with_dw_conv) {
542 jcp.dw_conv_oh = jcp.oh;
543 jcp.dw_conv_ow = jcp.ow;
544 jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
545 jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
548 if (jcp.with_dw_conv && !mayiuse(avx2))
549 return status::unimplemented;
551 if (!mayiuse(avx2)) {
552 for (int i = 0; i < p.len_; i++) {
553 auto &post_op = p.entry_[i];
554 if (post_op.is_eltwise()) {
555 if (post_op.eltwise.alg != alg_kind::eltwise_relu)
556 return status::unimplemented;
557 } else if (post_op.is_depthwise()) {
558 return status::unimplemented;
563 jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
565 jcp.src_dt = cd.src_desc.data_type;
566 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
567 jcp.dst_dt = cd.dst_desc.data_type;
569 jcp.os = jcp.oh * jcp.ow;
570 jcp.is = jcp.ih * jcp.iw;
572 const int is_bwd_d = jcp.prop_kind == backward_data;
573 memory_format_t weights_format = with_groups
574 ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o,
576 : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o,
579 const int simd_w = 8;
581 jcp.oc = rnd_up(jcp.oc, simd_w);
582 jcp.ic = rnd_up(jcp.ic, simd_w);
586 && one_of(src_d.format(), nCw8c, nChw8c)
587 && weights_d.format() == weights_format
588 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
589 && one_of(dst_d.format(), nCw8c, nChw8c);
590 if (!args_ok) return status::unimplemented;
593 && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
594 && jcp.t_pad == 0 && jcp.l_pad == 0
595 && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
596 && jcp.kh == 1 && jcp.kw == 1;
597 if (!args_ok) return status::unimplemented;
599 // TODO: remove this restriction
600 // optimized 1x1 bwd_w does not support Intel AVX
601 if (jcp.prop_kind == backward_weights && !mayiuse(avx2))
602 return status::unimplemented;
604 jcp.ic_block = jcp.oc_block = simd_w;
606 jcp.ur = mayiuse(avx2) ? 4 : 3; // Intel AVX support
608 int load_blocking{ 0 };
609 int load_blocking_max{ 0 };
610 int bcast_blocking{ 0 };
611 int bcast_blocking_max{ 0 };
612 int reduce_blocking{ 0 };
614 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
615 jcp.reduce_dim = jcp.ic;
616 jcp.reduce_block = jcp.ic_block;
618 jcp.load_dim = jcp.oc;
619 jcp.load_block = jcp.oc_block;
621 jcp.bcast_dim = jcp.with_dw_conv ? jcp.iw : jcp.is;
622 jcp.bcast_block = jcp.ur;
624 jcp.reduce_loop_unroll = jcp.reduce_block;
625 jcp.reduce_loop_bcast_step
626 = jcp.reduce_loop_unroll * jcp.is * sizeof(float);
627 jcp.reduce_loop_load_step
628 = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
630 jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float);
631 jcp.bcast_loop_output_substep = -1; // unused
632 jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float);
633 jcp.bcast_loop_bcast_substep = -1; // unused
635 jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float);
636 jcp.load_loop_iter_step = jcp.oc_block;
638 load_blocking = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 120; // assumes the kernel is jcp.ur x 3
639 load_blocking_max = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 144;
640 bcast_blocking = 128; // affects load balancing across threads
641 bcast_blocking_max = 192;
642 reduce_blocking = 128; // affects L1$ utilization
643 } else if (jcp.prop_kind == backward_data) {
644 jcp.reduce_dim = jcp.oc;
645 jcp.reduce_block = jcp.oc_block;
647 jcp.load_dim = jcp.ic;
648 jcp.load_block = jcp.oc_block;
650 jcp.bcast_dim = jcp.os;
651 jcp.bcast_block = jcp.ur;
653 jcp.reduce_loop_unroll = jcp.reduce_block;
654 jcp.reduce_loop_bcast_step
655 = jcp.reduce_loop_unroll * jcp.os * sizeof(float);
656 jcp.reduce_loop_load_step
657 = jcp.reduce_loop_unroll * jcp.ic * sizeof(float);
659 jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float);
660 jcp.bcast_loop_output_substep = -1; // unused
661 jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float);
662 jcp.bcast_loop_bcast_substep = -1; // unused
664 jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
665 jcp.load_loop_iter_step = jcp.ic_block;
667 load_blocking = 96; // assumes the kernel is jcp.ur x 3
668 load_blocking_max = 144;
669 bcast_blocking = 128; // affects load balancing across threads
670 bcast_blocking_max = 196;
671 reduce_blocking = 64; // affects L1$ utilization
672 } else if (jcp.prop_kind == backward_weights) {
673 jcp.reduce_dim = jcp.os;
674 jcp.reduce_block = 1;
676 jcp.load_dim = jcp.oc;
677 jcp.load_block = jcp.oc_block;
679 jcp.bcast_dim = jcp.ic;
680 jcp.bcast_block = jcp.ic_block;
682 jcp.reduce_loop_unroll = jcp.reduce_block;
683 jcp.reduce_loop_bcast_step
684 = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float);
685 jcp.reduce_loop_load_step
686 = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
688 jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float);
689 jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
690 jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float);
691 jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
693 jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float);
694 jcp.load_loop_iter_step = jcp.oc_block;
698 load_blocking = div_up(jcp.load_dim, jcp.load_block);
700 if (load_blocking <= 32) break;
701 else if (load_blocking % 2 == 0) load_blocking /= 2;
702 else if (load_blocking % 3 == 0) load_blocking /= 3;
705 load_blocking *= jcp.load_block;
706 load_blocking_max = load_blocking;
707 assert(jcp.load_dim % load_blocking == 0);
709 bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
711 if (bcast_blocking <= 9) break;
712 else if (bcast_blocking % 2 == 0) bcast_blocking /= 2;
713 else if (bcast_blocking % 3 == 0) bcast_blocking /= 3;
716 bcast_blocking *= jcp.bcast_block;
717 bcast_blocking_max = bcast_blocking;
718 assert(jcp.bcast_dim % bcast_blocking == 0);
720 reduce_blocking = 128; // affects L1$ utilization
722 return status::unimplemented;
724 assert(load_blocking);
725 assert(load_blocking_max);
726 assert(bcast_blocking);
727 assert(bcast_blocking_max);
728 assert(reduce_blocking);
730 assert(jcp.bcast_block % jcp.ur == 0);
731 jcp.ur_tail = jcp.bcast_dim % jcp.ur;
733 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
734 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
735 jcp.nb_load_blocking = load_blocking / jcp.load_block;
736 jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
737 jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
739 jcp.nb_bcast = jcp.with_dw_conv ? jcp.ih : div_up(jcp.bcast_dim, jcp.bcast_block);
740 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
741 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
743 return status::success;
746 void jit_avx2_1x1_conv_kernel_f32::init_scratchpad(
747 memory_tracking::registrar_t &scratchpad,
748 const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
749 using namespace mkldnn::impl::memory_tracking::names;
751 if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
752 scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
754 if (jcp.with_dw_conv) {
755 const int nthreads = mkldnn_get_max_threads();
756 size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
757 scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
759 if (jcp.oc != jcp.oc_without_padding)
760 scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);