updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_bf16_convolution.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_GEMM_BF16_CONVOLUTION_HPP
18 #define CPU_JIT_GEMM_BF16_CONVOLUTION_HPP
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22
23 #include "cpu_convolution_pd.hpp"
24 #include "cpu_engine.hpp"
25 #include "gemm_convolution_utils.hpp"
26 #include "gemm/gemm.hpp"
27 #include "jit_avx512_core_bf16cvt.hpp"
28 #include "jit_uni_eltwise.hpp"
29 #include "cpu_reducer.hpp"
30
31 namespace mkldnn {
32 namespace impl {
33 namespace cpu {
34
35 template <data_type_t dst_data_type>
36 struct gemm_bf16_convolution_fwd_t: public cpu_primitive_t {
37     struct pd_t: public cpu_convolution_fwd_pd_t {
38         pd_t(engine_t *engine,
39                 const convolution_desc_t *adesc, const primitive_attr_t *attr,
40                 const typename pd_t::base_class *hint_fwd_pd)
41             : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
42             , jcp_() {}
43
44         DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_fwd_t);
45
46         virtual status_t init() override {
47             using namespace prop_kind;
48             using namespace memory_format;
49
50             assert(this->engine()->kind() == engine_kind::cpu);
51
52             bool ok = true
53                 && this->set_default_params() == status::success
54                 && utils::one_of(this->desc()->prop_kind, forward_training,
55                            forward_inference)
56                 && utils::one_of(this->desc()->alg_kind,
57                         alg_kind::convolution_auto,
58                         alg_kind::convolution_direct)
59                 && !this->has_zero_dim_memory()
60                 && utils::everyone_is(data_type::bf16,
61                            this->desc()->src_desc.data_type,
62                            this->desc()->weights_desc.data_type)
63                 && dst_data_type == this->desc()->dst_desc.data_type
64                 && this->src_pd_.desc()->format == src_format()
65                 && this->dst_pd_.desc()->format == src_format()
66                 && this->weights_pd_.desc()->format == wei_format()
67                 && this->is_gemm_conv_format();
68             if (!ok) return status::unimplemented;
69
70             auto scratchpad = scratchpad_registry().registrar();
71             return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
72                     *desc(), src_pd(), weights_pd(0), dst_pd(),
73                     mkldnn_get_max_threads());
74         }
75
76         bool is_postprocess_required() const {
77             bool post_ops_sum_only_for_dst_f32 = true
78                 && dst_data_type == data_type::f32
79                 && attr()->post_ops_.len_ == 1
80                 && attr()->post_ops_.contain(primitive_kind::sum, 0);
81             bool is_pp_for_post_ops_required = true
82                 && attr()->post_ops_.len_ > 0
83                 && !post_ops_sum_only_for_dst_f32;
84             return dst_data_type == data_type::bf16
85                        || with_bias()
86                        || is_pp_for_post_ops_required;
87         }
88
89         jit_gemm_conv_conf_t jcp_;
90
91     protected:
92         memory_format_t src_format() const {
93             using namespace memory_format;
94             const int ndims_sp = this->desc()->src_desc.ndims - 2;
95             return (utils::pick(ndims_sp - 1, ncw, nchw, ncdhw));
96         }
97
98         memory_format_t wei_format() const {
99             using namespace memory_format;
100             const int ndims_sp = this->desc()->src_desc.ndims - 2;
101             return (this->with_groups()
102                 ? utils::pick(ndims_sp - 1, goiw, goihw, goidhw)
103                 : utils::pick(ndims_sp - 1, oiw, oihw, oidhw));
104         }
105
106         virtual status_t set_default_params() override {
107             using namespace memory_format;
108             if (this->src_pd_.desc()->format == any)
109                 CHECK(this->src_pd_.set_format(src_format()));
110             if (this->dst_pd_.desc()->format == any)
111                 CHECK(this->dst_pd_.set_format(src_format()));
112             if (this->weights_pd_.desc()->format == any)
113                 CHECK(this->weights_pd_.set_format(wei_format()));
114             if (this->bias_pd_.desc()->format == any)
115                 CHECK(this->bias_pd_.set_format(x));
116             if (this->desc()->alg_kind == alg_kind::convolution_auto)
117                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
118             return status::success;
119         }
120
121         virtual bool is_gemm_conv_format() const {
122             auto const &po = this->attr()->post_ops_;
123             auto is_eltwise = [&](int idx)
124             { return po.entry_[idx].is_eltwise(); };
125             auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); };
126
127             switch (po.len_) {
128             case 0: return true; // no post_ops
129             case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
130             case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
131             default: return false;
132             }
133         }
134     };
135
136     gemm_bf16_convolution_fwd_t(const pd_t *apd, const input_vector &inputs,
137            const output_vector &outputs)
138         : cpu_primitive_t(apd, inputs, outputs, true), pp_ker_(nullptr)
139     {
140         const auto &post_ops = pd()->attr()->post_ops_;
141         const acc_data_t one = 1.0, zero = 0.0;
142         beta_ = dst_data_type == data_type::f32
143             && post_ops.find(primitive_kind::sum) >= 0
144                 ? one
145                 : zero;
146
147         if (this->pd()->is_postprocess_required())
148             pp_ker_ = new pp_ker_t(this->pd());
149     }
150
151     ~gemm_bf16_convolution_fwd_t() {
152         delete pp_ker_;
153     }
154
155     typedef typename prec_traits<dst_data_type>::type dst_data_t;
156     typedef typename prec_traits<data_type::f32>::type acc_data_t;
157     typedef typename prec_traits<data_type::bf16>::type src_data_t;
158     typedef typename prec_traits<data_type::bf16>::type wei_data_t;
159
160     virtual void execute(event_t *e) const {
161         execute_forward();
162         e->set_state(event_t::ready);
163     }
164
165 private:
166     void execute_forward() const;
167     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
168
169     class pp_ker_t : jit_generator {
170     public:
171         DECLARE_CPU_JIT_AUX_FUNCTIONS(
172         gemm_bf16_convolution_fwd_t::pp_kernel);
173         pp_ker_t(const pd_t *pd);
174
175         ~pp_ker_t() {
176             delete bf16_emu_;
177             delete eltwise_injector_;
178         }
179
180         void operator()(dst_data_t *dst, const acc_data_t *acc,
181             const acc_data_t *bias, float sum_scale,
182             size_t dst_str, size_t acc_str, size_t len, bool do_parallel);
183
184         size_t dst_os_stride_;
185
186     private:
187         struct ker_args {
188             dst_data_t *dst;
189             const acc_data_t *acc;
190             const acc_data_t *bias;
191             float sum_scale;
192             size_t dst_stride_in_bytes;
193             size_t acc_stride_in_bytes;
194             size_t spatial_length;
195             size_t oc_work;
196         };
197
198         enum {
199             default_unroll_2_pow_ = 2
200         };
201
202         Xbyak::Reg64 reg_param = abi_param1;
203         Xbyak::Reg64 reg_dst_base = rdx;
204         Xbyak::Reg64 reg_acc_base = rax;
205         Xbyak::Reg64 reg_dst = rsi;
206         Xbyak::Reg64 reg_acc = rbp;
207         Xbyak::Reg64 reg_bias = rbx;
208
209         Xbyak::Reg64 reg_len = r8;
210         Xbyak::Reg64 reg_tmp = rcx; // intentional for shifting purposes
211         Xbyak::Reg64 reg_rem_mask = r9;
212         Xbyak::Opmask kreg_rem_mask = k1;
213         Xbyak::Reg64 reg_oc_iter = r11;
214         Xbyak::Reg64 reg_len_iter = r12;
215         Xbyak::Reg64 reg_dst_str = r13;
216         Xbyak::Reg64 reg_acc_str = r14;
217
218         Xbyak::Reg64 reserved_eltwise_gpr = r10;
219         Xbyak::Opmask reserved_eltwise_maskr = k2;
220
221         Xbyak::Zmm vreg_sum_scale, vreg_bias;
222
223         Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(27);
224         Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(28);
225         Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(29);
226         Xbyak::Reg64 bf16_emu_reserv_4 = r11;
227         Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(30);
228         Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(31);
229
230         void(*ker_)(const ker_args *args);
231         const jit_gemm_conv_conf_t &jcp_;
232         size_t OC_;
233         bool do_bias_;
234         bool do_eltwise_;
235         bool do_sum_;
236         int max_data_reg_idx_, max_unroll_, compute_reg_step_;
237         int data_reg_base_idx_;
238         size_t vlen_;
239         bool is_cpx_;
240         bf16_emulation_t *bf16_emu_;
241         jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
242
243         void generate();
244         int vreg_dst_idx(int iter) {
245             int idx = data_reg_base_idx_ + iter * compute_reg_step_ + 0;
246             assert(idx <= max_data_reg_idx_);
247             return idx;
248         }
249         int vreg_prev_dst_idx(int iter) {
250             int idx = data_reg_base_idx_ + iter * compute_reg_step_ + 1;
251             assert(idx <= max_data_reg_idx_);
252             return idx;
253         }
254
255         Xbyak::Zmm vreg_dst(int iter) {
256             return Xbyak::Zmm(vreg_dst_idx(iter));
257         };
258
259         Xbyak::Ymm vreg_dst_ymm(int iter) {
260             return Xbyak::Ymm(vreg_dst_idx(iter));
261         };
262
263         Xbyak::Zmm vreg_prev_dst(int iter) {
264             return Xbyak::Zmm(vreg_prev_dst_idx(iter));
265         };
266
267         Xbyak::Ymm vreg_prev_dst_ymm(int iter) {
268             return Xbyak::Ymm(vreg_prev_dst_idx(iter));
269         };
270     };
271
272     acc_data_t beta_;
273     pp_ker_t *pp_ker_;
274 };
275
276 template <data_type_t diff_src_data_type>
277 struct gemm_bf16_convolution_bwd_data_t: public cpu_primitive_t {
278     struct pd_t: public cpu_convolution_bwd_data_pd_t {
279         pd_t(engine_t *engine,
280                 const convolution_desc_t *adesc, const primitive_attr_t *attr,
281                 const convolution_fwd_pd_t *hint_fwd_pd)
282             : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
283             , jcp_() {}
284
285         DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_bwd_data_t);
286
287         virtual status_t init() override {
288             using namespace prop_kind;
289             using namespace memory_format;
290
291             assert(this->engine()->kind() == engine_kind::cpu);
292
293             bool ok = true
294                 && this->set_default_params() == status::success
295                 && this->desc()->prop_kind == backward_data
296                 && utils::one_of(this->desc()->alg_kind,
297                            alg_kind::convolution_auto,
298                            alg_kind::convolution_direct)
299                 && !this->has_zero_dim_memory()
300                 && utils::everyone_is(data_type::bf16,
301                         this->desc()->weights_desc.data_type,
302                         this->desc()->diff_dst_desc.data_type)
303                 && diff_src_data_type == this->desc()->diff_src_desc.data_type
304                 && this->diff_src_pd_.desc()->format == src_format()
305                 && this->diff_dst_pd_.desc()->format == src_format()
306                 && this->weights_pd_.desc()->format == wei_format();
307             if (!ok) return status::unimplemented;
308
309             auto scratchpad = scratchpad_registry().registrar();
310             return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
311                     *desc(), diff_src_pd(), weights_pd(0), diff_dst_pd(),
312                     mkldnn_get_max_threads());
313         }
314
315         jit_gemm_conv_conf_t jcp_;
316
317     protected:
318         memory_format_t src_format() const {
319             using namespace memory_format;
320             const int ndims_sp = this->desc()->diff_src_desc.ndims - 2;
321             return (utils::pick(ndims_sp - 1, ncw, nchw, ncdhw));
322         }
323
324         memory_format_t wei_format() const {
325             using namespace memory_format;
326             const int ndims_sp = this->desc()->diff_src_desc.ndims - 2;
327             return (this->with_groups()
328                 ? utils::pick(ndims_sp - 1, goiw, goihw, goidhw)
329                 : utils::pick(ndims_sp - 1, oiw, oihw, oidhw));
330         }
331
332         virtual status_t set_default_params() override {
333             using namespace memory_format;
334             if (this->diff_src_pd_.desc()->format == any)
335                 CHECK(this->diff_src_pd_.set_format(src_format()));
336             if (this->diff_dst_pd_.desc()->format == any)
337                 CHECK(this->diff_dst_pd_.set_format(src_format()));
338             if (this->weights_pd_.desc()->format == any)
339                 CHECK(this->weights_pd_.set_format(wei_format()));
340             if (this->desc()->alg_kind == alg_kind::convolution_auto)
341                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
342             return status::success;
343         }
344     };
345
346     gemm_bf16_convolution_bwd_data_t(const pd_t *apd,
347               const input_vector &inputs,
348               const output_vector &outputs)
349         : cpu_primitive_t(apd, inputs, outputs, true) {}
350
351     typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
352     typedef typename prec_traits<data_type::f32>::type acc_data_t;
353     typedef typename prec_traits<diff_src_data_type>::type diff_src_data_t;
354     typedef typename prec_traits<data_type::bf16>::type wei_data_t;
355
356     virtual void execute(event_t *e) const {
357         switch (pd()->desc()->prop_kind) {
358         case prop_kind::backward_data:
359             execute_backward_data();
360             break;
361         default:
362             assert(!"invalid prop_kind");
363         }
364         e->set_state(event_t::ready);
365     }
366
367 private:
368     void execute_backward_data() const;
369     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
370 };
371
372 template <data_type_t diff_wei_data_type>
373 struct gemm_bf16_convolution_bwd_weights_t: public cpu_primitive_t {
374     struct pd_t: public cpu_convolution_bwd_weights_pd_t {
375         pd_t(engine_t *engine,
376                 const convolution_desc_t *adesc,
377                 const primitive_attr_t *attr,
378                 const convolution_fwd_pd_t *hint_fwd_pd)
379             : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
380             , jcp_() {}
381
382         DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_bwd_weights_t);
383
384         virtual status_t init() override {
385             using namespace prop_kind;
386             using namespace memory_format;
387
388             assert(this->engine()->kind() == engine_kind::cpu);
389
390             bool ok = true
391             && this->set_default_params() == status::success
392             && this->desc()->prop_kind == backward_weights
393             && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
394                        alg_kind::convolution_direct)
395             && !this->has_zero_dim_memory()
396             && utils::everyone_is(data_type::bf16,
397                     this->desc()->src_desc.data_type,
398                     this->desc()->diff_dst_desc.data_type)
399             && diff_wei_data_type == this->desc()->diff_weights_desc.data_type
400             && this->src_pd_.desc()->format == src_format()
401             && this->diff_dst_pd_.desc()->format == src_format()
402             && this->diff_weights_pd_.desc()->format == wei_format();
403             if (!ok) return status::unimplemented;
404
405             auto scratchpad = scratchpad_registry().registrar();
406             return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
407                     *desc(), src_pd(), diff_weights_pd(0), diff_dst_pd(),
408                     mkldnn_get_max_threads());
409         }
410
411         jit_gemm_conv_conf_t jcp_;
412
413     protected:
414         memory_format_t src_format() const {
415             using namespace memory_format;
416             const int ndims_sp = this->desc()->src_desc.ndims - 2;
417             return (utils::pick(ndims_sp - 1, ncw, nchw, ncdhw));
418         }
419
420         memory_format_t wei_format() const {
421             using namespace memory_format;
422             const int ndims_sp = this->desc()->src_desc.ndims - 2;
423             return (this->with_groups()
424                 ? utils::pick(ndims_sp - 1, goiw, goihw, goidhw)
425                 : utils::pick(ndims_sp - 1, oiw, oihw, oidhw));
426         }
427
428         virtual status_t set_default_params() override {
429             using namespace memory_format;
430             if (this->src_pd_.desc()->format == any)
431                 CHECK(this->src_pd_.set_format(src_format()));
432             if (this->diff_dst_pd_.desc()->format == any)
433                 CHECK(this->diff_dst_pd_.set_format(src_format()));
434             if (this->diff_weights_pd_.desc()->format == any)
435                 CHECK(this->diff_weights_pd_.set_format(wei_format()));
436             if (this->diff_bias_pd_.desc()->format == any)
437                 CHECK(this->diff_bias_pd_.set_format(x));
438             if (this->desc()->alg_kind == alg_kind::convolution_auto)
439                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
440             return status::success;
441         }
442     };
443
444     gemm_bf16_convolution_bwd_weights_t(const pd_t *apd,
445               const input_vector &inputs,
446               const output_vector &outputs)
447         : cpu_primitive_t(apd, inputs, outputs, true)
448         , acc_ker_(nullptr)
449     {
450         acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
451     }
452     ~gemm_bf16_convolution_bwd_weights_t() {
453         delete acc_ker_;
454     }
455
456     typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
457     typedef typename prec_traits<data_type::f32>::type acc_data_t;
458     typedef typename prec_traits<data_type::bf16>::type src_data_t;
459     typedef typename prec_traits<diff_wei_data_type>::type diff_wei_data_t;
460
461     virtual void execute(event_t *e) const {
462         switch (pd()->desc()->prop_kind) {
463         case prop_kind::backward_weights:
464             execute_backward_weights();
465             break;
466         default:
467             assert(!"invalid prop_kind");
468         }
469         e->set_state(event_t::ready);
470     }
471
472 private:
473     void bf16_bwd_weights_reduction_par(int ithr_mb, int nthr_mb,
474         const jit_gemm_conv_conf_t &jcp, const acc_data_t *weights_reduce_base,
475         diff_wei_data_t *weights_base) const;
476
477     void execute_backward_weights() const;
478     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
479
480     cpu_accumulator_1d_t<data_type::f32> *acc_ker_;
481 };
482
483 }
484 }
485 }
486
487 #endif