Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_1x1_convolution.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP
18 #define CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "utils.hpp"
24
25 #include "cpu_convolution_pd.hpp"
26 #include "cpu_engine.hpp"
27 #include "cpu_reducer.hpp"
28
29 #include "jit_avx512_common_1x1_conv_kernel.hpp"
30 #include "jit_uni_1x1_conv_utils.hpp"
31 #include "jit_transpose_src_utils.hpp"
32
33 namespace mkldnn {
34 namespace impl {
35 namespace cpu {
36
37 template <impl::data_type_t src_type,
38          impl::data_type_t wei_type = src_type,
39          impl::data_type_t dst_type = src_type>
40 struct jit_avx512_common_1x1_convolution_fwd_t : public cpu_primitive_t {
41     // TODO: (Roma) Code duplication duplication! Remove with templates
42     //              (maybe...)!
43     struct pd_t: public cpu_convolution_fwd_pd_t {
44         pd_t(engine_t *engine, const convolution_desc_t *adesc,
45                 const primitive_attr_t *attr,
46                 const typename pd_t::base_class *hint_fwd_pd)
47             : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
48             , jcp_(), rtus_() {}
49
50         DECLARE_COMMON_PD_T(
51                 JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
52                 jit_avx512_common_1x1_convolution_fwd_t);
53
54         virtual status_t init() override {
55             using namespace prop_kind;
56             using namespace utils;
57             assert(this->engine()->kind() == engine_kind::cpu);
58             bool ok = true
59                 && this->set_default_params() == status::success
60                 && utils::one_of(this->desc()->prop_kind, forward_training,
61                         forward_inference)
62                 && utils::one_of(this->desc()->alg_kind,
63                            alg_kind::convolution_auto,
64                            alg_kind::convolution_direct)
65                 && !this->has_zero_dim_memory()
66                 && this->desc()->src_desc.data_type == src_type
67                 && this->desc()->weights_desc.data_type == wei_type
68                 && this->desc()->dst_desc.data_type == dst_type
69                 && IMPLICATION(this->with_bias(),
70                     dst_type == this->desc()->bias_desc.data_type);
71             if (!ok) return status::unimplemented;
72
73             const convolution_desc_t *conv_d = this->desc();
74             const memory_desc_t *src_d = this->src_pd_.desc();
75             rtus_prepare(this, conv_d, src_d, this->dst_pd_.desc());
76
77             status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
78                     jcp_, *conv_d, *src_d, *this->weights_pd_.desc(),
79                     *this->dst_pd_.desc(), *this->attr(),
80                     mkldnn_get_max_threads(), rtus_.reduce_src_);
81             if (status != status::success) return status;
82
83             auto scratchpad = scratchpad_registry().registrar();
84             jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
85                     jcp_);
86
87             rtus_prepare_space_info(this, scratchpad);
88
89             return status::success;
90         }
91
92         jit_1x1_conv_conf_t jcp_;
93         reduce_to_unit_stride_t rtus_;
94
95     protected:
96         virtual status_t set_default_params() override {
97             using namespace memory_format;
98             if (this->src_pd_.desc()->format == any)
99                 CHECK(this->src_pd_.set_format(pick(this->ndims() - 3,
100                     nCw16c, nChw16c)));
101             if (this->dst_pd_.desc()->format == any)
102                 CHECK(this->dst_pd_.set_format(pick(this->ndims() - 3,
103                     nCw16c, nChw16c)));
104             if (this->weights_pd_.desc()->format == any) {
105                 if (dst_type == data_type::f32 && src_type == data_type::f32
106                     && wei_type == data_type::f32)
107                         CHECK(this->weights_pd_.set_format(this->with_groups()
108                             ? pick(this->ndims() - 3, gOIw16i16o, gOIhw16i16o)
109                             : pick(this->ndims() - 3, OIw16i16o, OIhw16i16o)));
110                 else if (dst_type == data_type::s32
111                     && src_type == data_type::s16
112                     && wei_type == data_type::s16)
113                         CHECK(this->weights_pd_.set_format(this->with_groups()
114                             ? pick(this->ndims() - 3, gOIw8i16o2i, gOIhw8i16o2i)
115                             : pick(this->ndims() - 3, OIw8i16o2i, OIhw8i16o2i)));
116             }
117             if (this->bias_pd_.desc()->format == any)
118                 CHECK(this->bias_pd_.set_format(x));
119             if (this->desc()->alg_kind == alg_kind::convolution_auto)
120                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
121             return status::success;
122         }
123     };
124
125     template <cpu_isa_t isa, typename conv_t>
126     friend void init_rtus_driver(conv_t *self);
127
128     jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd,
129             const input_vector &inputs, const output_vector &outputs)
130         : cpu_primitive_t(apd, inputs, outputs)
131         , kernel_(nullptr), rtus_driver_(nullptr)
132     {
133         kernel_ =
134             new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr());
135         init_rtus_driver<avx512_common>(this);
136     }
137
138     ~jit_avx512_common_1x1_convolution_fwd_t() {
139         delete kernel_;
140         delete rtus_driver_;
141     }
142
143     typedef typename prec_traits<src_type>::type src_data_t;
144     typedef typename prec_traits<wei_type>::type wei_data_t;
145     typedef typename prec_traits<dst_type>::type dst_data_t;
146
147     virtual void execute(event_t *e) const {
148         execute_forward();
149         e->set_state(event_t::ready);
150     }
151
152   private:
153     void execute_forward() const;
154     void execute_forward_thr(const int ithr, const int nthr,
155             const src_data_t *src, const wei_data_t *weights,
156             const dst_data_t *bias, dst_data_t *dst,
157             const memory_tracking::grantor_t &scratchpad) const;
158     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
159
160     jit_avx512_common_1x1_conv_kernel *kernel_;
161     rtus_driver_t<avx512_common> *rtus_driver_;
162 };
163
164 using jit_avx512_common_1x1_convolution_fwd_f32_t
165         = jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
166 using jit_avx512_common_1x1_convolution_fwd_s16s16s32_t
167         = jit_avx512_common_1x1_convolution_fwd_t<data_type::s16,
168             data_type::s16, data_type::s32>;
169
170 template <impl::data_type_t diff_dst_type,
171           impl::data_type_t wei_type = diff_dst_type,
172           impl::data_type_t diff_src_type = diff_dst_type>
173 struct jit_avx512_common_1x1_convolution_bwd_data_t : public cpu_primitive_t {
174     struct pd_t : public cpu_convolution_bwd_data_pd_t {
175         pd_t(engine_t *engine,
176                 const convolution_desc_t *adesc,
177                 const primitive_attr_t *attr,
178                 const convolution_fwd_pd_t *hint_fwd_pd)
179             : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
180             , jcp_(), rtus_() {}
181
182         DECLARE_COMMON_PD_T(
183                 JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
184                 jit_avx512_common_1x1_convolution_bwd_data_t);
185
186         virtual status_t init() override {
187             using namespace prop_kind;
188             assert(this->engine()->kind() == engine_kind::cpu);
189             bool ok = true
190                 && this->set_default_params() == status::success
191                 && this->desc()->prop_kind == backward_data
192                 && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
193                            alg_kind::convolution_direct)
194                 && !this->has_zero_dim_memory()
195                 && this->desc()->diff_dst_desc.data_type == diff_dst_type
196                 && this->desc()->weights_desc.data_type == wei_type
197                 && this->desc()->diff_src_desc.data_type == diff_src_type;
198             if (!ok) return status::unimplemented;
199
200             const convolution_desc_t *conv_d = this->desc();
201             const memory_desc_t *diff_src_d = this->diff_src_pd_.desc();
202             rtus_prepare(this, conv_d, diff_src_d, this->diff_dst_pd_.desc());
203
204             status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
205                     jcp_, *conv_d, *diff_src_d, *this->weights_pd_.desc(),
206                     *this->diff_dst_pd_.desc(), *this->attr(),
207                     mkldnn_get_max_threads(), rtus_.reduce_src_);
208             if (status != status::success) return status;
209
210             auto scratchpad = scratchpad_registry().registrar();
211             jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
212                     jcp_);
213
214             rtus_prepare_space_info(this, scratchpad);
215
216             return status::success;
217         }
218
219         // TODO (Roma): structs conf header cleanup
220         jit_1x1_conv_conf_t jcp_;
221         reduce_to_unit_stride_t rtus_;
222
223     protected:
224         virtual status_t set_default_params() override {
225             using namespace memory_format;
226
227             if (this->diff_src_pd_.desc()->format == any)
228                 CHECK(this->diff_src_pd_.set_format(pick(this->ndims() - 3,
229                     nCw16c, nChw16c)));
230             if (this->diff_dst_pd_.desc()->format == any)
231                 CHECK(this->diff_dst_pd_.set_format(pick(this->ndims() - 3,
232                    nCw16c, nChw16c)));
233             if (this->weights_pd_.desc()->format == any) {
234                 if (diff_dst_type == data_type::f32
235                     && diff_src_type == data_type::f32
236                     && wei_type == data_type::f32) {
237                     CHECK(this->weights_pd_.set_format(this->with_groups()
238                         ? pick(this->ndims() - 3, gIOw16o16i, gIOhw16o16i)
239                         : pick(this->ndims() - 3, IOw16o16i, IOhw16o16i)));
240                 }
241                 else if (diff_dst_type == data_type::s16
242                     && diff_src_type == data_type::s32
243                     && wei_type == data_type::s16)
244                         CHECK(this->weights_pd_.set_format(this->with_groups()
245                             ? pick(this->ndims() - 3, gOIw8o16i2o, gOIhw8o16i2o)
246                             : pick(this->ndims() - 3, OIw8o16i2o, OIhw8o16i2o)));
247             }
248             if (this->desc()->alg_kind == alg_kind::convolution_auto)
249                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
250
251             return status::success;
252         }
253     };
254
255     template <cpu_isa_t isa, typename conv_t>
256     friend void init_rtus_driver(conv_t *self);
257
258     jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd,
259             const input_vector &inputs, const output_vector &outputs)
260         : cpu_primitive_t(apd, inputs, outputs)
261         , kernel_(nullptr), rtus_driver_(nullptr)
262     {
263         kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_,
264                     *pd()->attr());
265         init_rtus_driver<avx512_common>(this);
266     }
267
268     ~jit_avx512_common_1x1_convolution_bwd_data_t() {
269         delete kernel_;
270         delete rtus_driver_;
271     }
272
273     typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
274     typedef typename prec_traits<wei_type>::type wei_data_t;
275     typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
276
277     virtual void execute(event_t *e) const {
278         switch (pd()->desc()->prop_kind) {
279         case prop_kind::backward_data:
280             execute_backward_data();
281             break;
282         default:
283             assert(!"invalid prop_kind");
284         }
285         e->set_state(event_t::ready);
286     }
287
288   private:
289     void execute_backward_data() const;
290     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
291
292     jit_avx512_common_1x1_conv_kernel *kernel_;
293     rtus_driver_t<avx512_common> *rtus_driver_;
294 };
295
296 using jit_avx512_common_1x1_convolution_bwd_data_f32_t
297         = jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
298 using jit_avx512_common_1x1_convolution_bwd_data_s16s16s32_t
299         = jit_avx512_common_1x1_convolution_bwd_data_t<data_type::s16,
300             data_type::s16, data_type::s32>;
301
302 struct jit_avx512_common_1x1_convolution_bwd_weights_t : public cpu_primitive_t
303 {
304     struct pd_t : public cpu_convolution_bwd_weights_pd_t {
305         pd_t(engine_t *engine,
306                 const convolution_desc_t *adesc,
307                 const primitive_attr_t *attr,
308                 const convolution_fwd_pd_t *hint_fwd_pd)
309             : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
310             , jcp_(), rtus_() {}
311
312         DECLARE_COMMON_PD_T(
313                 JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
314                 jit_avx512_common_1x1_convolution_bwd_weights_t);
315
316         virtual status_t init() override {
317             using namespace prop_kind;
318             assert(this->engine()->kind() == engine_kind::cpu);
319             bool ok = true
320                 && this->set_default_params() == status::success
321                 && this->desc()->prop_kind == backward_weights
322                 && utils::one_of(this->desc()->alg_kind,
323                            alg_kind::convolution_auto,
324                            alg_kind::convolution_direct)
325                 && !this->has_zero_dim_memory()
326                 && utils::everyone_is(data_type::f32,
327                         this->desc()->src_desc.data_type,
328                         this->desc()->diff_weights_desc.data_type,
329                         this->desc()->diff_dst_desc.data_type)
330                 && IMPLICATION(this->with_bias(),
331                         data_type::f32 == desc()->diff_bias_desc.data_type);
332             if (!ok) return status::unimplemented;
333
334             const convolution_desc_t *conv_d = this->desc();
335             const memory_desc_t *src_d = this->src_pd_.desc();
336             rtus_prepare(this, conv_d, src_d, this->diff_dst_pd_.desc());
337
338             status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
339                     jcp_, *conv_d, *src_d, *this->diff_weights_pd_.desc(),
340                     *this->diff_dst_pd_.desc(), *this->attr(),
341                     mkldnn_get_max_threads(), rtus_.reduce_src_);
342             if (status != status::success) return status;
343
344             init_balancers();
345
346             auto scratchpad = scratchpad_registry().registrar();
347             jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
348                     jcp_);
349
350             auto reducer_bia_scratchpad = memory_tracking::registrar_t(
351                     scratchpad, memory_tracking::names::prefix_reducer_bia);
352             reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
353
354             rtus_prepare_space_info(this, scratchpad);
355
356             return status::success;
357         }
358
359         // TODO (Roma): structs conf header cleanup
360         jit_1x1_conv_conf_t jcp_;
361         cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
362         reduce_to_unit_stride_t rtus_;
363
364     protected:
365         virtual status_t set_default_params() override {
366             using namespace memory_format;
367
368             if (this->src_pd_.desc()->format == any)
369                 CHECK(this->src_pd_.set_format(pick(this->ndims() - 3,
370                     nCw16c, nChw16c)));
371             if (this->diff_dst_pd_.desc()->format == any)
372                 CHECK(this->diff_dst_pd_.set_format(pick(this->ndims() - 3,
373                     nCw16c, nChw16c)));
374             if (this->diff_weights_pd_.desc()->format == any)
375                 CHECK(this->diff_weights_pd_.set_format(this->with_groups()
376                     ? pick(this->ndims() - 3, gOIw16i16o, gOIhw16i16o)
377                     : pick(this->ndims() - 3, OIw16i16o, OIhw16i16o)));
378             if (this->diff_bias_pd_.desc()->format == any)
379                 CHECK(this->diff_bias_pd_.set_format(x));
380             if (this->desc()->alg_kind == alg_kind::convolution_auto)
381                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
382             return status::success;
383         }
384
385     private:
386         void init_balancers() {
387             const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
388             if (with_bias()) {
389                 reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
390                             jcp_.oc_block, jcp_.ngroups * jcp_.nb_load,
391                             jcp_.mb, max_buffer_size));
392             }
393         }
394     };
395
396     template <cpu_isa_t isa, typename conv_t>
397     friend void init_rtus_driver(conv_t *self);
398
399     jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd,
400             const input_vector &inputs, const output_vector &outputs);
401
402     ~jit_avx512_common_1x1_convolution_bwd_weights_t() {
403         delete kernel_;
404         delete acc_ker_;
405         delete reducer_bias_;
406         delete rtus_driver_;
407         delete trans_kernel_;
408     }
409
410     typedef typename prec_traits<data_type::f32>::type data_t;
411
412     virtual void execute(event_t *e) const {
413         switch (pd()->desc()->prop_kind) {
414         case prop_kind::backward_weights:
415             execute_backward_weights();
416             break;
417         default:
418             assert(!"invalid prop_kind");
419         }
420         e->set_state(event_t::ready);
421     }
422
423   private:
424     void execute_backward_weights() const;
425     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
426
427     jit_avx512_common_1x1_conv_kernel *kernel_;
428     cpu_accumulator_1d_t<data_type::f32> *acc_ker_;
429     cpu_reducer_t<data_type::f32> *reducer_bias_;
430     jit_transpose4x16_src *trans_kernel_;
431     rtus_driver_t<avx512_common> *rtus_driver_;
432 };
433
434 }
435 }
436 }
437
438 #endif