1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_JIT_GEMM_BF16_CONVOLUTION_HPP
18 #define CPU_JIT_GEMM_BF16_CONVOLUTION_HPP
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
23 #include "cpu_convolution_pd.hpp"
24 #include "cpu_engine.hpp"
25 #include "gemm_convolution_utils.hpp"
26 #include "gemm/gemm.hpp"
27 #include "jit_avx512_core_bf16cvt.hpp"
28 #include "jit_uni_eltwise.hpp"
29 #include "cpu_reducer.hpp"
35 template <data_type_t dst_data_type>
36 struct gemm_bf16_convolution_fwd_t: public cpu_primitive_t {
37 struct pd_t: public cpu_convolution_fwd_pd_t {
38 pd_t(engine_t *engine,
39 const convolution_desc_t *adesc, const primitive_attr_t *attr,
40 const typename pd_t::base_class *hint_fwd_pd)
41 : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
44 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_fwd_t);
46 virtual status_t init() override {
47 using namespace prop_kind;
48 using namespace memory_format;
50 assert(this->engine()->kind() == engine_kind::cpu);
53 && this->set_default_params() == status::success
54 && utils::one_of(this->desc()->prop_kind, forward_training,
56 && utils::one_of(this->desc()->alg_kind,
57 alg_kind::convolution_auto,
58 alg_kind::convolution_direct)
59 && !this->has_zero_dim_memory()
60 && utils::everyone_is(data_type::bf16,
61 this->desc()->src_desc.data_type,
62 this->desc()->weights_desc.data_type)
63 && dst_data_type == this->desc()->dst_desc.data_type
64 && this->src_pd_.desc()->format == src_format()
65 && this->dst_pd_.desc()->format == src_format()
66 && this->weights_pd_.desc()->format == wei_format()
67 && this->is_gemm_conv_format();
68 if (!ok) return status::unimplemented;
70 auto scratchpad = scratchpad_registry().registrar();
71 return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
72 *desc(), src_pd(), weights_pd(0), dst_pd(),
73 mkldnn_get_max_threads());
76 bool is_postprocess_required() const {
77 bool post_ops_sum_only_for_dst_f32 = true
78 && dst_data_type == data_type::f32
79 && attr()->post_ops_.len_ == 1
80 && attr()->post_ops_.contain(primitive_kind::sum, 0);
81 bool is_pp_for_post_ops_required = true
82 && attr()->post_ops_.len_ > 0
83 && !post_ops_sum_only_for_dst_f32;
84 return dst_data_type == data_type::bf16
86 || is_pp_for_post_ops_required;
89 jit_gemm_conv_conf_t jcp_;
92 memory_format_t src_format() const {
93 using namespace memory_format;
94 const int ndims_sp = this->desc()->src_desc.ndims - 2;
95 return (utils::pick(ndims_sp - 1, ncw, nchw, ncdhw));
98 memory_format_t wei_format() const {
99 using namespace memory_format;
100 const int ndims_sp = this->desc()->src_desc.ndims - 2;
101 return (this->with_groups()
102 ? utils::pick(ndims_sp - 1, goiw, goihw, goidhw)
103 : utils::pick(ndims_sp - 1, oiw, oihw, oidhw));
106 virtual status_t set_default_params() override {
107 using namespace memory_format;
108 if (this->src_pd_.desc()->format == any)
109 CHECK(this->src_pd_.set_format(src_format()));
110 if (this->dst_pd_.desc()->format == any)
111 CHECK(this->dst_pd_.set_format(src_format()));
112 if (this->weights_pd_.desc()->format == any)
113 CHECK(this->weights_pd_.set_format(wei_format()));
114 if (this->bias_pd_.desc()->format == any)
115 CHECK(this->bias_pd_.set_format(x));
116 if (this->desc()->alg_kind == alg_kind::convolution_auto)
117 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
118 return status::success;
121 virtual bool is_gemm_conv_format() const {
122 auto const &po = this->attr()->post_ops_;
123 auto is_eltwise = [&](int idx)
124 { return po.entry_[idx].is_eltwise(); };
125 auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); };
128 case 0: return true; // no post_ops
129 case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
130 case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
131 default: return false;
136 gemm_bf16_convolution_fwd_t(const pd_t *apd, const input_vector &inputs,
137 const output_vector &outputs)
138 : cpu_primitive_t(apd, inputs, outputs, true), pp_ker_(nullptr)
140 const auto &post_ops = pd()->attr()->post_ops_;
141 const acc_data_t one = 1.0, zero = 0.0;
142 beta_ = dst_data_type == data_type::f32
143 && post_ops.find(primitive_kind::sum) >= 0
147 if (this->pd()->is_postprocess_required())
148 pp_ker_ = new pp_ker_t(this->pd());
151 ~gemm_bf16_convolution_fwd_t() {
155 typedef typename prec_traits<dst_data_type>::type dst_data_t;
156 typedef typename prec_traits<data_type::f32>::type acc_data_t;
157 typedef typename prec_traits<data_type::bf16>::type src_data_t;
158 typedef typename prec_traits<data_type::bf16>::type wei_data_t;
160 virtual void execute(event_t *e) const {
162 e->set_state(event_t::ready);
166 void execute_forward() const;
167 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
169 class pp_ker_t : jit_generator {
171 DECLARE_CPU_JIT_AUX_FUNCTIONS(
172 gemm_bf16_convolution_fwd_t::pp_kernel);
173 pp_ker_t(const pd_t *pd);
177 delete eltwise_injector_;
180 void operator()(dst_data_t *dst, const acc_data_t *acc,
181 const acc_data_t *bias, float sum_scale,
182 size_t dst_str, size_t acc_str, size_t len, bool do_parallel);
184 size_t dst_os_stride_;
189 const acc_data_t *acc;
190 const acc_data_t *bias;
192 size_t dst_stride_in_bytes;
193 size_t acc_stride_in_bytes;
194 size_t spatial_length;
199 default_unroll_2_pow_ = 2
202 Xbyak::Reg64 reg_param = abi_param1;
203 Xbyak::Reg64 reg_dst_base = rdx;
204 Xbyak::Reg64 reg_acc_base = rax;
205 Xbyak::Reg64 reg_dst = rsi;
206 Xbyak::Reg64 reg_acc = rbp;
207 Xbyak::Reg64 reg_bias = rbx;
209 Xbyak::Reg64 reg_len = r8;
210 Xbyak::Reg64 reg_tmp = rcx; // intentional for shifting purposes
211 Xbyak::Reg64 reg_rem_mask = r9;
212 Xbyak::Opmask kreg_rem_mask = k1;
213 Xbyak::Reg64 reg_oc_iter = r11;
214 Xbyak::Reg64 reg_len_iter = r12;
215 Xbyak::Reg64 reg_dst_str = r13;
216 Xbyak::Reg64 reg_acc_str = r14;
218 Xbyak::Reg64 reserved_eltwise_gpr = r10;
219 Xbyak::Opmask reserved_eltwise_maskr = k2;
221 Xbyak::Zmm vreg_sum_scale, vreg_bias;
223 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(27);
224 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(28);
225 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(29);
226 Xbyak::Reg64 bf16_emu_reserv_4 = r11;
227 Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(30);
228 Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(31);
230 void(*ker_)(const ker_args *args);
231 const jit_gemm_conv_conf_t &jcp_;
236 int max_data_reg_idx_, max_unroll_, compute_reg_step_;
237 int data_reg_base_idx_;
240 bf16_emulation_t *bf16_emu_;
241 jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
244 int vreg_dst_idx(int iter) {
245 int idx = data_reg_base_idx_ + iter * compute_reg_step_ + 0;
246 assert(idx <= max_data_reg_idx_);
249 int vreg_prev_dst_idx(int iter) {
250 int idx = data_reg_base_idx_ + iter * compute_reg_step_ + 1;
251 assert(idx <= max_data_reg_idx_);
255 Xbyak::Zmm vreg_dst(int iter) {
256 return Xbyak::Zmm(vreg_dst_idx(iter));
259 Xbyak::Ymm vreg_dst_ymm(int iter) {
260 return Xbyak::Ymm(vreg_dst_idx(iter));
263 Xbyak::Zmm vreg_prev_dst(int iter) {
264 return Xbyak::Zmm(vreg_prev_dst_idx(iter));
267 Xbyak::Ymm vreg_prev_dst_ymm(int iter) {
268 return Xbyak::Ymm(vreg_prev_dst_idx(iter));
276 template <data_type_t diff_src_data_type>
277 struct gemm_bf16_convolution_bwd_data_t: public cpu_primitive_t {
278 struct pd_t: public cpu_convolution_bwd_data_pd_t {
279 pd_t(engine_t *engine,
280 const convolution_desc_t *adesc, const primitive_attr_t *attr,
281 const convolution_fwd_pd_t *hint_fwd_pd)
282 : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
285 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_bwd_data_t);
287 virtual status_t init() override {
288 using namespace prop_kind;
289 using namespace memory_format;
291 assert(this->engine()->kind() == engine_kind::cpu);
294 && this->set_default_params() == status::success
295 && this->desc()->prop_kind == backward_data
296 && utils::one_of(this->desc()->alg_kind,
297 alg_kind::convolution_auto,
298 alg_kind::convolution_direct)
299 && !this->has_zero_dim_memory()
300 && utils::everyone_is(data_type::bf16,
301 this->desc()->weights_desc.data_type,
302 this->desc()->diff_dst_desc.data_type)
303 && diff_src_data_type == this->desc()->diff_src_desc.data_type
304 && this->diff_src_pd_.desc()->format == src_format()
305 && this->diff_dst_pd_.desc()->format == src_format()
306 && this->weights_pd_.desc()->format == wei_format();
307 if (!ok) return status::unimplemented;
309 auto scratchpad = scratchpad_registry().registrar();
310 return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
311 *desc(), diff_src_pd(), weights_pd(0), diff_dst_pd(),
312 mkldnn_get_max_threads());
315 jit_gemm_conv_conf_t jcp_;
318 memory_format_t src_format() const {
319 using namespace memory_format;
320 const int ndims_sp = this->desc()->diff_src_desc.ndims - 2;
321 return (utils::pick(ndims_sp - 1, ncw, nchw, ncdhw));
324 memory_format_t wei_format() const {
325 using namespace memory_format;
326 const int ndims_sp = this->desc()->diff_src_desc.ndims - 2;
327 return (this->with_groups()
328 ? utils::pick(ndims_sp - 1, goiw, goihw, goidhw)
329 : utils::pick(ndims_sp - 1, oiw, oihw, oidhw));
332 virtual status_t set_default_params() override {
333 using namespace memory_format;
334 if (this->diff_src_pd_.desc()->format == any)
335 CHECK(this->diff_src_pd_.set_format(src_format()));
336 if (this->diff_dst_pd_.desc()->format == any)
337 CHECK(this->diff_dst_pd_.set_format(src_format()));
338 if (this->weights_pd_.desc()->format == any)
339 CHECK(this->weights_pd_.set_format(wei_format()));
340 if (this->desc()->alg_kind == alg_kind::convolution_auto)
341 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
342 return status::success;
346 gemm_bf16_convolution_bwd_data_t(const pd_t *apd,
347 const input_vector &inputs,
348 const output_vector &outputs)
349 : cpu_primitive_t(apd, inputs, outputs, true) {}
351 typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
352 typedef typename prec_traits<data_type::f32>::type acc_data_t;
353 typedef typename prec_traits<diff_src_data_type>::type diff_src_data_t;
354 typedef typename prec_traits<data_type::bf16>::type wei_data_t;
356 virtual void execute(event_t *e) const {
357 switch (pd()->desc()->prop_kind) {
358 case prop_kind::backward_data:
359 execute_backward_data();
362 assert(!"invalid prop_kind");
364 e->set_state(event_t::ready);
368 void execute_backward_data() const;
369 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
372 template <data_type_t diff_wei_data_type>
373 struct gemm_bf16_convolution_bwd_weights_t: public cpu_primitive_t {
374 struct pd_t: public cpu_convolution_bwd_weights_pd_t {
375 pd_t(engine_t *engine,
376 const convolution_desc_t *adesc,
377 const primitive_attr_t *attr,
378 const convolution_fwd_pd_t *hint_fwd_pd)
379 : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
382 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_bwd_weights_t);
384 virtual status_t init() override {
385 using namespace prop_kind;
386 using namespace memory_format;
388 assert(this->engine()->kind() == engine_kind::cpu);
391 && this->set_default_params() == status::success
392 && this->desc()->prop_kind == backward_weights
393 && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
394 alg_kind::convolution_direct)
395 && !this->has_zero_dim_memory()
396 && utils::everyone_is(data_type::bf16,
397 this->desc()->src_desc.data_type,
398 this->desc()->diff_dst_desc.data_type)
399 && diff_wei_data_type == this->desc()->diff_weights_desc.data_type
400 && this->src_pd_.desc()->format == src_format()
401 && this->diff_dst_pd_.desc()->format == src_format()
402 && this->diff_weights_pd_.desc()->format == wei_format();
403 if (!ok) return status::unimplemented;
405 auto scratchpad = scratchpad_registry().registrar();
406 return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
407 *desc(), src_pd(), diff_weights_pd(0), diff_dst_pd(),
408 mkldnn_get_max_threads());
411 jit_gemm_conv_conf_t jcp_;
414 memory_format_t src_format() const {
415 using namespace memory_format;
416 const int ndims_sp = this->desc()->src_desc.ndims - 2;
417 return (utils::pick(ndims_sp - 1, ncw, nchw, ncdhw));
420 memory_format_t wei_format() const {
421 using namespace memory_format;
422 const int ndims_sp = this->desc()->src_desc.ndims - 2;
423 return (this->with_groups()
424 ? utils::pick(ndims_sp - 1, goiw, goihw, goidhw)
425 : utils::pick(ndims_sp - 1, oiw, oihw, oidhw));
428 virtual status_t set_default_params() override {
429 using namespace memory_format;
430 if (this->src_pd_.desc()->format == any)
431 CHECK(this->src_pd_.set_format(src_format()));
432 if (this->diff_dst_pd_.desc()->format == any)
433 CHECK(this->diff_dst_pd_.set_format(src_format()));
434 if (this->diff_weights_pd_.desc()->format == any)
435 CHECK(this->diff_weights_pd_.set_format(wei_format()));
436 if (this->diff_bias_pd_.desc()->format == any)
437 CHECK(this->diff_bias_pd_.set_format(x));
438 if (this->desc()->alg_kind == alg_kind::convolution_auto)
439 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
440 return status::success;
444 gemm_bf16_convolution_bwd_weights_t(const pd_t *apd,
445 const input_vector &inputs,
446 const output_vector &outputs)
447 : cpu_primitive_t(apd, inputs, outputs, true)
450 acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
452 ~gemm_bf16_convolution_bwd_weights_t() {
456 typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
457 typedef typename prec_traits<data_type::f32>::type acc_data_t;
458 typedef typename prec_traits<data_type::bf16>::type src_data_t;
459 typedef typename prec_traits<diff_wei_data_type>::type diff_wei_data_t;
461 virtual void execute(event_t *e) const {
462 switch (pd()->desc()->prop_kind) {
463 case prop_kind::backward_weights:
464 execute_backward_weights();
467 assert(!"invalid prop_kind");
469 e->set_state(event_t::ready);
473 void bf16_bwd_weights_reduction_par(int ithr_mb, int nthr_mb,
474 const jit_gemm_conv_conf_t &jcp, const acc_data_t *weights_reduce_base,
475 diff_wei_data_t *weights_base) const;
477 void execute_backward_weights() const;
478 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
480 cpu_accumulator_1d_t<data_type::f32> *acc_ker_;