updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_primitive_conf.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_PRIMITIVE_CONF_HPP
18 #define JIT_PRIMITIVE_CONF_HPP
19
20 #include <stdint.h>
21
22 #include "common/primitive_attr.hpp"
23
24 namespace mkldnn {
25 namespace impl {
26 namespace cpu {
27
28 /* convolution */
29 enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_4vnni,
30                      ver_vnni};
31 enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn,
32                             loop_ngcw, loop_nhwcg, loop_nwcg};
33 enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr,
34                             loop_brl};
35 enum conv_kernel_kind_t {embd_bcast, expl_bcast};
36 enum conv_harness_t {harness_2d_reduction, harness_3d_reduction,
37                      harness_mb_reduction};
38
39 enum {
40     FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1,
41     FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3,
42     FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5,
43     FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7,
44     FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9,
45     FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips
46                                    loading weights-data from memory; this
47                                    needs to happen on the first Group/16
48                                    iteration. */
49     FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip
50                                loading bias data from memory */
51     FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution
52                                     pass */
53 };
54
55 struct jit_conv_conf_t {
56     prop_kind_t prop_kind;
57     conv_version_t ver;
58     conv_loop_order_t loop_order;
59     conv_harness_t harness;
60
61     int simd_w;
62     int ndims;
63     int mb;
64     int ngroups, ic, oc, oc_without_padding, ic_without_padding;
65     int oc_padded;
66     int id, ih, iw, od, oh, ow;
67     int f_pad, l_pad, t_pad;
68     int back_pad, r_pad, b_pad;
69     int kd, kh, kw;
70     int stride_d, stride_h, stride_w;
71     int dilate_d, dilate_h, dilate_w;
72     memory_format_t src_fmt;
73     memory_format_t dst_fmt;
74     bool with_bias;
75     bool with_sum;
76     bool with_eltwise;
77     bool with_dw_conv;
78     bool with_binarization;
79
80     post_ops_t::entry_t::eltwise_t eltwise;
81
82     int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
83
84     int idp, ihp, iwp, ohp, owp;
85
86     const float* conv_weights;
87     const float* conv_biases;
88     int dw_conv_oh, dw_conv_ow;
89     data_type_t dw_conv_dst_dt;
90
91     int nb_ic, ic_block;
92     int nb_oc, oc_block;
93     int nb_ow, ow_block;
94     int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking
95                            into account vector registers distribution */
96     int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work
97                                       within threads */
98     int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work
99     int nb_ic_L2;
100     int h_blocking;
101     int nb_oc_L2;
102     int ur_h, ur_w;
103     int ur_w_tail;
104     bool is_1stconv;
105     int nonblk_group_off;
106     /* fma avx512_core */
107     conv_kernel_kind_t kernel_kind;
108     /* 4fma */
109     int tr_iw;
110     int tr_src_num_guard_elems;
111     /* 1st conv: 4fma */
112     int tr_ld;
113     int kh_step;
114     /* 4vnni */
115     int typesize_in;
116     int typesize_out;
117     int typesize_bia;
118     int typesize_acc;
119     int tr_ow;
120     /* avx512_u8s8u8 */
121     int ic_nb1, ic_nb2;
122     int oc_nb1;
123     int ur_ow_max, ur_ow, ur_ow_tail;
124     int ur_ow_nsteps;
125     data_type_t src_dt;
126     data_type_t bia_dt;
127     /* bf16 data-type for output */
128     data_type_t dst_dt;
129     data_type_t dsrc_dt;
130     data_type_t dwei_dt;
131     /* avx512: max possible value is nregs(32) - aux_regs(4) */
132     int src_offsets[28];
133     int src_count;
134     bool expl_bcast;
135     bool large_spatial;
136     int is_oc_scale;
137     int max_regs_ur; // maximum accumulation registers
138     // dw conv
139     int nb_ch, ch_block, nb_ch_blocking;
140     bool is_depthwise, is_fast_depthwise, is_resrc_depthwise;
141     int aligned_threads;
142     // large spatial
143     int oh_blk_size;
144     // s8s8 convolution
145     bool signed_input;
146     float wei_adj_scale;
147
148     bool is_cpx;
149
150     // planar conv
151     int nb_ow_blocking;
152
153     int oh_block;
154     int nb_oh_blocking;
155     int oh_block_step;
156 };
157
158 struct jit_conv_conf_2x3_wino_t {
159     conv_version_t ver;
160
161     int m;
162     int r;
163     int alpha;
164     int tile_h, tile_w;
165
166     int mb;
167     int ngroups, ic, oc, oc_without_padding;
168     int ih, iw, oh, ow;
169     int l_pad, t_pad;
170     int r_pad, b_pad;
171     int kh, kw;
172     int stride_h, stride_w;
173     int dilate_h, dilate_w;
174
175     int nb_ic, ic_block;
176     int nb_oc, oc_block;
177
178     int w_block_size, h_block_size;
179
180     data_type_t bia_dt;
181     data_type_t dst_dt;
182
183     int is_oc_scale;
184     int typesize_in;
185     int typesize_out;
186     int typesize_bia;
187     int typesize_acc;
188
189     memory_format_t src_fmt;
190     bool with_bias;
191     bool small_mb;
192
193     int xb, yb;
194     int inp_stride;
195     int out_stride;
196     int wei_stride;
197     int bia_stride;
198
199     int M, N, K;
200     int m_block, n_block, k_block;
201     int n2_block, n_chunks;
202     int k2_block, k_chunks;
203
204     int mb_block, nb_mb;
205
206     size_t size_wino_src, size_wino_wei, size_wino_dst;
207
208     int nthr;
209 };
210
211 /*
212    Winograd sched policy:
213
214    Computation Unit:
215    W: weights transform
216    S: src transform
217    D: dst transform
218    G: gemm
219
220    Thread grouping by:
221    i: nb_ic
222    o: nb_oc
223    t: tile_block
224    e: element in tile
225
226    Note: 'i' and 'o' are omited if
227    i. not comblined with t or
228    ii. with discrete transforms
229
230    Current policies supported:
231 */
232 enum winograd_sched_t {
233     WSCHED_INVALID = 0,
234
235     /* Forward & backward-data */
236     /* W_S_G_D implements discrete transforms */
237     WSCHED_DATA_W_S_G_D,
238     /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/
239     WSCHED_DATA_W_SGD,
240
241     /* Backward-weights */
242     WSCHED_WEI_S_D_G_W,
243     WSCHED_WEI_SDGtWo,
244     WSCHED_WEI_S_D_Giot_W,
245     WSCHED_WEI_SDGt_W,
246 };
247
248 struct jit_conv_winograd_conf_t : public jit_conv_conf_t {
249     int itiles;
250     int jtiles;
251     int ntiles;
252     int ic_simd_block=16;
253     int tile_4fma_padding;
254     int tile_4fma;
255     int oc_simd_block=16;
256     int oc_reg_block;
257     int ic_reg_block;
258     int tile_block;
259     int tile_block_ur;
260     int nb_tile_block_ur;
261
262     bool double_buffering;
263     bool with_relu_postsum;
264     int zmm_start;
265     int nb_reg;
266
267     int dimK;
268     int dimK_4fma;
269     int dimK_reg_block;
270     int dimK_block;
271     int dimK_nb_block;
272
273     int dimM;
274     int dimM_reg_block;
275     int dimM_simd_block;
276     int dimM_block;
277     int dimM_nb_block;
278
279     int dimN;
280     int dimN_reg_block;
281     int dimN_bcast_ur;
282     int dimN_block;
283     int dimN_nb_block;
284
285     winograd_sched_t sched_policy;
286 };
287
288 struct jit_bin_conv_conf_t {
289     prop_kind_t prop_kind;
290     conv_version_t ver;
291     conv_loop_order_t loop_order;
292
293     int ndims;
294     int mb;
295     int ngroups, ic, oc, oc_padded, ic_padded;
296     int id, ih, iw, od, oh, ow;
297     int f_pad, l_pad, t_pad;
298     int back_pad, r_pad, b_pad;
299     int kd, kh, kw;
300     int stride_d, stride_h, stride_w;
301     int dilate_d, dilate_h, dilate_w;
302     memory_format_t src_fmt;
303     bool with_bias;
304     bool with_sum;
305     bool with_eltwise;
306     bool with_dw_conv;
307     bool with_binarization;
308
309     float pad_value;
310     bool exclude_pad;
311
312     int dw_conv_oh;
313     int dw_conv_ow;
314     data_type_t dw_conv_dst_dt;
315
316     int nb_ic, ic_block;
317     int nb_oc, oc_block;
318     int nb_ic_blocking, nb_oc_blocking; // blocking of nb_ic and nb_ic
319     int ur_h, ur_w;
320     int ur_w_tail;
321     int typesize_in;
322     int typesize_out;
323     int typesize_bia;
324     int typesize_acc;
325     data_type_t src_dt;
326     data_type_t bia_dt;
327     data_type_t dst_dt;
328 };
329
330 struct jit_def_conv_conf_t {
331     prop_kind_t prop_kind;
332
333     int ndims;
334     int mb;
335     int dg;
336     int ngroups, ic, oc, oc_padded, ic_padded;
337     int id, ih, iw, od, oh, ow;
338     int f_pad, l_pad, t_pad;
339     int back_pad, r_pad, b_pad;
340     int kd, kh, kw;
341     int stride_d, stride_h, stride_w;
342     int dilate_d, dilate_h, dilate_w;
343     memory_format_t src_fmt;
344     bool with_bias;
345     bool with_sum;
346     int nthr;
347     int nb_ic, ic_block;
348     int nb_oc, oc_block;
349     int nb_ic_blocking, nb_oc_blocking;
350     int ur_h, ur_w;
351     int ur_w_tail;
352     int typesize_in;
353     int typesize_off;
354     int typesize_bia;
355     int typesize_out;
356     data_type_t src_dt;
357     data_type_t off_dt;
358     data_type_t bia_dt;
359     data_type_t dst_dt;
360 };
361
362 struct jit_conv_call_s {
363     const void *src; /* hack, non-const for backward_data */
364     const void *dst; /* hack, non-const for forward */
365     const void *filt; /* hack, non-const for backward_weights */
366     const void *bias; /* hack, non-const for backward_bias */
367     const void *src_prf;
368     const void *dst_prf;
369     const void *filt_prf;
370     const void *bias_prf;
371     const void *scales;
372     const void *acc_s32;
373     const void *compensation;
374     size_t kd_offset;
375     size_t kd_offset_prf;
376     size_t kh_offset;
377     size_t kh_offset_prf;
378     size_t os_index_begin;
379     size_t os_index_begin_prf;
380     size_t os_index_end;
381     size_t os_index_end_prf;
382     size_t kd_padding;
383     size_t kd_padding_prf;
384     size_t kh_padding;
385     size_t kh_padding_prf;
386     size_t owb;
387     size_t owb_prf;
388     size_t kw_padding;
389     size_t channel;
390     size_t channel_prf;
391     size_t oc_blocks;
392     size_t oc_work;
393     size_t ur_w;
394     size_t ur_str_w;
395     size_t ch_blocks;
396     size_t ch_work;
397     size_t t_overflow;
398     size_t b_overflow;
399     size_t front_overflow;
400     size_t back_overflow;
401     size_t oh_blocks;
402     int flags;
403
404     const void *src_row0; /* hack, non-const for backward_data */
405     const void *src_row1; /* hack, non-const for backward_data */
406     const void *src_row2; /* hack, non-const for backward_data */
407
408     size_t oc_off;
409     size_t oc_off_prf;
410 };
411
412 struct jit_deconv_call_s {
413     const void *src; /* hack, non-const for backward_data */
414     const void *dst; /* hack, non-const for forward */
415     const void *filt; /* hack, non-const for backward_weights */
416     const void *bias; /* hack, non-const for backward_bias */
417     const void *scales;
418     const void *compensation;
419     size_t t_overflow;
420     size_t b_overflow;
421     size_t kh_padding;
422     size_t oc_blocks;
423 };
424
425 struct jit_dw_conv_call_s {
426     const void *input;
427     const void *output;
428     const void *filter;
429     const void *bias;
430     size_t kh_count;
431     size_t oh_count;
432     size_t oh_index;
433     size_t filter_pad_off;
434     unsigned char
435             exec_flags; /* Flags passed by driver execution to inner kernel */
436 };
437
438 struct jit_wino_transform_call_s {
439     size_t tile_block;
440     size_t tile_block_ur;
441     size_t nb_tile_block_ur;
442     size_t tile_count;
443     size_t tj;
444     size_t ti;
445     void *src;
446     void *dst;
447     void *Mw;
448     void *M;
449     void *T;
450     void *G;
451     void *bias;
452 };
453
454 struct jit_1x1_conv_conf_t {
455     prop_kind_t prop_kind;
456     conv_version_t ver;
457
458     int mb;
459     int ngroups, ic, oc, oc_without_padding, ic_without_padding;
460     int iw, ih, ow, oh;
461     int l_pad, t_pad;
462     int r_pad, b_pad;
463     int kh, kw;
464     int stride_h, stride_w;
465     memory_format_t src_fmt;
466     memory_format_t dst_fmt;
467     bool with_bias;
468     bool with_sum;
469     bool with_eltwise;
470     bool with_dw_conv;
471
472     post_ops_t::entry_t::eltwise_t eltwise;
473
474     int is, os;
475     int ic_block, oc_block;
476
477     int ur, ur_tail;
478
479     int reduce_dim, reduce_block, nb_reduce,
480         nb_reduce_blocking, nb_reduce_blocking_max;
481     int load_dim, load_block, nb_load,
482         nb_load_blocking, nb_load_blocking_max, nb_load_chunk;
483     int bcast_dim, bcast_block, nb_bcast,
484         nb_bcast_blocking, nb_bcast_blocking_max;
485
486     int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step;
487     int load_loop_load_step, load_loop_iter_step;
488     int bcast_loop_output_step, bcast_loop_output_substep;
489     int bcast_loop_bcast_step, bcast_loop_bcast_substep;
490     int fma_step;
491     int load_grp_count;
492     conv_1x1_loop_order_t loop_order;
493     bool use_vmovntps;
494     /* avx512 core */
495     bool expl_bcast;
496     /* 4vnni */
497     int typesize_in;
498     int typesize_out;
499     int typesize_bia;
500     int typesize_acc;
501     /* 4fma */
502     bool transpose_src;
503     int tr_is;
504     int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
505     int is_oc_scale;
506     data_type_t src_dt;
507     data_type_t bia_dt;
508     data_type_t dst_dt;
509     bool signed_input;
510     float wei_adj_scale;
511     int dw_conv_oh, dw_conv_ow;
512     data_type_t dw_conv_dst_dt;
513
514     bool is_cpx;
515
516     /* u8s8s32x */
517     int ic_dim, nb_ic, nb_ic_blocking, nb_ic_blocking_max;
518     int oc_dim, nb_oc, nb_oc_blocking, nb_oc_blocking_max;
519     int is_dim, os_block, nb_oh_blocking, nb_oh_blocking_max;
520     int ow_tail;
521
522     int ic_loop_unroll, ic_loop_src_step, ic_loop_wei_step;
523     int os_loop_dst_step, os_loop_src_step, os_loop_acc_step;
524     int os_loop_src_tail_step, os_loop_dst_tail_step, os_loop_acc_tail_step;
525 };
526
527 struct jit_gemm_conv_conf_t {
528     prop_kind_t prop_kind;
529
530     int mb;
531     int ngroups, ic, oc;
532     int iw, ih, id, ow, oh, od;
533     int l_pad, t_pad, f_pad;
534     int kh, kw, kd;
535     int stride_h, stride_w, stride_d;
536     int dilate_h, dilate_w, dilate_d;
537     memory_format_t src_fmt;
538     bool with_bias;
539
540     int is, os, ks;
541     int ic_block, oc_block;
542
543     int nthr;
544     ptrdiff_t im2col_sz;
545     bool need_wei_reduction;
546     bool signed_input;
547     float wei_adj_scale;
548     int oh_block;
549     int ow_block;
550     bool outer_threading;
551 };
552
553 struct jit_1x1_conv_call_s {
554     const void *bcast_data;
555     const void *load_data;
556     const void *output_data;
557     const void *bias_data; // used in forward and backward_weights only
558     const void *acc_s32;
559     const void *scales;
560     const void *compensation;
561
562     size_t load_dim;
563     size_t bcast_dim;
564     size_t reduce_dim;
565
566     size_t output_stride; // used in backward_weights only
567
568     size_t first_last_flag;
569
570     // dw conv fusing
571     const void *weights_dw;
572     const void *bias_dw;
573
574     size_t oc_off;
575
576     /* u8s8s32x */
577     size_t oc_dim;
578     size_t os_dim;
579     size_t ic_dim;
580     size_t ic_pos_flag;
581         const void *is_data;
582         const void *oc_data;
583 };
584
585 struct jit_def_conv_call_s {
586     const void *src;
587     const void *off;
588     const void *filt;
589     const void *bias;
590     const void *dst;
591     const void *buf;
592     size_t oh_pos;
593 };
594
595 /* pooling */
596 struct jit_pool_conf_t {
597     int ndims;
598     int mb, c;
599     int id, ih, iw, od, oh, ow;
600     int stride_d, stride_h, stride_w;
601     int kd, kh, kw;
602     int f_pad, t_pad, l_pad, b_pad, r_pad, back_pad;
603     alg_kind_t alg;
604     bool is_training;
605     bool pad_w_is_null;
606     bool is_backward;
607     bool simple_alg;
608     data_type_t ind_dt;
609
610     int c_block, c_tail, nb_c;
611     int ur_c, ur_c_tail;
612     int ur_w;
613     int ur_w_tail;
614     size_t tail[4];
615     data_type_t src_dt;
616     data_type_t dst_dt;
617
618     bool is_bf16;
619     int dt_size;
620
621     bool is_cpx;
622 };
623
624 struct jit_pool_call_s {
625     const void *src;
626     const void *dst;
627     const void *indices;
628     const void *src_prf;
629     const void *dst_prf;
630     const void *indices_prf;
631     size_t oh;
632     size_t kd_padding;
633     size_t kh_padding;
634     size_t kh_padding_shift;
635     size_t kd_padding_shift;
636     size_t kw_padding;
637     const void *init_value;
638     float ker_area_h;
639 };
640
641 /* roi pooling */
642 struct jit_roi_pool_conf_t {
643     int mb, c;
644     int ih, iw, oh, ow;
645
646     int c_block, nb_c, nb_c_blocking;
647
648     double spatial_scale;
649     int pooled_h;
650     int pooled_w;
651
652     alg_kind_t alg;
653 };
654
655 struct jit_roi_pool_call_s {
656     const float *src;
657     float *dst;
658
659     size_t kh;
660     size_t kw;
661     size_t bin_area;
662
663     size_t c_blocks;
664
665     float xf;
666     float yf;
667
668     size_t xoff;
669     size_t yoff;
670 };
671
672 /* softmax */
673 struct jit_softmax_conf_t {
674     size_t outer_size;
675     size_t channels;
676     size_t inner_size;
677     size_t ur_channel;
678     size_t ur_inner;
679     size_t outer_block;
680 };
681
682 struct jit_softmax_call_s {
683     const float *src;
684     float *dst;
685
686     size_t channels;
687     size_t work;
688 };
689
690 }
691 }
692 }
693
694 #endif