1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef JIT_AVX512_CORE_BF16_SUM_HPP
18 #define JIT_AVX512_CORE_BF16_SUM_HPP
20 #include "c_types_map.hpp"
22 #include "bfloat16_utils.hpp"
23 #include "cpu_memory.hpp"
24 #include "cpu_sum.hpp"
25 #include "jit_generator.hpp"
26 #include "jit_primitive_conf.hpp"
27 #include "jit_avx512_core_bf16cvt.hpp"
33 struct jit_sum_conf_t {
40 int size_blocking; /* minimum recommended data blocking size as this
41 number of elements computes main unrolled loop
42 in jit kernel per iteration */
45 struct jit_sum_call_s {
53 /* this needs to be multiple of 2 for vnni instruction
54 * to work with scales */
55 constexpr int max_num_arrs = 8 ;
60 struct jit_avx512_core_bf16_sum_kernel : public jit_generator {
61 using cpu_memory_pd_t = cpu_memory_t::pd_t;
62 jit_avx512_core_bf16_sum_kernel(jit_sum_conf_t ajsp)
66 if (!mayiuse(avx512_core_bf16))
67 bf16_emu_ = new bf16_emulation_t(this, one, even,
68 selector, abi_not_param1,
72 jit_ker = (void (*)(jit_sum_call_s *)) this->getCode();
75 ~jit_avx512_core_bf16_sum_kernel() {
79 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_bf16_sum_kernel)
81 static status_t init_conf(jit_sum_conf_t &jsp,
83 const cpu_memory_pd_t &dst_d);
86 void (*jit_ker)(jit_sum_call_s *);
89 using reg64_t = const Xbyak::Reg64;
90 using reg32_t = const Xbyak::Reg32;
91 using reg8_t = const Xbyak::Reg8;
92 using zmm_t = const Xbyak::Zmm;
93 using ymm_t = const Xbyak::Ymm;
94 using mask_t = const Xbyak::Opmask;
96 enum { f32_simd_w_ = 16 };
98 reg64_t param = abi_param1; /* may be rcx, note that cl is required
99 for mask computation */
101 reg64_t reg_srcs = abi_not_param1; /* may be rcx, note that cl is required
102 for mask computation */
103 reg64_t reg_idx_table = abi_not_param1; /* may be rcx, note that cl is
104 required for mask computation */
105 reg64_t reg_mask = rsi;
106 reg32_t reg32_mask = esi;
108 reg64_t reg_dst = rax;
109 reg64_t reg_scales = rbx;
110 reg64_t reg_sz = rdx;
112 reg64_t reg_src[max_num_arrs] = {r8, r9, r10, r11, r12, r13, r14, r15};
114 static int max_vregs_available(bool is_cpx)
116 // one vector registers are reserved for vperm index and zero values
117 // additional 5 registers are reserved for bf16 emulation on non-cpx
118 return is_cpx ? 31 : 26;
121 int acc_vreg_idx(int i_unroll, int i_acc)
123 // 2 accumulation registers per unroll iteration
124 int idx = 2 * i_unroll + i_acc;
125 assert(idx < max_vregs_available(jsp.is_cpx));
129 int scale_vreg_idx(int i_acc_iter)
131 int scale_idx_start = 2 * jsp.loop_unroll; // reserved for acc registers
132 int idx = scale_idx_start + i_acc_iter;
133 assert(idx < max_vregs_available(jsp.is_cpx));
137 int src_vreg_idx(int i_unroll, int i_inp)
139 // reserved for acc and scale registers
140 int inp_idx_start = 2 * jsp.loop_unroll + utils::div_up(jsp.num_srcs, 2);
141 int idx = inp_idx_start
142 + utils::rnd_up(jsp.num_srcs, 2) * i_unroll + i_inp;
143 assert(idx < max_vregs_available(jsp.is_cpx));
147 int tmp_vreg_idx(int i_unroll, int i_acc_iter)
149 int num_acc_iters = utils::div_up(jsp.num_srcs, 2);
150 // reserved for acc, scale and src registers
151 int tmp_idx_start = utils::div_up(jsp.num_srcs, 2)
152 + (2 + utils::rnd_up(jsp.num_srcs, 2)) * jsp.loop_unroll;
153 int idx = tmp_idx_start
154 + num_acc_iters * i_unroll + i_acc_iter;
155 assert(idx < max_vregs_available(jsp.is_cpx));
159 static int num_vregs_required(int unroll, int num_srcs)
161 int num_acc_iters = utils::div_up(num_srcs, 2);
162 // reserved for acc, scale and src registers
163 int num_regs = utils::div_up(num_srcs, 2)
164 + (2 + utils::rnd_up(num_srcs, 2)) * unroll;
166 num_regs += num_acc_iters * unroll;
170 Xbyak::Zmm one = Xbyak::Zmm(26);
171 Xbyak::Zmm even = Xbyak::Zmm(27);
172 Xbyak::Zmm selector = Xbyak::Zmm(28);
173 Xbyak::Zmm zmm_tmp0 = Xbyak::Zmm(29);
174 Xbyak::Zmm zmm_tmp1 = Xbyak::Zmm(30);
176 Xbyak::Zmm zmm_idx = Xbyak::Zmm(31);
178 Xbyak::Label idx_table;
179 const Xbyak::Opmask k_mask = k1;
182 void loop_iteration(int current_unroll);
183 bf16_emulation_t *bf16_emu_;
186 template <data_type_t src_data_type, data_type_t dst_data_type>
187 struct jit_bf16_sum_t: public cpu_primitive_t {
188 using cpu_memory_pd_t = cpu_memory_t::pd_t;
190 struct pd_t: public cpu_sum_pd_t {
191 pd_t(const memory_desc_t *output_d, int n, const float *scales,
192 const cpu_memory_pd_t **input_pds, const primitive_attr_t *attr)
193 : cpu_sum_pd_t(output_d, n, scales, input_pds, attr), jsp_() {}
195 DECLARE_CPU_SUM_PD_T(
196 JIT_IMPL_NAME_HELPER("jit_bf16_", avx512_core, ""),
199 virtual status_t init() override {
201 && mayiuse(avx512_core)
202 && cpu_sum_pd_t::init() == success
203 && src_pds_.size() <= max_num_arrs;
204 if (!ok) return unimplemented;
206 const memory_desc_wrapper o_d(&dst_pd_);
208 && o_d.data_type() == dst_data_type
210 if (!ok) return unimplemented;
212 const auto n = src_pds_.size();
214 if (n > max_num_arrs)
215 return status::unimplemented;
217 for (size_t i = 0; i < n; ++i) {
218 const memory_desc_wrapper i_d(&src_pds_[i]);
220 && src_data_type == i_d.data_type()
221 && i_d.format() == o_d.format()
224 is_float_representable_in_bfloat16(scales_[i]);
225 if (!ok) return unimplemented;
228 return jit_avx512_core_bf16_sum_kernel::init_conf(jsp_,
235 jit_bf16_sum_t(const pd_t *apd, const input_vector &inputs,
236 const output_vector &outputs)
237 : cpu_primitive_t(apd, inputs, outputs) {
238 kernel_ = new jit_avx512_core_bf16_sum_kernel(pd()->jsp_);
245 virtual void execute(event_t *e) const{
247 e->set_state(event_t::ready);
250 typedef typename prec_traits<src_data_type>::type src_data_t;
251 typedef typename prec_traits<dst_data_type>::type dst_data_t;
252 typedef typename prec_traits<data_type::f32>::type acc_data_t;
255 void execute() const;
256 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
257 jit_avx512_core_bf16_sum_kernel *kernel_;