1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_JIT_AVX512_CORE_BF16_CONVOLUTION_HPP
18 #define CPU_JIT_AVX512_CORE_BF16_CONVOLUTION_HPP
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
25 #include "cpu_barrier.hpp"
26 #include "cpu_convolution_pd.hpp"
27 #include "cpu_reducer.hpp"
29 #include "jit_transpose_src_utils.hpp"
30 #include "jit_avx512_core_bf16_conv_kernel.hpp"
31 #include "bfloat16_utils.hpp"
37 template <impl::data_type_t dst_type>
38 struct _jit_avx512_core_bf16_convolution_fwd_t : public cpu_primitive_t {
39 struct pd_t : public cpu_convolution_fwd_pd_t {
40 pd_t(engine_t *engine, const convolution_desc_t *adesc,
41 const primitive_attr_t *attr,
42 const typename pd_t::base_class *hint_fwd_pd)
43 : cpu_convolution_fwd_pd_t(engine, adesc, attr,
50 JIT_IMPL_NAME_HELPER("jit_bf16:", avx512_core, ""),
51 _jit_avx512_core_bf16_convolution_fwd_t<dst_type>);
53 virtual status_t init() override
55 using namespace prop_kind;
56 assert(this->engine()->kind() == engine_kind::cpu);
58 && mayiuse(avx512_core)
59 && utils::one_of(this->desc()->prop_kind, forward_training,
61 && utils::one_of(this->desc()->alg_kind,
62 alg_kind::convolution_auto,
63 alg_kind::convolution_direct)
64 && !this->has_zero_dim_memory()
65 && this->desc()->src_desc.data_type == data_type::bf16
66 && this->desc()->weights_desc.data_type == data_type::bf16
67 && this->desc()->dst_desc.data_type == dst_type
68 && IMPLICATION(this->with_bias(),
69 data_type::f32 == this->desc()->bias_desc.data_type);
71 return status::unimplemented;
73 status_t status = jit_avx512_core_bf16_fwd_kernel::init_conf(
74 jcp_, *this->desc(), this->src_pd_, this->weights_pd_,
75 this->dst_pd_, this->bias_pd_, *this->attr(),
76 mkldnn_get_max_threads());
77 if (status != status::success) return status;
79 if (status == status::success
80 && this->desc()->alg_kind == alg_kind::convolution_auto)
81 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
85 return status::success;
88 inline int ndims() { return this->desc()->src_desc.ndims; }
94 void init_scratchpad() {
95 using namespace memory_tracking::names;
96 auto scratchpad = scratchpad_registry().registrar();
97 if (jcp_.with_bias && jcp_.oc != jcp_.oc_without_padding)
98 scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp_.oc);
102 _jit_avx512_core_bf16_convolution_fwd_t(const pd_t *apd,
103 const input_vector &inputs, const output_vector &outputs)
104 : cpu_primitive_t(apd, inputs, outputs)
106 kernel_ = new jit_avx512_core_bf16_fwd_kernel(pd()->jcp_,
109 ~_jit_avx512_core_bf16_convolution_fwd_t() { delete kernel_;}
111 typedef typename prec_traits<data_type::bf16>::type src_data_t;
112 typedef typename prec_traits<data_type::bf16>::type wei_data_t;
113 typedef typename prec_traits<dst_type>::type dst_data_t;
115 virtual void execute(event_t *e) const {
117 e->set_state(event_t::ready);
121 void execute_forward() const;
122 void prepare_padded_bias(const float *&bias) const;
123 jit_avx512_core_bf16_fwd_kernel *kernel_;
124 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
127 template <impl::data_type_t dst_type>
128 using jit_avx512_core_bf16_convolution_fwd_t =
129 _jit_avx512_core_bf16_convolution_fwd_t<dst_type>;
131 template <impl::data_type_t diff_src_type>
132 struct _jit_avx512_core_bf16_convolution_bwd_data_t: public cpu_primitive_t {
133 struct pd_t: public cpu_convolution_bwd_data_pd_t {
134 pd_t(engine_t *engine,
135 const convolution_desc_t *adesc,
136 const primitive_attr_t *attr,
137 const convolution_fwd_pd_t *hint_fwd_pd)
138 : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
143 JIT_IMPL_NAME_HELPER("jit_bf16:", avx512_core, ""),
144 _jit_avx512_core_bf16_convolution_bwd_data_t<diff_src_type>);
146 virtual status_t init() override {
147 using namespace prop_kind;
148 assert(this->engine()->kind() == engine_kind::cpu);
150 && mayiuse(avx512_core)
151 && this->set_default_params() == status::success
152 && utils::one_of(this->desc()->prop_kind, backward_data) // XXX (this->!)
153 && utils::one_of(this->desc()->alg_kind,
154 alg_kind::convolution_auto,
155 alg_kind::convolution_direct)
156 && !this->has_zero_dim_memory()
157 && this->desc()->alg_kind == alg_kind::convolution_direct
158 && this->desc()->diff_dst_desc.data_type == data_type::bf16
159 && this->desc()->weights_desc.data_type == data_type::bf16
160 && this->desc()->diff_src_desc.data_type == diff_src_type;
161 if (!ok) return status::unimplemented;
163 status_t status = jit_avx512_core_bf16_bwd_data_kernel::init_conf(
164 jcp_, *this->desc(), *this->diff_src_pd_.desc(),
165 *this->weights_pd_.desc(), *this->diff_dst_pd_.desc());
166 if (status != status::success) return status;
168 if (status == status::success
169 && this->desc()->alg_kind == alg_kind::convolution_auto)
170 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
172 return status::success;
175 inline int ndims() { return this->desc()->diff_src_desc.ndims; }
177 inline memory_format_t src_format()
179 using namespace memory_format;
180 return utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
182 inline memory_format_t wei_format()
184 using namespace memory_format;
185 return this->with_groups()
186 ? utils::pick(ndims() - 3, gOIw8o16i2o,
187 gOIhw8o16i2o, gOIdhw8o16i2o)
188 : utils::pick(ndims() - 3, OIw8o16i2o,
189 OIhw8o16i2o, OIdhw8o16i2o);
192 jit_conv_conf_t jcp_;
195 virtual status_t set_default_params() override {
196 using namespace memory_format;
198 if (this->diff_src_pd_.desc()->format == any)
199 CHECK(this->diff_src_pd_.set_format(src_format()));
200 if (this->diff_dst_pd_.desc()->format == any)
201 CHECK(this->diff_dst_pd_.set_format(src_format()));
202 if (this->weights_pd_.desc()->format == any)
203 CHECK(this->weights_pd_.set_format(wei_format()));
204 return status::success;
208 _jit_avx512_core_bf16_convolution_bwd_data_t(const pd_t *apd,
209 const input_vector &inputs, const output_vector &outputs)
210 : cpu_primitive_t(apd, inputs, outputs)
212 kernel_ = new jit_avx512_core_bf16_bwd_data_kernel(pd()->jcp_);
214 ~_jit_avx512_core_bf16_convolution_bwd_data_t() { delete kernel_; };
216 typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
217 typedef typename prec_traits<data_type::bf16>::type wei_data_t;
218 typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
220 virtual void execute(event_t *e) const {
221 execute_backward_data();
222 e->set_state(event_t::ready);
226 void execute_backward_data() const;
227 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
228 jit_avx512_core_bf16_bwd_data_kernel *kernel_;
231 template <impl::data_type_t diff_src_type>
232 using jit_avx512_core_bf16_convolution_bwd_data_t =
233 _jit_avx512_core_bf16_convolution_bwd_data_t<diff_src_type>;
235 template <impl::data_type_t diff_weights_type>
236 struct _jit_avx512_core_bf16_convolution_bwd_weights_t: public cpu_primitive_t {
237 struct pd_t: public cpu_convolution_bwd_weights_pd_t {
238 pd_t(engine_t *engine, const convolution_desc_t *adesc,
239 const primitive_attr_t *attr,
240 const convolution_fwd_pd_t *hint_fwd_pd)
241 : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
245 JIT_IMPL_NAME_HELPER("jit_bf16:", avx512_core, ""),
246 _jit_avx512_core_bf16_convolution_bwd_weights_t);
248 virtual status_t init() override {
249 assert(this->engine()->kind() == engine_kind::cpu);
251 && mayiuse(avx512_core)
252 && this->desc()->prop_kind == prop_kind::backward_weights
253 && this->desc()->alg_kind == alg_kind::convolution_direct
254 && !this->has_zero_dim_memory()
255 && this->desc()->src_desc.data_type == data_type::bf16
256 && this->desc()->diff_dst_desc.data_type == data_type::bf16
257 && this->desc()->diff_weights_desc.data_type
259 && IMPLICATION(this->with_bias(),
260 data_type::f32 == this->desc()->diff_bias_desc.data_type);
261 if (!ok) return status::unimplemented;
264 jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf(jcp_,
265 *this->desc(), this->src_pd_, this->diff_weights_pd_,
266 this->diff_bias_pd_, this->diff_dst_pd_);
267 if (status != status::success) return status;
271 auto scratchpad = scratchpad_registry().registrar();
272 jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_scratchpad(
275 auto reducer_bia_scratchpad = memory_tracking::registrar_t(
276 scratchpad, memory_tracking::names::prefix_reducer_bia);
277 reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
281 jit_conv_conf_t jcp_;
282 typename cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
285 memory_format_t src_format()
287 using namespace memory_format;
288 return utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
291 memory_format_t wei_format()
293 using namespace memory_format;
294 return this->with_groups()
295 ? utils::pick(ndims() - 3, gOIw16o16i, gOIhw16o16i,
297 : utils::pick(ndims() - 3, OIw16o16i, OIhw16o16i,
301 virtual status_t set_default_params() override {
302 using namespace memory_format;
304 if (this->src_pd_.desc()->format == any)
305 CHECK(this->src_pd_.set_format(src_format()));
306 if (this->diff_weights_pd_.desc()->format == any)
307 CHECK(this->diff_weights_pd_.set_format(wei_format()));
308 if (this->diff_dst_pd_.desc()->format == any)
309 CHECK(this->diff_dst_pd_.set_format(src_format()));
311 return status::success;
315 void init_balancers() {
316 const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
318 reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
319 jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
325 _jit_avx512_core_bf16_convolution_bwd_weights_t(const pd_t *pd,
326 const input_vector &inputs, const output_vector &outputs);
328 ~_jit_avx512_core_bf16_convolution_bwd_weights_t() {
331 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
332 delete trans_kernel_;
333 delete trans_dst_kernel_;
336 delete reducer_bias_;
339 typedef typename prec_traits<data_type::bf16>::type src_data_t;
340 typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
341 typedef typename prec_traits<diff_weights_type>::type diff_weights_data_t;
343 virtual void execute(event_t *e) const {
344 execute_backward_weights();
345 e->set_state(event_t::ready);
349 struct thread_info_t;
350 void execute_backward_weights() const;
351 void prepare_scratchpad_data() const;
352 void compute_diff_weights(const thread_info_t *) const;
353 void reduce_and_convert_diff_weights(const thread_info_t *) const;
354 void compute_diff_bias(const thread_info_t *) const;
356 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
358 int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_;
360 jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 *kernel_;
362 cpu_accumulator_1d_t<data_type::f32> *acc_ker_;
363 cpu_reducer_t<data_type::f32> *reducer_bias_;
365 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
366 jit_trans_src_t *trans_kernel_;
367 jit_trans_dst_t *trans_dst_kernel_;
371 template <impl::data_type_t diff_src_type>
372 using jit_avx512_core_bf16_convolution_bwd_weights_t =
373 _jit_avx512_core_bf16_convolution_bwd_weights_t<diff_src_type>;
381 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s