1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef JIT_PRIMITIVE_CONF_HPP
18 #define JIT_PRIMITIVE_CONF_HPP
22 #include "common/primitive_attr.hpp"
29 enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_4vnni,
31 enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn,
32 loop_ngcw, loop_nhwcg, loop_nwcg};
33 enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr,
35 enum conv_kernel_kind_t {embd_bcast, expl_bcast};
36 enum conv_harness_t {harness_2d_reduction, harness_3d_reduction,
37 harness_mb_reduction};
40 FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1,
41 FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3,
42 FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5,
43 FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7,
44 FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9,
45 FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips
46 loading weights-data from memory; this
47 needs to happen on the first Group/16
49 FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip
50 loading bias data from memory */
51 FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution
55 struct jit_conv_conf_t {
56 prop_kind_t prop_kind;
58 conv_loop_order_t loop_order;
59 conv_harness_t harness;
64 int ngroups, ic, oc, oc_without_padding, ic_without_padding;
66 int id, ih, iw, od, oh, ow;
67 int f_pad, l_pad, t_pad;
68 int back_pad, r_pad, b_pad;
70 int stride_d, stride_h, stride_w;
71 int dilate_d, dilate_h, dilate_w;
72 memory_format_t src_fmt;
73 memory_format_t dst_fmt;
78 bool with_binarization;
80 post_ops_t::entry_t::eltwise_t eltwise;
82 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
84 int idp, ihp, iwp, ohp, owp;
86 const float* conv_weights;
87 const float* conv_biases;
88 int dw_conv_oh, dw_conv_ow;
89 data_type_t dw_conv_dst_dt;
94 int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking
95 into account vector registers distribution */
96 int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work
98 int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work
105 int nonblk_group_off;
106 /* fma avx512_core */
107 conv_kernel_kind_t kernel_kind;
110 int tr_src_num_guard_elems;
123 int ur_ow_max, ur_ow, ur_ow_tail;
127 /* bf16 data-type for output */
131 /* avx512: max possible value is nregs(32) - aux_regs(4) */
137 int max_regs_ur; // maximum accumulation registers
139 int nb_ch, ch_block, nb_ch_blocking;
140 bool is_depthwise, is_fast_depthwise, is_resrc_depthwise;
158 struct jit_conv_conf_2x3_wino_t {
167 int ngroups, ic, oc, oc_without_padding;
172 int stride_h, stride_w;
173 int dilate_h, dilate_w;
178 int w_block_size, h_block_size;
189 memory_format_t src_fmt;
200 int m_block, n_block, k_block;
201 int n2_block, n_chunks;
202 int k2_block, k_chunks;
206 size_t size_wino_src, size_wino_wei, size_wino_dst;
212 Winograd sched policy:
226 Note: 'i' and 'o' are omited if
227 i. not comblined with t or
228 ii. with discrete transforms
230 Current policies supported:
232 enum winograd_sched_t {
235 /* Forward & backward-data */
236 /* W_S_G_D implements discrete transforms */
238 /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/
241 /* Backward-weights */
244 WSCHED_WEI_S_D_Giot_W,
248 struct jit_conv_winograd_conf_t : public jit_conv_conf_t {
252 int ic_simd_block=16;
253 int tile_4fma_padding;
255 int oc_simd_block=16;
260 int nb_tile_block_ur;
262 bool double_buffering;
263 bool with_relu_postsum;
285 winograd_sched_t sched_policy;
288 struct jit_bin_conv_conf_t {
289 prop_kind_t prop_kind;
291 conv_loop_order_t loop_order;
295 int ngroups, ic, oc, oc_padded, ic_padded;
296 int id, ih, iw, od, oh, ow;
297 int f_pad, l_pad, t_pad;
298 int back_pad, r_pad, b_pad;
300 int stride_d, stride_h, stride_w;
301 int dilate_d, dilate_h, dilate_w;
302 memory_format_t src_fmt;
307 bool with_binarization;
314 data_type_t dw_conv_dst_dt;
318 int nb_ic_blocking, nb_oc_blocking; // blocking of nb_ic and nb_ic
330 struct jit_def_conv_conf_t {
331 prop_kind_t prop_kind;
336 int ngroups, ic, oc, oc_padded, ic_padded;
337 int id, ih, iw, od, oh, ow;
338 int f_pad, l_pad, t_pad;
339 int back_pad, r_pad, b_pad;
341 int stride_d, stride_h, stride_w;
342 int dilate_d, dilate_h, dilate_w;
343 memory_format_t src_fmt;
349 int nb_ic_blocking, nb_oc_blocking;
362 struct jit_conv_call_s {
363 const void *src; /* hack, non-const for backward_data */
364 const void *dst; /* hack, non-const for forward */
365 const void *filt; /* hack, non-const for backward_weights */
366 const void *bias; /* hack, non-const for backward_bias */
369 const void *filt_prf;
370 const void *bias_prf;
373 const void *compensation;
375 size_t kd_offset_prf;
377 size_t kh_offset_prf;
378 size_t os_index_begin;
379 size_t os_index_begin_prf;
381 size_t os_index_end_prf;
383 size_t kd_padding_prf;
385 size_t kh_padding_prf;
399 size_t front_overflow;
400 size_t back_overflow;
404 const void *src_row0; /* hack, non-const for backward_data */
405 const void *src_row1; /* hack, non-const for backward_data */
406 const void *src_row2; /* hack, non-const for backward_data */
412 struct jit_deconv_call_s {
413 const void *src; /* hack, non-const for backward_data */
414 const void *dst; /* hack, non-const for forward */
415 const void *filt; /* hack, non-const for backward_weights */
416 const void *bias; /* hack, non-const for backward_bias */
418 const void *compensation;
425 struct jit_dw_conv_call_s {
433 size_t filter_pad_off;
435 exec_flags; /* Flags passed by driver execution to inner kernel */
438 struct jit_wino_transform_call_s {
440 size_t tile_block_ur;
441 size_t nb_tile_block_ur;
454 struct jit_1x1_conv_conf_t {
455 prop_kind_t prop_kind;
459 int ngroups, ic, oc, oc_without_padding, ic_without_padding;
464 int stride_h, stride_w;
465 memory_format_t src_fmt;
466 memory_format_t dst_fmt;
472 post_ops_t::entry_t::eltwise_t eltwise;
475 int ic_block, oc_block;
479 int reduce_dim, reduce_block, nb_reduce,
480 nb_reduce_blocking, nb_reduce_blocking_max;
481 int load_dim, load_block, nb_load,
482 nb_load_blocking, nb_load_blocking_max, nb_load_chunk;
483 int bcast_dim, bcast_block, nb_bcast,
484 nb_bcast_blocking, nb_bcast_blocking_max;
486 int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step;
487 int load_loop_load_step, load_loop_iter_step;
488 int bcast_loop_output_step, bcast_loop_output_substep;
489 int bcast_loop_bcast_step, bcast_loop_bcast_substep;
492 conv_1x1_loop_order_t loop_order;
504 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
511 int dw_conv_oh, dw_conv_ow;
512 data_type_t dw_conv_dst_dt;
517 int ic_dim, nb_ic, nb_ic_blocking, nb_ic_blocking_max;
518 int oc_dim, nb_oc, nb_oc_blocking, nb_oc_blocking_max;
519 int is_dim, os_block, nb_oh_blocking, nb_oh_blocking_max;
522 int ic_loop_unroll, ic_loop_src_step, ic_loop_wei_step;
523 int os_loop_dst_step, os_loop_src_step, os_loop_acc_step;
524 int os_loop_src_tail_step, os_loop_dst_tail_step, os_loop_acc_tail_step;
527 struct jit_gemm_conv_conf_t {
528 prop_kind_t prop_kind;
532 int iw, ih, id, ow, oh, od;
533 int l_pad, t_pad, f_pad;
535 int stride_h, stride_w, stride_d;
536 int dilate_h, dilate_w, dilate_d;
537 memory_format_t src_fmt;
541 int ic_block, oc_block;
545 bool need_wei_reduction;
550 bool outer_threading;
553 struct jit_1x1_conv_call_s {
554 const void *bcast_data;
555 const void *load_data;
556 const void *output_data;
557 const void *bias_data; // used in forward and backward_weights only
560 const void *compensation;
566 size_t output_stride; // used in backward_weights only
568 size_t first_last_flag;
571 const void *weights_dw;
585 struct jit_def_conv_call_s {
596 struct jit_pool_conf_t {
599 int id, ih, iw, od, oh, ow;
600 int stride_d, stride_h, stride_w;
602 int f_pad, t_pad, l_pad, b_pad, r_pad, back_pad;
610 int c_block, c_tail, nb_c;
624 struct jit_pool_call_s {
630 const void *indices_prf;
634 size_t kh_padding_shift;
635 size_t kd_padding_shift;
637 const void *init_value;
642 struct jit_roi_pool_conf_t {
646 int c_block, nb_c, nb_c_blocking;
648 double spatial_scale;
655 struct jit_roi_pool_call_s {
673 struct jit_softmax_conf_t {
682 struct jit_softmax_call_s {