updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_bf16_convolution.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_AVX512_CORE_BF16_CONVOLUTION_HPP
18 #define CPU_JIT_AVX512_CORE_BF16_CONVOLUTION_HPP
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "utils.hpp"
24
25 #include "cpu_barrier.hpp"
26 #include "cpu_convolution_pd.hpp"
27 #include "cpu_reducer.hpp"
28
29 #include "jit_transpose_src_utils.hpp"
30 #include "jit_avx512_core_bf16_conv_kernel.hpp"
31 #include "bfloat16_utils.hpp"
32
33 namespace mkldnn {
34 namespace impl {
35 namespace cpu {
36
37 template <impl::data_type_t dst_type>
38 struct _jit_avx512_core_bf16_convolution_fwd_t : public cpu_primitive_t {
39     struct pd_t : public cpu_convolution_fwd_pd_t {
40         pd_t(engine_t *engine, const convolution_desc_t *adesc,
41                 const primitive_attr_t *attr,
42                 const typename pd_t::base_class *hint_fwd_pd)
43             : cpu_convolution_fwd_pd_t(engine, adesc, attr,
44                     hint_fwd_pd)
45             , jcp_()
46         {
47         }
48
49         DECLARE_COMMON_PD_T(
50                 JIT_IMPL_NAME_HELPER("jit_bf16:", avx512_core, ""),
51                 _jit_avx512_core_bf16_convolution_fwd_t<dst_type>);
52
53         virtual status_t init() override
54         {
55             using namespace prop_kind;
56             assert(this->engine()->kind() == engine_kind::cpu);
57             bool ok = true
58                     && mayiuse(avx512_core)
59                     && utils::one_of(this->desc()->prop_kind, forward_training,
60                                forward_inference)
61                     && utils::one_of(this->desc()->alg_kind,
62                                alg_kind::convolution_auto,
63                                alg_kind::convolution_direct)
64                     && !this->has_zero_dim_memory()
65                     && this->desc()->src_desc.data_type == data_type::bf16
66                     && this->desc()->weights_desc.data_type == data_type::bf16
67                     && this->desc()->dst_desc.data_type == dst_type
68                     && IMPLICATION(this->with_bias(),
69                         data_type::f32 == this->desc()->bias_desc.data_type);
70             if (!ok)
71                 return status::unimplemented;
72
73             status_t status = jit_avx512_core_bf16_fwd_kernel::init_conf(
74                     jcp_, *this->desc(), this->src_pd_, this->weights_pd_,
75                     this->dst_pd_, this->bias_pd_, *this->attr(),
76                     mkldnn_get_max_threads());
77             if (status != status::success) return status;
78
79             if (status == status::success
80                     && this->desc()->alg_kind == alg_kind::convolution_auto)
81                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
82
83             init_scratchpad();
84
85             return status::success;
86         }
87
88         inline int ndims() { return this->desc()->src_desc.ndims; }
89
90         jit_conv_conf_t jcp_;
91
92         private:
93
94             void init_scratchpad() {
95                 using namespace memory_tracking::names;
96                 auto scratchpad = scratchpad_registry().registrar();
97                 if (jcp_.with_bias && jcp_.oc != jcp_.oc_without_padding)
98                     scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp_.oc);
99             }
100     };
101
102     _jit_avx512_core_bf16_convolution_fwd_t(const pd_t *apd,
103             const input_vector &inputs, const output_vector &outputs)
104         : cpu_primitive_t(apd, inputs, outputs)
105     {
106         kernel_ = new jit_avx512_core_bf16_fwd_kernel(pd()->jcp_,
107                     *pd()->attr());
108     }
109     ~_jit_avx512_core_bf16_convolution_fwd_t() { delete kernel_;}
110
111     typedef typename prec_traits<data_type::bf16>::type src_data_t;
112     typedef typename prec_traits<data_type::bf16>::type wei_data_t;
113     typedef typename prec_traits<dst_type>::type dst_data_t;
114
115     virtual void execute(event_t *e) const {
116         execute_forward();
117         e->set_state(event_t::ready);
118     }
119
120 private:
121     void execute_forward() const;
122     void prepare_padded_bias(const float *&bias) const;
123     jit_avx512_core_bf16_fwd_kernel *kernel_;
124     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
125 };
126
127 template <impl::data_type_t dst_type>
128 using jit_avx512_core_bf16_convolution_fwd_t =
129     _jit_avx512_core_bf16_convolution_fwd_t<dst_type>;
130
131 template <impl::data_type_t diff_src_type>
132 struct _jit_avx512_core_bf16_convolution_bwd_data_t: public cpu_primitive_t {
133     struct pd_t: public cpu_convolution_bwd_data_pd_t {
134         pd_t(engine_t *engine,
135                 const convolution_desc_t *adesc,
136                 const primitive_attr_t *attr,
137                 const convolution_fwd_pd_t *hint_fwd_pd)
138             : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
139             , jcp_()
140         {}
141
142         DECLARE_COMMON_PD_T(
143                 JIT_IMPL_NAME_HELPER("jit_bf16:", avx512_core, ""),
144                 _jit_avx512_core_bf16_convolution_bwd_data_t<diff_src_type>);
145
146         virtual status_t init() override {
147             using namespace prop_kind;
148             assert(this->engine()->kind() == engine_kind::cpu);
149             bool ok = true
150                 && mayiuse(avx512_core)
151                 && this->set_default_params() == status::success
152                 && utils::one_of(this->desc()->prop_kind, backward_data) // XXX (this->!)
153                 && utils::one_of(this->desc()->alg_kind,
154                            alg_kind::convolution_auto,
155                            alg_kind::convolution_direct)
156                 && !this->has_zero_dim_memory()
157                 && this->desc()->alg_kind == alg_kind::convolution_direct
158                 && this->desc()->diff_dst_desc.data_type == data_type::bf16
159                 && this->desc()->weights_desc.data_type == data_type::bf16
160                 && this->desc()->diff_src_desc.data_type == diff_src_type;
161             if (!ok) return status::unimplemented;
162
163             status_t status = jit_avx512_core_bf16_bwd_data_kernel::init_conf(
164                     jcp_, *this->desc(), *this->diff_src_pd_.desc(),
165                     *this->weights_pd_.desc(), *this->diff_dst_pd_.desc());
166             if (status != status::success) return status;
167
168             if (status == status::success
169                     && this->desc()->alg_kind == alg_kind::convolution_auto)
170                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
171
172             return status::success;
173         }
174
175         inline int ndims() { return this->desc()->diff_src_desc.ndims; }
176
177         inline memory_format_t src_format()
178         {
179             using namespace memory_format;
180             return utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
181         }
182         inline memory_format_t wei_format()
183         {
184             using namespace memory_format;
185             return  this->with_groups()
186                 ? utils::pick(ndims() - 3, gOIw8o16i2o,
187                               gOIhw8o16i2o, gOIdhw8o16i2o)
188                 : utils::pick(ndims() - 3, OIw8o16i2o,
189                               OIhw8o16i2o, OIdhw8o16i2o);
190         }
191
192         jit_conv_conf_t jcp_;
193
194     protected:
195         virtual status_t set_default_params() override {
196             using namespace memory_format;
197
198             if (this->diff_src_pd_.desc()->format == any)
199                 CHECK(this->diff_src_pd_.set_format(src_format()));
200             if (this->diff_dst_pd_.desc()->format == any)
201                 CHECK(this->diff_dst_pd_.set_format(src_format()));
202             if (this->weights_pd_.desc()->format == any)
203                 CHECK(this->weights_pd_.set_format(wei_format()));
204             return status::success;
205         }
206     };
207
208     _jit_avx512_core_bf16_convolution_bwd_data_t(const pd_t *apd,
209             const input_vector &inputs, const output_vector &outputs)
210         : cpu_primitive_t(apd, inputs, outputs)
211     {
212         kernel_ = new jit_avx512_core_bf16_bwd_data_kernel(pd()->jcp_);
213     }
214     ~_jit_avx512_core_bf16_convolution_bwd_data_t() { delete kernel_; };
215
216     typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
217     typedef typename prec_traits<data_type::bf16>::type wei_data_t;
218     typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
219
220     virtual void execute(event_t *e) const {
221         execute_backward_data();
222         e->set_state(event_t::ready);
223     }
224
225 private:
226     void execute_backward_data() const;
227     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
228     jit_avx512_core_bf16_bwd_data_kernel *kernel_;
229 };
230
231 template <impl::data_type_t diff_src_type>
232 using jit_avx512_core_bf16_convolution_bwd_data_t =
233     _jit_avx512_core_bf16_convolution_bwd_data_t<diff_src_type>;
234
235 template <impl::data_type_t diff_weights_type>
236 struct _jit_avx512_core_bf16_convolution_bwd_weights_t: public cpu_primitive_t {
237     struct pd_t: public  cpu_convolution_bwd_weights_pd_t {
238         pd_t(engine_t *engine, const convolution_desc_t *adesc,
239                 const primitive_attr_t *attr,
240                 const convolution_fwd_pd_t *hint_fwd_pd)
241             : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
242             , jcp_() {}
243
244         DECLARE_COMMON_PD_T(
245                 JIT_IMPL_NAME_HELPER("jit_bf16:", avx512_core, ""),
246                 _jit_avx512_core_bf16_convolution_bwd_weights_t);
247
248         virtual status_t init() override {
249             assert(this->engine()->kind() == engine_kind::cpu);
250             bool ok = true
251                 && mayiuse(avx512_core)
252                 && this->desc()->prop_kind == prop_kind::backward_weights
253                 && this->desc()->alg_kind == alg_kind::convolution_direct
254                 && !this->has_zero_dim_memory()
255                 && this->desc()->src_desc.data_type == data_type::bf16
256                 && this->desc()->diff_dst_desc.data_type == data_type::bf16
257                 && this->desc()->diff_weights_desc.data_type
258                     == diff_weights_type
259                 && IMPLICATION(this->with_bias(),
260                         data_type::f32 == this->desc()->diff_bias_desc.data_type);
261            if (!ok) return status::unimplemented;
262
263             status_t status =
264                 jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf(jcp_,
265                         *this->desc(), this->src_pd_, this->diff_weights_pd_,
266                         this->diff_bias_pd_, this->diff_dst_pd_);
267             if (status != status::success) return status;
268
269             init_balancers();
270
271             auto scratchpad = scratchpad_registry().registrar();
272             jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_scratchpad(
273                     scratchpad, jcp_);
274
275             auto reducer_bia_scratchpad = memory_tracking::registrar_t(
276                     scratchpad, memory_tracking::names::prefix_reducer_bia);
277             reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
278             return status;
279         }
280
281         jit_conv_conf_t jcp_;
282         typename cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
283
284     protected:
285         memory_format_t src_format()
286         {
287             using namespace memory_format;
288             return utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
289         }
290
291         memory_format_t wei_format()
292         {
293             using namespace memory_format;
294             return this->with_groups()
295                 ? utils::pick(ndims() - 3, gOIw16o16i, gOIhw16o16i,
296                       gOIdhw16o16i)
297                 : utils::pick(ndims() - 3, OIw16o16i, OIhw16o16i,
298                       OIdhw16o16i);
299         }
300
301         virtual status_t set_default_params() override {
302             using namespace memory_format;
303
304             if (this->src_pd_.desc()->format == any)
305                 CHECK(this->src_pd_.set_format(src_format()));
306             if (this->diff_weights_pd_.desc()->format == any)
307                 CHECK(this->diff_weights_pd_.set_format(wei_format()));
308             if (this->diff_dst_pd_.desc()->format == any)
309                 CHECK(this->diff_dst_pd_.set_format(src_format()));
310
311             return status::success;
312         }
313
314     private:
315         void init_balancers() {
316             const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
317             if (with_bias()) {
318                 reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
319                             jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
320                             max_buffer_size));
321             }
322         }
323     };
324
325     _jit_avx512_core_bf16_convolution_bwd_weights_t(const pd_t *pd,
326             const input_vector &inputs, const output_vector &outputs);
327
328     ~_jit_avx512_core_bf16_convolution_bwd_weights_t() {
329
330         delete kernel_;
331 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
332         delete trans_kernel_;
333         delete trans_dst_kernel_;
334 #endif
335         delete acc_ker_;
336         delete reducer_bias_;
337     }
338
339     typedef typename prec_traits<data_type::bf16>::type src_data_t;
340     typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
341     typedef typename prec_traits<diff_weights_type>::type diff_weights_data_t;
342
343     virtual void execute(event_t *e) const {
344         execute_backward_weights();
345         e->set_state(event_t::ready);
346     }
347
348 private:
349     struct thread_info_t;
350     void execute_backward_weights() const;
351     void prepare_scratchpad_data() const;
352     void compute_diff_weights(const thread_info_t *) const;
353     void reduce_and_convert_diff_weights(const thread_info_t *) const;
354     void compute_diff_bias(const thread_info_t *) const;
355
356     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
357
358     int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_;
359
360     jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 *kernel_;
361
362     cpu_accumulator_1d_t<data_type::f32> *acc_ker_;
363     cpu_reducer_t<data_type::f32> *reducer_bias_;
364
365 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
366     jit_trans_src_t *trans_kernel_;
367     jit_trans_dst_t *trans_dst_kernel_;
368 #endif
369 };
370
371 template <impl::data_type_t diff_src_type>
372 using jit_avx512_core_bf16_convolution_bwd_weights_t =
373     _jit_avx512_core_bf16_convolution_bwd_weights_t<diff_src_type>;
374
375 }
376 }
377 }
378
379 #endif
380
381 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s