Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_convolution.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP
18 #define CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "utils.hpp"
24
25 #include "cpu_barrier.hpp"
26 #include "cpu_convolution_pd.hpp"
27 #include "cpu_reducer.hpp"
28
29 #include "jit_transpose_src_utils.hpp"
30 #include "jit_avx512_common_conv_kernel.hpp"
31
32 namespace mkldnn {
33 namespace impl {
34 namespace cpu {
35
36 template <impl::data_type_t src_type,
37          impl::data_type_t wei_type = src_type,
38          impl::data_type_t dst_type = src_type>
39 struct jit_avx512_common_convolution_fwd_t : public cpu_primitive_t {
40     struct pd_t : public cpu_convolution_fwd_pd_t {
41         pd_t(engine_t *engine, const convolution_desc_t *adesc,
42                 const primitive_attr_t *attr,
43                 const typename pd_t::base_class *hint_fwd_pd)
44             : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
45             , jcp_()
46         {
47         }
48
49         DECLARE_COMMON_PD_T(
50                 JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
51                 jit_avx512_common_convolution_fwd_t);
52
53         virtual status_t init() override
54         {
55             using namespace prop_kind;
56             assert(this->engine()->kind() == engine_kind::cpu);
57             bool ok = true
58                     && utils::one_of(this->desc()->prop_kind, forward_training,
59                                forward_inference)
60                     && utils::one_of(this->desc()->alg_kind,
61                                alg_kind::convolution_auto,
62                                alg_kind::convolution_direct)
63                     && !this->has_zero_dim_memory()
64                     && this->desc()->src_desc.data_type == src_type
65                     && this->desc()->weights_desc.data_type == wei_type
66                     && this->desc()->dst_desc.data_type == dst_type
67                     && IMPLICATION(this->with_bias(), dst_type
68                                        == this->desc()->bias_desc.data_type);
69             if (!ok)
70                 return status::unimplemented;
71
72             status_t status = jit_avx512_common_conv_fwd_kernel::init_conf(
73                     jcp_, *this->desc(), this->src_pd_, this->weights_pd_,
74                     this->dst_pd_,this->bias_pd_, *this->attr(),
75                     mkldnn_get_max_threads());
76             if (status != status::success) return status;
77
78             auto scratchpad = scratchpad_registry().registrar();
79             jit_avx512_common_conv_fwd_kernel::init_scratchpad(scratchpad,
80                     jcp_);
81
82             if (status == status::success
83                     && this->desc()->alg_kind == alg_kind::convolution_auto)
84                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
85             return status;
86         }
87
88         jit_conv_conf_t jcp_;
89     };
90
91     jit_avx512_common_convolution_fwd_t(const pd_t *apd,
92             const input_vector &inputs, const output_vector &outputs)
93         : cpu_primitive_t(apd, inputs, outputs)
94     {
95         kernel_ = new jit_avx512_common_conv_fwd_kernel(pd()->jcp_,
96                     *pd()->attr());
97     }
98     ~jit_avx512_common_convolution_fwd_t() { delete kernel_; }
99
100     typedef typename prec_traits<src_type>::type src_data_t;
101     typedef typename prec_traits<wei_type>::type wei_data_t;
102     typedef typename prec_traits<dst_type>::type dst_data_t;
103
104     virtual void execute(event_t *e) const
105     {
106         if (pd()->ndims() == 3)
107             execute_forward_1d();
108         else if (pd()->ndims() == 4)
109             execute_forward_2d();
110         else if (pd()->ndims() == 5)
111             execute_forward_3d();
112         else
113             assert(false);
114
115         if (pd()->wants_zero_pad_dst())
116             output_memory_primitive(0)->zero_pad();
117
118         e->set_state(event_t::ready);
119     }
120
121 private:
122     void prepare_padded_bias(const dst_data_t *&bias) const;
123     void execute_forward_1d() const;
124     void execute_forward_2d() const;
125     void execute_forward_3d() const;
126     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
127
128     jit_avx512_common_conv_fwd_kernel *kernel_;
129 };
130
131 template <impl::data_type_t diff_dst_type,
132           impl::data_type_t wei_type = diff_dst_type,
133           impl::data_type_t diff_src_type = diff_dst_type>
134 struct jit_avx512_common_convolution_bwd_data_t: public cpu_primitive_t {
135     struct pd_t: public cpu_convolution_bwd_data_pd_t {
136         pd_t(engine_t *engine,
137                 const convolution_desc_t *adesc,
138                 const primitive_attr_t *attr,
139                 const convolution_fwd_pd_t *hint_fwd_pd)
140             : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
141             , jcp_()
142         {}
143
144         DECLARE_COMMON_PD_T(
145                 JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
146                 jit_avx512_common_convolution_bwd_data_t);
147
148         virtual status_t init() override {
149             using namespace prop_kind;
150             assert(this->engine()->kind() == engine_kind::cpu);
151             bool ok = true
152                 && this->set_default_params() == status::success
153                 && utils::one_of(this->desc()->prop_kind, backward_data) // XXX (this->!)
154                 && utils::one_of(this->desc()->alg_kind,
155                            alg_kind::convolution_auto,
156                            alg_kind::convolution_direct)
157                 && !this->has_zero_dim_memory()
158                 && this->desc()->diff_dst_desc.data_type == diff_dst_type
159                 && this->desc()->weights_desc.data_type == wei_type
160                 && this->desc()->diff_src_desc.data_type == diff_src_type;
161             if (!ok) return status::unimplemented;
162
163             status_t status =
164                 jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(jcp_,
165                         *this->desc(), *this->diff_src_pd_.desc(),
166                         *this->weights_pd_.desc(), *this->diff_dst_pd_.desc());
167             if (status != status::success) return status;
168
169             auto scratchpad = scratchpad_registry().registrar();
170             jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
171                     scratchpad, jcp_);
172
173             return status::success;
174         }
175
176         inline memory_format_t src_format()
177         {
178             using namespace memory_format;
179             return utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
180         }
181         inline memory_format_t wei_format()
182         {
183             using namespace memory_format;
184             if (diff_dst_type == data_type::s16
185                 && diff_src_type == data_type::s32
186                 && wei_type == data_type::s16) {
187                 return  this->with_groups() ? gOIhw8o16i2o : OIhw8o16i2o;
188             } else {
189                 return this->with_groups()
190                     ? utils::pick(ndims() - 3, gOIw16o16i, gOIhw16o16i,
191                           gOIdhw16o16i)
192                     : utils::pick(ndims() - 3, OIw16o16i, OIhw16o16i,
193                           OIdhw16o16i);
194             }
195         }
196
197         jit_conv_conf_t jcp_;
198
199     protected:
200         virtual status_t set_default_params() override {
201             using namespace memory_format;
202
203             if (this->diff_src_pd_.desc()->format == any)
204                 CHECK(this->diff_src_pd_.set_format(src_format()));
205             if (this->diff_dst_pd_.desc()->format == any)
206                 CHECK(this->diff_dst_pd_.set_format(src_format()));
207             if (this->weights_pd_.desc()->format == any)
208                 CHECK(this->weights_pd_.set_format(wei_format()));
209             if (this->desc()->alg_kind == alg_kind::convolution_auto)
210                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
211             return status::success;
212         }
213     };
214
215     jit_avx512_common_convolution_bwd_data_t(const pd_t *apd,
216             const input_vector &inputs, const output_vector &outputs)
217         : cpu_primitive_t(apd, inputs, outputs)
218     { kernel_ = new jit_avx512_common_conv_bwd_data_kernel_f32(pd()->jcp_); }
219     ~jit_avx512_common_convolution_bwd_data_t() { delete kernel_; };
220
221     typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
222     typedef typename prec_traits<wei_type>::type wei_data_t;
223     typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
224
225     virtual void execute(event_t *e) const {
226         switch (pd()->desc()->prop_kind) {
227         case prop_kind::backward_data:
228             if (pd()->ndims() == 3)
229                 execute_backward_data_1d();
230             else if (pd()->ndims() == 4)
231                 execute_backward_data_2d();
232             else if (pd()->ndims() == 5)
233                 execute_backward_data_3d();
234             else
235                 assert(false);
236             break;
237         default:
238             assert(!"invalid prop_kind");
239         }
240         e->set_state(event_t::ready);
241     }
242
243 private:
244     void execute_backward_data_1d() const;
245     void execute_backward_data_2d() const;
246     void execute_backward_data_3d() const;
247     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
248
249     jit_avx512_common_conv_bwd_data_kernel_f32 *kernel_;
250 };
251
252 template <impl::data_type_t src_type,
253           impl::data_type_t diff_dst_type = src_type,
254           impl::data_type_t diff_weights_type = src_type>
255 struct jit_avx512_common_convolution_bwd_weights_t: public cpu_primitive_t {
256     struct pd_t: public  cpu_convolution_bwd_weights_pd_t {
257         pd_t(engine_t *engine, const convolution_desc_t *adesc,
258                 const primitive_attr_t *attr,
259                 const convolution_fwd_pd_t *hint_fwd_pd)
260             : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
261             , jcp_() {}
262
263         DECLARE_COMMON_PD_T(
264                 JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
265                 jit_avx512_common_convolution_bwd_weights_t);
266
267         virtual status_t init() override {
268             assert(this->engine()->kind() == engine_kind::cpu);
269             bool ok = true
270                 && this->desc()->prop_kind == prop_kind::backward_weights
271                 && utils::one_of(this->desc()->alg_kind,
272                            alg_kind::convolution_auto,
273                            alg_kind::convolution_direct)
274                 && !this->has_zero_dim_memory()
275                 && this->desc()->src_desc.data_type == src_type
276                 && this->desc()->diff_dst_desc.data_type == diff_dst_type
277                 && this->desc()->diff_weights_desc.data_type
278                     == diff_weights_type;
279             if (!ok) return status::unimplemented;
280
281             status_t status =
282                 jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(jcp_,
283                         *this->desc(), this->src_pd_, this->diff_weights_pd_,
284                         this->diff_bias_pd_, this->diff_dst_pd_);
285             if (status != status::success) return status;
286
287             init_balancers();
288
289             auto scratchpad = scratchpad_registry().registrar();
290             jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
291                     scratchpad, jcp_);
292
293             auto reducer_bia_scratchpad = memory_tracking::registrar_t(
294                     scratchpad, memory_tracking::names::prefix_reducer_bia);
295             reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
296
297             if (status == status::success &&
298                     this->desc()->alg_kind == alg_kind::convolution_auto)
299                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
300             return status;
301         }
302
303         inline memory_format_t src_format()
304         {
305             using namespace memory_format;
306             return utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
307         }
308         inline memory_format_t wei_format()
309         {
310             using namespace memory_format;
311             return this->with_groups()
312                 ? utils::pick(ndims() - 3, gOIw16o16i, gOIhw16o16i,
313                       gOIdhw16o16i)
314                 : utils::pick(ndims() - 3, OIw16o16i, OIhw16o16i,
315                       OIdhw16o16i);
316         }
317
318         jit_conv_conf_t jcp_;
319         typename cpu_reducer_t<diff_weights_type>::conf_t reducer_bia_conf_;
320
321     protected:
322         virtual status_t set_default_params() override {
323             using namespace memory_format;
324
325             if (this->src_pd_.desc()->format == any)
326                 CHECK(this->src_pd_.set_format(src_format()));
327             if (this->diff_weights_pd_.desc()->format == any)
328                 CHECK(this->diff_weights_pd_.set_format(wei_format()));
329             if (this->diff_dst_pd_.desc()->format == any)
330                 CHECK(this->diff_dst_pd_.set_format(src_format()));
331
332             return status::success;
333         }
334
335     private:
336         void init_balancers() {
337             const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
338             if (with_bias()) {
339                 reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
340                             jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
341                             max_buffer_size));
342             }
343         }
344     };
345
346     jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd,
347             const input_vector &inputs, const output_vector &outputs);
348     ~jit_avx512_common_convolution_bwd_weights_t() {
349         delete kernel_;
350         if (trans_kernel_)
351             delete trans_kernel_;
352         if (trans_dst_kernel_)
353             delete trans_dst_kernel_;
354         if (acc_ker_)
355             delete acc_ker_;
356         delete reducer_bias_;
357     }
358
359     typedef typename prec_traits<src_type>::type src_data_t;
360     typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
361     typedef typename prec_traits<diff_weights_type>::type diff_weights_data_t;
362
363     virtual void execute(event_t *e) const {
364         execute_backward_weights();
365         e->set_state(event_t::ready);
366     }
367
368 private:
369     void execute_backward_weights() const;
370     void prepare_scratchpad_data() const;
371     struct thread_info_t;
372     void compute_diff_weights(const thread_info_t *) const;
373     void compute_diff_weights_3d(const thread_info_t *) const;
374     void reduce_diff_weights(const thread_info_t *) const;
375     void reduce_diff_weights_3d(const thread_info_t *) const;
376     void compute_diff_bias(const thread_info_t *) const;
377     void compute_diff_bias_3d(const thread_info_t *) const;
378
379     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
380
381     int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_;
382
383     jit_avx512_common_conv_bwd_weights_kernel_f32 *kernel_;
384     jit_trans_src_t *trans_kernel_;
385     jit_trans_dst_t *trans_dst_kernel_;
386     cpu_accumulator_1d_t<diff_weights_type> *acc_ker_;
387     cpu_reducer_t<diff_weights_type> *reducer_bias_;
388 };
389
390 }
391 }
392 }
393
394 #endif
395
396 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s