Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx2_1x1_convolution.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_AVX2_1x1_CONVOLUTION_HPP
18 #define CPU_JIT_AVX2_1x1_CONVOLUTION_HPP
19
20 #include <common/primitive_attr.hpp>
21 #include "c_types_map.hpp"
22 #include "memory_tracking.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "utils.hpp"
25
26 #include "cpu_convolution_pd.hpp"
27 #include "cpu_engine.hpp"
28 #include "cpu_reducer.hpp"
29
30 #include "jit_avx2_1x1_conv_kernel_f32.hpp"
31 #include "jit_uni_1x1_conv_utils.hpp"
32
33 #include "jit_uni_depthwise.hpp"
34
35 namespace mkldnn {
36 namespace impl {
37 namespace cpu {
38
39 struct jit_avx2_1x1_convolution_fwd_t: public cpu_primitive_t {
40     // TODO: (Roma) Code duplication duplication! Remove with templates
41     //              (maybe...)!
42     struct pd_t: public cpu_convolution_fwd_pd_t {
43         pd_t(engine_t *engine, const convolution_desc_t *adesc,
44                 const primitive_attr_t *attr,
45                 const typename pd_t::base_class *hint_fwd_pd)
46             : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
47             , jcp_(), jcp_dw_(), rtus_() {}
48
49         DECLARE_COMMON_PD_T(
50                 JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
51                 jit_avx2_1x1_convolution_fwd_t);
52
53         virtual status_t init() override {
54             using namespace prop_kind;
55             assert(this->engine()->kind() == engine_kind::cpu);
56             bool ok = true
57                 && this->set_default_params() == status::success
58                 && utils::one_of(this->desc()->prop_kind, forward_training,
59                         forward_inference)
60                 && utils::one_of(this->desc()->alg_kind,
61                         alg_kind::convolution_auto,
62                         alg_kind::convolution_direct)
63                 && !this->has_zero_dim_memory()
64                 && utils::everyone_is(data_type::f32,
65                         this->desc()->src_desc.data_type,
66                         this->desc()->weights_desc.data_type,
67                         this->desc()->dst_desc.data_type)
68                 && IMPLICATION(this->with_bias(),
69                         data_type::f32 == this->desc()->bias_desc.data_type);
70             if (!ok) return status::unimplemented;
71
72             const convolution_desc_t *conv_d = this->desc();
73             const memory_desc_t *src_d = this->src_pd_.desc();
74             rtus_prepare(this, conv_d, src_d, this->dst_pd_.desc());
75
76             status_t sts_1x1 = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
77                     *conv_d, *src_d, *this->weights_pd_.desc(),
78                     *this->dst_pd_.desc(), *this->attr());
79             if (sts_1x1 != status::success) return sts_1x1;
80
81             if (jcp_.with_dw_conv) {
82                 status_t sts_dw = jit_uni_dw_conv_row_f32<avx2>::init_conf(jcp_, jcp_dw_, *this->attr());
83                 if (sts_dw != status::success) return sts_dw;
84             }
85
86             auto scratchpad = scratchpad_registry().registrar();
87             jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_, jcp_dw_);
88
89             rtus_prepare_space_info(this, scratchpad);
90
91             return status::success;
92         }
93
94         jit_1x1_conv_conf_t jcp_;
95         jit_conv_conf_t jcp_dw_;
96         reduce_to_unit_stride_t rtus_;
97
98     protected:
99         virtual status_t set_default_params() override {
100             using namespace memory_format;
101             if (this->src_pd_.desc()->format == any)
102                 CHECK(this->src_pd_.set_format(utils::pick(this->ndims() - 3,
103                     nCw8c, nChw8c)));
104             if (this->dst_pd_.desc()->format == any)
105                 CHECK(this->dst_pd_.set_format(utils::pick(this->ndims() - 3,
106                     nCw8c, nChw8c)));
107             if (this->weights_pd_.desc()->format == any)
108                 CHECK(this->weights_pd_.set_format(this->with_groups()
109                     ? utils::pick(this->ndims() - 3, gOIw8i8o, gOIhw8i8o)
110                     : utils::pick(this->ndims() - 3, OIw8i8o, OIhw8i8o)));
111             if (this->bias_pd_.desc()->format == any)
112                 CHECK(this->bias_pd_.set_format(x));
113             if (this->desc()->alg_kind == alg_kind::convolution_auto)
114                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
115             return status::success;
116         }
117     };
118
119     template <cpu_isa_t isa, typename conv_t>
120     friend void init_rtus_driver(conv_t *self);
121
122     jit_avx2_1x1_convolution_fwd_t(const pd_t *apd, const input_vector &inputs,
123             const output_vector &outputs)
124         : cpu_primitive_t(apd, inputs, outputs)
125         , kernel_(nullptr), rtus_driver_(nullptr)
126     {
127         kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, pd()->jcp_dw_, *pd()->attr());
128         init_rtus_driver<avx2>(this);
129
130         if (pd()->jcp_.with_dw_conv) {
131             kernel_dw_ = new jit_uni_dw_conv_row_f32<avx2>(pd()->jcp_dw_, *pd()->attr(), pd()->jcp_dw_.ch_block);
132         }
133     }
134
135     ~jit_avx2_1x1_convolution_fwd_t() {
136         delete kernel_;
137         delete rtus_driver_;
138
139         if (pd()->jcp_.with_dw_conv) {
140             delete kernel_dw_;
141         }
142     }
143
144     typedef typename prec_traits<data_type::f32>::type data_t;
145
146     virtual void execute(event_t *e) const {
147         if (pd()->jcp_.with_dw_conv)
148             execute_forward_with_dw_conv();
149         else
150             execute_forward();
151
152         e->set_state(event_t::ready);
153     }
154
155 private:
156     void execute_forward() const;
157     void execute_forward_with_dw_conv() const;
158     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
159
160     jit_avx2_1x1_conv_kernel_f32 *kernel_;
161     jit_uni_dw_conv_row_f32<avx2> *kernel_dw_;
162     rtus_driver_t<avx2> *rtus_driver_;
163 };
164
165 struct jit_avx2_1x1_convolution_bwd_data_t: public cpu_primitive_t {
166     struct pd_t: public cpu_convolution_bwd_data_pd_t {
167         pd_t(engine_t *engine,
168                 const convolution_desc_t *adesc,
169                 const primitive_attr_t *attr,
170                 const convolution_fwd_pd_t *hint_fwd_pd)
171             : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
172             , jcp_(), rtus_() {}
173
174         DECLARE_COMMON_PD_T(
175                 JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
176                 jit_avx2_1x1_convolution_bwd_data_t);
177
178         virtual status_t init() override {
179             using namespace prop_kind;
180             assert(this->engine()->kind() == engine_kind::cpu);
181             bool ok = true
182                 && this->set_default_params() == status::success
183                 && this->desc()->prop_kind == backward_data
184                 && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
185                            alg_kind::convolution_direct)
186                 && !this->has_zero_dim_memory()
187                 && utils::everyone_is(data_type::f32,
188                         this->desc()->diff_src_desc.data_type,
189                         this->desc()->weights_desc.data_type,
190                         this->desc()->diff_dst_desc.data_type);
191             if (!ok) return status::unimplemented;
192
193             const convolution_desc_t *conv_d = this->desc();
194             const memory_desc_t *diff_src_d = this->diff_src_pd_.desc();
195             rtus_prepare(this, conv_d, diff_src_d, this->diff_dst_pd_.desc());
196
197             status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
198                     *conv_d, *diff_src_d, *this->weights_pd_.desc(),
199                     *this->diff_dst_pd_.desc(), *this->attr());
200             if (status != status::success) return status;
201
202             auto scratchpad = scratchpad_registry().registrar();
203             jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
204
205             rtus_prepare_space_info(this, scratchpad);
206
207             return status::success;
208         }
209
210         // TODO (Roma): structs conf header cleanup
211         jit_1x1_conv_conf_t jcp_;
212         reduce_to_unit_stride_t rtus_;
213
214     protected:
215         virtual status_t set_default_params() override {
216             using namespace memory_format;
217
218             if (this->diff_src_pd_.desc()->format == any)
219                 CHECK(this->diff_src_pd_.set_format(utils::pick(
220                     this->ndims() - 3, nCw8c, nChw8c)));
221             if (this->diff_dst_pd_.desc()->format == any)
222                 CHECK(this->diff_dst_pd_.set_format(utils::pick(
223                     this->ndims() - 3, nCw8c, nChw8c)));
224             if (this->weights_pd_.desc()->format == any)
225                 CHECK(this->weights_pd_.set_format(this->with_groups()
226                     ? utils::pick(this->ndims() - 3, gOIw8o8i, gOIhw8o8i)
227                     : utils::pick(this->ndims() - 3, OIw8o8i, OIhw8o8i)));
228             if (this->desc()->alg_kind == alg_kind::convolution_auto)
229                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
230             return status::success;
231         }
232     };
233
234     template <cpu_isa_t isa, typename conv_t>
235     friend void init_rtus_driver(conv_t *self);
236
237     jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd,
238             const input_vector &inputs, const output_vector &outputs)
239         : cpu_primitive_t(apd, inputs, outputs)
240         , kernel_(nullptr), rtus_driver_(nullptr)
241     {
242         kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, jit_conv_conf_t(), *pd()->attr());
243         init_rtus_driver<avx2>(this);
244     }
245
246     ~jit_avx2_1x1_convolution_bwd_data_t() {
247         delete kernel_;
248         delete rtus_driver_;
249     }
250
251     typedef typename prec_traits<data_type::f32>::type data_t;
252
253     virtual void execute(event_t *e) const {
254         switch (pd()->desc()->prop_kind) {
255         case prop_kind::backward_data:
256             execute_backward_data();
257             break;
258         default:
259             assert(!"invalid prop_kind");
260         }
261         e->set_state(event_t::ready);
262     }
263
264 private:
265     void execute_backward_data() const;
266     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
267
268     jit_avx2_1x1_conv_kernel_f32 *kernel_;
269     rtus_driver_t<avx2> *rtus_driver_;
270 };
271
272 struct jit_avx2_1x1_convolution_bwd_weights_t: public cpu_primitive_t {
273     struct pd_t: public cpu_convolution_bwd_weights_pd_t {
274         pd_t(engine_t *engine, const convolution_desc_t *adesc,
275                 const primitive_attr_t *attr,
276                 const convolution_fwd_pd_t *hint_fwd_pd)
277             : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
278             , jcp_(), rtus_() {}
279
280         DECLARE_COMMON_PD_T(
281                 JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
282                 jit_avx2_1x1_convolution_bwd_weights_t);
283
284         virtual status_t init() override {
285             using namespace prop_kind;
286             assert(this->engine()->kind() == engine_kind::cpu);
287             bool ok = true
288                 && this->set_default_params() == status::success
289                 && this->desc()->prop_kind == backward_weights
290                 && utils::one_of(this->desc()->alg_kind, alg_kind::convolution_auto,
291                            alg_kind::convolution_direct)
292                 && !this->has_zero_dim_memory()
293                 && utils::everyone_is(data_type::f32,
294                         this->desc()->src_desc.data_type,
295                         this->desc()->diff_weights_desc.data_type,
296                         this->desc()->diff_dst_desc.data_type)
297                 && IMPLICATION(this->with_bias(),
298                         data_type::f32 == desc()->diff_bias_desc.data_type);
299             if (!ok) return status::unimplemented;
300
301             const convolution_desc_t *conv_d = this->desc();
302             const memory_desc_t *src_d = this->src_pd_.desc();
303             rtus_prepare(this, conv_d, src_d, this->diff_dst_pd_.desc());
304
305             status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
306                     *conv_d, *src_d, *this->diff_weights_pd_.desc(),
307                     *this->diff_dst_pd_.desc(), *this->attr());
308             if (status != status::success) return status;
309
310             init_balancers();
311
312             auto scratchpad = scratchpad_registry().registrar();
313             jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
314
315             rtus_prepare_space_info(this, scratchpad);
316
317             auto reducer_bia_scratchpad = memory_tracking::registrar_t(
318                     scratchpad, memory_tracking::names::prefix_reducer_bia);
319             reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
320
321             auto reducer_wei_scratchpad = memory_tracking::registrar_t(
322                     scratchpad, memory_tracking::names::prefix_reducer_wei);
323             reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad);
324
325             return status::success;
326         }
327
328         jit_1x1_conv_conf_t jcp_;
329         cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
330         cpu_reducer_2d_t<data_type::f32>::conf_t reducer_wei_conf_;
331         reduce_to_unit_stride_t rtus_;
332
333     protected:
334         virtual status_t set_default_params() override {
335             using namespace memory_format;
336
337             if (this->src_pd_.desc()->format == any)
338                 CHECK(this->src_pd_.set_format(utils::pick(this->ndims() - 3,
339                     nCw8c, nChw8c)));
340             if (this->diff_dst_pd_.desc()->format == any)
341                 CHECK(this->diff_dst_pd_.set_format(utils::pick(
342                     this->ndims() - 3, nCw8c, nChw8c)));
343             if (this->diff_weights_pd_.desc()->format == any)
344                 CHECK(this->diff_weights_pd_.set_format(this->with_groups()
345                     ? utils::pick(this->ndims() - 3, gOIw8i8o, gOIhw8i8o)
346                     : utils::pick(this->ndims() - 3, OIw8i8o, OIhw8i8o)));
347             if (this->diff_bias_pd_.desc()->format == any)
348                 CHECK(this->diff_bias_pd_.set_format(x));
349             if (this->desc()->alg_kind == alg_kind::convolution_auto)
350                 CHECK(this->set_alg_kind(alg_kind::convolution_direct));
351             return status::success;
352         }
353
354     private:
355         void init_balancers() {
356             const int ic_block = jcp_.bcast_block;
357             const int nb_ic = jcp_.nb_bcast;
358             const int nb_ic_blocking = jcp_.nb_bcast_blocking;
359             const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking);
360
361             const int oc_block = jcp_.load_block;
362             const int nb_oc = jcp_.nb_load;
363             const int nb_oc_blocking = jcp_.nb_load_blocking;
364             const int load_work = utils::div_up(nb_oc, nb_oc_blocking);
365
366             const int job_size
367                 = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block;
368             const int njobs_x = bcast_work;
369             const int njobs_y = jcp_.ngroups * load_work;
370
371             const int max_threads = mkldnn_get_max_threads();
372             const size_t max_buffer_size = max_threads * job_size * 8;
373
374             if (with_bias()) {
375                 reducer_bia_conf_.init(reduce_balancer_t(max_threads,
376                             oc_block, jcp_.ngroups * jcp_.oc / oc_block,
377                             jcp_.mb, max_buffer_size));
378             }
379
380             reducer_wei_conf_.init(
381                     reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x,
382                         jcp_.mb * jcp_.nb_reduce, max_buffer_size),
383                     job_size / nb_oc_blocking, nb_oc_blocking, ic_block,
384                     nb_ic * ic_block * oc_block, nb_oc);
385         }
386     };
387
388     template <cpu_isa_t isa, typename conv_t>
389     friend void init_rtus_driver(conv_t *self);
390
391     jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd,
392             const input_vector &inputs, const output_vector &outputs);
393
394     ~jit_avx2_1x1_convolution_bwd_weights_t() {
395         delete kernel_;
396         delete rtus_driver_;
397         delete reducer_weights_;
398         delete reducer_bias_;
399     }
400
401     typedef typename prec_traits<data_type::f32>::type data_t;
402
403     virtual void execute(event_t *e) const {
404         switch (pd()->desc()->prop_kind) {
405         case prop_kind::backward_weights:
406             execute_backward_weights();
407             break;
408         default:
409             assert(!"invalid prop_kind");
410         }
411         e->set_state(event_t::ready);
412     }
413
414 private:
415     void execute_backward_weights() const;
416     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
417
418     jit_avx2_1x1_conv_kernel_f32 *kernel_;
419     cpu_reducer_2d_t<data_type::f32> *reducer_weights_;
420     cpu_reducer_t<data_type::f32> *reducer_bias_;
421     rtus_driver_t<avx2> *rtus_driver_;
422 };
423
424 }
425 }
426 }
427
428 #endif